1*da0073e9SAndroid Build Coastguard Worker# Copyright (c) Facebook, Inc. and its affiliates. 2*da0073e9SAndroid Build Coastguard Worker# All rights reserved. 3*da0073e9SAndroid Build Coastguard Worker# 4*da0073e9SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*da0073e9SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*da0073e9SAndroid Build Coastguard Workerimport dis 7*da0073e9SAndroid Build Coastguard Workerimport inspect 8*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 9*da0073e9SAndroid Build Coastguard Workerfrom typing import Union 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerfrom . import DimList 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker_vmap_levels = [] 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker@dataclass 18*da0073e9SAndroid Build Coastguard Workerclass LevelInfo: 19*da0073e9SAndroid Build Coastguard Worker level: int 20*da0073e9SAndroid Build Coastguard Worker alive: bool = True 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerclass Dim: 24*da0073e9SAndroid Build Coastguard Worker def __init__(self, name: str, size: Union[None, int] = None): 25*da0073e9SAndroid Build Coastguard Worker self.name = name 26*da0073e9SAndroid Build Coastguard Worker self._size = None 27*da0073e9SAndroid Build Coastguard Worker self._vmap_level = None 28*da0073e9SAndroid Build Coastguard Worker if size is not None: 29*da0073e9SAndroid Build Coastguard Worker self.size = size 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker def __del__(self): 32*da0073e9SAndroid Build Coastguard Worker if self._vmap_level is not None: 33*da0073e9SAndroid Build Coastguard Worker _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821 34*da0073e9SAndroid Build Coastguard Worker while ( 35*da0073e9SAndroid Build Coastguard Worker not _vmap_levels[-1].alive 36*da0073e9SAndroid Build Coastguard Worker and current_level() == _vmap_levels[-1].level # noqa: F821 37*da0073e9SAndroid Build Coastguard Worker ): 38*da0073e9SAndroid Build Coastguard Worker _vmap_decrement_nesting() # noqa: F821 39*da0073e9SAndroid Build Coastguard Worker _vmap_levels.pop() 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker @property 42*da0073e9SAndroid Build Coastguard Worker def size(self): 43*da0073e9SAndroid Build Coastguard Worker assert self.is_bound 44*da0073e9SAndroid Build Coastguard Worker return self._size 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker @size.setter 47*da0073e9SAndroid Build Coastguard Worker def size(self, size: int): 48*da0073e9SAndroid Build Coastguard Worker from . import DimensionBindError 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker if self._size is None: 51*da0073e9SAndroid Build Coastguard Worker self._size = size 52*da0073e9SAndroid Build Coastguard Worker self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821 53*da0073e9SAndroid Build Coastguard Worker self._vmap_stack = len(_vmap_levels) 54*da0073e9SAndroid Build Coastguard Worker _vmap_levels.append(LevelInfo(self._vmap_level)) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker elif self._size != size: 57*da0073e9SAndroid Build Coastguard Worker raise DimensionBindError( 58*da0073e9SAndroid Build Coastguard Worker f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}" 59*da0073e9SAndroid Build Coastguard Worker ) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker @property 62*da0073e9SAndroid Build Coastguard Worker def is_bound(self): 63*da0073e9SAndroid Build Coastguard Worker return self._size is not None 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 66*da0073e9SAndroid Build Coastguard Worker return self.name 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerdef extract_name(inst): 70*da0073e9SAndroid Build Coastguard Worker assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME" 71*da0073e9SAndroid Build Coastguard Worker return inst.argval 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker_cache = {} 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Workerdef dims(lists=0): 78*da0073e9SAndroid Build Coastguard Worker frame = inspect.currentframe() 79*da0073e9SAndroid Build Coastguard Worker assert frame is not None 80*da0073e9SAndroid Build Coastguard Worker calling_frame = frame.f_back 81*da0073e9SAndroid Build Coastguard Worker assert calling_frame is not None 82*da0073e9SAndroid Build Coastguard Worker code, lasti = calling_frame.f_code, calling_frame.f_lasti 83*da0073e9SAndroid Build Coastguard Worker key = (code, lasti) 84*da0073e9SAndroid Build Coastguard Worker if key not in _cache: 85*da0073e9SAndroid Build Coastguard Worker first = lasti // 2 + 1 86*da0073e9SAndroid Build Coastguard Worker instructions = list(dis.get_instructions(calling_frame.f_code)) 87*da0073e9SAndroid Build Coastguard Worker unpack = instructions[first] 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME": 90*da0073e9SAndroid Build Coastguard Worker # just a single dim, not a list 91*da0073e9SAndroid Build Coastguard Worker name = unpack.argval 92*da0073e9SAndroid Build Coastguard Worker ctor = Dim if lists == 0 else DimList 93*da0073e9SAndroid Build Coastguard Worker _cache[key] = lambda: ctor(name=name) 94*da0073e9SAndroid Build Coastguard Worker else: 95*da0073e9SAndroid Build Coastguard Worker assert unpack.opname == "UNPACK_SEQUENCE" 96*da0073e9SAndroid Build Coastguard Worker ndims = unpack.argval 97*da0073e9SAndroid Build Coastguard Worker names = tuple( 98*da0073e9SAndroid Build Coastguard Worker extract_name(instructions[first + 1 + i]) for i in range(ndims) 99*da0073e9SAndroid Build Coastguard Worker ) 100*da0073e9SAndroid Build Coastguard Worker first_list = len(names) - lists 101*da0073e9SAndroid Build Coastguard Worker _cache[key] = lambda: tuple( 102*da0073e9SAndroid Build Coastguard Worker Dim(n) if i < first_list else DimList(name=n) 103*da0073e9SAndroid Build Coastguard Worker for i, n in enumerate(names) 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker return _cache[key]() 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Workerdef _dim_set(positional, arg): 109*da0073e9SAndroid Build Coastguard Worker def convert(a): 110*da0073e9SAndroid Build Coastguard Worker if isinstance(a, Dim): 111*da0073e9SAndroid Build Coastguard Worker return a 112*da0073e9SAndroid Build Coastguard Worker else: 113*da0073e9SAndroid Build Coastguard Worker assert isinstance(a, int) 114*da0073e9SAndroid Build Coastguard Worker return positional[a] 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker if arg is None: 117*da0073e9SAndroid Build Coastguard Worker return positional 118*da0073e9SAndroid Build Coastguard Worker elif not isinstance(arg, (Dim, int)): 119*da0073e9SAndroid Build Coastguard Worker return tuple(convert(a) for a in arg) 120*da0073e9SAndroid Build Coastguard Worker else: 121*da0073e9SAndroid Build Coastguard Worker return (convert(arg),) 122