xref: /aosp_15_r20/external/pytorch/torch/_refs/_conversions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch._prims_common as utils
4
5# Utilities should come BEFORE this import
6from torch._decomp import register_decomposition
7from torch._prims_common import TensorLikeType
8from torch._prims_common.wrappers import out_wrapper
9from torch._refs import _broadcast_shapes
10
11
12# Data conversion references.
13#
14# Note: this module breaks the usual _refs to torch naming scheme where
15# _refs.foo.bar is a ref for torch.foo.bar.  The following definitions are not
16# part of _refs/__init__.py to avoid name clashes with Python builtin types
17# (like int).
18
19__all__ = [
20    # dtypes
21    "bfloat16",
22    "bool",
23    "byte",
24    "cdouble",
25    "cfloat",
26    "chalf",
27    "char",
28    "double",
29    "float",
30    "half",
31    "int",
32    "long",
33    "short",
34    # misc
35    "complex",
36    "polar",
37]
38
39
40def _make_conversion_method(name: str, dtype: torch.dtype):
41    def fn(
42        self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format
43    ) -> TensorLikeType:
44        return self.to(dtype, memory_format=memory_format)  # type: ignore[call-overload]
45
46    fn.__name__ = name
47    return fn
48
49
50bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16)
51
52bool = _make_conversion_method("bool", torch.bool)
53
54byte = _make_conversion_method("byte", torch.uint8)
55
56cdouble = _make_conversion_method("cdouble", torch.cdouble)
57
58cfloat = _make_conversion_method("cfloat", torch.cfloat)
59
60chalf = _make_conversion_method("chalf", torch.complex32)
61
62char = _make_conversion_method("char", torch.int8)
63
64double = _make_conversion_method("double", torch.double)
65
66float = _make_conversion_method("float", torch.float)
67
68half = _make_conversion_method("half", torch.half)
69
70int = _make_conversion_method("int", torch.int)
71
72long = _make_conversion_method("long", torch.long)
73
74short = _make_conversion_method("short", torch.short)
75
76
77@register_decomposition(torch._ops.ops.aten.complex)
78# Note: complex has type promotion tests disabled due to different semantics.
79# exact_dtype is for compat with complex_check_dtype from core.
80@out_wrapper(exact_dtype=True)
81def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
82    allowed_dtypes = (torch.float32, torch.float64, torch.float16)
83    torch._check(
84        real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
85        lambda: (
86            f"Expected both inputs to be Half, Float or Double tensors but got "
87            f"{real.dtype} and {imag.dtype}"
88        ),
89    )
90    torch._check(
91        real.dtype == imag.dtype,
92        lambda: (
93            f"Expected object of scalar type {real.dtype} but got "
94            f"scalar type {imag.dtype} for second argument"
95        ),
96    )
97    result_dtype = utils.corresponding_complex_dtype(real.dtype)  # type: ignore[arg-type]
98    common_shape = _broadcast_shapes(real.shape, imag.shape)
99    result = real.new_empty(
100        common_shape,
101        dtype=result_dtype,
102        layout=real.layout,
103        device=real.device,
104        # pin_memory=real.is_pinned(),  # NYI
105    )
106    result.real = real
107    result.imag = imag
108    return result
109
110
111@register_decomposition(torch._ops.ops.aten.polar)
112# Note: polar has type promotion tests disabled due to different semantics.
113# exact_dtype is for compat with complex_check_dtype from core.
114@out_wrapper(exact_dtype=True)
115def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType:
116    result = torch.complex(abs, angle)
117    result.real = abs * torch.cos(angle)
118    result.imag = abs * torch.sin(angle)
119    return result
120