1# mypy: allow-untyped-defs 2import torch 3import torch._prims_common as utils 4 5# Utilities should come BEFORE this import 6from torch._decomp import register_decomposition 7from torch._prims_common import TensorLikeType 8from torch._prims_common.wrappers import out_wrapper 9from torch._refs import _broadcast_shapes 10 11 12# Data conversion references. 13# 14# Note: this module breaks the usual _refs to torch naming scheme where 15# _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not 16# part of _refs/__init__.py to avoid name clashes with Python builtin types 17# (like int). 18 19__all__ = [ 20 # dtypes 21 "bfloat16", 22 "bool", 23 "byte", 24 "cdouble", 25 "cfloat", 26 "chalf", 27 "char", 28 "double", 29 "float", 30 "half", 31 "int", 32 "long", 33 "short", 34 # misc 35 "complex", 36 "polar", 37] 38 39 40def _make_conversion_method(name: str, dtype: torch.dtype): 41 def fn( 42 self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format 43 ) -> TensorLikeType: 44 return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload] 45 46 fn.__name__ = name 47 return fn 48 49 50bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16) 51 52bool = _make_conversion_method("bool", torch.bool) 53 54byte = _make_conversion_method("byte", torch.uint8) 55 56cdouble = _make_conversion_method("cdouble", torch.cdouble) 57 58cfloat = _make_conversion_method("cfloat", torch.cfloat) 59 60chalf = _make_conversion_method("chalf", torch.complex32) 61 62char = _make_conversion_method("char", torch.int8) 63 64double = _make_conversion_method("double", torch.double) 65 66float = _make_conversion_method("float", torch.float) 67 68half = _make_conversion_method("half", torch.half) 69 70int = _make_conversion_method("int", torch.int) 71 72long = _make_conversion_method("long", torch.long) 73 74short = _make_conversion_method("short", torch.short) 75 76 77@register_decomposition(torch._ops.ops.aten.complex) 78# Note: complex has type promotion tests disabled due to different semantics. 79# exact_dtype is for compat with complex_check_dtype from core. 80@out_wrapper(exact_dtype=True) 81def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: 82 allowed_dtypes = (torch.float32, torch.float64, torch.float16) 83 torch._check( 84 real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, 85 lambda: ( 86 f"Expected both inputs to be Half, Float or Double tensors but got " 87 f"{real.dtype} and {imag.dtype}" 88 ), 89 ) 90 torch._check( 91 real.dtype == imag.dtype, 92 lambda: ( 93 f"Expected object of scalar type {real.dtype} but got " 94 f"scalar type {imag.dtype} for second argument" 95 ), 96 ) 97 result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] 98 common_shape = _broadcast_shapes(real.shape, imag.shape) 99 result = real.new_empty( 100 common_shape, 101 dtype=result_dtype, 102 layout=real.layout, 103 device=real.device, 104 # pin_memory=real.is_pinned(), # NYI 105 ) 106 result.real = real 107 result.imag = imag 108 return result 109 110 111@register_decomposition(torch._ops.ops.aten.polar) 112# Note: polar has type promotion tests disabled due to different semantics. 113# exact_dtype is for compat with complex_check_dtype from core. 114@out_wrapper(exact_dtype=True) 115def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType: 116 result = torch.complex(abs, angle) 117 result.real = abs * torch.cos(angle) 118 result.imag = abs * torch.sin(angle) 119 return result 120