1# mypy: allow-untyped-defs 2import contextlib 3import functools 4from typing import Callable, Optional 5from typing_extensions import deprecated 6 7import torch 8from torch._library.utils import Kernel, RegistrationHandle 9 10 11class FakeImplHolder: 12 """A holder where one can register an fake impl to.""" 13 14 def __init__(self, qualname: str): 15 self.qualname: str = qualname 16 self.kernel: Optional[Kernel] = None 17 self.lib: Optional[torch.library.Library] = None 18 19 def register(self, func: Callable, source: str) -> RegistrationHandle: 20 """Register an fake impl. 21 22 Returns a RegistrationHandle that one can use to de-register this 23 fake impl. 24 """ 25 if self.kernel is not None: 26 raise RuntimeError( 27 f"register_fake(...): the operator {self.qualname} " 28 f"already has an fake impl registered at " 29 f"{self.kernel.source}." 30 ) 31 if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): 32 raise RuntimeError( 33 f"register_fake(...): the operator {self.qualname} " 34 f"already has an DispatchKey::Meta implementation via a " 35 f"pre-existing torch.library or TORCH_LIBRARY registration. " 36 f"Please either remove that registration or don't call " 37 f"register_fake." 38 ) 39 40 if torch._C._dispatch_has_kernel_for_dispatch_key( 41 self.qualname, "CompositeImplicitAutograd" 42 ): 43 raise RuntimeError( 44 f"register_fake(...): the operator {self.qualname} " 45 f"already has an implementation for this device type via a " 46 f"pre-existing registration to " 47 f"DispatchKey::CompositeImplicitAutograd." 48 f"CompositeImplicitAutograd operators do not need an fake " 49 f"impl; " 50 f"instead, the operator will decompose into its constituents " 51 f"and those " 52 f"can have fake impls defined on them." 53 ) 54 55 # Store the kernel in this holder 56 self.kernel = Kernel(func, source) 57 58 # Also register the fake impl to Meta key 59 if self.lib is None: 60 ns = self.qualname.split("::")[0] 61 self.lib = torch.library.Library(ns, "FRAGMENT") # noqa: TOR901 62 meta_kernel = construct_meta_kernel(self.qualname, self) 63 self.lib.impl(self.qualname, meta_kernel, "Meta") 64 65 def deregister_fake_class(): 66 if self.lib: 67 self.lib._destroy() 68 self.lib = None 69 self.kernel = None 70 71 return RegistrationHandle(deregister_fake_class) 72 73 74def construct_meta_kernel(qualname: str, fake_impl_holder: FakeImplHolder) -> Callable: 75 assert fake_impl_holder.kernel is not None 76 77 @functools.wraps(fake_impl_holder.kernel.func) 78 def meta_kernel(*args, **kwargs): 79 assert fake_impl_holder.kernel is not None 80 source = fake_impl_holder.kernel.source 81 82 def error_on_ctx(): 83 raise RuntimeError( 84 f"Attempted to call get_ctx() for the meta implementation " 85 f"for {qualname} (implemented at {source})" 86 f"You have presumably called get_ctx() because the operator " 87 f"has a data-dependent output shape; if so, there is no " 88 f"such meta implementation and this error is the correct " 89 f"behavior." 90 ) 91 92 with set_ctx_getter(error_on_ctx): 93 return fake_impl_holder.kernel(*args, **kwargs) 94 95 return meta_kernel 96 97 98def get_none(): 99 return None 100 101 102global_ctx_getter: Callable = get_none 103 104 105@contextlib.contextmanager 106def set_ctx_getter(ctx_getter): 107 global global_ctx_getter 108 prev = global_ctx_getter 109 try: 110 global_ctx_getter = ctx_getter 111 yield 112 finally: 113 global_ctx_getter = prev 114 115 116class FakeImplCtx: 117 """ 118 Context object for writing fake implementations for custom operators. 119 """ 120 121 def __init__(self, _fake_mode, _op): 122 self._fake_mode = _fake_mode 123 self._shape_env = _fake_mode.shape_env 124 self._op = _op 125 126 @deprecated( 127 "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead", 128 category=FutureWarning, 129 ) 130 def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: 131 return self.new_dynamic_size(min=min, max=max) 132 133 def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: 134 """Constructs a new symint (symbolic int) representing a data-dependent value. 135 136 This is useful for writing the fake implementation (which is necessary 137 for torch.compile) for a CustomOp where an output Tensor has a size 138 that depends on the data of the input Tensors. 139 140 Args: 141 min (int): A statically known inclusive lower bound for this symint. Default: 0 142 max (Optional[int]): A statically known inclusive upper bound for this 143 symint. Default: None 144 145 .. warning: 146 147 It is important that the ``min`` and ``max`` (if not None) values are set 148 correctly, otherwise, there will be undefined behavior under 149 torch.compile. The default value of ``min`` is 2 due to torch.compile 150 specializing on 0/1 sizes. 151 152 You must also verify that your implementation on concrete Tensors 153 (e.g. CPU/CUDA) only returns Tensors where the size that corresponds 154 to the symint also has respects these constraint. 155 The easiest way to do this is to add an assertion in the CPU/CUDA/etc 156 implementation that the size follows these bounds. 157 158 Example:: 159 160 >>> # An operator with data-dependent output shape 161 >>> lib = torch.library.Library("mymodule", "FRAGMENT") 162 >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor") 163 >>> 164 >>> @torch.library.register_fake("mymodule::custom_nonzero") 165 >>> def _(x): 166 >>> # Number of nonzero-elements is data-dependent. 167 >>> # Since we cannot peek at the data in an fake impl, 168 >>> # we use the ctx object to construct a new symint that 169 >>> # represents the data-dependent size. 170 >>> ctx = torch.library.get_ctx() 171 >>> nnz = ctx.new_dynamic_size() 172 >>> shape = [nnz, x.dim()] 173 >>> result = x.new_empty(shape, dtype=torch.int64) 174 >>> return result 175 >>> 176 >>> @torch.library.impl(lib, "custom_nonzero", "CPU") 177 >>> def _(x): 178 >>> x_np = x.numpy() 179 >>> res = np.stack(np.nonzero(x_np), axis=1) 180 >>> return torch.tensor(res, device=x.device) 181 182 """ 183 if ( 184 self._shape_env is None 185 or not self._shape_env.allow_dynamic_output_shape_ops 186 ): 187 raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op) 188 189 if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt): 190 raise ValueError( 191 f"ctx.new_dynamic_size(min={min}, max={max}): expected " 192 f"min and max to be statically known ints but got SymInt. " 193 f"This is not supported." 194 ) 195 196 if min < 0: 197 raise ValueError( 198 f"ctx.new_dynamic_size(min={min}, ...): expected min to be " 199 f"greater than or equal to 0: this API can only create " 200 f"non-negative sizes." 201 ) 202 203 result = self._shape_env.create_unbacked_symint() 204 torch.fx.experimental.symbolic_shapes._constrain_range_for_size( 205 result, min=min, max=max 206 ) 207 return result 208