xref: /aosp_15_r20/external/pytorch/test/quantization/eager/test_equalize_eager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import torch
4import torch.nn as nn
5
6from torch.testing._internal.common_quantization import QuantizationTestCase
7from torch.ao.quantization.fuse_modules import fuse_modules
8
9import torch.ao.quantization._equalize as _equalize
10
11import copy
12
13class TestEqualizeEager(QuantizationTestCase):
14    def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis):
15        ''' Checks the channel ranges of tensor1, tensor2 are the same,
16        which is an indication that equalization has been applied correctly
17        '''
18        output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis)
19        input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis)
20
21        # ensuring the channels ranges of tensor1's input is the same as
22        # tensor2's output
23        self.assertEqual(output_channel_tensor1, input_channel_tensor2)
24
25    def getModule(self, model, name):
26        ''' Given the name is a submodule to a model, return the submodule
27        '''
28        curr = model
29        name = name.split('.')
30        for subname in name:
31            curr = curr._modules[subname]
32        return curr
33
34    def test_cross_layer_equalization(self):
35        ''' applies _equalize.cross_layer_equalization on two modules and checks
36        to make sure channels ranges are equivalent
37        '''
38        module1 = nn.Conv2d(3, 4, 2)
39        module2 = nn.Linear(4, 4)
40
41        module1_output_channel_axis = 0
42        module2_input_channel_axis = 1
43
44        _equalize.cross_layer_equalization(module1, module2)
45
46        mod_tensor1, mod_tensor2 = module1.weight, module2.weight
47
48        self.checkChannelsEqualized(mod_tensor1, mod_tensor2, module1_output_channel_axis, module2_input_channel_axis)
49
50    def test_converged(self):
51        ''' Sanity checks on _equalize.converged working
52        identical modules should return true
53        modules with high difference in weights should return false
54        '''
55        module1 = nn.Linear(3, 3)
56        module2 = nn.Linear(3, 3)
57
58        module1.weight = nn.parameter.Parameter(torch.ones(module1.weight.size()))
59        module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size()))
60
61        # input is a dictionary
62        dictionary_1 = {'linear1': module1}
63        dictionary_2 = {'linear1': module2}
64        self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6))
65        self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6))
66
67    def test_equalize(self):
68        ''' First checks to see if _equalize.equalize can handle multiple
69        pair modules as input
70        then checks correctness of the function by ensuring the equalized
71        and unequalized versions of the model yield the same output
72        given the same input
73        '''
74        class ChainModule(nn.Module):
75            def __init__(self) -> None:
76                super().__init__()
77                self.linear1 = nn.Linear(3, 4)
78                self.linear2 = nn.Linear(4, 5)
79                self.linear3 = nn.Linear(5, 6)
80
81            def forward(self, x):
82                x = self.linear1(x)
83                x = self.linear2(x)
84                x = self.linear3(x)
85                return x
86        chain1 = ChainModule()
87        chain2 = copy.deepcopy(chain1)
88
89        _equalize.equalize(chain1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6)
90        linear1 = self.getModule(chain1, 'linear1')
91        linear2 = self.getModule(chain1, 'linear2')
92        linear3 = self.getModule(chain1, 'linear3')
93
94        self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
95        self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
96
97        input = torch.randn(20, 3)
98        self.assertEqual(chain1(input), chain2(input))
99
100    def test_equalize_fused_convrelu(self):
101        ''' Checks to see if eager mode equalization supports fused
102        ConvReLU2d models
103
104        A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu
105        layers are fused together and adjacent conv2d layers have cross-layer
106        equalization applied. Finally, we ensure that the channels have been
107        equalized and that the equalized and unequalized versions of the model
108        yield the same output given the same input
109        '''
110        class M(nn.Module):
111            def __init__(self) -> None:
112                super().__init__()
113                self.conv1 = nn.Conv2d(3, 3, 1).to(dtype=torch.float)
114                self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float)
115                self.conv2 = nn.Conv2d(3, 3, 1).to(dtype=torch.float)
116                self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
117                self.conv3 = nn.Conv2d(3, 3, 1).to(dtype=torch.float)
118                self.relu3 = nn.ReLU(inplace=False).to(dtype=torch.float)
119
120            def forward(self, x):
121                x = self.conv1(x)
122                x = self.relu1(x)
123                x = self.conv2(x)
124                x = self.relu2(x)
125                x = self.conv3(x)
126                x = self.relu3(x)
127                return x
128
129        model = M()
130
131        fused_model1 = fuse_modules(model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['conv3', 'relu3']])
132        fused_model2 = copy.deepcopy(fused_model1)
133
134        _equalize.equalize(fused_model1, [['conv1', 'conv2'], ['conv2', 'conv3']], 1e-6)
135        conv1 = self.getModule(fused_model1, 'conv1')[0]
136        conv2 = self.getModule(fused_model1, 'conv2')[0]
137        conv3 = self.getModule(fused_model1, 'conv3')[0]
138
139        self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1)
140        self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1)
141
142        input = torch.randn(3, 3, 1, 1)
143        self.assertEqual(fused_model1(input), fused_model2(input))
144        self.assertEqual(fused_model1(input), model(input))
145
146    def test_equalize_fused_linearrelu(self):
147        ''' Checks to see if eager mode equalization supports fused
148        LinearReLU models
149
150        A model with 3 LinearReLU is constructed. Next, the linear and relu
151        layers are fused together and adjacent linear layers have cross-layer
152        equalization applied. Finally, we ensure that the channels have been
153        equalized and that the equalized and unequalized versions of the model
154        yield the same output given the same input
155        '''
156        class M(nn.Module):
157            def __init__(self) -> None:
158                super().__init__()
159                self.linear1 = nn.Linear(3, 4)
160                self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float)
161                self.linear2 = nn.Linear(4, 5)
162                self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
163                self.linear3 = nn.Linear(5, 6)
164                self.relu3 = nn.ReLU(inplace=False).to(dtype=torch.float)
165
166            def forward(self, x):
167                x = self.linear1(x)
168                x = self.relu1(x)
169                x = self.linear2(x)
170                x = self.relu2(x)
171                x = self.linear3(x)
172                x = self.relu3(x)
173                return x
174
175        model = M()
176
177        fused_model1 = fuse_modules(model, [['linear1', 'relu1'], ['linear2', 'relu2'], ['linear3', 'relu3']])
178        fused_model2 = copy.deepcopy(fused_model1)
179
180        _equalize.equalize(fused_model1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6)
181        linear1 = self.getModule(fused_model1, 'linear1')[0]
182        linear2 = self.getModule(fused_model1, 'linear2')[0]
183        linear3 = self.getModule(fused_model1, 'linear3')[0]
184
185        self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
186        self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
187
188        input = torch.randn(20, 3)
189        self.assertEqual(fused_model1(input), fused_model2(input))
190        self.assertEqual(fused_model1(input), model(input))
191