xref: /aosp_15_r20/external/pytorch/functorch/dim/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import dis
2import inspect
3from typing import Sequence, Union
4
5import functorch._C
6import torch
7from functorch._C import dim as _C
8
9from .tree_map import tree_flatten, tree_map
10from .wrap_type import wrap_type
11
12
13_C._patch_tensor_class()
14dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
15
16
17class DimensionMismatchError(Exception):
18    pass
19
20
21class DimensionBindError(Exception):
22    pass
23
24
25from . import op_properties
26
27
28# use dict to avoid writing C++ bindings for set
29pointwise = dict.fromkeys(op_properties.pointwise, True)
30
31use_c = True
32if not use_c:
33    from . import reference
34
35
36class _Tensor:
37    # fast path around slow wrapping/unwrapping logic for simply queries used
38    # by the implementation...
39
40    @property
41    def dims(self):
42        return tuple(d for d in self._levels if isinstance(d, Dim))
43
44    def dim(self):
45        return self.ndim
46
47    if use_c:
48        __torch_function__ = classmethod(_C.__torch_function__)
49        expand = _C._instancemethod(_C.expand)
50    else:
51        __torch_function__ = reference.__torch_function__
52        expand = reference.expand
53
54    index = _C._instancemethod(_C.index)
55
56    def __repr__(self):
57        tensor, levels, ndim = self._tensor, self._levels, self.ndim
58        return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
59
60
61TensorLike = (_Tensor, torch.Tensor)
62
63
64class Dim(_C.Dim, _Tensor):
65    # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
66    # Tensor defines format, but we want to print Dims with special formatting
67    __format__ = object.__format__
68
69
70class Tensor(_Tensor, _C.Tensor):
71    if not use_c:
72        from_batched = staticmethod(_C.Tensor_from_batched)
73    from_positional = staticmethod(_C.Tensor_from_positional)
74    sum = _C._instancemethod(_C.Tensor_sum)
75
76
77def cat(tensors, dim, new_dim):
78    n = dims()
79    return stack(tensors, n, dim).index([n, dim], new_dim)
80
81
82if use_c:
83    _wrap = _C._wrap
84
85    def _def(name, *args, **kwargs):
86        orig = getattr(torch.Tensor, name)
87        setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
88
89    t__getitem__ = _C._instancemethod(_C.__getitem__)
90    stack = _C.stack
91    split = _C._instancemethod(_C.split)
92else:
93    _wrap, _def = reference._wrap, reference._def
94    t__getitem__ = reference.t__getitem__
95    stack = reference.stack
96    split = reference.split
97
98# note: there is no python reference
99t__setitem__ = _C._instancemethod(_C.__setitem__)
100# this is patched in the C API because otherwise torch.Tensor will
101# no longer be considered a sequence and things will break
102# torch.Tensor.__getitem__ = t__getitem__
103
104_Tensor.__getitem__ = t__getitem__
105# torch.Tensor.__setitem__ = t__setitem__
106_Tensor.__setitem__ = t__setitem__
107
108torch.Tensor.split = split
109_Tensor.split = split
110torch.Tensor.expand = _C._instancemethod(_C.expand)
111torch.Tensor.index = _C._instancemethod(_C.index)
112wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
113del _Tensor.ndim
114
115if use_c:
116    _Tensor.order = _C._instancemethod(_C.order)
117else:
118    _Tensor.order = reference.positional
119
120_def("mean")
121_def("sum")
122_def("all")
123_def("amax")
124_def("amin")
125_def("aminmax")
126_def("any")
127_def("count_nonzero")
128_def("logsumexp")
129_def("nanmean")
130_def("nansum")
131_def("prod")
132_def("std", keepdim_offset=2)
133_def("var", keepdim_offset=2)
134_def("max", single_dim=True)
135_def("min", single_dim=True)
136_def("argmax", single_dim=True)
137_def("argmin", single_dim=True)
138_def("kthvalue", single_dim=True)
139_def("median", single_dim=True)
140_def("nanmedian", single_dim=True)
141_def("mode", single_dim=True)
142_def("sort", reduce=False)
143_def("argsort", reduce=False)
144_def("unbind", single_dim=True)
145_def("chunk", dim_offset=1, reduce=False)
146_def("cummax", single_dim=True, reduce=False)
147_def("cummin", single_dim=True, reduce=False)
148_def("cumprod", single_dim=True, reduce=False)
149_def("cumprod_", single_dim=True, reduce=False)
150_def("cumsum", single_dim=True, reduce=False)
151_def("cumsum_", single_dim=True, reduce=False)
152_def("logcumsumexp", single_dim=True, reduce=False)
153_def("renorm", dim_offset=1, single_dim=True, reduce=False)
154_def("softmax", single_dim=True, reduce=False)
155softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
156
157# stuff to handle in the future, because they require special
158# binding logic for dims
159# cross
160# diag_embed
161# diagonal
162# diagonal_scatter
163# diff
164# nanquantile
165# quantile
166# roll
167# rot90
168# topk (new dimes on output)
169# should these all be subsumed by inplace indexing?
170# index_add_
171# index_add
172# index_copy
173# index_copy_
174# index_fill
175# index_fill_
176# index_select
177# scatter
178# scatter_
179# scatter_add
180# scatter_add_
181# scatter_reduce
182