xref: /aosp_15_r20/external/pytorch/test/optim/test_swa_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: optimizer"]
2
3import itertools
4import pickle
5
6import torch
7from torch.optim.swa_utils import (
8    AveragedModel,
9    get_ema_multi_avg_fn,
10    get_swa_multi_avg_fn,
11    update_bn,
12)
13from torch.testing._internal.common_utils import (
14    instantiate_parametrized_tests,
15    load_tests,
16    parametrize,
17    TestCase,
18)
19
20
21# load_tests from common_utils is used to automatically filter tests for
22# sharding on sandcastle. This line silences flake warnings
23load_tests = load_tests
24
25
26class TestSWAUtils(TestCase):
27    class SWATestDNN(torch.nn.Module):
28        def __init__(self, input_features):
29            super().__init__()
30            self.n_features = 100
31            self.fc1 = torch.nn.Linear(input_features, self.n_features)
32            self.bn = torch.nn.BatchNorm1d(self.n_features)
33
34        def compute_preactivation(self, x):
35            return self.fc1(x)
36
37        def forward(self, x):
38            x = self.fc1(x)
39            x = self.bn(x)
40            return x
41
42    class SWATestCNN(torch.nn.Module):
43        def __init__(self, input_channels):
44            super().__init__()
45            self.n_features = 10
46            self.conv1 = torch.nn.Conv2d(
47                input_channels, self.n_features, kernel_size=3, padding=1
48            )
49            self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3)
50
51        def compute_preactivation(self, x):
52            return self.conv1(x)
53
54        def forward(self, x):
55            x = self.conv1(x)
56            x = self.bn(x)
57            return x
58
59    def _test_averaged_model(self, net_device, swa_device, ema):
60        dnn = torch.nn.Sequential(
61            torch.nn.Conv2d(1, 5, kernel_size=3),
62            torch.nn.ReLU(),
63            torch.nn.MaxPool2d(kernel_size=2),
64            torch.nn.BatchNorm2d(5, momentum=0.3),
65            torch.nn.Conv2d(5, 2, kernel_size=3),
66            torch.nn.ReLU(),
67            torch.nn.Linear(5, 5),
68            torch.nn.ReLU(),
69            torch.nn.Linear(5, 10),
70        ).to(net_device)
71
72        averaged_params, averaged_dnn = self._run_averaged_steps(dnn, swa_device, ema)
73
74        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
75            self.assertEqual(p_avg, p_swa)
76            # Check that AveragedModel is on the correct device
77            self.assertTrue(p_swa.device == swa_device)
78            self.assertTrue(p_avg.device == net_device)
79        self.assertTrue(averaged_dnn.n_averaged.device == swa_device)
80
81    def _run_averaged_steps(self, dnn, swa_device, ema):
82        ema_decay = 0.999
83        if ema:
84            averaged_dnn = AveragedModel(
85                dnn, device=swa_device, multi_avg_fn=get_ema_multi_avg_fn(ema_decay)
86            )
87        else:
88            averaged_dnn = AveragedModel(
89                dnn, device=swa_device, multi_avg_fn=get_swa_multi_avg_fn()
90            )
91
92        averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
93
94        n_updates = 10
95        for i in range(n_updates):
96            for p, p_avg in zip(dnn.parameters(), averaged_params):
97                p.detach().add_(torch.randn_like(p))
98                if ema:
99                    p_avg += (
100                        p.detach()
101                        * ema_decay ** (n_updates - i - 1)
102                        * ((1 - ema_decay) if i > 0 else 1.0)
103                    )
104                else:
105                    p_avg += p.detach() / n_updates
106            averaged_dnn.update_parameters(dnn)
107
108        return averaged_params, averaged_dnn
109
110    @parametrize("ema", [True, False])
111    def test_averaged_model_all_devices(self, ema):
112        cpu = torch.device("cpu")
113        self._test_averaged_model(cpu, cpu, ema)
114        if torch.cuda.is_available():
115            cuda = torch.device(0)
116            self._test_averaged_model(cuda, cpu, ema)
117            self._test_averaged_model(cpu, cuda, ema)
118            self._test_averaged_model(cuda, cuda, ema)
119
120    @parametrize("ema", [True, False])
121    def test_averaged_model_mixed_device(self, ema):
122        if not torch.cuda.is_available():
123            return
124        dnn = torch.nn.Sequential(
125            torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
126        )
127        dnn[0].cuda()
128        dnn[1].cpu()
129
130        averaged_params, averaged_dnn = self._run_averaged_steps(dnn, None, ema)
131
132        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
133            self.assertEqual(p_avg, p_swa)
134            # Check that AveragedModel is on the correct device
135            self.assertTrue(p_avg.device == p_swa.device)
136
137    def test_averaged_model_state_dict(self):
138        dnn = torch.nn.Sequential(
139            torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
140        )
141        averaged_dnn = AveragedModel(dnn)
142        averaged_dnn2 = AveragedModel(dnn)
143        n_updates = 10
144        for i in range(n_updates):
145            for p in dnn.parameters():
146                p.detach().add_(torch.randn_like(p))
147            averaged_dnn.update_parameters(dnn)
148        averaged_dnn2.load_state_dict(averaged_dnn.state_dict())
149        for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
150            self.assertEqual(p_swa, p_swa2)
151        self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)
152
153    def test_averaged_model_default_avg_fn_picklable(self):
154        dnn = torch.nn.Sequential(
155            torch.nn.Conv2d(1, 5, kernel_size=3),
156            torch.nn.BatchNorm2d(5),
157            torch.nn.Linear(5, 5),
158        )
159        averaged_dnn = AveragedModel(dnn)
160        pickle.dumps(averaged_dnn)
161
162    @parametrize("use_multi_avg_fn", [True, False])
163    @parametrize("use_buffers", [True, False])
164    def test_averaged_model_exponential(self, use_multi_avg_fn, use_buffers):
165        # Test AveragedModel with EMA as avg_fn and use_buffers as True.
166        dnn = torch.nn.Sequential(
167            torch.nn.Conv2d(1, 5, kernel_size=3),
168            torch.nn.BatchNorm2d(5, momentum=0.3),
169            torch.nn.Linear(5, 10),
170        )
171        decay = 0.9
172
173        if use_multi_avg_fn:
174            averaged_dnn = AveragedModel(
175                dnn, multi_avg_fn=get_ema_multi_avg_fn(decay), use_buffers=use_buffers
176            )
177        else:
178
179            def avg_fn(p_avg, p, n_avg):
180                return decay * p_avg + (1 - decay) * p
181
182            averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=use_buffers)
183
184        if use_buffers:
185            dnn_params = list(itertools.chain(dnn.parameters(), dnn.buffers()))
186        else:
187            dnn_params = list(dnn.parameters())
188
189        averaged_params = [
190            torch.zeros_like(param)
191            for param in dnn_params
192            if param.size() != torch.Size([])
193        ]
194
195        n_updates = 10
196        for i in range(n_updates):
197            updated_averaged_params = []
198            for p, p_avg in zip(dnn_params, averaged_params):
199                if p.size() == torch.Size([]):
200                    continue
201                p.detach().add_(torch.randn_like(p))
202                if i == 0:
203                    updated_averaged_params.append(p.clone())
204                else:
205                    updated_averaged_params.append(
206                        (p_avg * decay + p * (1 - decay)).clone()
207                    )
208            averaged_dnn.update_parameters(dnn)
209            averaged_params = updated_averaged_params
210
211        if use_buffers:
212            for p_avg, p_swa in zip(
213                averaged_params,
214                itertools.chain(
215                    averaged_dnn.module.parameters(), averaged_dnn.module.buffers()
216                ),
217            ):
218                self.assertEqual(p_avg, p_swa)
219        else:
220            for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
221                self.assertEqual(p_avg, p_swa)
222            for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):
223                self.assertEqual(b_avg, b_swa)
224
225    def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
226        preactivation_sum = torch.zeros(dnn.n_features)
227        preactivation_squared_sum = torch.zeros(dnn.n_features)
228        if cuda:
229            preactivation_sum = preactivation_sum.cuda()
230            preactivation_squared_sum = preactivation_squared_sum.cuda()
231        total_num = 0
232        for x in dl_x:
233            x = x[0]
234            if cuda:
235                x = x.cuda()
236
237            dnn.forward(x)
238            preactivations = dnn.compute_preactivation(x)
239            if len(preactivations.shape) == 4:
240                preactivations = preactivations.transpose(1, 3)
241            preactivations = preactivations.contiguous().view(-1, dnn.n_features)
242            total_num += preactivations.shape[0]
243
244            preactivation_sum += torch.sum(preactivations, dim=0)
245            preactivation_squared_sum += torch.sum(preactivations**2, dim=0)
246
247        preactivation_mean = preactivation_sum / total_num
248        preactivation_var = preactivation_squared_sum / total_num
249        preactivation_var = preactivation_var - preactivation_mean**2
250
251        update_bn(dl_xy, dnn, device=x.device)
252        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
253        self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
254
255        def _reset_bn(module):
256            if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
257                module.running_mean = torch.zeros_like(module.running_mean)
258                module.running_var = torch.ones_like(module.running_var)
259
260        # reset batch norm and run update_bn again
261        dnn.apply(_reset_bn)
262        update_bn(dl_xy, dnn, device=x.device)
263        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
264        self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
265        # using the dl_x loader instead of dl_xy
266        dnn.apply(_reset_bn)
267        update_bn(dl_x, dnn, device=x.device)
268        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
269        self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
270
271    def test_update_bn_dnn(self):
272        # Test update_bn for a fully-connected network with BatchNorm1d
273        objects, input_features = 100, 5
274        x = torch.rand(objects, input_features)
275        y = torch.rand(objects)
276        ds_x = torch.utils.data.TensorDataset(x)
277        ds_xy = torch.utils.data.TensorDataset(x, y)
278        dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
279        dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
280        dnn = self.SWATestDNN(input_features=input_features)
281        dnn.train()
282        self._test_update_bn(dnn, dl_x, dl_xy, False)
283        if torch.cuda.is_available():
284            dnn = self.SWATestDNN(input_features=input_features)
285            dnn.train()
286            self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True)
287        self.assertTrue(dnn.training)
288
289    def test_update_bn_cnn(self):
290        # Test update_bn for convolutional network and BatchNorm2d
291        objects = 100
292        input_channels = 3
293        height, width = 5, 5
294        x = torch.rand(objects, input_channels, height, width)
295        y = torch.rand(objects)
296        ds_x = torch.utils.data.TensorDataset(x)
297        ds_xy = torch.utils.data.TensorDataset(x, y)
298        dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
299        dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
300        cnn = self.SWATestCNN(input_channels=input_channels)
301        cnn.train()
302        self._test_update_bn(cnn, dl_x, dl_xy, False)
303        if torch.cuda.is_available():
304            cnn = self.SWATestCNN(input_channels=input_channels)
305            cnn.train()
306            self._test_update_bn(cnn.cuda(), dl_x, dl_xy, True)
307        self.assertTrue(cnn.training)
308
309    def test_bn_update_eval_momentum(self):
310        # check that update_bn preserves eval mode
311        objects = 100
312        input_channels = 3
313        height, width = 5, 5
314        x = torch.rand(objects, input_channels, height, width)
315        ds_x = torch.utils.data.TensorDataset(x)
316        dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
317        cnn = self.SWATestCNN(input_channels=input_channels)
318        cnn.eval()
319        update_bn(dl_x, cnn)
320        self.assertFalse(cnn.training)
321
322        # check that momentum is preserved
323        self.assertEqual(cnn.bn.momentum, 0.3)
324
325
326instantiate_parametrized_tests(TestSWAUtils)
327
328
329if __name__ == "__main__":
330    print("These tests should be run through test/test_optim.py instead")
331