1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport copy 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import List, Optional, Union 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[21]: Could not find module `executorch.exir.verification.bindings`. 13*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.verification.bindings as bindings # @manual=//executorch/exir/verification:bindings 14*523fa7a6SAndroid Build Coastguard Workerimport executorch.extension.pytree as ex_pytree 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerimport torch 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch import exir 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import ( 21*523fa7a6SAndroid Build Coastguard Worker Bool, 22*523fa7a6SAndroid Build Coastguard Worker BoolList, 23*523fa7a6SAndroid Build Coastguard Worker Double, 24*523fa7a6SAndroid Build Coastguard Worker DoubleList, 25*523fa7a6SAndroid Build Coastguard Worker ExecutionPlan, 26*523fa7a6SAndroid Build Coastguard Worker Int, 27*523fa7a6SAndroid Build Coastguard Worker IntList, 28*523fa7a6SAndroid Build Coastguard Worker JumpFalseCall, 29*523fa7a6SAndroid Build Coastguard Worker KernelCall, 30*523fa7a6SAndroid Build Coastguard Worker KernelTypes, 31*523fa7a6SAndroid Build Coastguard Worker MoveCall, 32*523fa7a6SAndroid Build Coastguard Worker Null, 33*523fa7a6SAndroid Build Coastguard Worker Operator, 34*523fa7a6SAndroid Build Coastguard Worker OptionalTensorList, 35*523fa7a6SAndroid Build Coastguard Worker Program, 36*523fa7a6SAndroid Build Coastguard Worker String, 37*523fa7a6SAndroid Build Coastguard Worker Tensor, 38*523fa7a6SAndroid Build Coastguard Worker TensorList, 39*523fa7a6SAndroid Build Coastguard Worker) 40*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import get_scalar_type, stride_from_dim_order 41*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import impl, Library 42*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import PyTree 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Workerclass Uninitialized: 46*523fa7a6SAndroid Build Coastguard Worker pass 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard WorkerValueListType = Union[ 50*523fa7a6SAndroid Build Coastguard Worker List[torch.Tensor], 51*523fa7a6SAndroid Build Coastguard Worker List[bool], 52*523fa7a6SAndroid Build Coastguard Worker List[float], 53*523fa7a6SAndroid Build Coastguard Worker List[int], 54*523fa7a6SAndroid Build Coastguard Worker] 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard WorkerValueScalarType = Union[int, str, float, torch.Tensor, Uninitialized, None] 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard WorkerValueType = Union[ 59*523fa7a6SAndroid Build Coastguard Worker ValueScalarType, 60*523fa7a6SAndroid Build Coastguard Worker ValueListType, 61*523fa7a6SAndroid Build Coastguard Worker] 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker# defining the operator executorch.move 64*523fa7a6SAndroid Build Coastguard Workerexecutorch_lib = Library("executorch", "DEF") 65*523fa7a6SAndroid Build Coastguard Workerexecutorch_lib.define("move(Tensor self) -> Tensor") 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Worker@impl(executorch_lib, "move", "CPU") 69*523fa7a6SAndroid Build Coastguard Workerdef move_impl(self: torch.Tensor) -> torch.Tensor: 70*523fa7a6SAndroid Build Coastguard Worker return torch.clone(self) 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Workerdef comp_types(val: KernelTypes, input_val: ValueType) -> Optional[bool]: # noqa 74*523fa7a6SAndroid Build Coastguard Worker """ 75*523fa7a6SAndroid Build Coastguard Worker Compares a schema type (val) with Python type (input_val) 76*523fa7a6SAndroid Build Coastguard Worker Map: Int -> int, Bool -> bool, Double -> float, 77*523fa7a6SAndroid Build Coastguard Worker String -> str, Tensor -> torch.Tensor, XList -> list[x] 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker Args: 80*523fa7a6SAndroid Build Coastguard Worker `val`: value from value list with type from `schema.py` 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker `input_val`: value with Python type, normally from the result of executing an operation 83*523fa7a6SAndroid Build Coastguard Worker """ 84*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, Int): 85*523fa7a6SAndroid Build Coastguard Worker return isinstance(input_val, type(val.int_val)) 86*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, Bool): 87*523fa7a6SAndroid Build Coastguard Worker return isinstance(input_val, type(val.bool_val)) 88*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, Double): 89*523fa7a6SAndroid Build Coastguard Worker return isinstance(input_val, type(val.double_val)) 90*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, String): 91*523fa7a6SAndroid Build Coastguard Worker return isinstance(input_val, type(val.string_val)) 92*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, Tensor): 93*523fa7a6SAndroid Build Coastguard Worker return isinstance(input_val, torch.Tensor) 94*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, IntList): 95*523fa7a6SAndroid Build Coastguard Worker if not isinstance(input_val, list): 96*523fa7a6SAndroid Build Coastguard Worker return False 97*523fa7a6SAndroid Build Coastguard Worker return all(isinstance(x, int) for x in input_val) 98*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, BoolList): 99*523fa7a6SAndroid Build Coastguard Worker if not isinstance(input_val, list): 100*523fa7a6SAndroid Build Coastguard Worker return False 101*523fa7a6SAndroid Build Coastguard Worker return all(isinstance(x, bool) for x in input_val) 102*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, DoubleList): 103*523fa7a6SAndroid Build Coastguard Worker if not isinstance(input_val, list): 104*523fa7a6SAndroid Build Coastguard Worker return False 105*523fa7a6SAndroid Build Coastguard Worker return all(isinstance(x, float) for x in input_val) 106*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, (TensorList, OptionalTensorList)): 107*523fa7a6SAndroid Build Coastguard Worker if not isinstance(input_val, list): 108*523fa7a6SAndroid Build Coastguard Worker return False 109*523fa7a6SAndroid Build Coastguard Worker return all(isinstance(x, torch.Tensor) for x in input_val) 110*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, Null): 111*523fa7a6SAndroid Build Coastguard Worker raise TypeError("Setting a value where there should be a Null") 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker 114*523fa7a6SAndroid Build Coastguard Workerdef resolve_op(operator: Operator) -> torch._ops.OpOverload: 115*523fa7a6SAndroid Build Coastguard Worker # pattern matching out the namespace and operation name 116*523fa7a6SAndroid Build Coastguard Worker namespace, op_name = operator.name.split("::") 117*523fa7a6SAndroid Build Coastguard Worker op = getattr(getattr(getattr(torch.ops, namespace), op_name), operator.overload) 118*523fa7a6SAndroid Build Coastguard Worker return op 119*523fa7a6SAndroid Build Coastguard Worker 120*523fa7a6SAndroid Build Coastguard Worker 121*523fa7a6SAndroid Build Coastguard Workerdef make_operators_list( 122*523fa7a6SAndroid Build Coastguard Worker execution_plan: ExecutionPlan, 123*523fa7a6SAndroid Build Coastguard Worker) -> List[torch._ops.OpOverload]: 124*523fa7a6SAndroid Build Coastguard Worker operator_list = [resolve_op(operator) for operator in execution_plan.operators] 125*523fa7a6SAndroid Build Coastguard Worker return operator_list 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Workerclass Interpreter: 129*523fa7a6SAndroid Build Coastguard Worker def __init__(self, program: Program) -> None: 130*523fa7a6SAndroid Build Coastguard Worker # Currently there is only 1 execution plan in the list -- this assert will help 131*523fa7a6SAndroid Build Coastguard Worker # catch any changes in the future 132*523fa7a6SAndroid Build Coastguard Worker assert len(program.execution_plan) == 1 133*523fa7a6SAndroid Build Coastguard Worker self.execution_plan: exir.schema.ExecutionPlan = program.execution_plan[0] 134*523fa7a6SAndroid Build Coastguard Worker self.container_metatype: exir.schema.ContainerMetadata = program.execution_plan[ 135*523fa7a6SAndroid Build Coastguard Worker 0 136*523fa7a6SAndroid Build Coastguard Worker ].container_meta_type 137*523fa7a6SAndroid Build Coastguard Worker 138*523fa7a6SAndroid Build Coastguard Worker # create buffer in memory and get reference to it 139*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 140*523fa7a6SAndroid Build Coastguard Worker self.data_buffers: List[bindings.DataBuffer] = [ 141*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 142*523fa7a6SAndroid Build Coastguard Worker bindings.DataBuffer(b.storage, len(b.storage)) 143*523fa7a6SAndroid Build Coastguard Worker for b in program.constant_buffer 144*523fa7a6SAndroid Build Coastguard Worker ] 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Worker # generate the list of values (including tensors) and operators from the execution plan 147*523fa7a6SAndroid Build Coastguard Worker self._value_list: List[ValueType] = [ 148*523fa7a6SAndroid Build Coastguard Worker Uninitialized() for val in self.execution_plan.values 149*523fa7a6SAndroid Build Coastguard Worker ] 150*523fa7a6SAndroid Build Coastguard Worker self._operators_list: List[torch._ops.OpOverload] = make_operators_list( 151*523fa7a6SAndroid Build Coastguard Worker self.execution_plan 152*523fa7a6SAndroid Build Coastguard Worker ) 153*523fa7a6SAndroid Build Coastguard Worker 154*523fa7a6SAndroid Build Coastguard Worker def get_value_list(self) -> List[ValueType]: 155*523fa7a6SAndroid Build Coastguard Worker # TODO(meghajain) may need to change deepcopy to clone 156*523fa7a6SAndroid Build Coastguard Worker return copy.deepcopy(self._value_list) 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker def get_operators_list(self) -> List[torch._ops.OpOverload]: 159*523fa7a6SAndroid Build Coastguard Worker return self._operators_list 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker def get_constant_tensors(self) -> List[Tensor]: 162*523fa7a6SAndroid Build Coastguard Worker """ 163*523fa7a6SAndroid Build Coastguard Worker No side effects on Interpreter's value list. List of constant tensors returned 164*523fa7a6SAndroid Build Coastguard Worker without having to run program. 165*523fa7a6SAndroid Build Coastguard Worker """ 166*523fa7a6SAndroid Build Coastguard Worker tensors = [] 167*523fa7a6SAndroid Build Coastguard Worker for elem in self.execution_plan.values: 168*523fa7a6SAndroid Build Coastguard Worker val = elem.val 169*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, Tensor) and val.data_buffer_idx != 0: 170*523fa7a6SAndroid Build Coastguard Worker # load val into res 171*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16] 172*523fa7a6SAndroid Build Coastguard Worker tensor = bindings.convert_to_tensor( 173*523fa7a6SAndroid Build Coastguard Worker self.data_buffers[val.data_buffer_idx], 174*523fa7a6SAndroid Build Coastguard Worker val.scalar_type, 175*523fa7a6SAndroid Build Coastguard Worker val.sizes, 176*523fa7a6SAndroid Build Coastguard Worker stride_from_dim_order(val.sizes, val.dim_order), 177*523fa7a6SAndroid Build Coastguard Worker ) 178*523fa7a6SAndroid Build Coastguard Worker tensors.append(tensor) 179*523fa7a6SAndroid Build Coastguard Worker return tensors 180*523fa7a6SAndroid Build Coastguard Worker 181*523fa7a6SAndroid Build Coastguard Worker def load_value(self, idx: int) -> None: 182*523fa7a6SAndroid Build Coastguard Worker """ 183*523fa7a6SAndroid Build Coastguard Worker Given an index in the value list, if value is `Uninitialized` or is a mutable object, 184*523fa7a6SAndroid Build Coastguard Worker like a Tensor List, calls `load` to load and initialize value into Interpreter's value_list. 185*523fa7a6SAndroid Build Coastguard Worker 186*523fa7a6SAndroid Build Coastguard Worker Args: 187*523fa7a6SAndroid Build Coastguard Worker `idx` : index in value lists that we want to load 188*523fa7a6SAndroid Build Coastguard Worker 189*523fa7a6SAndroid Build Coastguard Worker Returns: No returned values - value list is updated in place 190*523fa7a6SAndroid Build Coastguard Worker """ 191*523fa7a6SAndroid Build Coastguard Worker # if instance of any mutable object, load regardless of being initialized 192*523fa7a6SAndroid Build Coastguard Worker if isinstance( 193*523fa7a6SAndroid Build Coastguard Worker self.execution_plan.values[idx].val, (TensorList, OptionalTensorList) 194*523fa7a6SAndroid Build Coastguard Worker ) or isinstance(self._value_list[idx], Uninitialized): 195*523fa7a6SAndroid Build Coastguard Worker self.load_from_value_list(idx) 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard Worker def load_from_value_list(self, idx: int) -> None: # noqa 198*523fa7a6SAndroid Build Coastguard Worker """ 199*523fa7a6SAndroid Build Coastguard Worker Accesses Execution Plan's value list at same index (see schema.py) to 200*523fa7a6SAndroid Build Coastguard Worker load and initialize value into Interpreter's value list. Extracts the 201*523fa7a6SAndroid Build Coastguard Worker necessary values depending on the type of the value. E.g. Tensor Lists 202*523fa7a6SAndroid Build Coastguard Worker have indices into the value list, so they are recursively loaded. 203*523fa7a6SAndroid Build Coastguard Worker If an Evalue is a Constant Tensor (denoted by allocation_info=None), it 204*523fa7a6SAndroid Build Coastguard Worker converts the python obj to a torch tensor object. 205*523fa7a6SAndroid Build Coastguard Worker 206*523fa7a6SAndroid Build Coastguard Worker Args: 207*523fa7a6SAndroid Build Coastguard Worker `idx` : index in value lists that we want to load 208*523fa7a6SAndroid Build Coastguard Worker 209*523fa7a6SAndroid Build Coastguard Worker Returns: No returned values - value list is updated in place 210*523fa7a6SAndroid Build Coastguard Worker """ 211*523fa7a6SAndroid Build Coastguard Worker assert idx >= 0 212*523fa7a6SAndroid Build Coastguard Worker val = self.execution_plan.values[idx].val 213*523fa7a6SAndroid Build Coastguard Worker 214*523fa7a6SAndroid Build Coastguard Worker # Case through all possible Evalue Types 215*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, Int): 216*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = val.int_val 217*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, Bool): 218*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = val.bool_val 219*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, Double): 220*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = val.double_val 221*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, String): 222*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = val.string_val 223*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, (IntList, BoolList, DoubleList)): 224*523fa7a6SAndroid Build Coastguard Worker unboxed_list = [] 225*523fa7a6SAndroid Build Coastguard Worker for item in val.items: 226*523fa7a6SAndroid Build Coastguard Worker assert isinstance(item, int) 227*523fa7a6SAndroid Build Coastguard Worker assert isinstance(self.execution_plan.values[item].val, Int) 228*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme [16] Undefined attribute [16]: Item `Bool` has no 229*523fa7a6SAndroid Build Coastguard Worker # attribute `int_val`. 230*523fa7a6SAndroid Build Coastguard Worker unboxed_list.append(self.execution_plan.values[item].val.int_val) 231*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = unboxed_list 232*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, (TensorList, OptionalTensorList)): 233*523fa7a6SAndroid Build Coastguard Worker tensor_list = [] 234*523fa7a6SAndroid Build Coastguard Worker for i in val.items: 235*523fa7a6SAndroid Build Coastguard Worker if i == -1: 236*523fa7a6SAndroid Build Coastguard Worker tensor_list.append(None) 237*523fa7a6SAndroid Build Coastguard Worker continue 238*523fa7a6SAndroid Build Coastguard Worker self.load_value(i) 239*523fa7a6SAndroid Build Coastguard Worker tensor_list.append(self._value_list[i]) 240*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = tensor_list 241*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, Tensor): 242*523fa7a6SAndroid Build Coastguard Worker if val.data_buffer_idx == 0: 243*523fa7a6SAndroid Build Coastguard Worker # TODO(zhengxu) Verify that argument is actually an out variant 244*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = torch.empty( 245*523fa7a6SAndroid Build Coastguard Worker val.sizes, dtype=get_scalar_type(val.scalar_type) 246*523fa7a6SAndroid Build Coastguard Worker ) 247*523fa7a6SAndroid Build Coastguard Worker else: 248*523fa7a6SAndroid Build Coastguard Worker # Constant Tensor conversion 249*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme [16] 250*523fa7a6SAndroid Build Coastguard Worker tensor = bindings.convert_to_tensor( 251*523fa7a6SAndroid Build Coastguard Worker self.data_buffers[val.data_buffer_idx], 252*523fa7a6SAndroid Build Coastguard Worker val.scalar_type, 253*523fa7a6SAndroid Build Coastguard Worker val.sizes, 254*523fa7a6SAndroid Build Coastguard Worker stride_from_dim_order(val.sizes, val.dim_order), 255*523fa7a6SAndroid Build Coastguard Worker ) 256*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = tensor 257*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, Null): 258*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = None 259*523fa7a6SAndroid Build Coastguard Worker else: 260*523fa7a6SAndroid Build Coastguard Worker raise TypeError( 261*523fa7a6SAndroid Build Coastguard Worker f"Unexpected type, {type(val)}, with value, {val}, in Execution Plan values." 262*523fa7a6SAndroid Build Coastguard Worker ) 263*523fa7a6SAndroid Build Coastguard Worker 264*523fa7a6SAndroid Build Coastguard Worker assert not isinstance(self._value_list[idx], Uninitialized) 265*523fa7a6SAndroid Build Coastguard Worker 266*523fa7a6SAndroid Build Coastguard Worker def set_value(self, idx: int, input_val: ValueType) -> None: 267*523fa7a6SAndroid Build Coastguard Worker """ 268*523fa7a6SAndroid Build Coastguard Worker Given an index in the value list, and a value, updates 269*523fa7a6SAndroid Build Coastguard Worker Interpreter's value in value list at given index in place 270*523fa7a6SAndroid Build Coastguard Worker If value is meant to be a TensorList at given index, 271*523fa7a6SAndroid Build Coastguard Worker iterate through all the indices in the TensorList and 272*523fa7a6SAndroid Build Coastguard Worker update each placeholder with the given Tensor from `input_val`. 273*523fa7a6SAndroid Build Coastguard Worker 274*523fa7a6SAndroid Build Coastguard Worker Args: 275*523fa7a6SAndroid Build Coastguard Worker `idx` : index in value lists that we want to set 276*523fa7a6SAndroid Build Coastguard Worker 277*523fa7a6SAndroid Build Coastguard Worker `input_val` : value we want to put at `self._value_list[idx]` 278*523fa7a6SAndroid Build Coastguard Worker 279*523fa7a6SAndroid Build Coastguard Worker Returns: No returned values - value list is updated in place 280*523fa7a6SAndroid Build Coastguard Worker """ 281*523fa7a6SAndroid Build Coastguard Worker evalue = self.execution_plan.values[idx] 282*523fa7a6SAndroid Build Coastguard Worker val = evalue.val 283*523fa7a6SAndroid Build Coastguard Worker 284*523fa7a6SAndroid Build Coastguard Worker if not comp_types(val, input_val): 285*523fa7a6SAndroid Build Coastguard Worker raise TypeError( 286*523fa7a6SAndroid Build Coastguard Worker f"Program trying to set a value of {input_val} : {type(input_val)} in memory location where {type(val)} is expected." 287*523fa7a6SAndroid Build Coastguard Worker ) 288*523fa7a6SAndroid Build Coastguard Worker 289*523fa7a6SAndroid Build Coastguard Worker # Case through all possible Evalue Types 290*523fa7a6SAndroid Build Coastguard Worker if isinstance( 291*523fa7a6SAndroid Build Coastguard Worker val, 292*523fa7a6SAndroid Build Coastguard Worker (Int, Bool, Double, String, IntList, BoolList, DoubleList, Tensor, Null), 293*523fa7a6SAndroid Build Coastguard Worker ): 294*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = input_val 295*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, (TensorList, OptionalTensorList)): 296*523fa7a6SAndroid Build Coastguard Worker assert isinstance(input_val, List) 297*523fa7a6SAndroid Build Coastguard Worker assert len(val.items) == len(input_val) 298*523fa7a6SAndroid Build Coastguard Worker tensor_list = [] 299*523fa7a6SAndroid Build Coastguard Worker for i in range(len(val.items)): 300*523fa7a6SAndroid Build Coastguard Worker val_idx = val.items[i] 301*523fa7a6SAndroid Build Coastguard Worker self._value_list[val_idx] = input_val[i] 302*523fa7a6SAndroid Build Coastguard Worker tensor_list.append(input_val[i]) 303*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = tensor_list 304*523fa7a6SAndroid Build Coastguard Worker else: 305*523fa7a6SAndroid Build Coastguard Worker raise TypeError( 306*523fa7a6SAndroid Build Coastguard Worker f"Unexpected type, {type(val)}, with value, {val}, in Execution Plan values." 307*523fa7a6SAndroid Build Coastguard Worker ) 308*523fa7a6SAndroid Build Coastguard Worker 309*523fa7a6SAndroid Build Coastguard Worker def call_kernel(self, kernel: KernelCall) -> None: 310*523fa7a6SAndroid Build Coastguard Worker """ 311*523fa7a6SAndroid Build Coastguard Worker Calls operator from kernel: 312*523fa7a6SAndroid Build Coastguard Worker 1. Determines kernel's operation through kernel.op_index, 313*523fa7a6SAndroid Build Coastguard Worker which indexes into operator list from Program's execution plan. 314*523fa7a6SAndroid Build Coastguard Worker 2. After identifying operation, determines number of arguments and 315*523fa7a6SAndroid Build Coastguard Worker keyword arguments through operator schema. 316*523fa7a6SAndroid Build Coastguard Worker 3. Extracts arguments from value list and calls operator. 317*523fa7a6SAndroid Build Coastguard Worker 4. Sets the given output indices in value list with the values 318*523fa7a6SAndroid Build Coastguard Worker returned from operation. 319*523fa7a6SAndroid Build Coastguard Worker 320*523fa7a6SAndroid Build Coastguard Worker Args: 321*523fa7a6SAndroid Build Coastguard Worker `kernel` : stores information about operator and which indices in 322*523fa7a6SAndroid Build Coastguard Worker the value list contain the necessary arguments 323*523fa7a6SAndroid Build Coastguard Worker 324*523fa7a6SAndroid Build Coastguard Worker Returns: No returned values - value list is updated with outputs 325*523fa7a6SAndroid Build Coastguard Worker from operator in place 326*523fa7a6SAndroid Build Coastguard Worker """ 327*523fa7a6SAndroid Build Coastguard Worker 328*523fa7a6SAndroid Build Coastguard Worker operator = self._operators_list[kernel.op_index] 329*523fa7a6SAndroid Build Coastguard Worker num_args = len( 330*523fa7a6SAndroid Build Coastguard Worker [arg for arg in operator._schema.arguments if not arg.kwarg_only] 331*523fa7a6SAndroid Build Coastguard Worker ) 332*523fa7a6SAndroid Build Coastguard Worker kwarg_list = [kwarg for kwarg in operator._schema.arguments if kwarg.kwarg_only] 333*523fa7a6SAndroid Build Coastguard Worker num_kwargs = len(kwarg_list) 334*523fa7a6SAndroid Build Coastguard Worker 335*523fa7a6SAndroid Build Coastguard Worker # Extracting arguments and keyword arguments from value_list given indices kernel.args 336*523fa7a6SAndroid Build Coastguard Worker args = [] 337*523fa7a6SAndroid Build Coastguard Worker for i in kernel.args[:num_args]: 338*523fa7a6SAndroid Build Coastguard Worker self.load_value(i) 339*523fa7a6SAndroid Build Coastguard Worker args.append(self._value_list[i]) 340*523fa7a6SAndroid Build Coastguard Worker 341*523fa7a6SAndroid Build Coastguard Worker kwargs = {} 342*523fa7a6SAndroid Build Coastguard Worker for j in range(num_kwargs): 343*523fa7a6SAndroid Build Coastguard Worker i = kernel.args[num_args + j] 344*523fa7a6SAndroid Build Coastguard Worker keyword = kwarg_list[j].name 345*523fa7a6SAndroid Build Coastguard Worker 346*523fa7a6SAndroid Build Coastguard Worker self.load_value(i) 347*523fa7a6SAndroid Build Coastguard Worker kwargs[keyword] = self._value_list[i] 348*523fa7a6SAndroid Build Coastguard Worker 349*523fa7a6SAndroid Build Coastguard Worker res = operator(*args, **kwargs) 350*523fa7a6SAndroid Build Coastguard Worker output_idxs = kernel.args[num_args + num_kwargs :] 351*523fa7a6SAndroid Build Coastguard Worker 352*523fa7a6SAndroid Build Coastguard Worker assert ( 353*523fa7a6SAndroid Build Coastguard Worker len(output_idxs) == 1 354*523fa7a6SAndroid Build Coastguard Worker ), "emitter is expected to pack multiple outputs into a TensorList" 355*523fa7a6SAndroid Build Coastguard Worker if isinstance(res, tuple): 356*523fa7a6SAndroid Build Coastguard Worker self.set_value(output_idxs[0], list(res)) 357*523fa7a6SAndroid Build Coastguard Worker else: 358*523fa7a6SAndroid Build Coastguard Worker self.set_value(output_idxs[0], res) 359*523fa7a6SAndroid Build Coastguard Worker 360*523fa7a6SAndroid Build Coastguard Worker def run(self, *raw_args: torch.Tensor) -> PyTree: 361*523fa7a6SAndroid Build Coastguard Worker """ 362*523fa7a6SAndroid Build Coastguard Worker Loops through instructions given some inputs 363*523fa7a6SAndroid Build Coastguard Worker 364*523fa7a6SAndroid Build Coastguard Worker Args: 365*523fa7a6SAndroid Build Coastguard Worker `args` : list of inputs required for interpretation 366*523fa7a6SAndroid Build Coastguard Worker 367*523fa7a6SAndroid Build Coastguard Worker Returns: 368*523fa7a6SAndroid Build Coastguard Worker Outputs after completing all computations 369*523fa7a6SAndroid Build Coastguard Worker """ 370*523fa7a6SAndroid Build Coastguard Worker 371*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 372*523fa7a6SAndroid Build Coastguard Worker args, pytree = ex_pytree.tree_flatten((raw_args, {})) 373*523fa7a6SAndroid Build Coastguard Worker 374*523fa7a6SAndroid Build Coastguard Worker if pytree.to_str() != self.container_metatype.encoded_inp_str: 375*523fa7a6SAndroid Build Coastguard Worker raise TypeError( 376*523fa7a6SAndroid Build Coastguard Worker f"Arguments provided do not match required type. \nRequired: {self.container_metatype.encoded_inp_str} \nProvided: {pytree.to_str()}" 377*523fa7a6SAndroid Build Coastguard Worker ) 378*523fa7a6SAndroid Build Coastguard Worker 379*523fa7a6SAndroid Build Coastguard Worker # Initialize user inputs in value list 380*523fa7a6SAndroid Build Coastguard Worker if len(self.execution_plan.inputs) != len(args): 381*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 382*523fa7a6SAndroid Build Coastguard Worker f"Incorrect number of arguments provided. Expected {len(self.execution_plan.inputs)} values, but received {len(args)}" 383*523fa7a6SAndroid Build Coastguard Worker ) 384*523fa7a6SAndroid Build Coastguard Worker for i in range(len(self.execution_plan.inputs)): 385*523fa7a6SAndroid Build Coastguard Worker idx = self.execution_plan.inputs[i] 386*523fa7a6SAndroid Build Coastguard Worker self._value_list[idx] = args[i] 387*523fa7a6SAndroid Build Coastguard Worker 388*523fa7a6SAndroid Build Coastguard Worker assert len(self.execution_plan.chains) == 1 389*523fa7a6SAndroid Build Coastguard Worker chain = self.execution_plan.chains[0] 390*523fa7a6SAndroid Build Coastguard Worker 391*523fa7a6SAndroid Build Coastguard Worker # instruction pointer 392*523fa7a6SAndroid Build Coastguard Worker ip = 0 393*523fa7a6SAndroid Build Coastguard Worker 394*523fa7a6SAndroid Build Coastguard Worker # Kernel loop 395*523fa7a6SAndroid Build Coastguard Worker while ip < len(chain.instructions): 396*523fa7a6SAndroid Build Coastguard Worker instruction = chain.instructions[ip] 397*523fa7a6SAndroid Build Coastguard Worker if isinstance(instruction.instr_args, KernelCall): 398*523fa7a6SAndroid Build Coastguard Worker self.call_kernel(instruction.instr_args) 399*523fa7a6SAndroid Build Coastguard Worker elif isinstance(instruction.instr_args, JumpFalseCall): 400*523fa7a6SAndroid Build Coastguard Worker self.load_value(instruction.instr_args.cond_value_index) 401*523fa7a6SAndroid Build Coastguard Worker ip = ( 402*523fa7a6SAndroid Build Coastguard Worker ip + 1 403*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 404*523fa7a6SAndroid Build Coastguard Worker if self._value_list[instruction.instr_args.cond_val_index] 405*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 406*523fa7a6SAndroid Build Coastguard Worker else instruction.instr_args.destination_instruction 407*523fa7a6SAndroid Build Coastguard Worker ) 408*523fa7a6SAndroid Build Coastguard Worker continue 409*523fa7a6SAndroid Build Coastguard Worker elif isinstance(instruction.instr_args, MoveCall): 410*523fa7a6SAndroid Build Coastguard Worker move_to = instruction.instr_args.move_to 411*523fa7a6SAndroid Build Coastguard Worker move_from = instruction.instr_args.move_from 412*523fa7a6SAndroid Build Coastguard Worker self.load_value(move_from) 413*523fa7a6SAndroid Build Coastguard Worker self.load_value(move_to) 414*523fa7a6SAndroid Build Coastguard Worker self._value_list[move_to] = self._value_list[move_from] 415*523fa7a6SAndroid Build Coastguard Worker else: 416*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 417*523fa7a6SAndroid Build Coastguard Worker f"Received unknown instruction from program: {instruction}." 418*523fa7a6SAndroid Build Coastguard Worker ) 419*523fa7a6SAndroid Build Coastguard Worker ip += 1 420*523fa7a6SAndroid Build Coastguard Worker 421*523fa7a6SAndroid Build Coastguard Worker ret = [self._value_list[i] for i in self.execution_plan.outputs] 422*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `from_str`. 423*523fa7a6SAndroid Build Coastguard Worker treespec = ex_pytree.from_str(self.container_metatype.encoded_out_str) 424*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `tree_unflatten`. 425*523fa7a6SAndroid Build Coastguard Worker return ex_pytree.tree_unflatten(ret, treespec) 426