xref: /aosp_15_r20/external/pytorch/torch/_functorch/pyfunctorch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3from abc import ABC, abstractmethod
4from typing import Any, List, Tuple
5
6import torch
7import torch.utils._pytree as pytree
8from torch._C._functorch import (
9    CFunctionalizeInterpreterPtr,
10    CGradInterpreterPtr,
11    CInterpreter,
12    CJvpInterpreterPtr,
13    CVmapInterpreterPtr,
14    pop_dynamic_layer_stack,
15    push_dynamic_layer_stack,
16    RandomnessType,
17    TransformType,
18)
19from torch.autograd.forward_ad import _set_fwd_grad_enabled
20
21
22"""
23This file contains the functorch integration with PyDispatcher.
24
25PyDispatcher does not understand functorch's DynamicLayerStack dispatching
26logic because it is entirely implemented in C++ in the fallbacks for two
27dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable
28to directly reuse C++ boxed fallbacks).
29
30Instead of trying to hammer PyDispatcher into understanding those fallbacks,
31we re-implement the logic of peeking the top of the stack for an interpreter,
32selecting the interpreter to dispatch on, etc, in Python. This leads to a
33simpler design.
34
35The main difference between C++ functorch and PyDispatcher's functorch logic
36is that:
37- C++ functorch needs to manually tweak dispatch keys to ping-pong between
38  DynamicLayerFrontMode and DynamicLayerBackMode.
39- PyDispatcher's functorch logic pops an Interpreter from the top of the stack
40  and asks it to execute the rule associated with the Interpreter.
41
42In C++ we do the ping-pong because e.g. vmap rules are associated with the
43batched DispatchKey, but in PyDispatcher we are able to avoid this by asking
44the user to register a batching rule directly to a transform that an
45interpreter then invokes.
46"""
47
48
49# FuncTorchInterpreter is the Python version of Interpreter (recall that
50# the DynamicLayerStack is a stack of interpreters).
51# It is a wrapper around the actual C++ Interpreter object.
52#
53# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h
54class FuncTorchInterpreter(ABC):
55    def __init__(self, cptr: Any):
56        self._cptr = cptr
57
58    # Process an operation. eg for vmap, this is invoking a batching rule.
59    # Conceptually this is analogous to Interpreter::process in C++
60    @abstractmethod
61    def process(self, op, args, kwargs):
62        pass
63
64    # lower an operation from this Interpreter to the next Interpreter on the stack.
65    # Concretely, this involves temporarily popping the current Interpreter.
66    # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++
67    def lower(self):
68        return temporarily_pop_interpreter_stack()
69
70    def level(self):
71        return self._cptr.level()
72
73    def key(self):
74        return self._cptr.key()
75
76    def get_state(self):
77        raise NotImplementedError
78
79    def check_state(self, state):
80        return state == self.get_state()
81
82
83@contextlib.contextmanager
84def temporarily_pop_interpreter_stack():
85    try:
86        saved = pop_dynamic_layer_stack()
87        yield
88    finally:
89        push_dynamic_layer_stack(saved)
90
91
92@contextlib.contextmanager
93def temporarily_clear_interpreter_stack():
94    stack = []
95    try:
96        while torch._C._functorch.peek_interpreter_stack() is not None:
97            stack.append(pop_dynamic_layer_stack())
98        yield list(stack)
99    finally:
100        while stack:
101            push_dynamic_layer_stack(stack.pop())
102
103
104@contextlib.contextmanager
105def temporarily_restore_interpreter_stack(stack):
106    pushed = []
107    try:
108        for s in reversed(stack):
109            push_dynamic_layer_stack(s)
110            pushed.append(s)
111        yield
112    finally:
113        for s in reversed(pushed):
114            # TODO: would be nice to assert that the layers are the same, but
115            # Python object identity is not preserved
116            pop_dynamic_layer_stack()
117
118
119class VmapInterpreter(FuncTorchInterpreter):
120    def __init__(self, cdata: CInterpreter):
121        assert cdata.key() == TransformType.Vmap
122        # NOTE: [Interpreter cdata vs cptr]
123        # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
124        # so that we can access methods specific to the vmap interpreter
125        self._cdata = cdata
126        self._cptr = CVmapInterpreterPtr(cdata)
127
128    def process(self, op, args, kwargs):
129        kernel = op.functorch_table[TransformType.Vmap]
130        return kernel(self, *args, **kwargs)
131
132    def batch_size(self):
133        return self._cptr.batchSize()
134
135    def randomness(self):
136        typ = self._cptr.randomness()
137        if typ == RandomnessType.Error:
138            return "error"
139        elif typ == RandomnessType.Same:
140            return "same"
141        elif typ == RandomnessType.Different:
142            return "different"
143        raise RuntimeError(f"Unknown RandomnessType: {typ}")
144
145    def get_state(self):
146        return (self.key().name, self.level(), self.randomness())
147
148
149@contextlib.contextmanager
150def nested(*contexts):
151    with contextlib.ExitStack() as stack:
152        for ctx in contexts:
153            stack.enter_context(ctx)
154        yield contexts
155
156
157class GradInterpreter(FuncTorchInterpreter):
158    def __init__(self, cdata: CInterpreter):
159        assert cdata.key() == TransformType.Grad
160        # See NOTE: [Interpreter cdata vs cptr]
161        self._cdata = cdata
162        self._cptr = CGradInterpreterPtr(cdata)
163
164    def lift(self, args, kwargs):
165        args, kwargs = pytree.tree_map_only(
166            torch.Tensor, self._cptr.lift, [args, kwargs]
167        )
168        return args, kwargs
169
170    def process(self, op, args, kwargs):
171        kernel = op.functorch_table[TransformType.Grad]
172        args, kwargs = self.lift(args, kwargs)
173        return kernel(self, *args, **kwargs)
174
175    # GradInterpreter has custom lower because of the no_grad interaction
176    # See NOTE [grad and vjp interaction with no_grad]
177    # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter
178    def lower(self):
179        prev_grad_mode = self.prev_grad_mode()
180        if not prev_grad_mode:
181            return nested(torch.no_grad(), super().lower())
182        return super().lower()
183
184    def prev_grad_mode(self):
185        return self._cptr.prevGradMode()
186
187    def get_state(self):
188        return (self.key().name, self.level(), self.prev_grad_mode())
189
190
191class JvpInterpreter(FuncTorchInterpreter):
192    def __init__(self, cdata: CInterpreter):
193        assert cdata.key() == TransformType.Jvp
194        # See NOTE: [Interpreter cdata vs cptr]
195        self._cdata = cdata
196        self._cptr = CJvpInterpreterPtr(cdata)
197
198    def lift(self, args, kwargs):
199        args, kwargs = pytree.tree_map_only(
200            torch.Tensor, self._cptr.lift, [args, kwargs]
201        )
202        return args, kwargs
203
204    def process(self, op, args, kwargs):
205        kernel = op.functorch_table[TransformType.Jvp]
206        args, kwargs = self.lift(args, kwargs)
207        return kernel(self, *args, **kwargs)
208
209    # Jvp has custom lower because of the no_fwd_grad interaction
210    # See NOTE [grad and vjp interaction with no_grad] for related info.
211    # This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter
212    def lower(self):
213        prev_fwd_grad_mode = self.prev_fwd_grad_mode()
214        if not prev_fwd_grad_mode:
215            return nested(_set_fwd_grad_enabled(False), super().lower())
216        return super().lower()
217
218    def prev_fwd_grad_mode(self):
219        return self._cptr.prevFwdGradMode()
220
221    def get_state(self):
222        return (self.key().name, self.level(), self.prev_fwd_grad_mode())
223
224
225class FunctionalizeInterpreter(FuncTorchInterpreter):
226    def __init__(self, cdata: CInterpreter):
227        assert cdata.key() == TransformType.Functionalize
228        self._cdata = cdata
229        self._cptr = CFunctionalizeInterpreterPtr(cdata)
230
231    def process(self, op, args, kwargs):
232        kernel = op.functorch_table[TransformType.Functionalize]
233        return kernel(self, *args, **kwargs)
234
235    def functionalize_add_back_views(self):
236        return self._cptr.functionalizeAddBackViews()
237
238    def get_state(self):
239        return (self.key().name, self.level())
240
241
242def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
243    key = cinterpreter.key()
244    if key == TransformType.Grad:
245        return GradInterpreter(cinterpreter)
246    if key == TransformType.Vmap:
247        return VmapInterpreter(cinterpreter)
248    if key == TransformType.Jvp:
249        return JvpInterpreter(cinterpreter)
250    if key == TransformType.Functionalize:
251        return FunctionalizeInterpreter(cinterpreter)
252    raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
253
254
255def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter:
256    interpreter = torch._C._functorch.peek_interpreter_stack()
257    assert interpreter is not None
258    return coerce_cinterpreter(interpreter)
259
260
261def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]:
262    cis = torch._C._functorch.get_interpreter_stack()
263    if cis is None:
264        return []
265    return [coerce_cinterpreter(ci) for ci in cis]
266
267
268def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool:
269    # There are four possible cases covered here:
270    # 1. Current stack empty AND stack when generated not empty -> Invalidate
271    # 2. Current stack not empty AND stack when generated empty -> Invalidate
272    # 3. Current stack and generated stack empty -> Valid FX graph
273    # 4. Current stack and generated stack not empty -> Valid if both states match
274    peek = torch._C._functorch.peek_interpreter_stack()
275    if (peek is None and len(states) != 0) or (peek is not None and len(states) == 0):
276        return False
277
278    cis = retrieve_all_functorch_interpreters()
279    return len(cis) == len(states) and all(
280        ci.check_state(state) for ci, state in zip(cis, states)
281    )
282
283
284def dispatch_functorch(op, args, kwargs):
285    interpreter = retrieve_current_functorch_interpreter()
286    # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
287    # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
288    # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
289    # transforms, so we manually unwrap the dead tensors here.
290    # This logic won't need to exist when we have mode-only functorch.
291    args, kwargs = pytree.tree_map_only(
292        torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs)
293    )
294    return interpreter.process(op, args, kwargs)
295