1import torch 2import torch.nn as nn 3import torch.nn.init as init 4 5 6class Fire(nn.Module): 7 def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes): 8 super().__init__() 9 self.inplanes = inplanes 10 self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 11 self.squeeze_activation = nn.ReLU(inplace=True) 12 self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1) 13 self.expand1x1_activation = nn.ReLU(inplace=True) 14 self.expand3x3 = nn.Conv2d( 15 squeeze_planes, expand3x3_planes, kernel_size=3, padding=1 16 ) 17 self.expand3x3_activation = nn.ReLU(inplace=True) 18 19 def forward(self, x): 20 x = self.squeeze_activation(self.squeeze(x)) 21 return torch.cat( 22 [ 23 self.expand1x1_activation(self.expand1x1(x)), 24 self.expand3x3_activation(self.expand3x3(x)), 25 ], 26 1, 27 ) 28 29 30class SqueezeNet(nn.Module): 31 def __init__(self, version=1.0, num_classes=1000, ceil_mode=False): 32 super().__init__() 33 if version not in [1.0, 1.1]: 34 raise ValueError( 35 f"Unsupported SqueezeNet version {version}:1.0 or 1.1 expected" 36 ) 37 self.num_classes = num_classes 38 if version == 1.0: 39 self.features = nn.Sequential( 40 nn.Conv2d(3, 96, kernel_size=7, stride=2), 41 nn.ReLU(inplace=True), 42 nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode), 43 Fire(96, 16, 64, 64), 44 Fire(128, 16, 64, 64), 45 Fire(128, 32, 128, 128), 46 nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode), 47 Fire(256, 32, 128, 128), 48 Fire(256, 48, 192, 192), 49 Fire(384, 48, 192, 192), 50 Fire(384, 64, 256, 256), 51 nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode), 52 Fire(512, 64, 256, 256), 53 ) 54 else: 55 self.features = nn.Sequential( 56 nn.Conv2d(3, 64, kernel_size=3, stride=2), 57 nn.ReLU(inplace=True), 58 nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode), 59 Fire(64, 16, 64, 64), 60 Fire(128, 16, 64, 64), 61 nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode), 62 Fire(128, 32, 128, 128), 63 Fire(256, 32, 128, 128), 64 nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode), 65 Fire(256, 48, 192, 192), 66 Fire(384, 48, 192, 192), 67 Fire(384, 64, 256, 256), 68 Fire(512, 64, 256, 256), 69 ) 70 # Final convolution is initialized differently from the rest 71 final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) 72 self.classifier = nn.Sequential( 73 nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AvgPool2d(13) 74 ) 75 76 for m in self.modules(): 77 if isinstance(m, nn.Conv2d): 78 if m is final_conv: 79 init.normal_(m.weight.data, mean=0.0, std=0.01) 80 else: 81 init.kaiming_uniform_(m.weight.data) 82 if m.bias is not None: 83 m.bias.data.zero_() 84 85 def forward(self, x): 86 x = self.features(x) 87 x = self.classifier(x) 88 return x.view(x.size(0), self.num_classes) 89