xref: /aosp_15_r20/external/pytorch/torch/_library/fake_impl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import functools
4from typing import Callable, Optional
5from typing_extensions import deprecated
6
7import torch
8from torch._library.utils import Kernel, RegistrationHandle
9
10
11class FakeImplHolder:
12    """A holder where one can register an fake impl to."""
13
14    def __init__(self, qualname: str):
15        self.qualname: str = qualname
16        self.kernel: Optional[Kernel] = None
17        self.lib: Optional[torch.library.Library] = None
18
19    def register(self, func: Callable, source: str) -> RegistrationHandle:
20        """Register an fake impl.
21
22        Returns a RegistrationHandle that one can use to de-register this
23        fake impl.
24        """
25        if self.kernel is not None:
26            raise RuntimeError(
27                f"register_fake(...): the operator {self.qualname} "
28                f"already has an fake impl registered at "
29                f"{self.kernel.source}."
30            )
31        if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
32            raise RuntimeError(
33                f"register_fake(...): the operator {self.qualname} "
34                f"already has an DispatchKey::Meta implementation via a "
35                f"pre-existing torch.library or TORCH_LIBRARY registration. "
36                f"Please either remove that registration or don't call "
37                f"register_fake."
38            )
39
40        if torch._C._dispatch_has_kernel_for_dispatch_key(
41            self.qualname, "CompositeImplicitAutograd"
42        ):
43            raise RuntimeError(
44                f"register_fake(...): the operator {self.qualname} "
45                f"already has an implementation for this device type via a "
46                f"pre-existing registration to "
47                f"DispatchKey::CompositeImplicitAutograd."
48                f"CompositeImplicitAutograd operators do not need an fake "
49                f"impl; "
50                f"instead, the operator will decompose into its constituents "
51                f"and those "
52                f"can have fake impls defined on them."
53            )
54
55        # Store the kernel in this holder
56        self.kernel = Kernel(func, source)
57
58        # Also register the fake impl to Meta key
59        if self.lib is None:
60            ns = self.qualname.split("::")[0]
61            self.lib = torch.library.Library(ns, "FRAGMENT")  # noqa: TOR901
62        meta_kernel = construct_meta_kernel(self.qualname, self)
63        self.lib.impl(self.qualname, meta_kernel, "Meta")
64
65        def deregister_fake_class():
66            if self.lib:
67                self.lib._destroy()
68                self.lib = None
69            self.kernel = None
70
71        return RegistrationHandle(deregister_fake_class)
72
73
74def construct_meta_kernel(qualname: str, fake_impl_holder: FakeImplHolder) -> Callable:
75    assert fake_impl_holder.kernel is not None
76
77    @functools.wraps(fake_impl_holder.kernel.func)
78    def meta_kernel(*args, **kwargs):
79        assert fake_impl_holder.kernel is not None
80        source = fake_impl_holder.kernel.source
81
82        def error_on_ctx():
83            raise RuntimeError(
84                f"Attempted to call get_ctx() for the meta implementation "
85                f"for {qualname} (implemented at {source})"
86                f"You have presumably called get_ctx() because the operator "
87                f"has a data-dependent output shape; if so, there is no "
88                f"such meta implementation and this error is the correct "
89                f"behavior."
90            )
91
92        with set_ctx_getter(error_on_ctx):
93            return fake_impl_holder.kernel(*args, **kwargs)
94
95    return meta_kernel
96
97
98def get_none():
99    return None
100
101
102global_ctx_getter: Callable = get_none
103
104
105@contextlib.contextmanager
106def set_ctx_getter(ctx_getter):
107    global global_ctx_getter
108    prev = global_ctx_getter
109    try:
110        global_ctx_getter = ctx_getter
111        yield
112    finally:
113        global_ctx_getter = prev
114
115
116class FakeImplCtx:
117    """
118    Context object for writing fake implementations for custom operators.
119    """
120
121    def __init__(self, _fake_mode, _op):
122        self._fake_mode = _fake_mode
123        self._shape_env = _fake_mode.shape_env
124        self._op = _op
125
126    @deprecated(
127        "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead",
128        category=FutureWarning,
129    )
130    def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
131        return self.new_dynamic_size(min=min, max=max)
132
133    def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt:
134        """Constructs a new symint (symbolic int) representing a data-dependent value.
135
136        This is useful for writing the fake implementation (which is necessary
137        for torch.compile) for a CustomOp where an output Tensor has a size
138        that depends on the data of the input Tensors.
139
140        Args:
141            min (int): A statically known inclusive lower bound for this symint. Default: 0
142            max (Optional[int]): A statically known inclusive upper bound for this
143                symint. Default: None
144
145        .. warning:
146
147            It is important that the ``min`` and ``max`` (if not None) values are set
148            correctly, otherwise, there will be undefined behavior under
149            torch.compile. The default value of ``min`` is 2 due to torch.compile
150            specializing on 0/1 sizes.
151
152            You must also verify that your implementation on concrete Tensors
153            (e.g. CPU/CUDA) only returns Tensors where the size that corresponds
154            to the symint also has respects these constraint.
155            The easiest way to do this is to add an assertion in the CPU/CUDA/etc
156            implementation that the size follows these bounds.
157
158        Example::
159
160            >>> # An operator with data-dependent output shape
161            >>> lib = torch.library.Library("mymodule", "FRAGMENT")
162            >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor")
163            >>>
164            >>> @torch.library.register_fake("mymodule::custom_nonzero")
165            >>> def _(x):
166            >>>     # Number of nonzero-elements is data-dependent.
167            >>>     # Since we cannot peek at the data in an fake impl,
168            >>>     # we use the ctx object to construct a new symint that
169            >>>     # represents the data-dependent size.
170            >>>     ctx = torch.library.get_ctx()
171            >>>     nnz = ctx.new_dynamic_size()
172            >>>     shape = [nnz, x.dim()]
173            >>>     result = x.new_empty(shape, dtype=torch.int64)
174            >>>     return result
175            >>>
176            >>> @torch.library.impl(lib, "custom_nonzero", "CPU")
177            >>> def _(x):
178            >>>     x_np = x.numpy()
179            >>>     res = np.stack(np.nonzero(x_np), axis=1)
180            >>>     return torch.tensor(res, device=x.device)
181
182        """
183        if (
184            self._shape_env is None
185            or not self._shape_env.allow_dynamic_output_shape_ops
186        ):
187            raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op)
188
189        if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt):
190            raise ValueError(
191                f"ctx.new_dynamic_size(min={min}, max={max}): expected "
192                f"min and max to be statically known ints but got SymInt. "
193                f"This is not supported."
194            )
195
196        if min < 0:
197            raise ValueError(
198                f"ctx.new_dynamic_size(min={min}, ...): expected min to be "
199                f"greater than or equal to 0: this API can only create "
200                f"non-negative sizes."
201            )
202
203        result = self._shape_env.create_unbacked_symint()
204        torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
205            result, min=min, max=max
206        )
207        return result
208