xref: /aosp_15_r20/external/pytorch/test/onnx/model_defs/squeezenet.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2import torch.nn as nn
3import torch.nn.init as init
4
5
6class Fire(nn.Module):
7    def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes):
8        super().__init__()
9        self.inplanes = inplanes
10        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
11        self.squeeze_activation = nn.ReLU(inplace=True)
12        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
13        self.expand1x1_activation = nn.ReLU(inplace=True)
14        self.expand3x3 = nn.Conv2d(
15            squeeze_planes, expand3x3_planes, kernel_size=3, padding=1
16        )
17        self.expand3x3_activation = nn.ReLU(inplace=True)
18
19    def forward(self, x):
20        x = self.squeeze_activation(self.squeeze(x))
21        return torch.cat(
22            [
23                self.expand1x1_activation(self.expand1x1(x)),
24                self.expand3x3_activation(self.expand3x3(x)),
25            ],
26            1,
27        )
28
29
30class SqueezeNet(nn.Module):
31    def __init__(self, version=1.0, num_classes=1000, ceil_mode=False):
32        super().__init__()
33        if version not in [1.0, 1.1]:
34            raise ValueError(
35                f"Unsupported SqueezeNet version {version}:1.0 or 1.1 expected"
36            )
37        self.num_classes = num_classes
38        if version == 1.0:
39            self.features = nn.Sequential(
40                nn.Conv2d(3, 96, kernel_size=7, stride=2),
41                nn.ReLU(inplace=True),
42                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode),
43                Fire(96, 16, 64, 64),
44                Fire(128, 16, 64, 64),
45                Fire(128, 32, 128, 128),
46                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode),
47                Fire(256, 32, 128, 128),
48                Fire(256, 48, 192, 192),
49                Fire(384, 48, 192, 192),
50                Fire(384, 64, 256, 256),
51                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode),
52                Fire(512, 64, 256, 256),
53            )
54        else:
55            self.features = nn.Sequential(
56                nn.Conv2d(3, 64, kernel_size=3, stride=2),
57                nn.ReLU(inplace=True),
58                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode),
59                Fire(64, 16, 64, 64),
60                Fire(128, 16, 64, 64),
61                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode),
62                Fire(128, 32, 128, 128),
63                Fire(256, 32, 128, 128),
64                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=ceil_mode),
65                Fire(256, 48, 192, 192),
66                Fire(384, 48, 192, 192),
67                Fire(384, 64, 256, 256),
68                Fire(512, 64, 256, 256),
69            )
70        # Final convolution is initialized differently from the rest
71        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
72        self.classifier = nn.Sequential(
73            nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AvgPool2d(13)
74        )
75
76        for m in self.modules():
77            if isinstance(m, nn.Conv2d):
78                if m is final_conv:
79                    init.normal_(m.weight.data, mean=0.0, std=0.01)
80                else:
81                    init.kaiming_uniform_(m.weight.data)
82                if m.bias is not None:
83                    m.bias.data.zero_()
84
85    def forward(self, x):
86        x = self.features(x)
87        x = self.classifier(x)
88        return x.view(x.size(0), self.num_classes)
89