xref: /aosp_15_r20/external/pytorch/torch/_prims_common/wrappers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3import warnings
4from functools import wraps
5from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple, TypeVar
6from typing_extensions import ParamSpec
7
8import torch
9import torch._prims_common as utils
10from torch._prims_common import (
11    CustomOutParamAnnotation,
12    ELEMENTWISE_TYPE_PROMOTION_KIND,
13    Number,
14    NumberType,
15    ShapeType,
16    TensorLike,
17    TensorLikeType,
18)
19from torch.utils import _pytree as pytree
20from torch.utils._pytree import tree_flatten, tree_unflatten
21
22
23_T = TypeVar("_T")
24_P = ParamSpec("_P")
25
26
27@overload
28def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
29    pass
30
31
32@overload
33def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
34    pass
35
36
37@overload
38def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
39    pass
40
41
42@overload
43def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None:
44    pass
45
46
47# TODO: implement ref.cast with an option to enforce safe casting
48def _maybe_convert_to_dtype(a, dtype):
49    if isinstance(a, TensorLike):
50        if a.dtype != dtype:
51            return a.to(dtype)
52        return a
53    if isinstance(a, Number):
54        return utils.dtype_to_type_ctor(dtype)(a)  # type: ignore[arg-type]
55    if isinstance(a, Sequence):
56        return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
57    # Passthrough None because some functions wrapped with type promotion
58    # wrapper might have optional args
59    if a is None:
60        return None
61
62    raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!")
63
64
65def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
66    if not isinstance(a, Number):
67        msg = f"Found unknown type {type(a)} when trying to convert scalars!"
68        raise ValueError(msg)
69    if not utils.is_weakly_lesser_type(type(a), typ):
70        msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!"
71        raise ValueError(msg)
72
73    return typ(a)
74
75
76def _annotation_has_type(*, typ, annotation):
77    if hasattr(annotation, "__args__"):
78        for a in annotation.__args__:
79            if _annotation_has_type(typ=typ, annotation=a):
80                return True
81        return False
82
83    return typ is annotation
84
85
86class elementwise_type_promotion_wrapper:
87    """
88    Adds elementwise type promotion to a Python reference implementation.
89
90    Takes two kwargs, type_promoting_args and type_promotion_kind.
91
92    type_promoting_args must be a string Sequence specifiying the argument names of all
93    arguments that participate in type promotion (and should be type promoted). If the
94    arg specifies a Sequence-type then every element of the Sequence will participate in
95    type promotion.
96
97    type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
98    See its documentation for details.
99
100    The return_dtype will be coerced to the wrapped function's dtype arg if it is available and
101    not None.
102
103    Other type promotion behavior, like validating the Python type of scalar arguments, must
104    be handled separately.
105    """
106
107    def __init__(
108        self,
109        *,
110        type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
111        type_promoting_args: Optional[Sequence[str]] = None,
112    ):
113        self.type_promoting_arg_names = type_promoting_args
114        self.type_promotion_kind = type_promotion_kind
115
116    def __call__(self, fn: Callable) -> Callable:
117        sig = inspect.signature(fn)
118
119        @wraps(fn)
120        def _fn(*args, **kwargs):
121            bound = sig.bind(*args, **kwargs)
122            type_promoting_args = tuple(
123                bound.arguments[x]
124                for x in self.type_promoting_arg_names  # type: ignore[union-attr]
125                if x in bound.arguments.keys()
126            )
127
128            flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args)
129            compute_dtype, result_dtype = utils.elementwise_dtypes(
130                *flattened_type_promoting_args,
131                type_promotion_kind=self.type_promotion_kind,
132            )
133
134            promoted_args = {
135                x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
136                for x in self.type_promoting_arg_names  # type: ignore[union-attr]
137                if x in bound.arguments.keys()
138            }
139            bound.arguments.update(promoted_args)
140
141            result = fn(**bound.arguments)
142
143            # Override the return_dtype if a dtype arg is present and not None
144            if "dtype" in bound.arguments:
145                maybe_dtype = bound.arguments["dtype"]
146                if maybe_dtype:  # dtype cannot be None
147                    result_dtype = maybe_dtype
148
149            if isinstance(result, TensorLike):
150                return _maybe_convert_to_dtype(result, result_dtype)
151            if isinstance(result, Sequence):
152                return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
153            raise AssertionError(f"Unhandled result type: {type(result)}")
154
155        _fn.__signature__ = sig  # type: ignore[attr-defined]
156        return _fn
157
158
159# Returns True if resize is necessary
160def _resize_output_check(out: TensorLikeType, shape: ShapeType):
161    # If the shapes are correct there's nothing to do
162    if utils.same_shape(out.shape, shape):
163        return False
164    if out.numel() != 0:
165        msg = (
166            f"An output with one or more elements was resized since it had shape {str(out.shape)} "
167            "which does not match the required output shape {str(shape)}. "
168            "This behavior is deprecated, and in a future PyTorch release outputs will not "
169            "be resized unless they have zero elements. "
170            "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
171        )
172        warnings.warn(msg)
173    return True
174
175
176# TODO: handle tuples of tensors
177def _maybe_resize_out(
178    out: TensorLikeType,
179    shape: ShapeType,
180    memory_format: Optional[torch.memory_format] = None,
181):
182    if _resize_output_check(out, shape):
183        return out.resize_(shape, memory_format=memory_format)
184    else:
185        return out
186
187
188def is_cpu_scalar(x: TensorLikeType) -> bool:
189    return x.dim() == 0 and x.device.type == "cpu"
190
191
192def _safe_copy_out(
193    *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
194):
195    # Checks same device
196    if not is_cpu_scalar(copy_from) and copy_from.device != copy_to.device:
197        msg = (
198            f"Attempting to copy from device {copy_from.device} "
199            f"to device {copy_to.device}, but cross-device copies are not allowed!"
200        )
201        raise RuntimeError(msg)
202
203    # Checks safe cast
204    if exact_dtype:
205        torch._check(
206            copy_from.dtype == copy_to.dtype,
207            lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
208            f"but got {copy_to.dtype} instead",
209        )
210    else:
211        torch._check(
212            utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
213            lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
214            "but this can't be cast because it is not safe!",
215        )
216
217    return copy_to.copy_(copy_from)
218
219
220def out_wrapper(
221    *out_names: str,
222    exact_dtype: bool = False,
223    pass_is_out: bool = False,
224    preserve_memory_format: bool = False,
225) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
226    # The wrapped function needs to convert the output parameters to ensure
227    # compatibility between the Python API (which always uses "out" as the
228    # parameter name and may be a tuple) and the Aten API (which may have
229    # multiple output parameters and use different parameter names such as
230    # "grad_input", "indices" or "values".)
231
232    default_out_names = ("out",)
233    if len(out_names) == 0:
234        # Use default in out name
235        out_names = default_out_names
236
237    is_tensor = len(out_names) == 1
238
239    def maybe_compute_memory_format(t):
240        return utils.suggest_memory_format(t) if preserve_memory_format else None
241
242    def _out_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
243        """
244        Adds the out parameter to a Python reference.
245        """
246        out_type = (
247            TensorLikeType
248            if is_tensor
249            else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
250        )
251        return_type = (
252            TensorLikeType
253            if is_tensor
254            else NamedTuple(
255                f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
256            )
257        )
258
259        sig = inspect.signature(fn)
260        factory_kwargs = ("device", "dtype")
261        is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
262
263        @wraps(fn)
264        def _fn(*args: _P.args, out=None, **kwargs: _P.kwargs):
265            if is_factory_fn and out is not None:
266                for k in factory_kwargs:
267                    out_attr = getattr(out, k)
268                    if k not in kwargs:
269                        kwargs[k] = out_attr
270            if pass_is_out:
271                result = fn(*args, is_out=(out is not None), **kwargs)  # type: ignore[arg-type]
272            else:
273                result = fn(*args, **kwargs)
274            assert (
275                isinstance(result, TensorLike)
276                and is_tensor
277                or isinstance(result, Tuple)  # type: ignore[arg-type]
278                and len(result) == len(out_names)  # type: ignore[arg-type]
279            )
280            if out is not None:
281                # Naively you might expect this assert to be true, but
282                # it's not:
283                #
284                #   assert type(out) == type(result)
285                #
286                # The reason is that functions under this wrapper can
287                # get registered to the Meta dispatch key, and that
288                # means they can be executed in a context where tensor
289                # subclasses are disabled (with no_dispatch), which is a
290                # handy way for an is-a tensor subclass (e.g.,
291                # FakeTensor) to have the normal meta backend create a
292                # meta tensor, to be wrapped once it gets returned.
293                # In this situation, you will get a FakeTensor as
294                # the output tensor, but not the result--which will
295                # be a normal meta tensor, but this is perfectly
296                # harmless.
297                if is_tensor:
298                    assert isinstance(out, TensorLike)
299                    # These two operations are done in-place
300                    _maybe_resize_out(
301                        out, result.shape, maybe_compute_memory_format(result)  # type: ignore[union-attr]
302                    )
303                    _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype)  # type: ignore[arg-type]
304                else:
305                    assert isinstance(out, Tuple)  # type: ignore[arg-type]
306                    torch._check_type(
307                        len(out) == len(result),  # type: ignore[arg-type]
308                        lambda: f"expected tuple of {len(result)} elements but got {len(out)}",  # type: ignore[arg-type]
309                    )
310                    for r, o in zip(result, out):  # type: ignore[arg-type]
311                        # These two operations are done in-place
312                        _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r))
313                        _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype)  # type: ignore[arg-type]
314            else:
315                out = result
316            # mypy does not see through  the definition of out_type given that it's in a different scope
317            return out if is_tensor else return_type(*out)  # type: ignore[operator]
318
319        out_param = inspect.Parameter(
320            "out",
321            kind=inspect.Parameter.KEYWORD_ONLY,
322            default=None,
323            annotation=out_type,
324        )
325        # Mark that the function now returns a tuple
326        assert isinstance(sig.return_annotation, str) or sig.return_annotation in (
327            sig.empty,
328            out_type,
329        )
330        params = *sig.parameters.values(), out_param
331
332        # If there's a Parameter.VAR_KEYWORD parameter (like **kwds), it must appear
333        # after the out= parameter, which is Parameter.KEYWORD_ONLY. Sorting by
334        # Parameter.kind guarantees that all the parameters are in legal order.
335        params = sorted(params, key=lambda p: p.kind)
336
337        _fn.__signature__ = inspect.Signature(  # type: ignore[attr-defined]
338            parameters=params, return_annotation=return_type  # type: ignore[arg-type]
339        )
340
341        _fn.__annotations__ = dict(getattr(fn, "__annotations__", {}))
342        _fn.__annotations__["out"] = out_type
343        _fn.__annotations__["return"] = return_type
344
345        # In the special case of having a single tensor out parameter with a
346        # name other than out, add a special annotation to name the parameter
347        if is_tensor and out_names != default_out_names:
348            _fn.__annotations__[CustomOutParamAnnotation] = out_names[0]
349
350        # Add an indicator attribute that can be used in special cases
351        # where having a function wrapped by `out_wrapper` is not desirable e.g.
352        # jit
353        _fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper"  # type: ignore[attr-defined]
354
355        return _fn
356
357    return _out_wrapper
358
359
360def _maybe_remove_out_wrapper(fn: Callable):
361    return inspect.unwrap(
362        fn,
363        stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"),
364    )
365
366
367def backwards_not_supported(prim):
368    def redispatch_prim(args, kwargs):
369        with torch._C._AutoDispatchBelowAutograd():
370            old = torch._C._dispatch_tls_is_dispatch_key_excluded(
371                torch._C.DispatchKey.ADInplaceOrView
372            )
373            return prim(*args, **kwargs)
374
375    class BackwardsNotSupported(torch.autograd.Function):
376        @staticmethod
377        def forward(ctx, args_spec, *flat_args):
378            args, kwargs = tree_unflatten(flat_args, args_spec)  # type: ignore[arg-type]
379            return redispatch_prim(args, kwargs)
380
381        @staticmethod
382        def backward(ctx, *args):
383            raise RuntimeError("backwards not supported on prim")
384
385    @wraps(prim)
386    def _autograd_impl(*args, **kwargs):
387        flat_args, args_spec = tree_flatten((args, kwargs))
388        if torch.is_grad_enabled() and any(
389            a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)
390        ):
391            # TODO: There is a subtle bug here: prims like copy_to
392            # return their input argument after mutating it; and custom
393            # autograd function will incorrectly turn the result into
394            # a view which will fail test_python_ref_executor tests.
395            # At the moment, we sidestep this by observing that the
396            # unit tests don't ever try to run the executor with
397            # autograd, so we don't exercise the buggy case, but if
398            # you ever want to feed autograd through this, be aware
399            # of it!  We need a way of properly implementing autograd
400            # for mutating operations in Python to do this.
401            return BackwardsNotSupported.apply(args_spec, *flat_args)
402        else:
403            return redispatch_prim(args, kwargs)
404
405    return _autograd_impl
406
407
408# TODO: when tracing this will add torch tensors and not TensorMeta objects
409# to the trace -- we should fix this by adding a tracing context and NumberMeta classes
410# TODO: this wrapper is currently untested
411def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable:
412    """
413    Allows unary operators that accept tensors to work with Python numbers.
414    """
415    sig = inspect.signature(fn)
416
417    @wraps(fn)
418    def _fn(*args, **kwargs):
419        if len(args) > 0 and isinstance(args[0], Number):
420            dtype = utils.type_to_dtype(type(args[0]))
421            args_ = list(args)
422            args_[0] = torch.tensor(args[0], dtype=dtype)
423            result = fn(*args_, **kwargs)
424            assert isinstance(result, torch.Tensor)
425            return result.item()
426
427        return fn(*args, **kwargs)
428
429    _fn.__signature__ = sig  # type: ignore[attr-defined]
430    return _fn
431