# Owner(s): ["module: optimizer"] import itertools import pickle import torch from torch.optim.swa_utils import ( AveragedModel, get_ema_multi_avg_fn, get_swa_multi_avg_fn, update_bn, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, load_tests, parametrize, TestCase, ) # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests class TestSWAUtils(TestCase): class SWATestDNN(torch.nn.Module): def __init__(self, input_features): super().__init__() self.n_features = 100 self.fc1 = torch.nn.Linear(input_features, self.n_features) self.bn = torch.nn.BatchNorm1d(self.n_features) def compute_preactivation(self, x): return self.fc1(x) def forward(self, x): x = self.fc1(x) x = self.bn(x) return x class SWATestCNN(torch.nn.Module): def __init__(self, input_channels): super().__init__() self.n_features = 10 self.conv1 = torch.nn.Conv2d( input_channels, self.n_features, kernel_size=3, padding=1 ) self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3) def compute_preactivation(self, x): return self.conv1(x) def forward(self, x): x = self.conv1(x) x = self.bn(x) return x def _test_averaged_model(self, net_device, swa_device, ema): dnn = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2), torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Conv2d(5, 2, kernel_size=3), torch.nn.ReLU(), torch.nn.Linear(5, 5), torch.nn.ReLU(), torch.nn.Linear(5, 10), ).to(net_device) averaged_params, averaged_dnn = self._run_averaged_steps(dnn, swa_device, ema) for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): self.assertEqual(p_avg, p_swa) # Check that AveragedModel is on the correct device self.assertTrue(p_swa.device == swa_device) self.assertTrue(p_avg.device == net_device) self.assertTrue(averaged_dnn.n_averaged.device == swa_device) def _run_averaged_steps(self, dnn, swa_device, ema): ema_decay = 0.999 if ema: averaged_dnn = AveragedModel( dnn, device=swa_device, multi_avg_fn=get_ema_multi_avg_fn(ema_decay) ) else: averaged_dnn = AveragedModel( dnn, device=swa_device, multi_avg_fn=get_swa_multi_avg_fn() ) averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] n_updates = 10 for i in range(n_updates): for p, p_avg in zip(dnn.parameters(), averaged_params): p.detach().add_(torch.randn_like(p)) if ema: p_avg += ( p.detach() * ema_decay ** (n_updates - i - 1) * ((1 - ema_decay) if i > 0 else 1.0) ) else: p_avg += p.detach() / n_updates averaged_dnn.update_parameters(dnn) return averaged_params, averaged_dnn @parametrize("ema", [True, False]) def test_averaged_model_all_devices(self, ema): cpu = torch.device("cpu") self._test_averaged_model(cpu, cpu, ema) if torch.cuda.is_available(): cuda = torch.device(0) self._test_averaged_model(cuda, cpu, ema) self._test_averaged_model(cpu, cuda, ema) self._test_averaged_model(cuda, cuda, ema) @parametrize("ema", [True, False]) def test_averaged_model_mixed_device(self, ema): if not torch.cuda.is_available(): return dnn = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) ) dnn[0].cuda() dnn[1].cpu() averaged_params, averaged_dnn = self._run_averaged_steps(dnn, None, ema) for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): self.assertEqual(p_avg, p_swa) # Check that AveragedModel is on the correct device self.assertTrue(p_avg.device == p_swa.device) def test_averaged_model_state_dict(self): dnn = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) ) averaged_dnn = AveragedModel(dnn) averaged_dnn2 = AveragedModel(dnn) n_updates = 10 for i in range(n_updates): for p in dnn.parameters(): p.detach().add_(torch.randn_like(p)) averaged_dnn.update_parameters(dnn) averaged_dnn2.load_state_dict(averaged_dnn.state_dict()) for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()): self.assertEqual(p_swa, p_swa2) self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged) def test_averaged_model_default_avg_fn_picklable(self): dnn = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.BatchNorm2d(5), torch.nn.Linear(5, 5), ) averaged_dnn = AveragedModel(dnn) pickle.dumps(averaged_dnn) @parametrize("use_multi_avg_fn", [True, False]) @parametrize("use_buffers", [True, False]) def test_averaged_model_exponential(self, use_multi_avg_fn, use_buffers): # Test AveragedModel with EMA as avg_fn and use_buffers as True. dnn = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10), ) decay = 0.9 if use_multi_avg_fn: averaged_dnn = AveragedModel( dnn, multi_avg_fn=get_ema_multi_avg_fn(decay), use_buffers=use_buffers ) else: def avg_fn(p_avg, p, n_avg): return decay * p_avg + (1 - decay) * p averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=use_buffers) if use_buffers: dnn_params = list(itertools.chain(dnn.parameters(), dnn.buffers())) else: dnn_params = list(dnn.parameters()) averaged_params = [ torch.zeros_like(param) for param in dnn_params if param.size() != torch.Size([]) ] n_updates = 10 for i in range(n_updates): updated_averaged_params = [] for p, p_avg in zip(dnn_params, averaged_params): if p.size() == torch.Size([]): continue p.detach().add_(torch.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) else: updated_averaged_params.append( (p_avg * decay + p * (1 - decay)).clone() ) averaged_dnn.update_parameters(dnn) averaged_params = updated_averaged_params if use_buffers: for p_avg, p_swa in zip( averaged_params, itertools.chain( averaged_dnn.module.parameters(), averaged_dnn.module.buffers() ), ): self.assertEqual(p_avg, p_swa) else: for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): self.assertEqual(p_avg, p_swa) for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()): self.assertEqual(b_avg, b_swa) def _test_update_bn(self, dnn, dl_x, dl_xy, cuda): preactivation_sum = torch.zeros(dnn.n_features) preactivation_squared_sum = torch.zeros(dnn.n_features) if cuda: preactivation_sum = preactivation_sum.cuda() preactivation_squared_sum = preactivation_squared_sum.cuda() total_num = 0 for x in dl_x: x = x[0] if cuda: x = x.cuda() dnn.forward(x) preactivations = dnn.compute_preactivation(x) if len(preactivations.shape) == 4: preactivations = preactivations.transpose(1, 3) preactivations = preactivations.contiguous().view(-1, dnn.n_features) total_num += preactivations.shape[0] preactivation_sum += torch.sum(preactivations, dim=0) preactivation_squared_sum += torch.sum(preactivations**2, dim=0) preactivation_mean = preactivation_sum / total_num preactivation_var = preactivation_squared_sum / total_num preactivation_var = preactivation_var - preactivation_mean**2 update_bn(dl_xy, dnn, device=x.device) self.assertEqual(preactivation_mean, dnn.bn.running_mean) self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) def _reset_bn(module): if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): module.running_mean = torch.zeros_like(module.running_mean) module.running_var = torch.ones_like(module.running_var) # reset batch norm and run update_bn again dnn.apply(_reset_bn) update_bn(dl_xy, dnn, device=x.device) self.assertEqual(preactivation_mean, dnn.bn.running_mean) self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) # using the dl_x loader instead of dl_xy dnn.apply(_reset_bn) update_bn(dl_x, dnn, device=x.device) self.assertEqual(preactivation_mean, dnn.bn.running_mean) self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) def test_update_bn_dnn(self): # Test update_bn for a fully-connected network with BatchNorm1d objects, input_features = 100, 5 x = torch.rand(objects, input_features) y = torch.rand(objects) ds_x = torch.utils.data.TensorDataset(x) ds_xy = torch.utils.data.TensorDataset(x, y) dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) dnn = self.SWATestDNN(input_features=input_features) dnn.train() self._test_update_bn(dnn, dl_x, dl_xy, False) if torch.cuda.is_available(): dnn = self.SWATestDNN(input_features=input_features) dnn.train() self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True) self.assertTrue(dnn.training) def test_update_bn_cnn(self): # Test update_bn for convolutional network and BatchNorm2d objects = 100 input_channels = 3 height, width = 5, 5 x = torch.rand(objects, input_channels, height, width) y = torch.rand(objects) ds_x = torch.utils.data.TensorDataset(x) ds_xy = torch.utils.data.TensorDataset(x, y) dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) cnn = self.SWATestCNN(input_channels=input_channels) cnn.train() self._test_update_bn(cnn, dl_x, dl_xy, False) if torch.cuda.is_available(): cnn = self.SWATestCNN(input_channels=input_channels) cnn.train() self._test_update_bn(cnn.cuda(), dl_x, dl_xy, True) self.assertTrue(cnn.training) def test_bn_update_eval_momentum(self): # check that update_bn preserves eval mode objects = 100 input_channels = 3 height, width = 5, 5 x = torch.rand(objects, input_channels, height, width) ds_x = torch.utils.data.TensorDataset(x) dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) cnn = self.SWATestCNN(input_channels=input_channels) cnn.eval() update_bn(dl_x, cnn) self.assertFalse(cnn.training) # check that momentum is preserved self.assertEqual(cnn.bn.momentum, 0.3) instantiate_parametrized_tests(TestSWAUtils) if __name__ == "__main__": print("These tests should be run through test/test_optim.py instead")