xref: /aosp_15_r20/external/pytorch/torch/_inductor/virtualized.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This file provides a number of "global" variables/handlers that are actually
4thread local and dynamically scoped, with Inductor patching them to various
5implementations depending on the situation.
6
7These handlers are interacted with in a fairly stylized way.  Typically,
8we will import V from this module::
9
10    from .virtualized import V
11
12Various handlers are accessible as attributes on this module; for example,
13you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with
14a number.
15
16There are a few distinct usage patterns for virtualized global variables:
17
181. Implicit argument passing.  Examples: ``V.current_node``, ``V.aot_compilation``.
19   Use ``V.set_current_node`` to change what the current node is while we're
20   executing some region of code, so code inside that region can query ``V.current_node``
21   to find out what it is.  This is often more convenient than manually threading
22   the current node as an argument through all call stacks.
23
242. Per-compilation global state.  Examples: ``V.fake_mode``, ``V.graph``.  For a
25   given ``compile_fx`` invocation, these typically don't change, but they are
26   associated with some internal state so they cannot just be global functions.
27   We install these objects at the beginning of compilation and then you can
28   conveniently access them without having to pass them around.
29
303. Alternate define-by-run interpretations.  Examples: ``V.ops``, ``V.kernel``.
31   A commonly used IR in Inductor is define-by-run: instead of maintaining
32   explicit syntax data structures, we instead represent loop bodies as
33   callable functions, which internally invoke operations defined on
34   ``V.ops``.  To perform semantic analysis, print or code generate these
35   operations, we dynamically patch ``V.ops`` with an alternate handler with
36   the intended semantics and then run the callable function.  For example, to
37   extract out a traditional (FX) graph representation of the define-by-run
38   IR, simply install a handler that records each ``ops`` call to a graph.
39
40   TODO: Define a parent class / protocol that defines all of the operations
41   V.ops is expected to support.
42
43It is typically an error to access a virtualized global without having installed
44an appropriate handler (you will get a NullHandler), although in some cases we
45provide a default implementation.
46
47One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is
48ubiquitous enough to have its own top level variable, so you will typically see
49``ops.constant(...)`` rather than ``V.ops.constant(...)``.  In fact, these are not
50equivalent; the former interface supports arithmetic overloads like ``x + y``
51instead of forcing ``ops.add(x, y)``, so it should be preferred.
52
53Some operators are seemingly unused, but they are implicitly used by ops_wrapper.
54In particular, we typically have an operator for every basic pointwise PyTorch operation
55supported.
56"""
57
58from __future__ import annotations
59
60from contextlib import AbstractContextManager, contextmanager
61from threading import local
62from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union
63
64from .ops_handler import (  # noqa: F401
65    KernelFormatterHandler,
66    MockHandler,
67    OpsHandler,
68    ReductionType,
69    StoreMode,
70    WrapperHandler,
71)
72
73
74if TYPE_CHECKING:
75    import torch
76    from torch._inductor.codegen.cpp_utils import LocalBufferContext
77    from torch._inductor.debug import DebugContext
78    from torch._inductor.graph import GraphLowering
79    from torch._inductor.loop_body import InterpreterShim
80    from torch._subclasses import FakeTensorMode
81
82threadlocal = local()
83
84T = TypeVar("T")
85
86
87class NullHandler:
88    """
89    Sentinel indicating that a global variable is unset ala None.  Typically,
90    attempting to access the global variable before it's set is an error, but with
91    NullHandler it won't fail until you try to access an attribute on it.
92    """
93
94
95class Virtualized(Generic[T]):
96    """
97    Implements a global variable that redirects via thread local variable
98    (NB: construct this class to create the global variable; this is not
99    a singleton class!)
100
101    This allows us to swap in different op implementations in codegen.
102
103    NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is
104    the default value of the variable), we sometimes use these variables to
105    store other things, like booleans.
106    """
107
108    def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]):
109        self._key: str = f"__torchinductor_{vname}"
110        self._default = default
111
112    def _set_handler(self, value: T) -> AbstractContextManager[None]:
113        prior = self._get_handler()
114        setattr(threadlocal, self._key, value)
115
116        @contextmanager
117        def ctx():
118            try:
119                yield
120            finally:
121                self._set_handler(prior)
122
123        return ctx()
124
125    def _get_handler(self) -> T:
126        try:
127            return getattr(threadlocal, self._key)
128        except AttributeError:
129            # TODO: To be honest, I feel we probably should just error in this
130            # case, instead of making a null handler that will probably error
131            # when you getattr on it
132            return self._default()  # type: ignore[return-value]
133
134    def __getattr__(self, name: str) -> Any:
135        return getattr(self._get_handler(), name)
136
137
138class NullKernelHandler(NullHandler):
139    """
140    We need access `V.kernel.removed_buffers` in DeferredLine class when there
141    is no kernel in the context. This happens when codegening the wrapper.
142    Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't
143    need call 'getattr' with default value which is error prone to typo in
144    attribute name.
145    """
146
147    def __init__(self):
148        super().__init__()
149        self.removed_buffers = set()
150        self.inplaced_to_remove = set()
151        self.index_dtype = "tl.int64"
152
153
154_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
155_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
156_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
157_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
158_kernel: Virtualized[NullKernelHandler] = Virtualized(
159    "kernel", NullKernelHandler
160)  # TODO: improve type
161_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler)
162_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler)
163_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler)
164_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler)
165_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
166    "local_buffer_context", NullHandler
167)
168
169
170class OpsValue:
171    """The return type of most ops calls.
172
173    This exists so we can overload magic methods, and write mathematical
174    expressions much more fluently. So instead of
175
176        ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1)
177
178    we can write
179
180        (_Ap2 * x - _Ap3) * x * x + _1
181
182    """
183
184    value: Any
185
186    def __init__(self, value):
187        self.value = value
188
189    def __str__(self):
190        return str(self.value)
191
192    def __repr__(self):
193        return f"OpsValue({self.value!r})"
194
195    def __add__(self, other):
196        return ops.add(self, other)
197
198    def __mul__(self, other):
199        return ops.mul(self, other)
200
201    def __sub__(self, other):
202        return ops.sub(self, other)
203
204    def __neg__(self):
205        return ops.neg(self)
206
207    def __truediv__(self, other):
208        return ops.truediv(self, other)
209
210    def __floordiv__(self, other):
211        return ops.floordiv(self, other)
212
213    def __mod__(self, other):
214        return ops.mod(self, other)
215
216    def __pow__(self, other):
217        return ops.pow(self, other)
218
219    def __lt__(self, other):
220        return ops.lt(self, other)
221
222    def __le__(self, other):
223        return ops.le(self, other)
224
225    def __eq__(self, other):
226        return ops.eq(self, other)
227
228    def __ne__(self, other):
229        return ops.ne(self, other)
230
231    def __gt__(self, other):
232        return ops.gt(self, other)
233
234    def __ge__(self, other):
235        return ops.ge(self, other)
236
237    def __and__(self, other):
238        return ops.bitwise_and(self, other)
239
240    def __or__(self, other):
241        return ops.bitwise_or(self, other)
242
243    def __xor__(self, other):
244        return ops.bitwise_xor(self, other)
245
246    def __invert__(self):
247        return ops.bitwise_not(self)
248
249    def __rshfit__(self, n):
250        return ops.bitwise_right_shift(self, n)
251
252    def __lshift__(self, n):
253        return ops.bitwise_left_shift(self, n)
254
255
256class OpsWrapper:
257    """This wraps any returned IR values into an `OpsValue` instance, so that we
258    can overload the magic methods for writing mathematical expressions fluently.
259    """
260
261    def __getattr__(self, name):
262        def inner(*args, **kwargs):
263            new_args = [OpsWrapper._unwrap(a) for a in args]
264            new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
265            return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
266
267        return inner
268
269    @staticmethod
270    def _unwrap(x):
271        if isinstance(x, (list, tuple)):
272            return tuple(OpsWrapper._unwrap(v) for v in x)
273        if isinstance(x, OpsValue):
274            return x.value
275        return x
276
277    @staticmethod
278    def _wrap(x):
279        if isinstance(x, (list, tuple)):
280            return tuple(OpsValue(v) for v in x)
281        return OpsValue(x)
282
283    @staticmethod
284    def indirect_indexing(index, size, check=True, wrap_neg=True):
285        # Returns a sympy value, not IR value
286        index = OpsWrapper._unwrap(index)
287        return _ops.indirect_indexing(index, size, check, wrap_neg)
288
289
290ops = OpsWrapper()
291
292
293class _V:
294    MockHandler = MockHandler
295    KernelFormatterHandler = KernelFormatterHandler
296    WrapperHandler = WrapperHandler
297
298    set_ops_handler: Callable[[Any], Any] = _ops._set_handler
299    get_ops_handler: Callable[[], Any] = _ops._get_handler
300    set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
301    set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
302    get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
303    set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
304    get_fake_mode: Callable[[], Any] = _fake_mode._get_handler
305    set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler
306    set_debug_handler: Callable[[Any], Any] = _debug._set_handler
307    set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler
308    set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler
309    get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
310    set_current_node: Callable[[Any], Any] = _current_node._set_handler
311    get_current_node: Callable[[], Any] = _current_node._get_handler
312    set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
313    get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
314
315    @property
316    def ops(self) -> OpsHandler[Any]:
317        """The operator handler specific to the current codegen task"""
318        return _ops._get_handler()
319
320    @property
321    def graph(self) -> GraphLowering:
322        """The graph currently being generated"""
323        return _graph._get_handler()
324
325    @property
326    def real_inputs(self):
327        """non-fake example inputs"""
328        return _real_inputs._get_handler()
329
330    @property
331    def fake_mode(self):
332        """The graph currently being generated"""
333        return _fake_mode._get_handler()
334
335    @property
336    def kernel(self):
337        """The kernel currently being generated"""
338        return _kernel._get_handler()
339
340    @property
341    def debug(self):
342        return _debug._get_handler()
343
344    @property
345    def interpreter(self):
346        return _interpreter._get_handler()
347
348    @property
349    def aot_compilation(self):
350        return _aot_compilation._get_handler()
351
352    @property
353    def current_node(self):
354        return _current_node._get_handler()
355
356    @property
357    def local_buffer_context(self):
358        return _local_buffer_context._get_handler()
359
360
361V = _V()
362