1# mypy: allow-untyped-defs 2import contextlib 3from abc import ABC, abstractmethod 4from typing import Any, List, Tuple 5 6import torch 7import torch.utils._pytree as pytree 8from torch._C._functorch import ( 9 CFunctionalizeInterpreterPtr, 10 CGradInterpreterPtr, 11 CInterpreter, 12 CJvpInterpreterPtr, 13 CVmapInterpreterPtr, 14 pop_dynamic_layer_stack, 15 push_dynamic_layer_stack, 16 RandomnessType, 17 TransformType, 18) 19from torch.autograd.forward_ad import _set_fwd_grad_enabled 20 21 22""" 23This file contains the functorch integration with PyDispatcher. 24 25PyDispatcher does not understand functorch's DynamicLayerStack dispatching 26logic because it is entirely implemented in C++ in the fallbacks for two 27dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable 28to directly reuse C++ boxed fallbacks). 29 30Instead of trying to hammer PyDispatcher into understanding those fallbacks, 31we re-implement the logic of peeking the top of the stack for an interpreter, 32selecting the interpreter to dispatch on, etc, in Python. This leads to a 33simpler design. 34 35The main difference between C++ functorch and PyDispatcher's functorch logic 36is that: 37- C++ functorch needs to manually tweak dispatch keys to ping-pong between 38 DynamicLayerFrontMode and DynamicLayerBackMode. 39- PyDispatcher's functorch logic pops an Interpreter from the top of the stack 40 and asks it to execute the rule associated with the Interpreter. 41 42In C++ we do the ping-pong because e.g. vmap rules are associated with the 43batched DispatchKey, but in PyDispatcher we are able to avoid this by asking 44the user to register a batching rule directly to a transform that an 45interpreter then invokes. 46""" 47 48 49# FuncTorchInterpreter is the Python version of Interpreter (recall that 50# the DynamicLayerStack is a stack of interpreters). 51# It is a wrapper around the actual C++ Interpreter object. 52# 53# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h 54class FuncTorchInterpreter(ABC): 55 def __init__(self, cptr: Any): 56 self._cptr = cptr 57 58 # Process an operation. eg for vmap, this is invoking a batching rule. 59 # Conceptually this is analogous to Interpreter::process in C++ 60 @abstractmethod 61 def process(self, op, args, kwargs): 62 pass 63 64 # lower an operation from this Interpreter to the next Interpreter on the stack. 65 # Concretely, this involves temporarily popping the current Interpreter. 66 # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++ 67 def lower(self): 68 return temporarily_pop_interpreter_stack() 69 70 def level(self): 71 return self._cptr.level() 72 73 def key(self): 74 return self._cptr.key() 75 76 def get_state(self): 77 raise NotImplementedError 78 79 def check_state(self, state): 80 return state == self.get_state() 81 82 83@contextlib.contextmanager 84def temporarily_pop_interpreter_stack(): 85 try: 86 saved = pop_dynamic_layer_stack() 87 yield 88 finally: 89 push_dynamic_layer_stack(saved) 90 91 92@contextlib.contextmanager 93def temporarily_clear_interpreter_stack(): 94 stack = [] 95 try: 96 while torch._C._functorch.peek_interpreter_stack() is not None: 97 stack.append(pop_dynamic_layer_stack()) 98 yield list(stack) 99 finally: 100 while stack: 101 push_dynamic_layer_stack(stack.pop()) 102 103 104@contextlib.contextmanager 105def temporarily_restore_interpreter_stack(stack): 106 pushed = [] 107 try: 108 for s in reversed(stack): 109 push_dynamic_layer_stack(s) 110 pushed.append(s) 111 yield 112 finally: 113 for s in reversed(pushed): 114 # TODO: would be nice to assert that the layers are the same, but 115 # Python object identity is not preserved 116 pop_dynamic_layer_stack() 117 118 119class VmapInterpreter(FuncTorchInterpreter): 120 def __init__(self, cdata: CInterpreter): 121 assert cdata.key() == TransformType.Vmap 122 # NOTE: [Interpreter cdata vs cptr] 123 # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr 124 # so that we can access methods specific to the vmap interpreter 125 self._cdata = cdata 126 self._cptr = CVmapInterpreterPtr(cdata) 127 128 def process(self, op, args, kwargs): 129 kernel = op.functorch_table[TransformType.Vmap] 130 return kernel(self, *args, **kwargs) 131 132 def batch_size(self): 133 return self._cptr.batchSize() 134 135 def randomness(self): 136 typ = self._cptr.randomness() 137 if typ == RandomnessType.Error: 138 return "error" 139 elif typ == RandomnessType.Same: 140 return "same" 141 elif typ == RandomnessType.Different: 142 return "different" 143 raise RuntimeError(f"Unknown RandomnessType: {typ}") 144 145 def get_state(self): 146 return (self.key().name, self.level(), self.randomness()) 147 148 149@contextlib.contextmanager 150def nested(*contexts): 151 with contextlib.ExitStack() as stack: 152 for ctx in contexts: 153 stack.enter_context(ctx) 154 yield contexts 155 156 157class GradInterpreter(FuncTorchInterpreter): 158 def __init__(self, cdata: CInterpreter): 159 assert cdata.key() == TransformType.Grad 160 # See NOTE: [Interpreter cdata vs cptr] 161 self._cdata = cdata 162 self._cptr = CGradInterpreterPtr(cdata) 163 164 def lift(self, args, kwargs): 165 args, kwargs = pytree.tree_map_only( 166 torch.Tensor, self._cptr.lift, [args, kwargs] 167 ) 168 return args, kwargs 169 170 def process(self, op, args, kwargs): 171 kernel = op.functorch_table[TransformType.Grad] 172 args, kwargs = self.lift(args, kwargs) 173 return kernel(self, *args, **kwargs) 174 175 # GradInterpreter has custom lower because of the no_grad interaction 176 # See NOTE [grad and vjp interaction with no_grad] 177 # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter 178 def lower(self): 179 prev_grad_mode = self.prev_grad_mode() 180 if not prev_grad_mode: 181 return nested(torch.no_grad(), super().lower()) 182 return super().lower() 183 184 def prev_grad_mode(self): 185 return self._cptr.prevGradMode() 186 187 def get_state(self): 188 return (self.key().name, self.level(), self.prev_grad_mode()) 189 190 191class JvpInterpreter(FuncTorchInterpreter): 192 def __init__(self, cdata: CInterpreter): 193 assert cdata.key() == TransformType.Jvp 194 # See NOTE: [Interpreter cdata vs cptr] 195 self._cdata = cdata 196 self._cptr = CJvpInterpreterPtr(cdata) 197 198 def lift(self, args, kwargs): 199 args, kwargs = pytree.tree_map_only( 200 torch.Tensor, self._cptr.lift, [args, kwargs] 201 ) 202 return args, kwargs 203 204 def process(self, op, args, kwargs): 205 kernel = op.functorch_table[TransformType.Jvp] 206 args, kwargs = self.lift(args, kwargs) 207 return kernel(self, *args, **kwargs) 208 209 # Jvp has custom lower because of the no_fwd_grad interaction 210 # See NOTE [grad and vjp interaction with no_grad] for related info. 211 # This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter 212 def lower(self): 213 prev_fwd_grad_mode = self.prev_fwd_grad_mode() 214 if not prev_fwd_grad_mode: 215 return nested(_set_fwd_grad_enabled(False), super().lower()) 216 return super().lower() 217 218 def prev_fwd_grad_mode(self): 219 return self._cptr.prevFwdGradMode() 220 221 def get_state(self): 222 return (self.key().name, self.level(), self.prev_fwd_grad_mode()) 223 224 225class FunctionalizeInterpreter(FuncTorchInterpreter): 226 def __init__(self, cdata: CInterpreter): 227 assert cdata.key() == TransformType.Functionalize 228 self._cdata = cdata 229 self._cptr = CFunctionalizeInterpreterPtr(cdata) 230 231 def process(self, op, args, kwargs): 232 kernel = op.functorch_table[TransformType.Functionalize] 233 return kernel(self, *args, **kwargs) 234 235 def functionalize_add_back_views(self): 236 return self._cptr.functionalizeAddBackViews() 237 238 def get_state(self): 239 return (self.key().name, self.level()) 240 241 242def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter: 243 key = cinterpreter.key() 244 if key == TransformType.Grad: 245 return GradInterpreter(cinterpreter) 246 if key == TransformType.Vmap: 247 return VmapInterpreter(cinterpreter) 248 if key == TransformType.Jvp: 249 return JvpInterpreter(cinterpreter) 250 if key == TransformType.Functionalize: 251 return FunctionalizeInterpreter(cinterpreter) 252 raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}") 253 254 255def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter: 256 interpreter = torch._C._functorch.peek_interpreter_stack() 257 assert interpreter is not None 258 return coerce_cinterpreter(interpreter) 259 260 261def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]: 262 cis = torch._C._functorch.get_interpreter_stack() 263 if cis is None: 264 return [] 265 return [coerce_cinterpreter(ci) for ci in cis] 266 267 268def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool: 269 # There are four possible cases covered here: 270 # 1. Current stack empty AND stack when generated not empty -> Invalidate 271 # 2. Current stack not empty AND stack when generated empty -> Invalidate 272 # 3. Current stack and generated stack empty -> Valid FX graph 273 # 4. Current stack and generated stack not empty -> Valid if both states match 274 peek = torch._C._functorch.peek_interpreter_stack() 275 if (peek is None and len(states) != 0) or (peek is not None and len(states) == 0): 276 return False 277 278 cis = retrieve_all_functorch_interpreters() 279 return len(cis) == len(states) and all( 280 ci.check_state(state) for ci, state in zip(cis, states) 281 ) 282 283 284def dispatch_functorch(op, args, kwargs): 285 interpreter = retrieve_current_functorch_interpreter() 286 # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's 287 # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers. 288 # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch 289 # transforms, so we manually unwrap the dead tensors here. 290 # This logic won't need to exist when we have mode-only functorch. 291 args, kwargs = pytree.tree_map_only( 292 torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs) 293 ) 294 return interpreter.process(op, args, kwargs) 295