1*da0073e9SAndroid Build Coastguard Worker# Copyright (c) Facebook, Inc. and its affiliates. 2*da0073e9SAndroid Build Coastguard Worker# All rights reserved. 3*da0073e9SAndroid Build Coastguard Worker# 4*da0073e9SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*da0073e9SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker# reference python implementations for C ops 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerfrom functorch._C import dim as _C 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerfrom . import op_properties 12*da0073e9SAndroid Build Coastguard Workerfrom .batch_tensor import _enable_layers 13*da0073e9SAndroid Build Coastguard Workerfrom .tree_map import tree_flatten, tree_map 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard WorkerDimList = _C.DimList 17*da0073e9SAndroid Build Coastguard Workerimport operator 18*da0073e9SAndroid Build Coastguard Workerfrom functools import reduce 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker# use dict to avoid writing C++ bindings for set 22*da0073e9SAndroid Build Coastguard Workerpointwise = set(op_properties.pointwise) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerdef prod(x): 26*da0073e9SAndroid Build Coastguard Worker return reduce(operator.mul, x, 1) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerdef _wrap_dim(d, N, keepdim): 30*da0073e9SAndroid Build Coastguard Worker from . import Dim 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker if isinstance(d, Dim): 33*da0073e9SAndroid Build Coastguard Worker assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" 34*da0073e9SAndroid Build Coastguard Worker return d 35*da0073e9SAndroid Build Coastguard Worker elif d >= 0: 36*da0073e9SAndroid Build Coastguard Worker return d - N 37*da0073e9SAndroid Build Coastguard Worker else: 38*da0073e9SAndroid Build Coastguard Worker return d 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerdef _dims(d, N, keepdim, single_dim): 42*da0073e9SAndroid Build Coastguard Worker from . import Dim 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker if isinstance(d, (Dim, int)): 45*da0073e9SAndroid Build Coastguard Worker return ltuple((_wrap_dim(d, N, keepdim),)) 46*da0073e9SAndroid Build Coastguard Worker assert not single_dim, f"expected a single dimension or int but found: {d}" 47*da0073e9SAndroid Build Coastguard Worker return ltuple(_wrap_dim(x, N, keepdim) for x in d) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerdef _bind_dims_to_size(lhs_size, rhs, lhs_debug): 51*da0073e9SAndroid Build Coastguard Worker from . import DimensionMismatchError 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) 54*da0073e9SAndroid Build Coastguard Worker if len(not_bound) == 1: 55*da0073e9SAndroid Build Coastguard Worker idx, d = not_bound[0] 56*da0073e9SAndroid Build Coastguard Worker rhs_so_far = prod(r.size for r in rhs if r.is_bound) 57*da0073e9SAndroid Build Coastguard Worker if lhs_size % rhs_so_far != 0: 58*da0073e9SAndroid Build Coastguard Worker rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) 59*da0073e9SAndroid Build Coastguard Worker raise DimensionMismatchError( 60*da0073e9SAndroid Build Coastguard Worker f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}" 61*da0073e9SAndroid Build Coastguard Worker ) 62*da0073e9SAndroid Build Coastguard Worker new_size = lhs_size // rhs_so_far 63*da0073e9SAndroid Build Coastguard Worker d.size = new_size 64*da0073e9SAndroid Build Coastguard Worker elif len(not_bound) > 1: 65*da0073e9SAndroid Build Coastguard Worker rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) 66*da0073e9SAndroid Build Coastguard Worker raise DimensionMismatchError( 67*da0073e9SAndroid Build Coastguard Worker f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}" 68*da0073e9SAndroid Build Coastguard Worker ) 69*da0073e9SAndroid Build Coastguard Worker else: 70*da0073e9SAndroid Build Coastguard Worker rhs_size = prod(r.size for r in rhs) 71*da0073e9SAndroid Build Coastguard Worker if lhs_size != rhs_size: 72*da0073e9SAndroid Build Coastguard Worker raise DimensionMismatchError( 73*da0073e9SAndroid Build Coastguard Worker f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}" 74*da0073e9SAndroid Build Coastguard Worker ) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Workerdef _tensor_levels(inp): 78*da0073e9SAndroid Build Coastguard Worker from . import _Tensor 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker if isinstance(inp, _Tensor): 81*da0073e9SAndroid Build Coastguard Worker return inp._tensor, llist(inp._levels), inp._has_device 82*da0073e9SAndroid Build Coastguard Worker else: 83*da0073e9SAndroid Build Coastguard Worker return inp, llist(range(-inp.ndim, 0)), True 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Workerdef _match_levels(v, from_levels, to_levels): 87*da0073e9SAndroid Build Coastguard Worker view = [] 88*da0073e9SAndroid Build Coastguard Worker permute = [] 89*da0073e9SAndroid Build Coastguard Worker requires_view = False 90*da0073e9SAndroid Build Coastguard Worker size = v.size() 91*da0073e9SAndroid Build Coastguard Worker for t in to_levels: 92*da0073e9SAndroid Build Coastguard Worker try: 93*da0073e9SAndroid Build Coastguard Worker idx = from_levels.index(t) 94*da0073e9SAndroid Build Coastguard Worker permute.append(idx) 95*da0073e9SAndroid Build Coastguard Worker view.append(size[idx]) 96*da0073e9SAndroid Build Coastguard Worker except ValueError: 97*da0073e9SAndroid Build Coastguard Worker view.append(1) 98*da0073e9SAndroid Build Coastguard Worker requires_view = True 99*da0073e9SAndroid Build Coastguard Worker if permute != list(range(len(permute))): 100*da0073e9SAndroid Build Coastguard Worker v = v.permute(*permute) 101*da0073e9SAndroid Build Coastguard Worker if requires_view: 102*da0073e9SAndroid Build Coastguard Worker v = v.view(*view) 103*da0073e9SAndroid Build Coastguard Worker return v 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker# make a single dimension positional but do not permute it, 107*da0073e9SAndroid Build Coastguard Worker# used to do multi-tensor operators where the dim being acted on 108*da0073e9SAndroid Build Coastguard Worker# should not physically move if possible 109*da0073e9SAndroid Build Coastguard Workerdef _positional_no_permute(self, dim, expand_dim=False): 110*da0073e9SAndroid Build Coastguard Worker from . import Tensor 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker ptensor, levels = self._tensor, llist(self._levels) 113*da0073e9SAndroid Build Coastguard Worker try: 114*da0073e9SAndroid Build Coastguard Worker idx = levels.index(dim) 115*da0073e9SAndroid Build Coastguard Worker except ValueError: 116*da0073e9SAndroid Build Coastguard Worker if not expand_dim: 117*da0073e9SAndroid Build Coastguard Worker raise 118*da0073e9SAndroid Build Coastguard Worker idx = 0 119*da0073e9SAndroid Build Coastguard Worker ptensor = ptensor.expand(dim.size, *ptensor.size()) 120*da0073e9SAndroid Build Coastguard Worker levels.insert(0, 0) 121*da0073e9SAndroid Build Coastguard Worker idx_batched = 0 122*da0073e9SAndroid Build Coastguard Worker for i in range(idx): 123*da0073e9SAndroid Build Coastguard Worker if isinstance(levels[i], int): 124*da0073e9SAndroid Build Coastguard Worker levels[i] -= 1 125*da0073e9SAndroid Build Coastguard Worker idx_batched += 1 126*da0073e9SAndroid Build Coastguard Worker levels[idx] = -idx_batched - 1 127*da0073e9SAndroid Build Coastguard Worker return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Workerdef seq(a, b): 131*da0073e9SAndroid Build Coastguard Worker from . import Dim 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker if isinstance(a, Dim) != isinstance(b, Dim): 134*da0073e9SAndroid Build Coastguard Worker return False 135*da0073e9SAndroid Build Coastguard Worker if isinstance(a, Dim): 136*da0073e9SAndroid Build Coastguard Worker return a is b 137*da0073e9SAndroid Build Coastguard Worker else: 138*da0073e9SAndroid Build Coastguard Worker return a == b 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Workerclass isin: 142*da0073e9SAndroid Build Coastguard Worker def __contains__(self, item): 143*da0073e9SAndroid Build Coastguard Worker for x in self: 144*da0073e9SAndroid Build Coastguard Worker if seq(item, x): 145*da0073e9SAndroid Build Coastguard Worker return True 146*da0073e9SAndroid Build Coastguard Worker return False 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker def index(self, item): 149*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(self): 150*da0073e9SAndroid Build Coastguard Worker if seq(item, x): 151*da0073e9SAndroid Build Coastguard Worker return i 152*da0073e9SAndroid Build Coastguard Worker raise ValueError 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Workerclass llist(isin, list): 156*da0073e9SAndroid Build Coastguard Worker pass 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Workerclass ltuple(isin, tuple): 160*da0073e9SAndroid Build Coastguard Worker pass 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Workerempty_dict = {} 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker@classmethod 167*da0073e9SAndroid Build Coastguard Workerdef __torch_function__(self, orig, cls, args, kwargs=empty_dict): 168*da0073e9SAndroid Build Coastguard Worker from . import _Tensor, Tensor, TensorLike 169*da0073e9SAndroid Build Coastguard Worker from .delayed_mul_tensor import DelayedMulTensor 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker if orig is torch.Tensor.__mul__: 172*da0073e9SAndroid Build Coastguard Worker lhs, rhs = args 173*da0073e9SAndroid Build Coastguard Worker if ( 174*da0073e9SAndroid Build Coastguard Worker isinstance(lhs, _Tensor) 175*da0073e9SAndroid Build Coastguard Worker and isinstance(rhs, _Tensor) 176*da0073e9SAndroid Build Coastguard Worker and lhs.ndim == 0 177*da0073e9SAndroid Build Coastguard Worker and rhs.ndim == 0 178*da0073e9SAndroid Build Coastguard Worker ): 179*da0073e9SAndroid Build Coastguard Worker return DelayedMulTensor(lhs, rhs) 180*da0073e9SAndroid Build Coastguard Worker all_dims = llist() 181*da0073e9SAndroid Build Coastguard Worker flat_args, unflatten = tree_flatten((args, kwargs)) 182*da0073e9SAndroid Build Coastguard Worker device_holding_tensor = None 183*da0073e9SAndroid Build Coastguard Worker for f in flat_args: 184*da0073e9SAndroid Build Coastguard Worker if isinstance(f, _Tensor): 185*da0073e9SAndroid Build Coastguard Worker if f._has_device: 186*da0073e9SAndroid Build Coastguard Worker device_holding_tensor = f._batchtensor 187*da0073e9SAndroid Build Coastguard Worker for d in f.dims: 188*da0073e9SAndroid Build Coastguard Worker if d not in all_dims: 189*da0073e9SAndroid Build Coastguard Worker all_dims.append(d) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker def unwrap(t): 192*da0073e9SAndroid Build Coastguard Worker if isinstance(t, _Tensor): 193*da0073e9SAndroid Build Coastguard Worker r = t._batchtensor 194*da0073e9SAndroid Build Coastguard Worker if device_holding_tensor is not None and not t._has_device: 195*da0073e9SAndroid Build Coastguard Worker r = r.to(device=device_holding_tensor.device) 196*da0073e9SAndroid Build Coastguard Worker return r 197*da0073e9SAndroid Build Coastguard Worker return t 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker if orig in pointwise: 200*da0073e9SAndroid Build Coastguard Worker result_levels = llist() 201*da0073e9SAndroid Build Coastguard Worker arg_levels = llist() 202*da0073e9SAndroid Build Coastguard Worker to_expand = [] 203*da0073e9SAndroid Build Coastguard Worker for i, f in enumerate(flat_args): 204*da0073e9SAndroid Build Coastguard Worker if isinstance(f, TensorLike): 205*da0073e9SAndroid Build Coastguard Worker ptensor, levels, _ = _tensor_levels(f) 206*da0073e9SAndroid Build Coastguard Worker if ( 207*da0073e9SAndroid Build Coastguard Worker isinstance(f, _Tensor) 208*da0073e9SAndroid Build Coastguard Worker and not f._has_device 209*da0073e9SAndroid Build Coastguard Worker and device_holding_tensor is not None 210*da0073e9SAndroid Build Coastguard Worker ): 211*da0073e9SAndroid Build Coastguard Worker ptensor = ptensor.to(device=device_holding_tensor.device) 212*da0073e9SAndroid Build Coastguard Worker flat_args[i] = ptensor 213*da0073e9SAndroid Build Coastguard Worker for l in levels: 214*da0073e9SAndroid Build Coastguard Worker if l not in result_levels: 215*da0073e9SAndroid Build Coastguard Worker result_levels.append(l) 216*da0073e9SAndroid Build Coastguard Worker to_expand.append((i, levels)) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker for i, levels in to_expand: 219*da0073e9SAndroid Build Coastguard Worker flat_args[i] = _match_levels(flat_args[i], levels, result_levels) 220*da0073e9SAndroid Build Coastguard Worker args, kwargs = unflatten(flat_args) 221*da0073e9SAndroid Build Coastguard Worker result = orig(*args, **kwargs) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker def wrap(t): 224*da0073e9SAndroid Build Coastguard Worker if isinstance(t, TensorLike): 225*da0073e9SAndroid Build Coastguard Worker return Tensor.from_positional( 226*da0073e9SAndroid Build Coastguard Worker t, result_levels, device_holding_tensor is not None 227*da0073e9SAndroid Build Coastguard Worker ) 228*da0073e9SAndroid Build Coastguard Worker return t 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker return tree_map(wrap, result) 231*da0073e9SAndroid Build Coastguard Worker else: 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker def wrap(t): 234*da0073e9SAndroid Build Coastguard Worker if isinstance(t, TensorLike): 235*da0073e9SAndroid Build Coastguard Worker return Tensor.from_batched(t, device_holding_tensor is not None) 236*da0073e9SAndroid Build Coastguard Worker return t 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker with _enable_layers(all_dims): 239*da0073e9SAndroid Build Coastguard Worker print(f"batch_tensor for {orig}") 240*da0073e9SAndroid Build Coastguard Worker args, kwargs = unflatten(unwrap(f) for f in flat_args) 241*da0073e9SAndroid Build Coastguard Worker result = orig(*args, **kwargs) 242*da0073e9SAndroid Build Coastguard Worker # print("END", orig) 243*da0073e9SAndroid Build Coastguard Worker return tree_map(wrap, result) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Workerdef positional(self, *dims): 247*da0073e9SAndroid Build Coastguard Worker from . import Dim, DimensionBindError, Tensor 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker ptensor, levels = self._tensor, llist(self._levels) 250*da0073e9SAndroid Build Coastguard Worker flat_dims = llist() 251*da0073e9SAndroid Build Coastguard Worker view = [] 252*da0073e9SAndroid Build Coastguard Worker needs_view = False 253*da0073e9SAndroid Build Coastguard Worker ndim = self.ndim 254*da0073e9SAndroid Build Coastguard Worker for d in dims: 255*da0073e9SAndroid Build Coastguard Worker if isinstance(d, DimList): 256*da0073e9SAndroid Build Coastguard Worker flat_dims.extend(d) 257*da0073e9SAndroid Build Coastguard Worker view.extend(e.size for e in d) 258*da0073e9SAndroid Build Coastguard Worker elif isinstance(d, Dim): 259*da0073e9SAndroid Build Coastguard Worker flat_dims.append(d) 260*da0073e9SAndroid Build Coastguard Worker view.append(d.size) 261*da0073e9SAndroid Build Coastguard Worker elif isinstance(d, int): 262*da0073e9SAndroid Build Coastguard Worker d = _wrap_dim(d, ndim, False) 263*da0073e9SAndroid Build Coastguard Worker flat_dims.append(d) 264*da0073e9SAndroid Build Coastguard Worker view.append(ptensor.size(d)) 265*da0073e9SAndroid Build Coastguard Worker else: 266*da0073e9SAndroid Build Coastguard Worker flat_dims.extend(d) 267*da0073e9SAndroid Build Coastguard Worker view.append(prod(e.size for e in d)) 268*da0073e9SAndroid Build Coastguard Worker needs_view = True 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker permute = list(range(len(levels))) 271*da0073e9SAndroid Build Coastguard Worker nflat = len(flat_dims) 272*da0073e9SAndroid Build Coastguard Worker for i, d in enumerate(flat_dims): 273*da0073e9SAndroid Build Coastguard Worker try: 274*da0073e9SAndroid Build Coastguard Worker idx = levels.index(d) 275*da0073e9SAndroid Build Coastguard Worker except ValueError as e: 276*da0073e9SAndroid Build Coastguard Worker raise DimensionBindError( 277*da0073e9SAndroid Build Coastguard Worker f"tensor of dimensions {self.dims} does not contain dim {d}" 278*da0073e9SAndroid Build Coastguard Worker ) from e 279*da0073e9SAndroid Build Coastguard Worker p = permute[idx] 280*da0073e9SAndroid Build Coastguard Worker del levels[idx] 281*da0073e9SAndroid Build Coastguard Worker del permute[idx] 282*da0073e9SAndroid Build Coastguard Worker levels.insert(i, 0) 283*da0073e9SAndroid Build Coastguard Worker permute.insert(i, p) 284*da0073e9SAndroid Build Coastguard Worker ptensor = ptensor.permute(*permute) 285*da0073e9SAndroid Build Coastguard Worker seen = 0 286*da0073e9SAndroid Build Coastguard Worker for i in range(len(levels) - 1, -1, -1): 287*da0073e9SAndroid Build Coastguard Worker if isinstance(levels[i], int): 288*da0073e9SAndroid Build Coastguard Worker seen += 1 289*da0073e9SAndroid Build Coastguard Worker levels[i] = -seen 290*da0073e9SAndroid Build Coastguard Worker result = Tensor.from_positional(ptensor, levels, self._has_device) 291*da0073e9SAndroid Build Coastguard Worker if needs_view: 292*da0073e9SAndroid Build Coastguard Worker result = result.reshape(*view, *result.size()[len(flat_dims) :]) 293*da0073e9SAndroid Build Coastguard Worker return result 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Workerdef _contains_dim(input): 297*da0073e9SAndroid Build Coastguard Worker from . import Dim 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker for i in input: 300*da0073e9SAndroid Build Coastguard Worker if isinstance(i, Dim): 301*da0073e9SAndroid Build Coastguard Worker return True 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Workerdef expand(self, *sizes): 305*da0073e9SAndroid Build Coastguard Worker if not _contains_dim(sizes): 306*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) 307*da0073e9SAndroid Build Coastguard Worker dims = sizes 308*da0073e9SAndroid Build Coastguard Worker sizes = [d.size for d in dims] + [-1] * self.ndim 309*da0073e9SAndroid Build Coastguard Worker self = self.expand(*sizes) 310*da0073e9SAndroid Build Coastguard Worker return self[dims] 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker_not_present = object() 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Workerdef _getarg(name, offset, args, kwargs, default): 317*da0073e9SAndroid Build Coastguard Worker if len(args) > offset: 318*da0073e9SAndroid Build Coastguard Worker return args[offset] 319*da0073e9SAndroid Build Coastguard Worker return kwargs.get(name, default) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Workerdef _patcharg(name, offset, args, kwargs, value): 323*da0073e9SAndroid Build Coastguard Worker if len(args) > offset: 324*da0073e9SAndroid Build Coastguard Worker args[offset] = value 325*da0073e9SAndroid Build Coastguard Worker else: 326*da0073e9SAndroid Build Coastguard Worker kwargs[name] = value 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Workerdef _wrap( 330*da0073e9SAndroid Build Coastguard Worker orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True 331*da0073e9SAndroid Build Coastguard Worker): 332*da0073e9SAndroid Build Coastguard Worker from . import Dim, Tensor, TensorLike 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker def fn(self, *args, **kwargs): 335*da0073e9SAndroid Build Coastguard Worker dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) 336*da0073e9SAndroid Build Coastguard Worker if dim is _not_present or (single_dim and not isinstance(dim, Dim)): 337*da0073e9SAndroid Build Coastguard Worker with _enable_layers(self.dims): 338*da0073e9SAndroid Build Coastguard Worker print(f"dim fallback batch_tensor for {orig}") 339*da0073e9SAndroid Build Coastguard Worker return Tensor.from_batched( 340*da0073e9SAndroid Build Coastguard Worker orig(self._batchtensor, *args, **kwargs), self._has_device 341*da0073e9SAndroid Build Coastguard Worker ) 342*da0073e9SAndroid Build Coastguard Worker keepdim = ( 343*da0073e9SAndroid Build Coastguard Worker _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False 344*da0073e9SAndroid Build Coastguard Worker ) 345*da0073e9SAndroid Build Coastguard Worker t, levels = self._tensor, llist(self._levels) 346*da0073e9SAndroid Build Coastguard Worker dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) 347*da0073e9SAndroid Build Coastguard Worker dim_indices = tuple(levels.index(d) for d in dims) 348*da0073e9SAndroid Build Coastguard Worker if reduce and not keepdim: 349*da0073e9SAndroid Build Coastguard Worker new_levels = [l for i, l in enumerate(levels) if i not in dim_indices] 350*da0073e9SAndroid Build Coastguard Worker else: 351*da0073e9SAndroid Build Coastguard Worker new_levels = levels 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker if len(dim_indices) == 1: 354*da0073e9SAndroid Build Coastguard Worker dim_indices = dim_indices[ 355*da0073e9SAndroid Build Coastguard Worker 0 356*da0073e9SAndroid Build Coastguard Worker ] # so that dims that really only take a single argument work... 357*da0073e9SAndroid Build Coastguard Worker args = list(args) 358*da0073e9SAndroid Build Coastguard Worker _patcharg(dim_name, dim_offset, args, kwargs, dim_indices) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker def wrap(t): 361*da0073e9SAndroid Build Coastguard Worker if isinstance(t, TensorLike): 362*da0073e9SAndroid Build Coastguard Worker return Tensor.from_positional(t, new_levels, self._has_device) 363*da0073e9SAndroid Build Coastguard Worker return t 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker with _enable_layers(new_levels): 366*da0073e9SAndroid Build Coastguard Worker print(f"dim used batch_tensor for {orig}") 367*da0073e9SAndroid Build Coastguard Worker r = orig(t, *args, **kwargs) 368*da0073e9SAndroid Build Coastguard Worker return tree_map(wrap, r) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker return fn 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Workerdef _def(name, *args, **kwargs): 374*da0073e9SAndroid Build Coastguard Worker from . import _Tensor 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker orig = getattr(torch.Tensor, name) 377*da0073e9SAndroid Build Coastguard Worker setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Workerno_slice = slice(None) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker_orig_getitem = torch.Tensor.__getitem__ 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Workerclass dim_tracker: 386*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 387*da0073e9SAndroid Build Coastguard Worker self.dims = llist() 388*da0073e9SAndroid Build Coastguard Worker self.count = [] 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker def record(self, d): 391*da0073e9SAndroid Build Coastguard Worker if d not in self.dims: 392*da0073e9SAndroid Build Coastguard Worker self.dims.append(d) 393*da0073e9SAndroid Build Coastguard Worker self.count.append(1) 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, d): 396*da0073e9SAndroid Build Coastguard Worker return self.count[self.dims.index(d)] 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Workerdef t__getitem__(self, input): 400*da0073e9SAndroid Build Coastguard Worker from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker # * bail to original example if we have a single non-Dim tensor, or a non-tensor 403*da0073e9SAndroid Build Coastguard Worker # * locate ... or an unbound tensor list, and determine its size, bind dim list 404*da0073e9SAndroid Build Coastguard Worker # (remember that None does not count to the total dim count) 405*da0073e9SAndroid Build Coastguard Worker # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim, 406*da0073e9SAndroid Build Coastguard Worker # produce the re-view if needed 407*da0073e9SAndroid Build Coastguard Worker # * for each single-use dim index, replace with no_slice and mark that it will be added 408*da0073e9SAndroid Build Coastguard Worker # (keep track of whether we have to call super) 409*da0073e9SAndroid Build Coastguard Worker # * call super if needed 410*da0073e9SAndroid Build Coastguard Worker # * if we have dims to bind, bind them (it will help if we eliminated ... and None before) 411*da0073e9SAndroid Build Coastguard Worker # this handles bool indexing handling, as well as some other simple cases. 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker is_simple = ( 414*da0073e9SAndroid Build Coastguard Worker not isinstance(input, Dim) 415*da0073e9SAndroid Build Coastguard Worker and not isinstance(input, (tuple, list)) 416*da0073e9SAndroid Build Coastguard Worker and 417*da0073e9SAndroid Build Coastguard Worker # WAR for functorch bug where zero time tensors in getitem are not handled correctly. 418*da0073e9SAndroid Build Coastguard Worker not (isinstance(input, TensorLike) and input.ndim == 0) 419*da0073e9SAndroid Build Coastguard Worker ) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker if is_simple: 422*da0073e9SAndroid Build Coastguard Worker if isinstance(self, _Tensor): 423*da0073e9SAndroid Build Coastguard Worker return _Tensor.__torch_function__(_orig_getitem, None, (self, input)) 424*da0073e9SAndroid Build Coastguard Worker else: 425*da0073e9SAndroid Build Coastguard Worker return _orig_getitem(self, input) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker # can further optimize this case 428*da0073e9SAndroid Build Coastguard Worker if not isinstance(input, tuple): 429*da0073e9SAndroid Build Coastguard Worker input = [input] 430*da0073e9SAndroid Build Coastguard Worker else: 431*da0073e9SAndroid Build Coastguard Worker input = list(input) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker dims_indexed = 0 434*da0073e9SAndroid Build Coastguard Worker expanding_object = None 435*da0073e9SAndroid Build Coastguard Worker dimlists = [] 436*da0073e9SAndroid Build Coastguard Worker for i, s in enumerate(input): 437*da0073e9SAndroid Build Coastguard Worker if s is ... or isinstance(s, DimList) and not s.is_bound: 438*da0073e9SAndroid Build Coastguard Worker if expanding_object is not None: 439*da0073e9SAndroid Build Coastguard Worker msg = ( 440*da0073e9SAndroid Build Coastguard Worker "at most one ... or unbound dimension list can exist in indexing list but" 441*da0073e9SAndroid Build Coastguard Worker f" found 2 at offsets {i} and {expanding_object}" 442*da0073e9SAndroid Build Coastguard Worker ) 443*da0073e9SAndroid Build Coastguard Worker raise DimensionBindError(msg) 444*da0073e9SAndroid Build Coastguard Worker expanding_object = i 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker if isinstance(s, DimList): 447*da0073e9SAndroid Build Coastguard Worker dims_indexed += len(s) if s.is_bound else 0 448*da0073e9SAndroid Build Coastguard Worker dimlists.append(i) 449*da0073e9SAndroid Build Coastguard Worker elif s is not None and s is not ...: 450*da0073e9SAndroid Build Coastguard Worker dims_indexed += 1 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker ndim = self.ndim 453*da0073e9SAndroid Build Coastguard Worker if dims_indexed > ndim: 454*da0073e9SAndroid Build Coastguard Worker raise IndexError( 455*da0073e9SAndroid Build Coastguard Worker f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions." 456*da0073e9SAndroid Build Coastguard Worker ) 457*da0073e9SAndroid Build Coastguard Worker if expanding_object is not None: 458*da0073e9SAndroid Build Coastguard Worker expanding_ndims = ndim - dims_indexed 459*da0073e9SAndroid Build Coastguard Worker obj = input[expanding_object] 460*da0073e9SAndroid Build Coastguard Worker if obj is ...: 461*da0073e9SAndroid Build Coastguard Worker input[expanding_object : expanding_object + 1] = [ 462*da0073e9SAndroid Build Coastguard Worker no_slice 463*da0073e9SAndroid Build Coastguard Worker ] * expanding_ndims 464*da0073e9SAndroid Build Coastguard Worker else: 465*da0073e9SAndroid Build Coastguard Worker obj.bind_len(expanding_ndims) 466*da0073e9SAndroid Build Coastguard Worker # flatten the dimslists into the indexing 467*da0073e9SAndroid Build Coastguard Worker for i in reversed(dimlists): 468*da0073e9SAndroid Build Coastguard Worker input[i : i + 1] = input[i] 469*da0073e9SAndroid Build Coastguard Worker dims_indexed = 0 470*da0073e9SAndroid Build Coastguard Worker requires_view = False 471*da0073e9SAndroid Build Coastguard Worker size = self.size() 472*da0073e9SAndroid Build Coastguard Worker view_sizes = [] 473*da0073e9SAndroid Build Coastguard Worker dims_seen = dim_tracker() 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker def add_dims(t): 476*da0073e9SAndroid Build Coastguard Worker if not isinstance(t, _Tensor): 477*da0073e9SAndroid Build Coastguard Worker return 478*da0073e9SAndroid Build Coastguard Worker for d in t.dims: 479*da0073e9SAndroid Build Coastguard Worker dims_seen.record(d) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker add_dims(self) 482*da0073e9SAndroid Build Coastguard Worker dim_packs = [] 483*da0073e9SAndroid Build Coastguard Worker for i, idx in enumerate(input): 484*da0073e9SAndroid Build Coastguard Worker if idx is None: 485*da0073e9SAndroid Build Coastguard Worker input[i] = no_slice 486*da0073e9SAndroid Build Coastguard Worker view_sizes.append(1) 487*da0073e9SAndroid Build Coastguard Worker requires_view = True 488*da0073e9SAndroid Build Coastguard Worker else: 489*da0073e9SAndroid Build Coastguard Worker sz = size[dims_indexed] 490*da0073e9SAndroid Build Coastguard Worker if isinstance(idx, Dim): 491*da0073e9SAndroid Build Coastguard Worker idx.size = sz 492*da0073e9SAndroid Build Coastguard Worker dims_seen.record(idx) 493*da0073e9SAndroid Build Coastguard Worker view_sizes.append(sz) 494*da0073e9SAndroid Build Coastguard Worker elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): 495*da0073e9SAndroid Build Coastguard Worker for d in idx: 496*da0073e9SAndroid Build Coastguard Worker dims_seen.record(idx) 497*da0073e9SAndroid Build Coastguard Worker _bind_dims_to_size(sz, idx, f"offset {i}") 498*da0073e9SAndroid Build Coastguard Worker view_sizes.extend(d.size for d in idx) 499*da0073e9SAndroid Build Coastguard Worker requires_view = True 500*da0073e9SAndroid Build Coastguard Worker dim_packs.append(i) 501*da0073e9SAndroid Build Coastguard Worker else: 502*da0073e9SAndroid Build Coastguard Worker add_dims(idx) 503*da0073e9SAndroid Build Coastguard Worker view_sizes.append(sz) 504*da0073e9SAndroid Build Coastguard Worker dims_indexed += 1 505*da0073e9SAndroid Build Coastguard Worker if requires_view: 506*da0073e9SAndroid Build Coastguard Worker self = self.view(*view_sizes) 507*da0073e9SAndroid Build Coastguard Worker for i in reversed(dim_packs): 508*da0073e9SAndroid Build Coastguard Worker input[i : i + 1] = input[i] 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker # currenty: 511*da0073e9SAndroid Build Coastguard Worker # input is flat, containing either Dim, or Tensor, or something valid for standard indexing 512*da0073e9SAndroid Build Coastguard Worker # self may have first-class dims as well. 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker # to index: 515*da0073e9SAndroid Build Coastguard Worker # drop the first class dims from self, they just become direct indices of their positions 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index. 518*da0073e9SAndroid Build Coastguard Worker # these dimensions will appear and need to be bound at the first place tensor occures 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker if isinstance(self, _Tensor): 521*da0073e9SAndroid Build Coastguard Worker ptensor_self, levels = self._tensor, list(self._levels) 522*da0073e9SAndroid Build Coastguard Worker # indices to ptensor rather than self which has first-class dimensions 523*da0073e9SAndroid Build Coastguard Worker input_it = iter(input) 524*da0073e9SAndroid Build Coastguard Worker flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels] 525*da0073e9SAndroid Build Coastguard Worker has_device = self._has_device 526*da0073e9SAndroid Build Coastguard Worker to_pad = 0 527*da0073e9SAndroid Build Coastguard Worker else: 528*da0073e9SAndroid Build Coastguard Worker ptensor_self, flat_inputs = self, input 529*da0073e9SAndroid Build Coastguard Worker to_pad = ptensor_self.ndim - len(flat_inputs) 530*da0073e9SAndroid Build Coastguard Worker has_device = True 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker result_levels = [] 533*da0073e9SAndroid Build Coastguard Worker index_levels = [] 534*da0073e9SAndroid Build Coastguard Worker tensor_insert_point = None 535*da0073e9SAndroid Build Coastguard Worker to_expand = {} 536*da0073e9SAndroid Build Coastguard Worker requires_getindex = False 537*da0073e9SAndroid Build Coastguard Worker for i, inp in enumerate(flat_inputs): 538*da0073e9SAndroid Build Coastguard Worker if isinstance(inp, Dim) and dims_seen[inp] == 1: 539*da0073e9SAndroid Build Coastguard Worker flat_inputs[i] = no_slice 540*da0073e9SAndroid Build Coastguard Worker result_levels.append(inp) 541*da0073e9SAndroid Build Coastguard Worker elif isinstance(inp, TensorLike): 542*da0073e9SAndroid Build Coastguard Worker requires_getindex = True 543*da0073e9SAndroid Build Coastguard Worker if tensor_insert_point is None: 544*da0073e9SAndroid Build Coastguard Worker tensor_insert_point = len(result_levels) 545*da0073e9SAndroid Build Coastguard Worker ptensor, levels, _ = _tensor_levels(inp) 546*da0073e9SAndroid Build Coastguard Worker to_expand[i] = levels 547*da0073e9SAndroid Build Coastguard Worker flat_inputs[i] = ptensor 548*da0073e9SAndroid Build Coastguard Worker for l in levels: 549*da0073e9SAndroid Build Coastguard Worker if l not in index_levels: 550*da0073e9SAndroid Build Coastguard Worker index_levels.append(l) 551*da0073e9SAndroid Build Coastguard Worker else: 552*da0073e9SAndroid Build Coastguard Worker requires_getindex = True 553*da0073e9SAndroid Build Coastguard Worker result_levels.append(0) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker if tensor_insert_point is not None: 556*da0073e9SAndroid Build Coastguard Worker result_levels[tensor_insert_point:tensor_insert_point] = index_levels 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker for i, levels in to_expand.items(): 559*da0073e9SAndroid Build Coastguard Worker flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker if requires_getindex: 562*da0073e9SAndroid Build Coastguard Worker result = _orig_getitem(ptensor_self, flat_inputs) 563*da0073e9SAndroid Build Coastguard Worker else: 564*da0073e9SAndroid Build Coastguard Worker result = ptensor_self 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker next_positional = -1 567*da0073e9SAndroid Build Coastguard Worker if to_pad > 0: 568*da0073e9SAndroid Build Coastguard Worker result_levels.extend([0] * to_pad) 569*da0073e9SAndroid Build Coastguard Worker for i, r in enumerate(reversed(result_levels)): 570*da0073e9SAndroid Build Coastguard Worker if isinstance(r, int): 571*da0073e9SAndroid Build Coastguard Worker result_levels[-1 - i] = next_positional 572*da0073e9SAndroid Build Coastguard Worker next_positional -= 1 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker return Tensor.from_positional(result, result_levels, has_device) 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker# XXX - dim is optional and can be the outer-most dimension... 578*da0073e9SAndroid Build Coastguard Workerdef stack(tensors, new_dim, dim=0, out=None): 579*da0073e9SAndroid Build Coastguard Worker if isinstance(dim, int): 580*da0073e9SAndroid Build Coastguard Worker return torch.stack(tensors, dim, out).index(dim, new_dim) 581*da0073e9SAndroid Build Coastguard Worker index = None 582*da0073e9SAndroid Build Coastguard Worker if out is not None: 583*da0073e9SAndroid Build Coastguard Worker out, index = _positional_no_permute(out, dim, expand_dim=True) 584*da0073e9SAndroid Build Coastguard Worker ptensors = [] 585*da0073e9SAndroid Build Coastguard Worker for t in tensors: 586*da0073e9SAndroid Build Coastguard Worker pt, pi = _positional_no_permute(t, dim, expand_dim=True) 587*da0073e9SAndroid Build Coastguard Worker if index is not None and pi != index: 588*da0073e9SAndroid Build Coastguard Worker pt = pt.move_dim(pi, index) 589*da0073e9SAndroid Build Coastguard Worker else: 590*da0073e9SAndroid Build Coastguard Worker index = pi 591*da0073e9SAndroid Build Coastguard Worker ptensors.append(pt) 592*da0073e9SAndroid Build Coastguard Worker pr = torch.stack(ptensors, index, out=out) 593*da0073e9SAndroid Build Coastguard Worker return pr.index((index, index + 1), (new_dim, dim)) 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker_orig_split = torch.Tensor.split 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Workerdef split(self, split_size_or_sections, dim=0): 600*da0073e9SAndroid Build Coastguard Worker from . import _Tensor, Dim 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker if isinstance(split_size_or_sections, int) or any( 603*da0073e9SAndroid Build Coastguard Worker isinstance(t, int) for t in split_size_or_sections 604*da0073e9SAndroid Build Coastguard Worker ): 605*da0073e9SAndroid Build Coastguard Worker if isinstance(dim, Dim): 606*da0073e9SAndroid Build Coastguard Worker raise ValueError( 607*da0073e9SAndroid Build Coastguard Worker "when dim is specified as a Dim object, split sizes must also be dimensions." 608*da0073e9SAndroid Build Coastguard Worker ) 609*da0073e9SAndroid Build Coastguard Worker return _orig_split(self, split_size_or_sections, dim=dim) 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker if isinstance(dim, Dim): 612*da0073e9SAndroid Build Coastguard Worker assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}" 613*da0073e9SAndroid Build Coastguard Worker self, dim = _positional_no_permute(self, dim) 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker size = self.size(dim) 616*da0073e9SAndroid Build Coastguard Worker total_bound_size = 0 617*da0073e9SAndroid Build Coastguard Worker unbound = [] 618*da0073e9SAndroid Build Coastguard Worker sizes = [] 619*da0073e9SAndroid Build Coastguard Worker for i, d in enumerate(split_size_or_sections): 620*da0073e9SAndroid Build Coastguard Worker if d.is_bound: 621*da0073e9SAndroid Build Coastguard Worker sizes.append(d.size) 622*da0073e9SAndroid Build Coastguard Worker total_bound_size += d.size 623*da0073e9SAndroid Build Coastguard Worker else: 624*da0073e9SAndroid Build Coastguard Worker sizes.append(0) 625*da0073e9SAndroid Build Coastguard Worker unbound.append(i) 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker if unbound: 628*da0073e9SAndroid Build Coastguard Worker assert ( 629*da0073e9SAndroid Build Coastguard Worker total_bound_size <= size 630*da0073e9SAndroid Build Coastguard Worker ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" 631*da0073e9SAndroid Build Coastguard Worker remaining_size = size - total_bound_size 632*da0073e9SAndroid Build Coastguard Worker chunk_size = -(-remaining_size // len(unbound)) 633*da0073e9SAndroid Build Coastguard Worker for u in unbound: 634*da0073e9SAndroid Build Coastguard Worker sz = min(chunk_size, remaining_size) 635*da0073e9SAndroid Build Coastguard Worker split_size_or_sections[u].size = sz 636*da0073e9SAndroid Build Coastguard Worker sizes[u] = sz 637*da0073e9SAndroid Build Coastguard Worker remaining_size -= sz 638*da0073e9SAndroid Build Coastguard Worker else: 639*da0073e9SAndroid Build Coastguard Worker assert ( 640*da0073e9SAndroid Build Coastguard Worker total_bound_size == size 641*da0073e9SAndroid Build Coastguard Worker ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" 642*da0073e9SAndroid Build Coastguard Worker return tuple( 643*da0073e9SAndroid Build Coastguard Worker t.index(dim, d) 644*da0073e9SAndroid Build Coastguard Worker for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) 645*da0073e9SAndroid Build Coastguard Worker ) 646