1# Owner(s): ["oncall: quantization"] 2 3import torch 4import torch.nn as nn 5 6from torch.testing._internal.common_quantization import QuantizationTestCase 7from torch.ao.quantization.fuse_modules import fuse_modules 8 9import torch.ao.quantization._equalize as _equalize 10 11import copy 12 13class TestEqualizeEager(QuantizationTestCase): 14 def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis): 15 ''' Checks the channel ranges of tensor1, tensor2 are the same, 16 which is an indication that equalization has been applied correctly 17 ''' 18 output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis) 19 input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis) 20 21 # ensuring the channels ranges of tensor1's input is the same as 22 # tensor2's output 23 self.assertEqual(output_channel_tensor1, input_channel_tensor2) 24 25 def getModule(self, model, name): 26 ''' Given the name is a submodule to a model, return the submodule 27 ''' 28 curr = model 29 name = name.split('.') 30 for subname in name: 31 curr = curr._modules[subname] 32 return curr 33 34 def test_cross_layer_equalization(self): 35 ''' applies _equalize.cross_layer_equalization on two modules and checks 36 to make sure channels ranges are equivalent 37 ''' 38 module1 = nn.Conv2d(3, 4, 2) 39 module2 = nn.Linear(4, 4) 40 41 module1_output_channel_axis = 0 42 module2_input_channel_axis = 1 43 44 _equalize.cross_layer_equalization(module1, module2) 45 46 mod_tensor1, mod_tensor2 = module1.weight, module2.weight 47 48 self.checkChannelsEqualized(mod_tensor1, mod_tensor2, module1_output_channel_axis, module2_input_channel_axis) 49 50 def test_converged(self): 51 ''' Sanity checks on _equalize.converged working 52 identical modules should return true 53 modules with high difference in weights should return false 54 ''' 55 module1 = nn.Linear(3, 3) 56 module2 = nn.Linear(3, 3) 57 58 module1.weight = nn.parameter.Parameter(torch.ones(module1.weight.size())) 59 module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size())) 60 61 # input is a dictionary 62 dictionary_1 = {'linear1': module1} 63 dictionary_2 = {'linear1': module2} 64 self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6)) 65 self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6)) 66 67 def test_equalize(self): 68 ''' First checks to see if _equalize.equalize can handle multiple 69 pair modules as input 70 then checks correctness of the function by ensuring the equalized 71 and unequalized versions of the model yield the same output 72 given the same input 73 ''' 74 class ChainModule(nn.Module): 75 def __init__(self) -> None: 76 super().__init__() 77 self.linear1 = nn.Linear(3, 4) 78 self.linear2 = nn.Linear(4, 5) 79 self.linear3 = nn.Linear(5, 6) 80 81 def forward(self, x): 82 x = self.linear1(x) 83 x = self.linear2(x) 84 x = self.linear3(x) 85 return x 86 chain1 = ChainModule() 87 chain2 = copy.deepcopy(chain1) 88 89 _equalize.equalize(chain1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6) 90 linear1 = self.getModule(chain1, 'linear1') 91 linear2 = self.getModule(chain1, 'linear2') 92 linear3 = self.getModule(chain1, 'linear3') 93 94 self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1) 95 self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1) 96 97 input = torch.randn(20, 3) 98 self.assertEqual(chain1(input), chain2(input)) 99 100 def test_equalize_fused_convrelu(self): 101 ''' Checks to see if eager mode equalization supports fused 102 ConvReLU2d models 103 104 A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu 105 layers are fused together and adjacent conv2d layers have cross-layer 106 equalization applied. Finally, we ensure that the channels have been 107 equalized and that the equalized and unequalized versions of the model 108 yield the same output given the same input 109 ''' 110 class M(nn.Module): 111 def __init__(self) -> None: 112 super().__init__() 113 self.conv1 = nn.Conv2d(3, 3, 1).to(dtype=torch.float) 114 self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float) 115 self.conv2 = nn.Conv2d(3, 3, 1).to(dtype=torch.float) 116 self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float) 117 self.conv3 = nn.Conv2d(3, 3, 1).to(dtype=torch.float) 118 self.relu3 = nn.ReLU(inplace=False).to(dtype=torch.float) 119 120 def forward(self, x): 121 x = self.conv1(x) 122 x = self.relu1(x) 123 x = self.conv2(x) 124 x = self.relu2(x) 125 x = self.conv3(x) 126 x = self.relu3(x) 127 return x 128 129 model = M() 130 131 fused_model1 = fuse_modules(model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['conv3', 'relu3']]) 132 fused_model2 = copy.deepcopy(fused_model1) 133 134 _equalize.equalize(fused_model1, [['conv1', 'conv2'], ['conv2', 'conv3']], 1e-6) 135 conv1 = self.getModule(fused_model1, 'conv1')[0] 136 conv2 = self.getModule(fused_model1, 'conv2')[0] 137 conv3 = self.getModule(fused_model1, 'conv3')[0] 138 139 self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1) 140 self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1) 141 142 input = torch.randn(3, 3, 1, 1) 143 self.assertEqual(fused_model1(input), fused_model2(input)) 144 self.assertEqual(fused_model1(input), model(input)) 145 146 def test_equalize_fused_linearrelu(self): 147 ''' Checks to see if eager mode equalization supports fused 148 LinearReLU models 149 150 A model with 3 LinearReLU is constructed. Next, the linear and relu 151 layers are fused together and adjacent linear layers have cross-layer 152 equalization applied. Finally, we ensure that the channels have been 153 equalized and that the equalized and unequalized versions of the model 154 yield the same output given the same input 155 ''' 156 class M(nn.Module): 157 def __init__(self) -> None: 158 super().__init__() 159 self.linear1 = nn.Linear(3, 4) 160 self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float) 161 self.linear2 = nn.Linear(4, 5) 162 self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float) 163 self.linear3 = nn.Linear(5, 6) 164 self.relu3 = nn.ReLU(inplace=False).to(dtype=torch.float) 165 166 def forward(self, x): 167 x = self.linear1(x) 168 x = self.relu1(x) 169 x = self.linear2(x) 170 x = self.relu2(x) 171 x = self.linear3(x) 172 x = self.relu3(x) 173 return x 174 175 model = M() 176 177 fused_model1 = fuse_modules(model, [['linear1', 'relu1'], ['linear2', 'relu2'], ['linear3', 'relu3']]) 178 fused_model2 = copy.deepcopy(fused_model1) 179 180 _equalize.equalize(fused_model1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6) 181 linear1 = self.getModule(fused_model1, 'linear1')[0] 182 linear2 = self.getModule(fused_model1, 'linear2')[0] 183 linear3 = self.getModule(fused_model1, 'linear3')[0] 184 185 self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1) 186 self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1) 187 188 input = torch.randn(20, 3) 189 self.assertEqual(fused_model1(input), fused_model2(input)) 190 self.assertEqual(fused_model1(input), model(input)) 191