xref: /aosp_15_r20/external/pytorch/functorch/dim/delayed_mul_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7
8from . import _Tensor, Tensor
9from .reference import _dims, _enable_layers, llist, ltuple
10
11
12class DelayedMulTensor(_Tensor):
13    def __init__(self, lhs, rhs):
14        self._lhs, self._rhs = lhs, rhs
15        self._data = None
16        self._levels_data = None
17        self._has_device = lhs._has_device or rhs._has_device
18        self._batchtensor_data = None
19        self._tensor_data = None
20
21    @property
22    def _levels(self):
23        if self._levels_data is None:
24            levels = llist(self._lhs._levels)
25            for l in self._rhs._levels:
26                if l not in levels:
27                    levels.append(l)
28            self._levels_data = ltuple(levels)
29        return self._levels_data
30
31    @property
32    def _batchtensor(self):
33        if self._batchtensor_data is None:
34            with _enable_layers(self._levels):
35                print("bt multiply fallback")
36                self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
37        return self._batchtensor_data
38
39    @property
40    def _tensor(self):
41        if self._tensor_data is None:
42            self._tensor_data = Tensor.from_batched(
43                self._batchtensor, self._has_device
44            )._tensor
45        return self._tensor_data
46
47    @property
48    def ndim(self):
49        return self._batchtensor.ndim
50
51    @property
52    def dims(self):
53        return ltuple(super().dims)
54
55    def sum(self, dim):
56        dims = _dims(dim, 0, False, False)
57        n = ord("a")
58        all_levels = self._levels
59
60        def to_char(d):
61            return chr(n + all_levels.index(d))
62
63        plhs, levelslhs = self._lhs._tensor, self._lhs._levels
64        prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
65        new_dims = tuple(d for d in self.dims if d not in dims)
66        new_levels = [l for l in self._levels if l not in dims]
67        fmt = "".join(
68            [
69                *(to_char(d) for d in levelslhs),
70                ",",
71                *(to_char(d) for d in levelsrhs),
72                "->",
73                *(to_char(d) for d in new_levels),
74            ]
75        )
76        result_data = torch.einsum(fmt, (plhs, prhs))
77        return Tensor.from_positional(result_data, new_levels, True)
78