1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4import torch 5 6from .core import _map_mt_args_kwargs, _wrap_result 7 8 9__all__ = [] # type: ignore[var-annotated] 10 11 12UNARY_NAMES = [ 13 "abs", 14 "absolute", 15 "acos", 16 "arccos", 17 "acosh", 18 "arccosh", 19 "angle", 20 "asin", 21 "arcsin", 22 "asinh", 23 "arcsinh", 24 "atan", 25 "arctan", 26 "atanh", 27 "arctanh", 28 "bitwise_not", 29 "ceil", 30 "clamp", 31 "clip", 32 "conj_physical", 33 "cos", 34 "cosh", 35 "deg2rad", 36 "digamma", 37 "erf", 38 "erfc", 39 "erfinv", 40 "exp", 41 "exp2", 42 "expm1", 43 "fix", 44 "floor", 45 "frac", 46 "lgamma", 47 "log", 48 "log10", 49 "log1p", 50 "log2", 51 "logit", 52 "i0", 53 "isnan", 54 "nan_to_num", 55 "neg", 56 "negative", 57 "positive", 58 "pow", 59 "rad2deg", 60 "reciprocal", 61 "round", 62 "rsqrt", 63 "sigmoid", 64 "sign", 65 "sgn", 66 "signbit", 67 "sin", 68 "sinc", 69 "sinh", 70 "sqrt", 71 "square", 72 "tan", 73 "tanh", 74 "trunc", 75] 76 77INPLACE_UNARY_NAMES = [ 78 n + "_" 79 for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"})) 80] 81 82# Explicitly tracking functions we know are currently not supported 83# This might be due to missing code gen or because of complex semantics 84UNARY_NAMES_UNSUPPORTED = [ 85 "atan2", 86 "arctan2", 87 "bitwise_left_shift", 88 "bitwise_right_shift", 89 "copysign", 90 "float_power", 91 "fmod", 92 "frexp", 93 "gradient", 94 "imag", 95 "ldexp", 96 "lerp", 97 "logical_not", 98 "hypot", 99 "igamma", 100 "igammac", 101 "mvlgamma", 102 "nextafter", 103 "polygamma", 104 "real", 105 "remainder", 106 "true_divide", 107 "xlogy", 108] 109 110 111def _unary_helper(fn, args, kwargs, inplace): 112 if len(kwargs) != 0: 113 raise ValueError( 114 "MaskedTensor unary ops require that len(kwargs) == 0. " 115 "If you need support for this, please open an issue on Github." 116 ) 117 for a in args[1:]: 118 if torch.is_tensor(a): 119 raise TypeError( 120 "MaskedTensor unary ops do not support additional Tensor arguments" 121 ) 122 123 mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask) 124 data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data) 125 126 if args[0].layout == torch.sparse_coo: 127 data_args[0] = data_args[0].coalesce() 128 s = data_args[0].size() 129 i = data_args[0].indices() 130 data_args[0] = data_args[0].coalesce().values() 131 v = fn(*data_args) 132 result_data = torch.sparse_coo_tensor(i, v, size=s) 133 134 elif args[0].layout == torch.sparse_csr: 135 crow = data_args[0].crow_indices() 136 col = data_args[0].col_indices() 137 data_args[0] = data_args[0].values() 138 v = fn(*data_args) 139 result_data = torch.sparse_csr_tensor(crow, col, v) 140 141 else: 142 result_data = fn(*data_args) 143 144 if inplace: 145 args[0]._set_data_mask(result_data, mask_args[0]) 146 return args[0] 147 else: 148 return _wrap_result(result_data, mask_args[0]) 149 150 151def _torch_unary(fn_name): 152 fn = getattr(torch.ops.aten, fn_name) 153 154 def unary_fn(*args, **kwargs): 155 return _unary_helper(fn, args, kwargs, inplace=False) 156 157 return unary_fn 158 159 160def _torch_inplace_unary(fn_name): 161 fn = getattr(torch.ops.aten, fn_name) 162 163 def unary_fn(*args, **kwargs): 164 return _unary_helper(fn, args, kwargs, inplace=True) 165 166 return unary_fn 167 168 169NATIVE_UNARY_MAP = { 170 getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES 171} 172NATIVE_INPLACE_UNARY_MAP = { 173 getattr(torch.ops.aten, name): _torch_inplace_unary(name) 174 for name in INPLACE_UNARY_NAMES 175} 176 177NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys()) 178NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys()) 179 180 181def _is_native_unary(fn): 182 return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS 183 184 185def _apply_native_unary(fn, *args, **kwargs): 186 if fn in NATIVE_UNARY_FNS: 187 return NATIVE_UNARY_MAP[fn](*args, **kwargs) 188 if fn in NATIVE_INPLACE_UNARY_FNS: 189 return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs) 190 return NotImplemented 191