xref: /aosp_15_r20/external/pytorch/functorch/dim/reference.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Copyright (c) Facebook, Inc. and its affiliates.
2*da0073e9SAndroid Build Coastguard Worker# All rights reserved.
3*da0073e9SAndroid Build Coastguard Worker#
4*da0073e9SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*da0073e9SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker# reference python implementations for C ops
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerfrom functorch._C import dim as _C
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerfrom . import op_properties
12*da0073e9SAndroid Build Coastguard Workerfrom .batch_tensor import _enable_layers
13*da0073e9SAndroid Build Coastguard Workerfrom .tree_map import tree_flatten, tree_map
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard WorkerDimList = _C.DimList
17*da0073e9SAndroid Build Coastguard Workerimport operator
18*da0073e9SAndroid Build Coastguard Workerfrom functools import reduce
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker# use dict to avoid writing C++ bindings for set
22*da0073e9SAndroid Build Coastguard Workerpointwise = set(op_properties.pointwise)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerdef prod(x):
26*da0073e9SAndroid Build Coastguard Worker    return reduce(operator.mul, x, 1)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerdef _wrap_dim(d, N, keepdim):
30*da0073e9SAndroid Build Coastguard Worker    from . import Dim
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    if isinstance(d, Dim):
33*da0073e9SAndroid Build Coastguard Worker        assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
34*da0073e9SAndroid Build Coastguard Worker        return d
35*da0073e9SAndroid Build Coastguard Worker    elif d >= 0:
36*da0073e9SAndroid Build Coastguard Worker        return d - N
37*da0073e9SAndroid Build Coastguard Worker    else:
38*da0073e9SAndroid Build Coastguard Worker        return d
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerdef _dims(d, N, keepdim, single_dim):
42*da0073e9SAndroid Build Coastguard Worker    from . import Dim
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    if isinstance(d, (Dim, int)):
45*da0073e9SAndroid Build Coastguard Worker        return ltuple((_wrap_dim(d, N, keepdim),))
46*da0073e9SAndroid Build Coastguard Worker    assert not single_dim, f"expected a single dimension or int but found: {d}"
47*da0073e9SAndroid Build Coastguard Worker    return ltuple(_wrap_dim(x, N, keepdim) for x in d)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerdef _bind_dims_to_size(lhs_size, rhs, lhs_debug):
51*da0073e9SAndroid Build Coastguard Worker    from . import DimensionMismatchError
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
54*da0073e9SAndroid Build Coastguard Worker    if len(not_bound) == 1:
55*da0073e9SAndroid Build Coastguard Worker        idx, d = not_bound[0]
56*da0073e9SAndroid Build Coastguard Worker        rhs_so_far = prod(r.size for r in rhs if r.is_bound)
57*da0073e9SAndroid Build Coastguard Worker        if lhs_size % rhs_so_far != 0:
58*da0073e9SAndroid Build Coastguard Worker            rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
59*da0073e9SAndroid Build Coastguard Worker            raise DimensionMismatchError(
60*da0073e9SAndroid Build Coastguard Worker                f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
61*da0073e9SAndroid Build Coastguard Worker            )
62*da0073e9SAndroid Build Coastguard Worker        new_size = lhs_size // rhs_so_far
63*da0073e9SAndroid Build Coastguard Worker        d.size = new_size
64*da0073e9SAndroid Build Coastguard Worker    elif len(not_bound) > 1:
65*da0073e9SAndroid Build Coastguard Worker        rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
66*da0073e9SAndroid Build Coastguard Worker        raise DimensionMismatchError(
67*da0073e9SAndroid Build Coastguard Worker            f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
68*da0073e9SAndroid Build Coastguard Worker        )
69*da0073e9SAndroid Build Coastguard Worker    else:
70*da0073e9SAndroid Build Coastguard Worker        rhs_size = prod(r.size for r in rhs)
71*da0073e9SAndroid Build Coastguard Worker        if lhs_size != rhs_size:
72*da0073e9SAndroid Build Coastguard Worker            raise DimensionMismatchError(
73*da0073e9SAndroid Build Coastguard Worker                f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
74*da0073e9SAndroid Build Coastguard Worker            )
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Workerdef _tensor_levels(inp):
78*da0073e9SAndroid Build Coastguard Worker    from . import _Tensor
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    if isinstance(inp, _Tensor):
81*da0073e9SAndroid Build Coastguard Worker        return inp._tensor, llist(inp._levels), inp._has_device
82*da0073e9SAndroid Build Coastguard Worker    else:
83*da0073e9SAndroid Build Coastguard Worker        return inp, llist(range(-inp.ndim, 0)), True
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Workerdef _match_levels(v, from_levels, to_levels):
87*da0073e9SAndroid Build Coastguard Worker    view = []
88*da0073e9SAndroid Build Coastguard Worker    permute = []
89*da0073e9SAndroid Build Coastguard Worker    requires_view = False
90*da0073e9SAndroid Build Coastguard Worker    size = v.size()
91*da0073e9SAndroid Build Coastguard Worker    for t in to_levels:
92*da0073e9SAndroid Build Coastguard Worker        try:
93*da0073e9SAndroid Build Coastguard Worker            idx = from_levels.index(t)
94*da0073e9SAndroid Build Coastguard Worker            permute.append(idx)
95*da0073e9SAndroid Build Coastguard Worker            view.append(size[idx])
96*da0073e9SAndroid Build Coastguard Worker        except ValueError:
97*da0073e9SAndroid Build Coastguard Worker            view.append(1)
98*da0073e9SAndroid Build Coastguard Worker            requires_view = True
99*da0073e9SAndroid Build Coastguard Worker    if permute != list(range(len(permute))):
100*da0073e9SAndroid Build Coastguard Worker        v = v.permute(*permute)
101*da0073e9SAndroid Build Coastguard Worker    if requires_view:
102*da0073e9SAndroid Build Coastguard Worker        v = v.view(*view)
103*da0073e9SAndroid Build Coastguard Worker    return v
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker# make a single dimension positional but do not permute it,
107*da0073e9SAndroid Build Coastguard Worker# used to do multi-tensor operators where the dim being acted on
108*da0073e9SAndroid Build Coastguard Worker# should not physically move if possible
109*da0073e9SAndroid Build Coastguard Workerdef _positional_no_permute(self, dim, expand_dim=False):
110*da0073e9SAndroid Build Coastguard Worker    from . import Tensor
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    ptensor, levels = self._tensor, llist(self._levels)
113*da0073e9SAndroid Build Coastguard Worker    try:
114*da0073e9SAndroid Build Coastguard Worker        idx = levels.index(dim)
115*da0073e9SAndroid Build Coastguard Worker    except ValueError:
116*da0073e9SAndroid Build Coastguard Worker        if not expand_dim:
117*da0073e9SAndroid Build Coastguard Worker            raise
118*da0073e9SAndroid Build Coastguard Worker        idx = 0
119*da0073e9SAndroid Build Coastguard Worker        ptensor = ptensor.expand(dim.size, *ptensor.size())
120*da0073e9SAndroid Build Coastguard Worker        levels.insert(0, 0)
121*da0073e9SAndroid Build Coastguard Worker    idx_batched = 0
122*da0073e9SAndroid Build Coastguard Worker    for i in range(idx):
123*da0073e9SAndroid Build Coastguard Worker        if isinstance(levels[i], int):
124*da0073e9SAndroid Build Coastguard Worker            levels[i] -= 1
125*da0073e9SAndroid Build Coastguard Worker            idx_batched += 1
126*da0073e9SAndroid Build Coastguard Worker    levels[idx] = -idx_batched - 1
127*da0073e9SAndroid Build Coastguard Worker    return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Workerdef seq(a, b):
131*da0073e9SAndroid Build Coastguard Worker    from . import Dim
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    if isinstance(a, Dim) != isinstance(b, Dim):
134*da0073e9SAndroid Build Coastguard Worker        return False
135*da0073e9SAndroid Build Coastguard Worker    if isinstance(a, Dim):
136*da0073e9SAndroid Build Coastguard Worker        return a is b
137*da0073e9SAndroid Build Coastguard Worker    else:
138*da0073e9SAndroid Build Coastguard Worker        return a == b
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Workerclass isin:
142*da0073e9SAndroid Build Coastguard Worker    def __contains__(self, item):
143*da0073e9SAndroid Build Coastguard Worker        for x in self:
144*da0073e9SAndroid Build Coastguard Worker            if seq(item, x):
145*da0073e9SAndroid Build Coastguard Worker                return True
146*da0073e9SAndroid Build Coastguard Worker        return False
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    def index(self, item):
149*da0073e9SAndroid Build Coastguard Worker        for i, x in enumerate(self):
150*da0073e9SAndroid Build Coastguard Worker            if seq(item, x):
151*da0073e9SAndroid Build Coastguard Worker                return i
152*da0073e9SAndroid Build Coastguard Worker        raise ValueError
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Workerclass llist(isin, list):
156*da0073e9SAndroid Build Coastguard Worker    pass
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Workerclass ltuple(isin, tuple):
160*da0073e9SAndroid Build Coastguard Worker    pass
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Workerempty_dict = {}
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker@classmethod
167*da0073e9SAndroid Build Coastguard Workerdef __torch_function__(self, orig, cls, args, kwargs=empty_dict):
168*da0073e9SAndroid Build Coastguard Worker    from . import _Tensor, Tensor, TensorLike
169*da0073e9SAndroid Build Coastguard Worker    from .delayed_mul_tensor import DelayedMulTensor
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    if orig is torch.Tensor.__mul__:
172*da0073e9SAndroid Build Coastguard Worker        lhs, rhs = args
173*da0073e9SAndroid Build Coastguard Worker        if (
174*da0073e9SAndroid Build Coastguard Worker            isinstance(lhs, _Tensor)
175*da0073e9SAndroid Build Coastguard Worker            and isinstance(rhs, _Tensor)
176*da0073e9SAndroid Build Coastguard Worker            and lhs.ndim == 0
177*da0073e9SAndroid Build Coastguard Worker            and rhs.ndim == 0
178*da0073e9SAndroid Build Coastguard Worker        ):
179*da0073e9SAndroid Build Coastguard Worker            return DelayedMulTensor(lhs, rhs)
180*da0073e9SAndroid Build Coastguard Worker    all_dims = llist()
181*da0073e9SAndroid Build Coastguard Worker    flat_args, unflatten = tree_flatten((args, kwargs))
182*da0073e9SAndroid Build Coastguard Worker    device_holding_tensor = None
183*da0073e9SAndroid Build Coastguard Worker    for f in flat_args:
184*da0073e9SAndroid Build Coastguard Worker        if isinstance(f, _Tensor):
185*da0073e9SAndroid Build Coastguard Worker            if f._has_device:
186*da0073e9SAndroid Build Coastguard Worker                device_holding_tensor = f._batchtensor
187*da0073e9SAndroid Build Coastguard Worker            for d in f.dims:
188*da0073e9SAndroid Build Coastguard Worker                if d not in all_dims:
189*da0073e9SAndroid Build Coastguard Worker                    all_dims.append(d)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    def unwrap(t):
192*da0073e9SAndroid Build Coastguard Worker        if isinstance(t, _Tensor):
193*da0073e9SAndroid Build Coastguard Worker            r = t._batchtensor
194*da0073e9SAndroid Build Coastguard Worker            if device_holding_tensor is not None and not t._has_device:
195*da0073e9SAndroid Build Coastguard Worker                r = r.to(device=device_holding_tensor.device)
196*da0073e9SAndroid Build Coastguard Worker            return r
197*da0073e9SAndroid Build Coastguard Worker        return t
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    if orig in pointwise:
200*da0073e9SAndroid Build Coastguard Worker        result_levels = llist()
201*da0073e9SAndroid Build Coastguard Worker        arg_levels = llist()
202*da0073e9SAndroid Build Coastguard Worker        to_expand = []
203*da0073e9SAndroid Build Coastguard Worker        for i, f in enumerate(flat_args):
204*da0073e9SAndroid Build Coastguard Worker            if isinstance(f, TensorLike):
205*da0073e9SAndroid Build Coastguard Worker                ptensor, levels, _ = _tensor_levels(f)
206*da0073e9SAndroid Build Coastguard Worker                if (
207*da0073e9SAndroid Build Coastguard Worker                    isinstance(f, _Tensor)
208*da0073e9SAndroid Build Coastguard Worker                    and not f._has_device
209*da0073e9SAndroid Build Coastguard Worker                    and device_holding_tensor is not None
210*da0073e9SAndroid Build Coastguard Worker                ):
211*da0073e9SAndroid Build Coastguard Worker                    ptensor = ptensor.to(device=device_holding_tensor.device)
212*da0073e9SAndroid Build Coastguard Worker                flat_args[i] = ptensor
213*da0073e9SAndroid Build Coastguard Worker                for l in levels:
214*da0073e9SAndroid Build Coastguard Worker                    if l not in result_levels:
215*da0073e9SAndroid Build Coastguard Worker                        result_levels.append(l)
216*da0073e9SAndroid Build Coastguard Worker                to_expand.append((i, levels))
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        for i, levels in to_expand:
219*da0073e9SAndroid Build Coastguard Worker            flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
220*da0073e9SAndroid Build Coastguard Worker        args, kwargs = unflatten(flat_args)
221*da0073e9SAndroid Build Coastguard Worker        result = orig(*args, **kwargs)
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker        def wrap(t):
224*da0073e9SAndroid Build Coastguard Worker            if isinstance(t, TensorLike):
225*da0073e9SAndroid Build Coastguard Worker                return Tensor.from_positional(
226*da0073e9SAndroid Build Coastguard Worker                    t, result_levels, device_holding_tensor is not None
227*da0073e9SAndroid Build Coastguard Worker                )
228*da0073e9SAndroid Build Coastguard Worker            return t
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        return tree_map(wrap, result)
231*da0073e9SAndroid Build Coastguard Worker    else:
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        def wrap(t):
234*da0073e9SAndroid Build Coastguard Worker            if isinstance(t, TensorLike):
235*da0073e9SAndroid Build Coastguard Worker                return Tensor.from_batched(t, device_holding_tensor is not None)
236*da0073e9SAndroid Build Coastguard Worker            return t
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        with _enable_layers(all_dims):
239*da0073e9SAndroid Build Coastguard Worker            print(f"batch_tensor for {orig}")
240*da0073e9SAndroid Build Coastguard Worker            args, kwargs = unflatten(unwrap(f) for f in flat_args)
241*da0073e9SAndroid Build Coastguard Worker            result = orig(*args, **kwargs)
242*da0073e9SAndroid Build Coastguard Worker            # print("END", orig)
243*da0073e9SAndroid Build Coastguard Worker            return tree_map(wrap, result)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Workerdef positional(self, *dims):
247*da0073e9SAndroid Build Coastguard Worker    from . import Dim, DimensionBindError, Tensor
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    ptensor, levels = self._tensor, llist(self._levels)
250*da0073e9SAndroid Build Coastguard Worker    flat_dims = llist()
251*da0073e9SAndroid Build Coastguard Worker    view = []
252*da0073e9SAndroid Build Coastguard Worker    needs_view = False
253*da0073e9SAndroid Build Coastguard Worker    ndim = self.ndim
254*da0073e9SAndroid Build Coastguard Worker    for d in dims:
255*da0073e9SAndroid Build Coastguard Worker        if isinstance(d, DimList):
256*da0073e9SAndroid Build Coastguard Worker            flat_dims.extend(d)
257*da0073e9SAndroid Build Coastguard Worker            view.extend(e.size for e in d)
258*da0073e9SAndroid Build Coastguard Worker        elif isinstance(d, Dim):
259*da0073e9SAndroid Build Coastguard Worker            flat_dims.append(d)
260*da0073e9SAndroid Build Coastguard Worker            view.append(d.size)
261*da0073e9SAndroid Build Coastguard Worker        elif isinstance(d, int):
262*da0073e9SAndroid Build Coastguard Worker            d = _wrap_dim(d, ndim, False)
263*da0073e9SAndroid Build Coastguard Worker            flat_dims.append(d)
264*da0073e9SAndroid Build Coastguard Worker            view.append(ptensor.size(d))
265*da0073e9SAndroid Build Coastguard Worker        else:
266*da0073e9SAndroid Build Coastguard Worker            flat_dims.extend(d)
267*da0073e9SAndroid Build Coastguard Worker            view.append(prod(e.size for e in d))
268*da0073e9SAndroid Build Coastguard Worker            needs_view = True
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker    permute = list(range(len(levels)))
271*da0073e9SAndroid Build Coastguard Worker    nflat = len(flat_dims)
272*da0073e9SAndroid Build Coastguard Worker    for i, d in enumerate(flat_dims):
273*da0073e9SAndroid Build Coastguard Worker        try:
274*da0073e9SAndroid Build Coastguard Worker            idx = levels.index(d)
275*da0073e9SAndroid Build Coastguard Worker        except ValueError as e:
276*da0073e9SAndroid Build Coastguard Worker            raise DimensionBindError(
277*da0073e9SAndroid Build Coastguard Worker                f"tensor of dimensions {self.dims} does not contain dim {d}"
278*da0073e9SAndroid Build Coastguard Worker            ) from e
279*da0073e9SAndroid Build Coastguard Worker        p = permute[idx]
280*da0073e9SAndroid Build Coastguard Worker        del levels[idx]
281*da0073e9SAndroid Build Coastguard Worker        del permute[idx]
282*da0073e9SAndroid Build Coastguard Worker        levels.insert(i, 0)
283*da0073e9SAndroid Build Coastguard Worker        permute.insert(i, p)
284*da0073e9SAndroid Build Coastguard Worker    ptensor = ptensor.permute(*permute)
285*da0073e9SAndroid Build Coastguard Worker    seen = 0
286*da0073e9SAndroid Build Coastguard Worker    for i in range(len(levels) - 1, -1, -1):
287*da0073e9SAndroid Build Coastguard Worker        if isinstance(levels[i], int):
288*da0073e9SAndroid Build Coastguard Worker            seen += 1
289*da0073e9SAndroid Build Coastguard Worker            levels[i] = -seen
290*da0073e9SAndroid Build Coastguard Worker    result = Tensor.from_positional(ptensor, levels, self._has_device)
291*da0073e9SAndroid Build Coastguard Worker    if needs_view:
292*da0073e9SAndroid Build Coastguard Worker        result = result.reshape(*view, *result.size()[len(flat_dims) :])
293*da0073e9SAndroid Build Coastguard Worker    return result
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Workerdef _contains_dim(input):
297*da0073e9SAndroid Build Coastguard Worker    from . import Dim
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    for i in input:
300*da0073e9SAndroid Build Coastguard Worker        if isinstance(i, Dim):
301*da0073e9SAndroid Build Coastguard Worker            return True
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Workerdef expand(self, *sizes):
305*da0073e9SAndroid Build Coastguard Worker    if not _contains_dim(sizes):
306*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
307*da0073e9SAndroid Build Coastguard Worker    dims = sizes
308*da0073e9SAndroid Build Coastguard Worker    sizes = [d.size for d in dims] + [-1] * self.ndim
309*da0073e9SAndroid Build Coastguard Worker    self = self.expand(*sizes)
310*da0073e9SAndroid Build Coastguard Worker    return self[dims]
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker_not_present = object()
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Workerdef _getarg(name, offset, args, kwargs, default):
317*da0073e9SAndroid Build Coastguard Worker    if len(args) > offset:
318*da0073e9SAndroid Build Coastguard Worker        return args[offset]
319*da0073e9SAndroid Build Coastguard Worker    return kwargs.get(name, default)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Workerdef _patcharg(name, offset, args, kwargs, value):
323*da0073e9SAndroid Build Coastguard Worker    if len(args) > offset:
324*da0073e9SAndroid Build Coastguard Worker        args[offset] = value
325*da0073e9SAndroid Build Coastguard Worker    else:
326*da0073e9SAndroid Build Coastguard Worker        kwargs[name] = value
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Workerdef _wrap(
330*da0073e9SAndroid Build Coastguard Worker    orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
331*da0073e9SAndroid Build Coastguard Worker):
332*da0073e9SAndroid Build Coastguard Worker    from . import Dim, Tensor, TensorLike
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker    def fn(self, *args, **kwargs):
335*da0073e9SAndroid Build Coastguard Worker        dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
336*da0073e9SAndroid Build Coastguard Worker        if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
337*da0073e9SAndroid Build Coastguard Worker            with _enable_layers(self.dims):
338*da0073e9SAndroid Build Coastguard Worker                print(f"dim fallback batch_tensor for {orig}")
339*da0073e9SAndroid Build Coastguard Worker                return Tensor.from_batched(
340*da0073e9SAndroid Build Coastguard Worker                    orig(self._batchtensor, *args, **kwargs), self._has_device
341*da0073e9SAndroid Build Coastguard Worker                )
342*da0073e9SAndroid Build Coastguard Worker        keepdim = (
343*da0073e9SAndroid Build Coastguard Worker            _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
344*da0073e9SAndroid Build Coastguard Worker        )
345*da0073e9SAndroid Build Coastguard Worker        t, levels = self._tensor, llist(self._levels)
346*da0073e9SAndroid Build Coastguard Worker        dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
347*da0073e9SAndroid Build Coastguard Worker        dim_indices = tuple(levels.index(d) for d in dims)
348*da0073e9SAndroid Build Coastguard Worker        if reduce and not keepdim:
349*da0073e9SAndroid Build Coastguard Worker            new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
350*da0073e9SAndroid Build Coastguard Worker        else:
351*da0073e9SAndroid Build Coastguard Worker            new_levels = levels
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker        if len(dim_indices) == 1:
354*da0073e9SAndroid Build Coastguard Worker            dim_indices = dim_indices[
355*da0073e9SAndroid Build Coastguard Worker                0
356*da0073e9SAndroid Build Coastguard Worker            ]  # so that dims that really only take a single argument work...
357*da0073e9SAndroid Build Coastguard Worker        args = list(args)
358*da0073e9SAndroid Build Coastguard Worker        _patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker        def wrap(t):
361*da0073e9SAndroid Build Coastguard Worker            if isinstance(t, TensorLike):
362*da0073e9SAndroid Build Coastguard Worker                return Tensor.from_positional(t, new_levels, self._has_device)
363*da0073e9SAndroid Build Coastguard Worker            return t
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        with _enable_layers(new_levels):
366*da0073e9SAndroid Build Coastguard Worker            print(f"dim used batch_tensor for {orig}")
367*da0073e9SAndroid Build Coastguard Worker            r = orig(t, *args, **kwargs)
368*da0073e9SAndroid Build Coastguard Worker            return tree_map(wrap, r)
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker    return fn
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Workerdef _def(name, *args, **kwargs):
374*da0073e9SAndroid Build Coastguard Worker    from . import _Tensor
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker    orig = getattr(torch.Tensor, name)
377*da0073e9SAndroid Build Coastguard Worker    setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Workerno_slice = slice(None)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker_orig_getitem = torch.Tensor.__getitem__
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Workerclass dim_tracker:
386*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
387*da0073e9SAndroid Build Coastguard Worker        self.dims = llist()
388*da0073e9SAndroid Build Coastguard Worker        self.count = []
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker    def record(self, d):
391*da0073e9SAndroid Build Coastguard Worker        if d not in self.dims:
392*da0073e9SAndroid Build Coastguard Worker            self.dims.append(d)
393*da0073e9SAndroid Build Coastguard Worker            self.count.append(1)
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, d):
396*da0073e9SAndroid Build Coastguard Worker        return self.count[self.dims.index(d)]
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Workerdef t__getitem__(self, input):
400*da0073e9SAndroid Build Coastguard Worker    from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker    # * bail to original example if we have a single non-Dim tensor, or a non-tensor
403*da0073e9SAndroid Build Coastguard Worker    # * locate ... or an unbound tensor list, and determine its size, bind dim list
404*da0073e9SAndroid Build Coastguard Worker    #   (remember that None does not count to the total dim count)
405*da0073e9SAndroid Build Coastguard Worker    # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
406*da0073e9SAndroid Build Coastguard Worker    #   produce the re-view if needed
407*da0073e9SAndroid Build Coastguard Worker    # * for each single-use dim index, replace with no_slice and mark that it will be added
408*da0073e9SAndroid Build Coastguard Worker    #   (keep track of whether we have to call super)
409*da0073e9SAndroid Build Coastguard Worker    # * call super if needed
410*da0073e9SAndroid Build Coastguard Worker    # * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
411*da0073e9SAndroid Build Coastguard Worker    # this handles bool indexing handling, as well as some other simple cases.
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker    is_simple = (
414*da0073e9SAndroid Build Coastguard Worker        not isinstance(input, Dim)
415*da0073e9SAndroid Build Coastguard Worker        and not isinstance(input, (tuple, list))
416*da0073e9SAndroid Build Coastguard Worker        and
417*da0073e9SAndroid Build Coastguard Worker        # WAR for functorch bug where zero time tensors in getitem are not handled correctly.
418*da0073e9SAndroid Build Coastguard Worker        not (isinstance(input, TensorLike) and input.ndim == 0)
419*da0073e9SAndroid Build Coastguard Worker    )
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker    if is_simple:
422*da0073e9SAndroid Build Coastguard Worker        if isinstance(self, _Tensor):
423*da0073e9SAndroid Build Coastguard Worker            return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
424*da0073e9SAndroid Build Coastguard Worker        else:
425*da0073e9SAndroid Build Coastguard Worker            return _orig_getitem(self, input)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker    # can further optimize this case
428*da0073e9SAndroid Build Coastguard Worker    if not isinstance(input, tuple):
429*da0073e9SAndroid Build Coastguard Worker        input = [input]
430*da0073e9SAndroid Build Coastguard Worker    else:
431*da0073e9SAndroid Build Coastguard Worker        input = list(input)
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    dims_indexed = 0
434*da0073e9SAndroid Build Coastguard Worker    expanding_object = None
435*da0073e9SAndroid Build Coastguard Worker    dimlists = []
436*da0073e9SAndroid Build Coastguard Worker    for i, s in enumerate(input):
437*da0073e9SAndroid Build Coastguard Worker        if s is ... or isinstance(s, DimList) and not s.is_bound:
438*da0073e9SAndroid Build Coastguard Worker            if expanding_object is not None:
439*da0073e9SAndroid Build Coastguard Worker                msg = (
440*da0073e9SAndroid Build Coastguard Worker                    "at most one ... or unbound dimension list can exist in indexing list but"
441*da0073e9SAndroid Build Coastguard Worker                    f" found 2 at offsets {i} and {expanding_object}"
442*da0073e9SAndroid Build Coastguard Worker                )
443*da0073e9SAndroid Build Coastguard Worker                raise DimensionBindError(msg)
444*da0073e9SAndroid Build Coastguard Worker            expanding_object = i
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker        if isinstance(s, DimList):
447*da0073e9SAndroid Build Coastguard Worker            dims_indexed += len(s) if s.is_bound else 0
448*da0073e9SAndroid Build Coastguard Worker            dimlists.append(i)
449*da0073e9SAndroid Build Coastguard Worker        elif s is not None and s is not ...:
450*da0073e9SAndroid Build Coastguard Worker            dims_indexed += 1
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker    ndim = self.ndim
453*da0073e9SAndroid Build Coastguard Worker    if dims_indexed > ndim:
454*da0073e9SAndroid Build Coastguard Worker        raise IndexError(
455*da0073e9SAndroid Build Coastguard Worker            f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
456*da0073e9SAndroid Build Coastguard Worker        )
457*da0073e9SAndroid Build Coastguard Worker    if expanding_object is not None:
458*da0073e9SAndroid Build Coastguard Worker        expanding_ndims = ndim - dims_indexed
459*da0073e9SAndroid Build Coastguard Worker        obj = input[expanding_object]
460*da0073e9SAndroid Build Coastguard Worker        if obj is ...:
461*da0073e9SAndroid Build Coastguard Worker            input[expanding_object : expanding_object + 1] = [
462*da0073e9SAndroid Build Coastguard Worker                no_slice
463*da0073e9SAndroid Build Coastguard Worker            ] * expanding_ndims
464*da0073e9SAndroid Build Coastguard Worker        else:
465*da0073e9SAndroid Build Coastguard Worker            obj.bind_len(expanding_ndims)
466*da0073e9SAndroid Build Coastguard Worker    # flatten the dimslists into the indexing
467*da0073e9SAndroid Build Coastguard Worker    for i in reversed(dimlists):
468*da0073e9SAndroid Build Coastguard Worker        input[i : i + 1] = input[i]
469*da0073e9SAndroid Build Coastguard Worker    dims_indexed = 0
470*da0073e9SAndroid Build Coastguard Worker    requires_view = False
471*da0073e9SAndroid Build Coastguard Worker    size = self.size()
472*da0073e9SAndroid Build Coastguard Worker    view_sizes = []
473*da0073e9SAndroid Build Coastguard Worker    dims_seen = dim_tracker()
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    def add_dims(t):
476*da0073e9SAndroid Build Coastguard Worker        if not isinstance(t, _Tensor):
477*da0073e9SAndroid Build Coastguard Worker            return
478*da0073e9SAndroid Build Coastguard Worker        for d in t.dims:
479*da0073e9SAndroid Build Coastguard Worker            dims_seen.record(d)
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker    add_dims(self)
482*da0073e9SAndroid Build Coastguard Worker    dim_packs = []
483*da0073e9SAndroid Build Coastguard Worker    for i, idx in enumerate(input):
484*da0073e9SAndroid Build Coastguard Worker        if idx is None:
485*da0073e9SAndroid Build Coastguard Worker            input[i] = no_slice
486*da0073e9SAndroid Build Coastguard Worker            view_sizes.append(1)
487*da0073e9SAndroid Build Coastguard Worker            requires_view = True
488*da0073e9SAndroid Build Coastguard Worker        else:
489*da0073e9SAndroid Build Coastguard Worker            sz = size[dims_indexed]
490*da0073e9SAndroid Build Coastguard Worker            if isinstance(idx, Dim):
491*da0073e9SAndroid Build Coastguard Worker                idx.size = sz
492*da0073e9SAndroid Build Coastguard Worker                dims_seen.record(idx)
493*da0073e9SAndroid Build Coastguard Worker                view_sizes.append(sz)
494*da0073e9SAndroid Build Coastguard Worker            elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
495*da0073e9SAndroid Build Coastguard Worker                for d in idx:
496*da0073e9SAndroid Build Coastguard Worker                    dims_seen.record(idx)
497*da0073e9SAndroid Build Coastguard Worker                _bind_dims_to_size(sz, idx, f"offset {i}")
498*da0073e9SAndroid Build Coastguard Worker                view_sizes.extend(d.size for d in idx)
499*da0073e9SAndroid Build Coastguard Worker                requires_view = True
500*da0073e9SAndroid Build Coastguard Worker                dim_packs.append(i)
501*da0073e9SAndroid Build Coastguard Worker            else:
502*da0073e9SAndroid Build Coastguard Worker                add_dims(idx)
503*da0073e9SAndroid Build Coastguard Worker                view_sizes.append(sz)
504*da0073e9SAndroid Build Coastguard Worker            dims_indexed += 1
505*da0073e9SAndroid Build Coastguard Worker    if requires_view:
506*da0073e9SAndroid Build Coastguard Worker        self = self.view(*view_sizes)
507*da0073e9SAndroid Build Coastguard Worker    for i in reversed(dim_packs):
508*da0073e9SAndroid Build Coastguard Worker        input[i : i + 1] = input[i]
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker    # currenty:
511*da0073e9SAndroid Build Coastguard Worker    # input is flat, containing either Dim, or Tensor, or something valid for standard indexing
512*da0073e9SAndroid Build Coastguard Worker    # self may have first-class dims as well.
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker    # to index:
515*da0073e9SAndroid Build Coastguard Worker    # drop the first class dims from self, they just become direct indices of their positions
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker    # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
518*da0073e9SAndroid Build Coastguard Worker    # these dimensions will appear and need to be bound at the first place tensor occures
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker    if isinstance(self, _Tensor):
521*da0073e9SAndroid Build Coastguard Worker        ptensor_self, levels = self._tensor, list(self._levels)
522*da0073e9SAndroid Build Coastguard Worker        # indices to ptensor rather than self which has first-class dimensions
523*da0073e9SAndroid Build Coastguard Worker        input_it = iter(input)
524*da0073e9SAndroid Build Coastguard Worker        flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
525*da0073e9SAndroid Build Coastguard Worker        has_device = self._has_device
526*da0073e9SAndroid Build Coastguard Worker        to_pad = 0
527*da0073e9SAndroid Build Coastguard Worker    else:
528*da0073e9SAndroid Build Coastguard Worker        ptensor_self, flat_inputs = self, input
529*da0073e9SAndroid Build Coastguard Worker        to_pad = ptensor_self.ndim - len(flat_inputs)
530*da0073e9SAndroid Build Coastguard Worker        has_device = True
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker    result_levels = []
533*da0073e9SAndroid Build Coastguard Worker    index_levels = []
534*da0073e9SAndroid Build Coastguard Worker    tensor_insert_point = None
535*da0073e9SAndroid Build Coastguard Worker    to_expand = {}
536*da0073e9SAndroid Build Coastguard Worker    requires_getindex = False
537*da0073e9SAndroid Build Coastguard Worker    for i, inp in enumerate(flat_inputs):
538*da0073e9SAndroid Build Coastguard Worker        if isinstance(inp, Dim) and dims_seen[inp] == 1:
539*da0073e9SAndroid Build Coastguard Worker            flat_inputs[i] = no_slice
540*da0073e9SAndroid Build Coastguard Worker            result_levels.append(inp)
541*da0073e9SAndroid Build Coastguard Worker        elif isinstance(inp, TensorLike):
542*da0073e9SAndroid Build Coastguard Worker            requires_getindex = True
543*da0073e9SAndroid Build Coastguard Worker            if tensor_insert_point is None:
544*da0073e9SAndroid Build Coastguard Worker                tensor_insert_point = len(result_levels)
545*da0073e9SAndroid Build Coastguard Worker            ptensor, levels, _ = _tensor_levels(inp)
546*da0073e9SAndroid Build Coastguard Worker            to_expand[i] = levels
547*da0073e9SAndroid Build Coastguard Worker            flat_inputs[i] = ptensor
548*da0073e9SAndroid Build Coastguard Worker            for l in levels:
549*da0073e9SAndroid Build Coastguard Worker                if l not in index_levels:
550*da0073e9SAndroid Build Coastguard Worker                    index_levels.append(l)
551*da0073e9SAndroid Build Coastguard Worker        else:
552*da0073e9SAndroid Build Coastguard Worker            requires_getindex = True
553*da0073e9SAndroid Build Coastguard Worker            result_levels.append(0)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    if tensor_insert_point is not None:
556*da0073e9SAndroid Build Coastguard Worker        result_levels[tensor_insert_point:tensor_insert_point] = index_levels
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker    for i, levels in to_expand.items():
559*da0073e9SAndroid Build Coastguard Worker        flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker    if requires_getindex:
562*da0073e9SAndroid Build Coastguard Worker        result = _orig_getitem(ptensor_self, flat_inputs)
563*da0073e9SAndroid Build Coastguard Worker    else:
564*da0073e9SAndroid Build Coastguard Worker        result = ptensor_self
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    next_positional = -1
567*da0073e9SAndroid Build Coastguard Worker    if to_pad > 0:
568*da0073e9SAndroid Build Coastguard Worker        result_levels.extend([0] * to_pad)
569*da0073e9SAndroid Build Coastguard Worker    for i, r in enumerate(reversed(result_levels)):
570*da0073e9SAndroid Build Coastguard Worker        if isinstance(r, int):
571*da0073e9SAndroid Build Coastguard Worker            result_levels[-1 - i] = next_positional
572*da0073e9SAndroid Build Coastguard Worker            next_positional -= 1
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker    return Tensor.from_positional(result, result_levels, has_device)
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker# XXX - dim is optional and can be the outer-most dimension...
578*da0073e9SAndroid Build Coastguard Workerdef stack(tensors, new_dim, dim=0, out=None):
579*da0073e9SAndroid Build Coastguard Worker    if isinstance(dim, int):
580*da0073e9SAndroid Build Coastguard Worker        return torch.stack(tensors, dim, out).index(dim, new_dim)
581*da0073e9SAndroid Build Coastguard Worker    index = None
582*da0073e9SAndroid Build Coastguard Worker    if out is not None:
583*da0073e9SAndroid Build Coastguard Worker        out, index = _positional_no_permute(out, dim, expand_dim=True)
584*da0073e9SAndroid Build Coastguard Worker    ptensors = []
585*da0073e9SAndroid Build Coastguard Worker    for t in tensors:
586*da0073e9SAndroid Build Coastguard Worker        pt, pi = _positional_no_permute(t, dim, expand_dim=True)
587*da0073e9SAndroid Build Coastguard Worker        if index is not None and pi != index:
588*da0073e9SAndroid Build Coastguard Worker            pt = pt.move_dim(pi, index)
589*da0073e9SAndroid Build Coastguard Worker        else:
590*da0073e9SAndroid Build Coastguard Worker            index = pi
591*da0073e9SAndroid Build Coastguard Worker        ptensors.append(pt)
592*da0073e9SAndroid Build Coastguard Worker    pr = torch.stack(ptensors, index, out=out)
593*da0073e9SAndroid Build Coastguard Worker    return pr.index((index, index + 1), (new_dim, dim))
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker_orig_split = torch.Tensor.split
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Workerdef split(self, split_size_or_sections, dim=0):
600*da0073e9SAndroid Build Coastguard Worker    from . import _Tensor, Dim
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker    if isinstance(split_size_or_sections, int) or any(
603*da0073e9SAndroid Build Coastguard Worker        isinstance(t, int) for t in split_size_or_sections
604*da0073e9SAndroid Build Coastguard Worker    ):
605*da0073e9SAndroid Build Coastguard Worker        if isinstance(dim, Dim):
606*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
607*da0073e9SAndroid Build Coastguard Worker                "when dim is specified as a Dim object, split sizes must also be dimensions."
608*da0073e9SAndroid Build Coastguard Worker            )
609*da0073e9SAndroid Build Coastguard Worker        return _orig_split(self, split_size_or_sections, dim=dim)
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker    if isinstance(dim, Dim):
612*da0073e9SAndroid Build Coastguard Worker        assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
613*da0073e9SAndroid Build Coastguard Worker        self, dim = _positional_no_permute(self, dim)
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker    size = self.size(dim)
616*da0073e9SAndroid Build Coastguard Worker    total_bound_size = 0
617*da0073e9SAndroid Build Coastguard Worker    unbound = []
618*da0073e9SAndroid Build Coastguard Worker    sizes = []
619*da0073e9SAndroid Build Coastguard Worker    for i, d in enumerate(split_size_or_sections):
620*da0073e9SAndroid Build Coastguard Worker        if d.is_bound:
621*da0073e9SAndroid Build Coastguard Worker            sizes.append(d.size)
622*da0073e9SAndroid Build Coastguard Worker            total_bound_size += d.size
623*da0073e9SAndroid Build Coastguard Worker        else:
624*da0073e9SAndroid Build Coastguard Worker            sizes.append(0)
625*da0073e9SAndroid Build Coastguard Worker            unbound.append(i)
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker    if unbound:
628*da0073e9SAndroid Build Coastguard Worker        assert (
629*da0073e9SAndroid Build Coastguard Worker            total_bound_size <= size
630*da0073e9SAndroid Build Coastguard Worker        ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
631*da0073e9SAndroid Build Coastguard Worker        remaining_size = size - total_bound_size
632*da0073e9SAndroid Build Coastguard Worker        chunk_size = -(-remaining_size // len(unbound))
633*da0073e9SAndroid Build Coastguard Worker        for u in unbound:
634*da0073e9SAndroid Build Coastguard Worker            sz = min(chunk_size, remaining_size)
635*da0073e9SAndroid Build Coastguard Worker            split_size_or_sections[u].size = sz
636*da0073e9SAndroid Build Coastguard Worker            sizes[u] = sz
637*da0073e9SAndroid Build Coastguard Worker            remaining_size -= sz
638*da0073e9SAndroid Build Coastguard Worker    else:
639*da0073e9SAndroid Build Coastguard Worker        assert (
640*da0073e9SAndroid Build Coastguard Worker            total_bound_size == size
641*da0073e9SAndroid Build Coastguard Worker        ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
642*da0073e9SAndroid Build Coastguard Worker    return tuple(
643*da0073e9SAndroid Build Coastguard Worker        t.index(dim, d)
644*da0073e9SAndroid Build Coastguard Worker        for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
645*da0073e9SAndroid Build Coastguard Worker    )
646