Tutorial on GoogleNet based image classification 

2018-06-26 15:50:29

本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。

1. GoogleNet 的结构:

  1. class Inception(nn.Module):
  2. def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
  3. super(Inception, self).__init__()
  4. # 1x1 conv branch
  5. self.b1 = nn.Sequential(
  6. nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
  7. nn.BatchNorm2d(kernel_1_x),
  8. nn.ReLU(True),
  9. )
  11. # 1x1 conv -> 3x3 conv branch
  12. self.b2 = nn.Sequential(
  13. nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
  14. nn.BatchNorm2d(kernel_3_in),
  15. nn.ReLU(True),
  16. nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
  17. nn.BatchNorm2d(kernel_3_x),
  18. nn.ReLU(True),
  19. )
  21. # 1x1 conv -> 5x5 conv branch
  22. self.b3 = nn.Sequential(
  23. nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
  24. nn.BatchNorm2d(kernel_5_in),
  25. nn.ReLU(True),
  26. nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
  27. nn.BatchNorm2d(kernel_5_x),
  28. nn.ReLU(True),
  29. nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
  30. nn.BatchNorm2d(kernel_5_x),
  31. nn.ReLU(True),
  32. )
  34. # 3x3 pool -> 1x1 conv branch
  35. self.b4 = nn.Sequential(
  36. nn.MaxPool2d(3, stride=1, padding=1),
  37. nn.Conv2d(in_planes, pool_planes, kernel_size=1),
  38. nn.BatchNorm2d(pool_planes),
  39. nn.ReLU(True),
  40. )
  42. def forward(self, x):
  43. y1 = self.b1(x)
  44. y2 = self.b2(x)
  45. y3 = self.b3(x)
  46. y4 = self.b4(x)
  47. return torch.cat([y1,y2,y3,y4], 1)
  1. class GoogLeNet(nn.Module):
  2. def __init__(self):
  3. super(GoogLeNet, self).__init__()
  4. self.pre_layers = nn.Sequential(
  5. nn.Conv2d(3, 192, kernel_size=3, padding=1),
  6. nn.BatchNorm2d(192),
  7. nn.ReLU(True),
  8. )
  10. self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
  11. self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
  13. self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
  15. self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
  16. self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
  17. self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
  18. self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
  19. self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
  21. self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
  22. self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
  24. self.avgpool = nn.AvgPool2d(8, stride=1)
  25. self.linear = nn.Linear(1024, 10)
  27. def forward(self, x):
  28. x = self.pre_layers(x)
  29. x = self.a3(x)
  30. x = self.b3(x)
  31. x = self.max_pool(x)
  32. x = self.a4(x)
  33. x = self.b4(x)
  34. x = self.c4(x)
  35. x = self.d4(x)
  36. x = self.e4(x)
  37. x = self.max_pool(x)
  38. x = self.a5(x)
  39. x = self.b5(x)
  40. x = self.avgpool(x)
  41. x = x.view(x.size(0), -1)
  42. x = self.linear(x)
  43. return x

2. 保存和加载模型:

  1. # 保存和加载整个模型
  2. torch.save(model_object, 'model.pkl')
  3. model = torch.load('model.pkl')
  5. # 仅保存和加载模型参数(推荐使用)
  6. torch.save(model_object.state_dict(), 'params.pkl')
  7. model_object.load_state_dict(torch.load('params.pkl'))

