1# Owner(s): ["module: optimizer"] 2 3import itertools 4import pickle 5 6import torch 7from torch.optim.swa_utils import ( 8 AveragedModel, 9 get_ema_multi_avg_fn, 10 get_swa_multi_avg_fn, 11 update_bn, 12) 13from torch.testing._internal.common_utils import ( 14 instantiate_parametrized_tests, 15 load_tests, 16 parametrize, 17 TestCase, 18) 19 20 21# load_tests from common_utils is used to automatically filter tests for 22# sharding on sandcastle. This line silences flake warnings 23load_tests = load_tests 24 25 26class TestSWAUtils(TestCase): 27 class SWATestDNN(torch.nn.Module): 28 def __init__(self, input_features): 29 super().__init__() 30 self.n_features = 100 31 self.fc1 = torch.nn.Linear(input_features, self.n_features) 32 self.bn = torch.nn.BatchNorm1d(self.n_features) 33 34 def compute_preactivation(self, x): 35 return self.fc1(x) 36 37 def forward(self, x): 38 x = self.fc1(x) 39 x = self.bn(x) 40 return x 41 42 class SWATestCNN(torch.nn.Module): 43 def __init__(self, input_channels): 44 super().__init__() 45 self.n_features = 10 46 self.conv1 = torch.nn.Conv2d( 47 input_channels, self.n_features, kernel_size=3, padding=1 48 ) 49 self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3) 50 51 def compute_preactivation(self, x): 52 return self.conv1(x) 53 54 def forward(self, x): 55 x = self.conv1(x) 56 x = self.bn(x) 57 return x 58 59 def _test_averaged_model(self, net_device, swa_device, ema): 60 dnn = torch.nn.Sequential( 61 torch.nn.Conv2d(1, 5, kernel_size=3), 62 torch.nn.ReLU(), 63 torch.nn.MaxPool2d(kernel_size=2), 64 torch.nn.BatchNorm2d(5, momentum=0.3), 65 torch.nn.Conv2d(5, 2, kernel_size=3), 66 torch.nn.ReLU(), 67 torch.nn.Linear(5, 5), 68 torch.nn.ReLU(), 69 torch.nn.Linear(5, 10), 70 ).to(net_device) 71 72 averaged_params, averaged_dnn = self._run_averaged_steps(dnn, swa_device, ema) 73 74 for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): 75 self.assertEqual(p_avg, p_swa) 76 # Check that AveragedModel is on the correct device 77 self.assertTrue(p_swa.device == swa_device) 78 self.assertTrue(p_avg.device == net_device) 79 self.assertTrue(averaged_dnn.n_averaged.device == swa_device) 80 81 def _run_averaged_steps(self, dnn, swa_device, ema): 82 ema_decay = 0.999 83 if ema: 84 averaged_dnn = AveragedModel( 85 dnn, device=swa_device, multi_avg_fn=get_ema_multi_avg_fn(ema_decay) 86 ) 87 else: 88 averaged_dnn = AveragedModel( 89 dnn, device=swa_device, multi_avg_fn=get_swa_multi_avg_fn() 90 ) 91 92 averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] 93 94 n_updates = 10 95 for i in range(n_updates): 96 for p, p_avg in zip(dnn.parameters(), averaged_params): 97 p.detach().add_(torch.randn_like(p)) 98 if ema: 99 p_avg += ( 100 p.detach() 101 * ema_decay ** (n_updates - i - 1) 102 * ((1 - ema_decay) if i > 0 else 1.0) 103 ) 104 else: 105 p_avg += p.detach() / n_updates 106 averaged_dnn.update_parameters(dnn) 107 108 return averaged_params, averaged_dnn 109 110 @parametrize("ema", [True, False]) 111 def test_averaged_model_all_devices(self, ema): 112 cpu = torch.device("cpu") 113 self._test_averaged_model(cpu, cpu, ema) 114 if torch.cuda.is_available(): 115 cuda = torch.device(0) 116 self._test_averaged_model(cuda, cpu, ema) 117 self._test_averaged_model(cpu, cuda, ema) 118 self._test_averaged_model(cuda, cuda, ema) 119 120 @parametrize("ema", [True, False]) 121 def test_averaged_model_mixed_device(self, ema): 122 if not torch.cuda.is_available(): 123 return 124 dnn = torch.nn.Sequential( 125 torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) 126 ) 127 dnn[0].cuda() 128 dnn[1].cpu() 129 130 averaged_params, averaged_dnn = self._run_averaged_steps(dnn, None, ema) 131 132 for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): 133 self.assertEqual(p_avg, p_swa) 134 # Check that AveragedModel is on the correct device 135 self.assertTrue(p_avg.device == p_swa.device) 136 137 def test_averaged_model_state_dict(self): 138 dnn = torch.nn.Sequential( 139 torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) 140 ) 141 averaged_dnn = AveragedModel(dnn) 142 averaged_dnn2 = AveragedModel(dnn) 143 n_updates = 10 144 for i in range(n_updates): 145 for p in dnn.parameters(): 146 p.detach().add_(torch.randn_like(p)) 147 averaged_dnn.update_parameters(dnn) 148 averaged_dnn2.load_state_dict(averaged_dnn.state_dict()) 149 for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()): 150 self.assertEqual(p_swa, p_swa2) 151 self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged) 152 153 def test_averaged_model_default_avg_fn_picklable(self): 154 dnn = torch.nn.Sequential( 155 torch.nn.Conv2d(1, 5, kernel_size=3), 156 torch.nn.BatchNorm2d(5), 157 torch.nn.Linear(5, 5), 158 ) 159 averaged_dnn = AveragedModel(dnn) 160 pickle.dumps(averaged_dnn) 161 162 @parametrize("use_multi_avg_fn", [True, False]) 163 @parametrize("use_buffers", [True, False]) 164 def test_averaged_model_exponential(self, use_multi_avg_fn, use_buffers): 165 # Test AveragedModel with EMA as avg_fn and use_buffers as True. 166 dnn = torch.nn.Sequential( 167 torch.nn.Conv2d(1, 5, kernel_size=3), 168 torch.nn.BatchNorm2d(5, momentum=0.3), 169 torch.nn.Linear(5, 10), 170 ) 171 decay = 0.9 172 173 if use_multi_avg_fn: 174 averaged_dnn = AveragedModel( 175 dnn, multi_avg_fn=get_ema_multi_avg_fn(decay), use_buffers=use_buffers 176 ) 177 else: 178 179 def avg_fn(p_avg, p, n_avg): 180 return decay * p_avg + (1 - decay) * p 181 182 averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=use_buffers) 183 184 if use_buffers: 185 dnn_params = list(itertools.chain(dnn.parameters(), dnn.buffers())) 186 else: 187 dnn_params = list(dnn.parameters()) 188 189 averaged_params = [ 190 torch.zeros_like(param) 191 for param in dnn_params 192 if param.size() != torch.Size([]) 193 ] 194 195 n_updates = 10 196 for i in range(n_updates): 197 updated_averaged_params = [] 198 for p, p_avg in zip(dnn_params, averaged_params): 199 if p.size() == torch.Size([]): 200 continue 201 p.detach().add_(torch.randn_like(p)) 202 if i == 0: 203 updated_averaged_params.append(p.clone()) 204 else: 205 updated_averaged_params.append( 206 (p_avg * decay + p * (1 - decay)).clone() 207 ) 208 averaged_dnn.update_parameters(dnn) 209 averaged_params = updated_averaged_params 210 211 if use_buffers: 212 for p_avg, p_swa in zip( 213 averaged_params, 214 itertools.chain( 215 averaged_dnn.module.parameters(), averaged_dnn.module.buffers() 216 ), 217 ): 218 self.assertEqual(p_avg, p_swa) 219 else: 220 for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): 221 self.assertEqual(p_avg, p_swa) 222 for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()): 223 self.assertEqual(b_avg, b_swa) 224 225 def _test_update_bn(self, dnn, dl_x, dl_xy, cuda): 226 preactivation_sum = torch.zeros(dnn.n_features) 227 preactivation_squared_sum = torch.zeros(dnn.n_features) 228 if cuda: 229 preactivation_sum = preactivation_sum.cuda() 230 preactivation_squared_sum = preactivation_squared_sum.cuda() 231 total_num = 0 232 for x in dl_x: 233 x = x[0] 234 if cuda: 235 x = x.cuda() 236 237 dnn.forward(x) 238 preactivations = dnn.compute_preactivation(x) 239 if len(preactivations.shape) == 4: 240 preactivations = preactivations.transpose(1, 3) 241 preactivations = preactivations.contiguous().view(-1, dnn.n_features) 242 total_num += preactivations.shape[0] 243 244 preactivation_sum += torch.sum(preactivations, dim=0) 245 preactivation_squared_sum += torch.sum(preactivations**2, dim=0) 246 247 preactivation_mean = preactivation_sum / total_num 248 preactivation_var = preactivation_squared_sum / total_num 249 preactivation_var = preactivation_var - preactivation_mean**2 250 251 update_bn(dl_xy, dnn, device=x.device) 252 self.assertEqual(preactivation_mean, dnn.bn.running_mean) 253 self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) 254 255 def _reset_bn(module): 256 if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 257 module.running_mean = torch.zeros_like(module.running_mean) 258 module.running_var = torch.ones_like(module.running_var) 259 260 # reset batch norm and run update_bn again 261 dnn.apply(_reset_bn) 262 update_bn(dl_xy, dnn, device=x.device) 263 self.assertEqual(preactivation_mean, dnn.bn.running_mean) 264 self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) 265 # using the dl_x loader instead of dl_xy 266 dnn.apply(_reset_bn) 267 update_bn(dl_x, dnn, device=x.device) 268 self.assertEqual(preactivation_mean, dnn.bn.running_mean) 269 self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) 270 271 def test_update_bn_dnn(self): 272 # Test update_bn for a fully-connected network with BatchNorm1d 273 objects, input_features = 100, 5 274 x = torch.rand(objects, input_features) 275 y = torch.rand(objects) 276 ds_x = torch.utils.data.TensorDataset(x) 277 ds_xy = torch.utils.data.TensorDataset(x, y) 278 dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) 279 dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) 280 dnn = self.SWATestDNN(input_features=input_features) 281 dnn.train() 282 self._test_update_bn(dnn, dl_x, dl_xy, False) 283 if torch.cuda.is_available(): 284 dnn = self.SWATestDNN(input_features=input_features) 285 dnn.train() 286 self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True) 287 self.assertTrue(dnn.training) 288 289 def test_update_bn_cnn(self): 290 # Test update_bn for convolutional network and BatchNorm2d 291 objects = 100 292 input_channels = 3 293 height, width = 5, 5 294 x = torch.rand(objects, input_channels, height, width) 295 y = torch.rand(objects) 296 ds_x = torch.utils.data.TensorDataset(x) 297 ds_xy = torch.utils.data.TensorDataset(x, y) 298 dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) 299 dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) 300 cnn = self.SWATestCNN(input_channels=input_channels) 301 cnn.train() 302 self._test_update_bn(cnn, dl_x, dl_xy, False) 303 if torch.cuda.is_available(): 304 cnn = self.SWATestCNN(input_channels=input_channels) 305 cnn.train() 306 self._test_update_bn(cnn.cuda(), dl_x, dl_xy, True) 307 self.assertTrue(cnn.training) 308 309 def test_bn_update_eval_momentum(self): 310 # check that update_bn preserves eval mode 311 objects = 100 312 input_channels = 3 313 height, width = 5, 5 314 x = torch.rand(objects, input_channels, height, width) 315 ds_x = torch.utils.data.TensorDataset(x) 316 dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) 317 cnn = self.SWATestCNN(input_channels=input_channels) 318 cnn.eval() 319 update_bn(dl_x, cnn) 320 self.assertFalse(cnn.training) 321 322 # check that momentum is preserved 323 self.assertEqual(cnn.bn.momentum, 0.3) 324 325 326instantiate_parametrized_tests(TestSWAUtils) 327 328 329if __name__ == "__main__": 330 print("These tests should be run through test/test_optim.py instead") 331