xref: /aosp_15_r20/external/pytorch/test/dynamo/test_python_autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2from typing import Callable, Dict, List, NamedTuple, Optional
3
4import torch
5import torch._dynamo
6from torch._dynamo.test_case import run_tests, TestCase
7from torch._dynamo.testing import CompileCounter, same
8
9
10"""
11This is an example of a pure-python version of autograd implemented by
12@zdevito.  It represents a rather challenging test case for TorchDynamo
13to push the limits of what it can do.
14"""
15
16
17_name: int = 0
18
19
20def fresh_name() -> str:
21    """create a new unique name for a variable: v0, v1, v2"""
22    global _name
23    r = f"v{_name}"
24    _name += 1
25    return r
26
27
28class Variable:
29    def __init__(self, value: torch.Tensor, name: str = None):
30        self.value = value
31        self.name = name or fresh_name()
32
33    # We need to start with some tensors whose values were not computed
34    # inside the autograd. This function constructs leaf nodes.
35    @staticmethod
36    def constant(value: torch.Tensor, name: str = None):
37        return Variable(value, name)
38
39    def __repr__(self):
40        return repr(self.value)
41
42    # This performs a pointwise multiplication of a Variable, tracking gradients
43    def __mul__(self, rhs: "Variable") -> "Variable":
44        # defined later in the notebook
45        return operator_mul(self, rhs)
46
47    def __add__(self, rhs: "Variable") -> "Variable":
48        return operator_add(self, rhs)
49
50    def sum(self, name: Optional[str] = None) -> "Variable":
51        return operator_sum(self, name)
52
53    def expand(self, sizes: List[int]) -> "Variable":
54        return operator_expand(self, sizes)
55
56
57class TapeEntry(NamedTuple):
58    # names of the inputs to the original computation
59    inputs: List[str]
60    # names of the outputs of the original computation
61    outputs: List[str]
62    # apply chain rule
63    propagate: "Callable[List[Variable], List[Variable]]"
64
65
66gradient_tape: List[TapeEntry] = []
67
68
69def reset_tape():
70    gradient_tape.clear()
71    global _name
72    _name = 0
73
74
75def grad(L, desired_results: List[Variable]) -> List[Variable]:
76    # this map holds dL/dX for all values X
77    dL_d: Dict[str, Variable] = {}
78    # It starts by initializing the 'seed' dL/dL, which is 1
79    dL_d[L.name] = Variable(torch.ones(()))
80    # print(f'd{L.name} ------------------------')
81
82    # look up dL_dentries. If a variable is never used to compute the loss,
83    # we consider its gradient None, see the note below about zeros for more information.
84    def gather_grad(entries: List[str]):
85        return [dL_d[entry] if entry in dL_d else None for entry in entries]
86
87    # propagate the gradient information backward
88    for entry in reversed(gradient_tape):
89        dL_doutputs = gather_grad(entry.outputs)
90        if all(dL_doutput is None for dL_doutput in dL_doutputs):
91            # optimize for the case where some gradient pathways are zero. See
92            # The note below for more details.
93            continue
94
95        # perform chain rule propagation specific to each compute
96        dL_dinputs = entry.propagate(dL_doutputs)
97
98        # Accumulate the gradient produced for each input.
99        # Each use of a variable produces some gradient dL_dinput for that
100        # use. The multivariate chain rule tells us it is safe to sum
101        # all the contributions together.
102        for input, dL_dinput in zip(entry.inputs, dL_dinputs):
103            if input not in dL_d:
104                dL_d[input] = dL_dinput
105            else:
106                dL_d[input].value += dL_dinput.value
107
108    # print some information to understand the values of each intermediate
109    # for name, value in dL_d.items():
110    #    print(f'd{L.name}_d{name} = {value.name}')
111    # print(f'------------------------')
112
113    return gather_grad(desired.name for desired in desired_results)
114
115
116def operator_mul(self: Variable, rhs: Variable) -> Variable:
117    if isinstance(rhs, float) and rhs == 1.0:
118        # peephole optimization
119        return self
120
121    # define forward
122    r = Variable(self.value * rhs.value)
123    # print(f'{r.name} = {self.name} * {rhs.name}')
124
125    # record what the inputs and outputs of the op were
126    inputs = [self.name, rhs.name]
127    outputs = [r.name]
128
129    # define backprop
130    def propagate(dL_doutputs: List[Variable]):
131        (dL_dr,) = dL_doutputs
132
133        dr_dself = rhs  # partial derivative of r = self*rhs
134        dr_drhs = self  # partial derivative of r = self*rhs
135
136        # chain rule propagation from outputs to inputs of multiply
137        dL_dself = dL_dr * dr_dself
138        dL_drhs = dL_dr * dr_drhs
139        dL_dinputs = [dL_dself, dL_drhs]
140        return dL_dinputs
141
142    # finally, we record the compute we did on the tape
143    gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate))
144    return r
145
146
147def operator_add(self: Variable, rhs: Variable) -> Variable:
148    # Add follows a similar pattern to Mul, but it doesn't end up
149    # capturing any variables.
150    r = Variable(self.value + rhs.value)
151    # print(f'{r.name} = {self.name} + {rhs.name}')
152
153    def propagate(dL_doutputs: List[Variable]):
154        (dL_dr,) = dL_doutputs
155        dr_dself = 1.0
156        dr_drhs = 1.0
157        dL_dself = dL_dr * dr_dself
158        dL_drhs = dL_dr * dr_drhs
159        return [dL_dself, dL_drhs]
160
161    gradient_tape.append(
162        TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate)
163    )
164    return r
165
166
167def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
168    r = Variable(torch.sum(self.value), name=name)
169    # print(f'{r.name} = {self.name}.sum()')
170
171    def propagate(dL_doutputs: List[Variable]):
172        (dL_dr,) = dL_doutputs
173        size = self.value.size()
174        return [dL_dr.expand(*size)]
175
176    gradient_tape.append(
177        TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
178    )
179    return r
180
181
182def operator_expand(self: Variable, sizes: List[int]) -> "Variable":
183    assert self.value.dim() == 0  # only works for scalars
184    r = Variable(self.value.expand(sizes))
185    # print(f'{r.name} = {self.name}.expand({sizes})')
186
187    def propagate(dL_doutputs: List[Variable]):
188        (dL_dr,) = dL_doutputs
189        return [dL_dr.sum()]
190
191    gradient_tape.append(
192        TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
193    )
194    return r
195
196
197def simple(a, b):
198    t = a + b
199    return t * b
200
201
202class TestPythonAutograd(TestCase):
203    def _common(self, fn, expected_ops):
204        args1 = [torch.randn(10), torch.randn(10)]
205        args2 = [torch.randn(10), torch.randn(10)]
206        cnt = CompileCounter()
207        fn_dynamo = torch._dynamo.optimize_assert(cnt)(fn)
208        reset_tape()
209        res1 = fn_dynamo(*args1)
210        reset_tape()
211        res2 = fn_dynamo(*args2)
212        reset_tape()
213        self.assertTrue(same(res1, fn(*args1)))
214        reset_tape()
215        self.assertTrue(same(res2, fn(*args2)))
216        reset_tape()
217        self.assertEqual(cnt.frame_count, 1)
218        self.assertEqual(cnt.op_count, expected_ops)
219
220    def test_forwards1(self):
221        def fn(a, b):
222            a = Variable.constant(a, name="a")
223            b = Variable.constant(b, name="b")
224            loss = simple(a, b).sum()
225            return loss
226
227        self._common(fn, 3)
228
229    def test_forwards2(self):
230        def fn(a, b):
231            reset_tape()
232            a = Variable.constant(a, name="a")
233            b = Variable.constant(b, name="b")
234            loss = simple(a, b).sum()
235            reset_tape()
236            return loss
237
238        self._common(fn, 3)
239
240    def test_backwards1(self):
241        def fn(a, b):
242            a = Variable.constant(a, name="a")
243            b = Variable.constant(b, name="b")
244            loss = simple(a, b).sum()
245            return grad(loss, [a, b])
246
247        self._common(fn, 8)
248
249    def test_backwards2(self):
250        def fn(a, b):
251            reset_tape()
252            a = Variable.constant(a, name="a")
253            b = Variable.constant(b, name="b")
254            loss = simple(a, b).sum()
255            res = grad(loss, [a, b])
256            reset_tape()
257            return res
258
259        self._common(fn, 8)
260
261    def test_split(self):
262        v1 = Variable.constant(torch.randn(10), name="a")
263        v2 = Variable.constant(torch.randn(10), name="b")
264        cnt = CompileCounter()
265
266        def forward(a, b):
267            return simple(a, b).sum()
268
269        reset_tape()
270        loss1 = forward(v1, v2)
271        grad1 = grad(loss1, [v1, v2])
272
273        reset_tape()
274        opt_forward = torch._dynamo.optimize_assert(cnt)(forward)
275        opt_grad = torch._dynamo.optimize_assert(cnt)(grad)
276        loss2 = opt_forward(v1, v2)
277        # force two frames
278        grad2 = opt_grad(loss2, [v1, v2])
279
280        self.assertTrue(same(loss1, loss2))
281        self.assertTrue(same(grad1, grad2))
282        self.assertEqual(cnt.frame_count, 2)
283        self.assertEqual(cnt.op_count, 8)
284
285
286if __name__ == "__main__":
287    run_tests()
288