xref: /aosp_15_r20/external/pytorch/torch/masked/maskedtensor/unary.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3
4import torch
5
6from .core import _map_mt_args_kwargs, _wrap_result
7
8
9__all__ = []  # type: ignore[var-annotated]
10
11
12UNARY_NAMES = [
13    "abs",
14    "absolute",
15    "acos",
16    "arccos",
17    "acosh",
18    "arccosh",
19    "angle",
20    "asin",
21    "arcsin",
22    "asinh",
23    "arcsinh",
24    "atan",
25    "arctan",
26    "atanh",
27    "arctanh",
28    "bitwise_not",
29    "ceil",
30    "clamp",
31    "clip",
32    "conj_physical",
33    "cos",
34    "cosh",
35    "deg2rad",
36    "digamma",
37    "erf",
38    "erfc",
39    "erfinv",
40    "exp",
41    "exp2",
42    "expm1",
43    "fix",
44    "floor",
45    "frac",
46    "lgamma",
47    "log",
48    "log10",
49    "log1p",
50    "log2",
51    "logit",
52    "i0",
53    "isnan",
54    "nan_to_num",
55    "neg",
56    "negative",
57    "positive",
58    "pow",
59    "rad2deg",
60    "reciprocal",
61    "round",
62    "rsqrt",
63    "sigmoid",
64    "sign",
65    "sgn",
66    "signbit",
67    "sin",
68    "sinc",
69    "sinh",
70    "sqrt",
71    "square",
72    "tan",
73    "tanh",
74    "trunc",
75]
76
77INPLACE_UNARY_NAMES = [
78    n + "_"
79    for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
80]
81
82# Explicitly tracking functions we know are currently not supported
83# This might be due to missing code gen or because of complex semantics
84UNARY_NAMES_UNSUPPORTED = [
85    "atan2",
86    "arctan2",
87    "bitwise_left_shift",
88    "bitwise_right_shift",
89    "copysign",
90    "float_power",
91    "fmod",
92    "frexp",
93    "gradient",
94    "imag",
95    "ldexp",
96    "lerp",
97    "logical_not",
98    "hypot",
99    "igamma",
100    "igammac",
101    "mvlgamma",
102    "nextafter",
103    "polygamma",
104    "real",
105    "remainder",
106    "true_divide",
107    "xlogy",
108]
109
110
111def _unary_helper(fn, args, kwargs, inplace):
112    if len(kwargs) != 0:
113        raise ValueError(
114            "MaskedTensor unary ops require that len(kwargs) == 0. "
115            "If you need support for this, please open an issue on Github."
116        )
117    for a in args[1:]:
118        if torch.is_tensor(a):
119            raise TypeError(
120                "MaskedTensor unary ops do not support additional Tensor arguments"
121            )
122
123    mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask)
124    data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data)
125
126    if args[0].layout == torch.sparse_coo:
127        data_args[0] = data_args[0].coalesce()
128        s = data_args[0].size()
129        i = data_args[0].indices()
130        data_args[0] = data_args[0].coalesce().values()
131        v = fn(*data_args)
132        result_data = torch.sparse_coo_tensor(i, v, size=s)
133
134    elif args[0].layout == torch.sparse_csr:
135        crow = data_args[0].crow_indices()
136        col = data_args[0].col_indices()
137        data_args[0] = data_args[0].values()
138        v = fn(*data_args)
139        result_data = torch.sparse_csr_tensor(crow, col, v)
140
141    else:
142        result_data = fn(*data_args)
143
144    if inplace:
145        args[0]._set_data_mask(result_data, mask_args[0])
146        return args[0]
147    else:
148        return _wrap_result(result_data, mask_args[0])
149
150
151def _torch_unary(fn_name):
152    fn = getattr(torch.ops.aten, fn_name)
153
154    def unary_fn(*args, **kwargs):
155        return _unary_helper(fn, args, kwargs, inplace=False)
156
157    return unary_fn
158
159
160def _torch_inplace_unary(fn_name):
161    fn = getattr(torch.ops.aten, fn_name)
162
163    def unary_fn(*args, **kwargs):
164        return _unary_helper(fn, args, kwargs, inplace=True)
165
166    return unary_fn
167
168
169NATIVE_UNARY_MAP = {
170    getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
171}
172NATIVE_INPLACE_UNARY_MAP = {
173    getattr(torch.ops.aten, name): _torch_inplace_unary(name)
174    for name in INPLACE_UNARY_NAMES
175}
176
177NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
178NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
179
180
181def _is_native_unary(fn):
182    return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
183
184
185def _apply_native_unary(fn, *args, **kwargs):
186    if fn in NATIVE_UNARY_FNS:
187        return NATIVE_UNARY_MAP[fn](*args, **kwargs)
188    if fn in NATIVE_INPLACE_UNARY_FNS:
189        return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
190    return NotImplemented
191