1import dis 2import inspect 3from typing import Sequence, Union 4 5import functorch._C 6import torch 7from functorch._C import dim as _C 8 9from .tree_map import tree_flatten, tree_map 10from .wrap_type import wrap_type 11 12 13_C._patch_tensor_class() 14dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists 15 16 17class DimensionMismatchError(Exception): 18 pass 19 20 21class DimensionBindError(Exception): 22 pass 23 24 25from . import op_properties 26 27 28# use dict to avoid writing C++ bindings for set 29pointwise = dict.fromkeys(op_properties.pointwise, True) 30 31use_c = True 32if not use_c: 33 from . import reference 34 35 36class _Tensor: 37 # fast path around slow wrapping/unwrapping logic for simply queries used 38 # by the implementation... 39 40 @property 41 def dims(self): 42 return tuple(d for d in self._levels if isinstance(d, Dim)) 43 44 def dim(self): 45 return self.ndim 46 47 if use_c: 48 __torch_function__ = classmethod(_C.__torch_function__) 49 expand = _C._instancemethod(_C.expand) 50 else: 51 __torch_function__ = reference.__torch_function__ 52 expand = reference.expand 53 54 index = _C._instancemethod(_C.index) 55 56 def __repr__(self): 57 tensor, levels, ndim = self._tensor, self._levels, self.ndim 58 return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}" 59 60 61TensorLike = (_Tensor, torch.Tensor) 62 63 64class Dim(_C.Dim, _Tensor): 65 # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence. 66 # Tensor defines format, but we want to print Dims with special formatting 67 __format__ = object.__format__ 68 69 70class Tensor(_Tensor, _C.Tensor): 71 if not use_c: 72 from_batched = staticmethod(_C.Tensor_from_batched) 73 from_positional = staticmethod(_C.Tensor_from_positional) 74 sum = _C._instancemethod(_C.Tensor_sum) 75 76 77def cat(tensors, dim, new_dim): 78 n = dims() 79 return stack(tensors, n, dim).index([n, dim], new_dim) 80 81 82if use_c: 83 _wrap = _C._wrap 84 85 def _def(name, *args, **kwargs): 86 orig = getattr(torch.Tensor, name) 87 setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) 88 89 t__getitem__ = _C._instancemethod(_C.__getitem__) 90 stack = _C.stack 91 split = _C._instancemethod(_C.split) 92else: 93 _wrap, _def = reference._wrap, reference._def 94 t__getitem__ = reference.t__getitem__ 95 stack = reference.stack 96 split = reference.split 97 98# note: there is no python reference 99t__setitem__ = _C._instancemethod(_C.__setitem__) 100# this is patched in the C API because otherwise torch.Tensor will 101# no longer be considered a sequence and things will break 102# torch.Tensor.__getitem__ = t__getitem__ 103 104_Tensor.__getitem__ = t__getitem__ 105# torch.Tensor.__setitem__ = t__setitem__ 106_Tensor.__setitem__ = t__setitem__ 107 108torch.Tensor.split = split 109_Tensor.split = split 110torch.Tensor.expand = _C._instancemethod(_C.expand) 111torch.Tensor.index = _C._instancemethod(_C.index) 112wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__) 113del _Tensor.ndim 114 115if use_c: 116 _Tensor.order = _C._instancemethod(_C.order) 117else: 118 _Tensor.order = reference.positional 119 120_def("mean") 121_def("sum") 122_def("all") 123_def("amax") 124_def("amin") 125_def("aminmax") 126_def("any") 127_def("count_nonzero") 128_def("logsumexp") 129_def("nanmean") 130_def("nansum") 131_def("prod") 132_def("std", keepdim_offset=2) 133_def("var", keepdim_offset=2) 134_def("max", single_dim=True) 135_def("min", single_dim=True) 136_def("argmax", single_dim=True) 137_def("argmin", single_dim=True) 138_def("kthvalue", single_dim=True) 139_def("median", single_dim=True) 140_def("nanmedian", single_dim=True) 141_def("mode", single_dim=True) 142_def("sort", reduce=False) 143_def("argsort", reduce=False) 144_def("unbind", single_dim=True) 145_def("chunk", dim_offset=1, reduce=False) 146_def("cummax", single_dim=True, reduce=False) 147_def("cummin", single_dim=True, reduce=False) 148_def("cumprod", single_dim=True, reduce=False) 149_def("cumprod_", single_dim=True, reduce=False) 150_def("cumsum", single_dim=True, reduce=False) 151_def("cumsum_", single_dim=True, reduce=False) 152_def("logcumsumexp", single_dim=True, reduce=False) 153_def("renorm", dim_offset=1, single_dim=True, reduce=False) 154_def("softmax", single_dim=True, reduce=False) 155softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False) 156 157# stuff to handle in the future, because they require special 158# binding logic for dims 159# cross 160# diag_embed 161# diagonal 162# diagonal_scatter 163# diff 164# nanquantile 165# quantile 166# roll 167# rot90 168# topk (new dimes on output) 169# should these all be subsumed by inplace indexing? 170# index_add_ 171# index_add 172# index_copy 173# index_copy_ 174# index_fill 175# index_fill_ 176# index_select 177# scatter 178# scatter_ 179# scatter_add 180# scatter_add_ 181# scatter_reduce 182