xref: /aosp_15_r20/external/pytorch/test/dynamo/test_optimizers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Worker"""
3*da0073e9SAndroid Build Coastguard WorkerPYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
4*da0073e9SAndroid Build Coastguard Workerwith test_adam in OptimizerTests)
5*da0073e9SAndroid Build Coastguard Worker"""
6*da0073e9SAndroid Build Coastguard Workerimport functools
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
11*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerclass MyOptimizer(torch.optim.Optimizer):
16*da0073e9SAndroid Build Coastguard Worker    def __init__(self, params):
17*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, {})
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    def _init_group(self, params, group):
20*da0073e9SAndroid Build Coastguard Worker        any_complex = False
21*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
22*da0073e9SAndroid Build Coastguard Worker            params.append(p)
23*da0073e9SAndroid Build Coastguard Worker            any_complex |= p.is_complex()
24*da0073e9SAndroid Build Coastguard Worker        return any_complex
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    def step(self):
27*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
28*da0073e9SAndroid Build Coastguard Worker            params = []
29*da0073e9SAndroid Build Coastguard Worker            any_complex = self._init_group(params, group)
30*da0073e9SAndroid Build Coastguard Worker            if any_complex:
31*da0073e9SAndroid Build Coastguard Worker                params[0] -= 1
32*da0073e9SAndroid Build Coastguard Worker            else:
33*da0073e9SAndroid Build Coastguard Worker                params[0] += 1
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerclass End2EndTests(torch._dynamo.test_case.TestCase):
37*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/torchdynamo/issues/1604
38*da0073e9SAndroid Build Coastguard Worker    def test_optimizing_over_tensor_with_requires_grad(self):
39*da0073e9SAndroid Build Coastguard Worker        class Net(torch.nn.Module):
40*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
41*da0073e9SAndroid Build Coastguard Worker                z = torch.bmm(x, y)
42*da0073e9SAndroid Build Coastguard Worker                z = torch.flatten(z, 1)
43*da0073e9SAndroid Build Coastguard Worker                return z
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        def training_iter_fn(batch, model, optimizer):
46*da0073e9SAndroid Build Coastguard Worker            optimizer.zero_grad()
47*da0073e9SAndroid Build Coastguard Worker            out = model(**batch)
48*da0073e9SAndroid Build Coastguard Worker            target = torch.tensor([0, 7])
49*da0073e9SAndroid Build Coastguard Worker            loss = torch.nn.CrossEntropyLoss()(out, target)
50*da0073e9SAndroid Build Coastguard Worker            loss.backward()
51*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
52*da0073e9SAndroid Build Coastguard Worker            return loss
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker        net = Net()
55*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(2, 1, 4)
56*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(2, 4, 8, requires_grad=True)
57*da0073e9SAndroid Build Coastguard Worker        optimizer = torch.optim.Adam([input2], lr=0.1)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
60*da0073e9SAndroid Build Coastguard Worker        opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn)
61*da0073e9SAndroid Build Coastguard Worker        batch = {"x": input1, "y": input2}
62*da0073e9SAndroid Build Coastguard Worker        for _ in range(2):
63*da0073e9SAndroid Build Coastguard Worker            opt_training_iter_fn(batch, net, optimizer)
64*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    def test_state_dict(self):
67*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
68*da0073e9SAndroid Build Coastguard Worker        def _test_state_dict(weight, bias, input):
69*da0073e9SAndroid Build Coastguard Worker            def fn_base(optimizer, weight, bias):
70*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
71*da0073e9SAndroid Build Coastguard Worker                i = input
72*da0073e9SAndroid Build Coastguard Worker                loss = (weight.mv(i) + bias).pow(2).sum()
73*da0073e9SAndroid Build Coastguard Worker                loss.backward()
74*da0073e9SAndroid Build Coastguard Worker                return loss
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker            optimizer = torch.optim.Adagrad([weight, bias])
77*da0073e9SAndroid Build Coastguard Worker            fn = functools.partial(fn_base, optimizer, weight, bias)
78*da0073e9SAndroid Build Coastguard Worker            return optimizer, fn
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        optimizer, fn = _test_state_dict(
81*da0073e9SAndroid Build Coastguard Worker            Parameter(torch.randn(10, 5)),
82*da0073e9SAndroid Build Coastguard Worker            Parameter(torch.randn(10)),
83*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, requires_grad=True),
84*da0073e9SAndroid Build Coastguard Worker        )
85*da0073e9SAndroid Build Coastguard Worker        optimizer.step(fn)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def test_init_group(self):
88*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float32, torch.cfloat]:
89*da0073e9SAndroid Build Coastguard Worker            tensor = torch.randn(5, 5, dtype=dtype)
90*da0073e9SAndroid Build Coastguard Worker            params = Parameter(tensor.detach().clone(), requires_grad=False)
91*da0073e9SAndroid Build Coastguard Worker            opt_params = Parameter(tensor.detach().clone(), requires_grad=False)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker            optim = MyOptimizer([params])
94*da0073e9SAndroid Build Coastguard Worker            optim.step()
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker            opt_optim = MyOptimizer([opt_params])
97*da0073e9SAndroid Build Coastguard Worker            opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step)
98*da0073e9SAndroid Build Coastguard Worker            opt_step()
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(params, opt_params)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
104*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    run_tests()
107