1# mypy: allow-untyped-defs 2""" 3This file provides a number of "global" variables/handlers that are actually 4thread local and dynamically scoped, with Inductor patching them to various 5implementations depending on the situation. 6 7These handlers are interacted with in a fairly stylized way. Typically, 8we will import V from this module:: 9 10 from .virtualized import V 11 12Various handlers are accessible as attributes on this module; for example, 13you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with 14a number. 15 16There are a few distinct usage patterns for virtualized global variables: 17 181. Implicit argument passing. Examples: ``V.current_node``, ``V.aot_compilation``. 19 Use ``V.set_current_node`` to change what the current node is while we're 20 executing some region of code, so code inside that region can query ``V.current_node`` 21 to find out what it is. This is often more convenient than manually threading 22 the current node as an argument through all call stacks. 23 242. Per-compilation global state. Examples: ``V.fake_mode``, ``V.graph``. For a 25 given ``compile_fx`` invocation, these typically don't change, but they are 26 associated with some internal state so they cannot just be global functions. 27 We install these objects at the beginning of compilation and then you can 28 conveniently access them without having to pass them around. 29 303. Alternate define-by-run interpretations. Examples: ``V.ops``, ``V.kernel``. 31 A commonly used IR in Inductor is define-by-run: instead of maintaining 32 explicit syntax data structures, we instead represent loop bodies as 33 callable functions, which internally invoke operations defined on 34 ``V.ops``. To perform semantic analysis, print or code generate these 35 operations, we dynamically patch ``V.ops`` with an alternate handler with 36 the intended semantics and then run the callable function. For example, to 37 extract out a traditional (FX) graph representation of the define-by-run 38 IR, simply install a handler that records each ``ops`` call to a graph. 39 40 TODO: Define a parent class / protocol that defines all of the operations 41 V.ops is expected to support. 42 43It is typically an error to access a virtualized global without having installed 44an appropriate handler (you will get a NullHandler), although in some cases we 45provide a default implementation. 46 47One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is 48ubiquitous enough to have its own top level variable, so you will typically see 49``ops.constant(...)`` rather than ``V.ops.constant(...)``. In fact, these are not 50equivalent; the former interface supports arithmetic overloads like ``x + y`` 51instead of forcing ``ops.add(x, y)``, so it should be preferred. 52 53Some operators are seemingly unused, but they are implicitly used by ops_wrapper. 54In particular, we typically have an operator for every basic pointwise PyTorch operation 55supported. 56""" 57 58from __future__ import annotations 59 60from contextlib import AbstractContextManager, contextmanager 61from threading import local 62from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union 63 64from .ops_handler import ( # noqa: F401 65 KernelFormatterHandler, 66 MockHandler, 67 OpsHandler, 68 ReductionType, 69 StoreMode, 70 WrapperHandler, 71) 72 73 74if TYPE_CHECKING: 75 import torch 76 from torch._inductor.codegen.cpp_utils import LocalBufferContext 77 from torch._inductor.debug import DebugContext 78 from torch._inductor.graph import GraphLowering 79 from torch._inductor.loop_body import InterpreterShim 80 from torch._subclasses import FakeTensorMode 81 82threadlocal = local() 83 84T = TypeVar("T") 85 86 87class NullHandler: 88 """ 89 Sentinel indicating that a global variable is unset ala None. Typically, 90 attempting to access the global variable before it's set is an error, but with 91 NullHandler it won't fail until you try to access an attribute on it. 92 """ 93 94 95class Virtualized(Generic[T]): 96 """ 97 Implements a global variable that redirects via thread local variable 98 (NB: construct this class to create the global variable; this is not 99 a singleton class!) 100 101 This allows us to swap in different op implementations in codegen. 102 103 NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is 104 the default value of the variable), we sometimes use these variables to 105 store other things, like booleans. 106 """ 107 108 def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]): 109 self._key: str = f"__torchinductor_{vname}" 110 self._default = default 111 112 def _set_handler(self, value: T) -> AbstractContextManager[None]: 113 prior = self._get_handler() 114 setattr(threadlocal, self._key, value) 115 116 @contextmanager 117 def ctx(): 118 try: 119 yield 120 finally: 121 self._set_handler(prior) 122 123 return ctx() 124 125 def _get_handler(self) -> T: 126 try: 127 return getattr(threadlocal, self._key) 128 except AttributeError: 129 # TODO: To be honest, I feel we probably should just error in this 130 # case, instead of making a null handler that will probably error 131 # when you getattr on it 132 return self._default() # type: ignore[return-value] 133 134 def __getattr__(self, name: str) -> Any: 135 return getattr(self._get_handler(), name) 136 137 138class NullKernelHandler(NullHandler): 139 """ 140 We need access `V.kernel.removed_buffers` in DeferredLine class when there 141 is no kernel in the context. This happens when codegening the wrapper. 142 Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't 143 need call 'getattr' with default value which is error prone to typo in 144 attribute name. 145 """ 146 147 def __init__(self): 148 super().__init__() 149 self.removed_buffers = set() 150 self.inplaced_to_remove = set() 151 self.index_dtype = "tl.int64" 152 153 154_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler) 155_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) 156_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler) 157_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) 158_kernel: Virtualized[NullKernelHandler] = Virtualized( 159 "kernel", NullKernelHandler 160) # TODO: improve type 161_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler) 162_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler) 163_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler) 164_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler) 165_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized( 166 "local_buffer_context", NullHandler 167) 168 169 170class OpsValue: 171 """The return type of most ops calls. 172 173 This exists so we can overload magic methods, and write mathematical 174 expressions much more fluently. So instead of 175 176 ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1) 177 178 we can write 179 180 (_Ap2 * x - _Ap3) * x * x + _1 181 182 """ 183 184 value: Any 185 186 def __init__(self, value): 187 self.value = value 188 189 def __str__(self): 190 return str(self.value) 191 192 def __repr__(self): 193 return f"OpsValue({self.value!r})" 194 195 def __add__(self, other): 196 return ops.add(self, other) 197 198 def __mul__(self, other): 199 return ops.mul(self, other) 200 201 def __sub__(self, other): 202 return ops.sub(self, other) 203 204 def __neg__(self): 205 return ops.neg(self) 206 207 def __truediv__(self, other): 208 return ops.truediv(self, other) 209 210 def __floordiv__(self, other): 211 return ops.floordiv(self, other) 212 213 def __mod__(self, other): 214 return ops.mod(self, other) 215 216 def __pow__(self, other): 217 return ops.pow(self, other) 218 219 def __lt__(self, other): 220 return ops.lt(self, other) 221 222 def __le__(self, other): 223 return ops.le(self, other) 224 225 def __eq__(self, other): 226 return ops.eq(self, other) 227 228 def __ne__(self, other): 229 return ops.ne(self, other) 230 231 def __gt__(self, other): 232 return ops.gt(self, other) 233 234 def __ge__(self, other): 235 return ops.ge(self, other) 236 237 def __and__(self, other): 238 return ops.bitwise_and(self, other) 239 240 def __or__(self, other): 241 return ops.bitwise_or(self, other) 242 243 def __xor__(self, other): 244 return ops.bitwise_xor(self, other) 245 246 def __invert__(self): 247 return ops.bitwise_not(self) 248 249 def __rshfit__(self, n): 250 return ops.bitwise_right_shift(self, n) 251 252 def __lshift__(self, n): 253 return ops.bitwise_left_shift(self, n) 254 255 256class OpsWrapper: 257 """This wraps any returned IR values into an `OpsValue` instance, so that we 258 can overload the magic methods for writing mathematical expressions fluently. 259 """ 260 261 def __getattr__(self, name): 262 def inner(*args, **kwargs): 263 new_args = [OpsWrapper._unwrap(a) for a in args] 264 new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} 265 return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) 266 267 return inner 268 269 @staticmethod 270 def _unwrap(x): 271 if isinstance(x, (list, tuple)): 272 return tuple(OpsWrapper._unwrap(v) for v in x) 273 if isinstance(x, OpsValue): 274 return x.value 275 return x 276 277 @staticmethod 278 def _wrap(x): 279 if isinstance(x, (list, tuple)): 280 return tuple(OpsValue(v) for v in x) 281 return OpsValue(x) 282 283 @staticmethod 284 def indirect_indexing(index, size, check=True, wrap_neg=True): 285 # Returns a sympy value, not IR value 286 index = OpsWrapper._unwrap(index) 287 return _ops.indirect_indexing(index, size, check, wrap_neg) 288 289 290ops = OpsWrapper() 291 292 293class _V: 294 MockHandler = MockHandler 295 KernelFormatterHandler = KernelFormatterHandler 296 WrapperHandler = WrapperHandler 297 298 set_ops_handler: Callable[[Any], Any] = _ops._set_handler 299 get_ops_handler: Callable[[], Any] = _ops._get_handler 300 set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler 301 set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler 302 get_real_inputs: Callable[[], Any] = _real_inputs._get_handler 303 set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler 304 get_fake_mode: Callable[[], Any] = _fake_mode._get_handler 305 set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler 306 set_debug_handler: Callable[[Any], Any] = _debug._set_handler 307 set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler 308 set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler 309 get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler 310 set_current_node: Callable[[Any], Any] = _current_node._set_handler 311 get_current_node: Callable[[], Any] = _current_node._get_handler 312 set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler 313 get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler 314 315 @property 316 def ops(self) -> OpsHandler[Any]: 317 """The operator handler specific to the current codegen task""" 318 return _ops._get_handler() 319 320 @property 321 def graph(self) -> GraphLowering: 322 """The graph currently being generated""" 323 return _graph._get_handler() 324 325 @property 326 def real_inputs(self): 327 """non-fake example inputs""" 328 return _real_inputs._get_handler() 329 330 @property 331 def fake_mode(self): 332 """The graph currently being generated""" 333 return _fake_mode._get_handler() 334 335 @property 336 def kernel(self): 337 """The kernel currently being generated""" 338 return _kernel._get_handler() 339 340 @property 341 def debug(self): 342 return _debug._get_handler() 343 344 @property 345 def interpreter(self): 346 return _interpreter._get_handler() 347 348 @property 349 def aot_compilation(self): 350 return _aot_compilation._get_handler() 351 352 @property 353 def current_node(self): 354 return _current_node._get_handler() 355 356 @property 357 def local_buffer_context(self): 358 return _local_buffer_context._get_handler() 359 360 361V = _V() 362