1# Owner(s): ["module: dynamo"] 2from typing import Callable, Dict, List, NamedTuple, Optional 3 4import torch 5import torch._dynamo 6from torch._dynamo.test_case import run_tests, TestCase 7from torch._dynamo.testing import CompileCounter, same 8 9 10""" 11This is an example of a pure-python version of autograd implemented by 12@zdevito. It represents a rather challenging test case for TorchDynamo 13to push the limits of what it can do. 14""" 15 16 17_name: int = 0 18 19 20def fresh_name() -> str: 21 """create a new unique name for a variable: v0, v1, v2""" 22 global _name 23 r = f"v{_name}" 24 _name += 1 25 return r 26 27 28class Variable: 29 def __init__(self, value: torch.Tensor, name: str = None): 30 self.value = value 31 self.name = name or fresh_name() 32 33 # We need to start with some tensors whose values were not computed 34 # inside the autograd. This function constructs leaf nodes. 35 @staticmethod 36 def constant(value: torch.Tensor, name: str = None): 37 return Variable(value, name) 38 39 def __repr__(self): 40 return repr(self.value) 41 42 # This performs a pointwise multiplication of a Variable, tracking gradients 43 def __mul__(self, rhs: "Variable") -> "Variable": 44 # defined later in the notebook 45 return operator_mul(self, rhs) 46 47 def __add__(self, rhs: "Variable") -> "Variable": 48 return operator_add(self, rhs) 49 50 def sum(self, name: Optional[str] = None) -> "Variable": 51 return operator_sum(self, name) 52 53 def expand(self, sizes: List[int]) -> "Variable": 54 return operator_expand(self, sizes) 55 56 57class TapeEntry(NamedTuple): 58 # names of the inputs to the original computation 59 inputs: List[str] 60 # names of the outputs of the original computation 61 outputs: List[str] 62 # apply chain rule 63 propagate: "Callable[List[Variable], List[Variable]]" 64 65 66gradient_tape: List[TapeEntry] = [] 67 68 69def reset_tape(): 70 gradient_tape.clear() 71 global _name 72 _name = 0 73 74 75def grad(L, desired_results: List[Variable]) -> List[Variable]: 76 # this map holds dL/dX for all values X 77 dL_d: Dict[str, Variable] = {} 78 # It starts by initializing the 'seed' dL/dL, which is 1 79 dL_d[L.name] = Variable(torch.ones(())) 80 # print(f'd{L.name} ------------------------') 81 82 # look up dL_dentries. If a variable is never used to compute the loss, 83 # we consider its gradient None, see the note below about zeros for more information. 84 def gather_grad(entries: List[str]): 85 return [dL_d[entry] if entry in dL_d else None for entry in entries] 86 87 # propagate the gradient information backward 88 for entry in reversed(gradient_tape): 89 dL_doutputs = gather_grad(entry.outputs) 90 if all(dL_doutput is None for dL_doutput in dL_doutputs): 91 # optimize for the case where some gradient pathways are zero. See 92 # The note below for more details. 93 continue 94 95 # perform chain rule propagation specific to each compute 96 dL_dinputs = entry.propagate(dL_doutputs) 97 98 # Accumulate the gradient produced for each input. 99 # Each use of a variable produces some gradient dL_dinput for that 100 # use. The multivariate chain rule tells us it is safe to sum 101 # all the contributions together. 102 for input, dL_dinput in zip(entry.inputs, dL_dinputs): 103 if input not in dL_d: 104 dL_d[input] = dL_dinput 105 else: 106 dL_d[input].value += dL_dinput.value 107 108 # print some information to understand the values of each intermediate 109 # for name, value in dL_d.items(): 110 # print(f'd{L.name}_d{name} = {value.name}') 111 # print(f'------------------------') 112 113 return gather_grad(desired.name for desired in desired_results) 114 115 116def operator_mul(self: Variable, rhs: Variable) -> Variable: 117 if isinstance(rhs, float) and rhs == 1.0: 118 # peephole optimization 119 return self 120 121 # define forward 122 r = Variable(self.value * rhs.value) 123 # print(f'{r.name} = {self.name} * {rhs.name}') 124 125 # record what the inputs and outputs of the op were 126 inputs = [self.name, rhs.name] 127 outputs = [r.name] 128 129 # define backprop 130 def propagate(dL_doutputs: List[Variable]): 131 (dL_dr,) = dL_doutputs 132 133 dr_dself = rhs # partial derivative of r = self*rhs 134 dr_drhs = self # partial derivative of r = self*rhs 135 136 # chain rule propagation from outputs to inputs of multiply 137 dL_dself = dL_dr * dr_dself 138 dL_drhs = dL_dr * dr_drhs 139 dL_dinputs = [dL_dself, dL_drhs] 140 return dL_dinputs 141 142 # finally, we record the compute we did on the tape 143 gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate)) 144 return r 145 146 147def operator_add(self: Variable, rhs: Variable) -> Variable: 148 # Add follows a similar pattern to Mul, but it doesn't end up 149 # capturing any variables. 150 r = Variable(self.value + rhs.value) 151 # print(f'{r.name} = {self.name} + {rhs.name}') 152 153 def propagate(dL_doutputs: List[Variable]): 154 (dL_dr,) = dL_doutputs 155 dr_dself = 1.0 156 dr_drhs = 1.0 157 dL_dself = dL_dr * dr_dself 158 dL_drhs = dL_dr * dr_drhs 159 return [dL_dself, dL_drhs] 160 161 gradient_tape.append( 162 TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate) 163 ) 164 return r 165 166 167def operator_sum(self: Variable, name: Optional[str]) -> "Variable": 168 r = Variable(torch.sum(self.value), name=name) 169 # print(f'{r.name} = {self.name}.sum()') 170 171 def propagate(dL_doutputs: List[Variable]): 172 (dL_dr,) = dL_doutputs 173 size = self.value.size() 174 return [dL_dr.expand(*size)] 175 176 gradient_tape.append( 177 TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate) 178 ) 179 return r 180 181 182def operator_expand(self: Variable, sizes: List[int]) -> "Variable": 183 assert self.value.dim() == 0 # only works for scalars 184 r = Variable(self.value.expand(sizes)) 185 # print(f'{r.name} = {self.name}.expand({sizes})') 186 187 def propagate(dL_doutputs: List[Variable]): 188 (dL_dr,) = dL_doutputs 189 return [dL_dr.sum()] 190 191 gradient_tape.append( 192 TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate) 193 ) 194 return r 195 196 197def simple(a, b): 198 t = a + b 199 return t * b 200 201 202class TestPythonAutograd(TestCase): 203 def _common(self, fn, expected_ops): 204 args1 = [torch.randn(10), torch.randn(10)] 205 args2 = [torch.randn(10), torch.randn(10)] 206 cnt = CompileCounter() 207 fn_dynamo = torch._dynamo.optimize_assert(cnt)(fn) 208 reset_tape() 209 res1 = fn_dynamo(*args1) 210 reset_tape() 211 res2 = fn_dynamo(*args2) 212 reset_tape() 213 self.assertTrue(same(res1, fn(*args1))) 214 reset_tape() 215 self.assertTrue(same(res2, fn(*args2))) 216 reset_tape() 217 self.assertEqual(cnt.frame_count, 1) 218 self.assertEqual(cnt.op_count, expected_ops) 219 220 def test_forwards1(self): 221 def fn(a, b): 222 a = Variable.constant(a, name="a") 223 b = Variable.constant(b, name="b") 224 loss = simple(a, b).sum() 225 return loss 226 227 self._common(fn, 3) 228 229 def test_forwards2(self): 230 def fn(a, b): 231 reset_tape() 232 a = Variable.constant(a, name="a") 233 b = Variable.constant(b, name="b") 234 loss = simple(a, b).sum() 235 reset_tape() 236 return loss 237 238 self._common(fn, 3) 239 240 def test_backwards1(self): 241 def fn(a, b): 242 a = Variable.constant(a, name="a") 243 b = Variable.constant(b, name="b") 244 loss = simple(a, b).sum() 245 return grad(loss, [a, b]) 246 247 self._common(fn, 8) 248 249 def test_backwards2(self): 250 def fn(a, b): 251 reset_tape() 252 a = Variable.constant(a, name="a") 253 b = Variable.constant(b, name="b") 254 loss = simple(a, b).sum() 255 res = grad(loss, [a, b]) 256 reset_tape() 257 return res 258 259 self._common(fn, 8) 260 261 def test_split(self): 262 v1 = Variable.constant(torch.randn(10), name="a") 263 v2 = Variable.constant(torch.randn(10), name="b") 264 cnt = CompileCounter() 265 266 def forward(a, b): 267 return simple(a, b).sum() 268 269 reset_tape() 270 loss1 = forward(v1, v2) 271 grad1 = grad(loss1, [v1, v2]) 272 273 reset_tape() 274 opt_forward = torch._dynamo.optimize_assert(cnt)(forward) 275 opt_grad = torch._dynamo.optimize_assert(cnt)(grad) 276 loss2 = opt_forward(v1, v2) 277 # force two frames 278 grad2 = opt_grad(loss2, [v1, v2]) 279 280 self.assertTrue(same(loss1, loss2)) 281 self.assertTrue(same(grad1, grad2)) 282 self.assertEqual(cnt.frame_count, 2) 283 self.assertEqual(cnt.op_count, 8) 284 285 286if __name__ == "__main__": 287 run_tests() 288