1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7 8from . import _Tensor, Tensor 9from .reference import _dims, _enable_layers, llist, ltuple 10 11 12class DelayedMulTensor(_Tensor): 13 def __init__(self, lhs, rhs): 14 self._lhs, self._rhs = lhs, rhs 15 self._data = None 16 self._levels_data = None 17 self._has_device = lhs._has_device or rhs._has_device 18 self._batchtensor_data = None 19 self._tensor_data = None 20 21 @property 22 def _levels(self): 23 if self._levels_data is None: 24 levels = llist(self._lhs._levels) 25 for l in self._rhs._levels: 26 if l not in levels: 27 levels.append(l) 28 self._levels_data = ltuple(levels) 29 return self._levels_data 30 31 @property 32 def _batchtensor(self): 33 if self._batchtensor_data is None: 34 with _enable_layers(self._levels): 35 print("bt multiply fallback") 36 self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor 37 return self._batchtensor_data 38 39 @property 40 def _tensor(self): 41 if self._tensor_data is None: 42 self._tensor_data = Tensor.from_batched( 43 self._batchtensor, self._has_device 44 )._tensor 45 return self._tensor_data 46 47 @property 48 def ndim(self): 49 return self._batchtensor.ndim 50 51 @property 52 def dims(self): 53 return ltuple(super().dims) 54 55 def sum(self, dim): 56 dims = _dims(dim, 0, False, False) 57 n = ord("a") 58 all_levels = self._levels 59 60 def to_char(d): 61 return chr(n + all_levels.index(d)) 62 63 plhs, levelslhs = self._lhs._tensor, self._lhs._levels 64 prhs, levelsrhs = self._rhs._tensor, self._rhs._levels 65 new_dims = tuple(d for d in self.dims if d not in dims) 66 new_levels = [l for l in self._levels if l not in dims] 67 fmt = "".join( 68 [ 69 *(to_char(d) for d in levelslhs), 70 ",", 71 *(to_char(d) for d in levelsrhs), 72 "->", 73 *(to_char(d) for d in new_levels), 74 ] 75 ) 76 result_data = torch.einsum(fmt, (plhs, prhs)) 77 return Tensor.from_positional(result_data, new_levels, True) 78