xref: /aosp_15_r20/external/pytorch/functorch/dim/dim.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Copyright (c) Facebook, Inc. and its affiliates.
2*da0073e9SAndroid Build Coastguard Worker# All rights reserved.
3*da0073e9SAndroid Build Coastguard Worker#
4*da0073e9SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*da0073e9SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*da0073e9SAndroid Build Coastguard Workerimport dis
7*da0073e9SAndroid Build Coastguard Workerimport inspect
8*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
9*da0073e9SAndroid Build Coastguard Workerfrom typing import Union
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerfrom . import DimList
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker_vmap_levels = []
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker@dataclass
18*da0073e9SAndroid Build Coastguard Workerclass LevelInfo:
19*da0073e9SAndroid Build Coastguard Worker    level: int
20*da0073e9SAndroid Build Coastguard Worker    alive: bool = True
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerclass Dim:
24*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name: str, size: Union[None, int] = None):
25*da0073e9SAndroid Build Coastguard Worker        self.name = name
26*da0073e9SAndroid Build Coastguard Worker        self._size = None
27*da0073e9SAndroid Build Coastguard Worker        self._vmap_level = None
28*da0073e9SAndroid Build Coastguard Worker        if size is not None:
29*da0073e9SAndroid Build Coastguard Worker            self.size = size
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    def __del__(self):
32*da0073e9SAndroid Build Coastguard Worker        if self._vmap_level is not None:
33*da0073e9SAndroid Build Coastguard Worker            _vmap_active_levels[self._vmap_stack].alive = False  # noqa: F821
34*da0073e9SAndroid Build Coastguard Worker            while (
35*da0073e9SAndroid Build Coastguard Worker                not _vmap_levels[-1].alive
36*da0073e9SAndroid Build Coastguard Worker                and current_level() == _vmap_levels[-1].level  # noqa: F821
37*da0073e9SAndroid Build Coastguard Worker            ):
38*da0073e9SAndroid Build Coastguard Worker                _vmap_decrement_nesting()  # noqa: F821
39*da0073e9SAndroid Build Coastguard Worker                _vmap_levels.pop()
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    @property
42*da0073e9SAndroid Build Coastguard Worker    def size(self):
43*da0073e9SAndroid Build Coastguard Worker        assert self.is_bound
44*da0073e9SAndroid Build Coastguard Worker        return self._size
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    @size.setter
47*da0073e9SAndroid Build Coastguard Worker    def size(self, size: int):
48*da0073e9SAndroid Build Coastguard Worker        from . import DimensionBindError
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        if self._size is None:
51*da0073e9SAndroid Build Coastguard Worker            self._size = size
52*da0073e9SAndroid Build Coastguard Worker            self._vmap_level = _vmap_increment_nesting(size, "same")  # noqa: F821
53*da0073e9SAndroid Build Coastguard Worker            self._vmap_stack = len(_vmap_levels)
54*da0073e9SAndroid Build Coastguard Worker            _vmap_levels.append(LevelInfo(self._vmap_level))
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        elif self._size != size:
57*da0073e9SAndroid Build Coastguard Worker            raise DimensionBindError(
58*da0073e9SAndroid Build Coastguard Worker                f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
59*da0073e9SAndroid Build Coastguard Worker            )
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    @property
62*da0073e9SAndroid Build Coastguard Worker    def is_bound(self):
63*da0073e9SAndroid Build Coastguard Worker        return self._size is not None
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
66*da0073e9SAndroid Build Coastguard Worker        return self.name
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerdef extract_name(inst):
70*da0073e9SAndroid Build Coastguard Worker    assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
71*da0073e9SAndroid Build Coastguard Worker    return inst.argval
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker_cache = {}
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Workerdef dims(lists=0):
78*da0073e9SAndroid Build Coastguard Worker    frame = inspect.currentframe()
79*da0073e9SAndroid Build Coastguard Worker    assert frame is not None
80*da0073e9SAndroid Build Coastguard Worker    calling_frame = frame.f_back
81*da0073e9SAndroid Build Coastguard Worker    assert calling_frame is not None
82*da0073e9SAndroid Build Coastguard Worker    code, lasti = calling_frame.f_code, calling_frame.f_lasti
83*da0073e9SAndroid Build Coastguard Worker    key = (code, lasti)
84*da0073e9SAndroid Build Coastguard Worker    if key not in _cache:
85*da0073e9SAndroid Build Coastguard Worker        first = lasti // 2 + 1
86*da0073e9SAndroid Build Coastguard Worker        instructions = list(dis.get_instructions(calling_frame.f_code))
87*da0073e9SAndroid Build Coastguard Worker        unpack = instructions[first]
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
90*da0073e9SAndroid Build Coastguard Worker            # just a single dim, not a list
91*da0073e9SAndroid Build Coastguard Worker            name = unpack.argval
92*da0073e9SAndroid Build Coastguard Worker            ctor = Dim if lists == 0 else DimList
93*da0073e9SAndroid Build Coastguard Worker            _cache[key] = lambda: ctor(name=name)
94*da0073e9SAndroid Build Coastguard Worker        else:
95*da0073e9SAndroid Build Coastguard Worker            assert unpack.opname == "UNPACK_SEQUENCE"
96*da0073e9SAndroid Build Coastguard Worker            ndims = unpack.argval
97*da0073e9SAndroid Build Coastguard Worker            names = tuple(
98*da0073e9SAndroid Build Coastguard Worker                extract_name(instructions[first + 1 + i]) for i in range(ndims)
99*da0073e9SAndroid Build Coastguard Worker            )
100*da0073e9SAndroid Build Coastguard Worker            first_list = len(names) - lists
101*da0073e9SAndroid Build Coastguard Worker            _cache[key] = lambda: tuple(
102*da0073e9SAndroid Build Coastguard Worker                Dim(n) if i < first_list else DimList(name=n)
103*da0073e9SAndroid Build Coastguard Worker                for i, n in enumerate(names)
104*da0073e9SAndroid Build Coastguard Worker            )
105*da0073e9SAndroid Build Coastguard Worker    return _cache[key]()
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Workerdef _dim_set(positional, arg):
109*da0073e9SAndroid Build Coastguard Worker    def convert(a):
110*da0073e9SAndroid Build Coastguard Worker        if isinstance(a, Dim):
111*da0073e9SAndroid Build Coastguard Worker            return a
112*da0073e9SAndroid Build Coastguard Worker        else:
113*da0073e9SAndroid Build Coastguard Worker            assert isinstance(a, int)
114*da0073e9SAndroid Build Coastguard Worker            return positional[a]
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    if arg is None:
117*da0073e9SAndroid Build Coastguard Worker        return positional
118*da0073e9SAndroid Build Coastguard Worker    elif not isinstance(arg, (Dim, int)):
119*da0073e9SAndroid Build Coastguard Worker        return tuple(convert(a) for a in arg)
120*da0073e9SAndroid Build Coastguard Worker    else:
121*da0073e9SAndroid Build Coastguard Worker        return (convert(arg),)
122