1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import copy 10from typing import List, Optional, Union 11 12# pyre-fixme[21]: Could not find module `executorch.exir.verification.bindings`. 13import executorch.exir.verification.bindings as bindings # @manual=//executorch/exir/verification:bindings 14import executorch.extension.pytree as ex_pytree 15 16import torch 17 18from executorch import exir 19 20from executorch.exir.schema import ( 21 Bool, 22 BoolList, 23 Double, 24 DoubleList, 25 ExecutionPlan, 26 Int, 27 IntList, 28 JumpFalseCall, 29 KernelCall, 30 KernelTypes, 31 MoveCall, 32 Null, 33 Operator, 34 OptionalTensorList, 35 Program, 36 String, 37 Tensor, 38 TensorList, 39) 40from executorch.exir.tensor import get_scalar_type, stride_from_dim_order 41from torch.library import impl, Library 42from torch.utils._pytree import PyTree 43 44 45class Uninitialized: 46 pass 47 48 49ValueListType = Union[ 50 List[torch.Tensor], 51 List[bool], 52 List[float], 53 List[int], 54] 55 56ValueScalarType = Union[int, str, float, torch.Tensor, Uninitialized, None] 57 58ValueType = Union[ 59 ValueScalarType, 60 ValueListType, 61] 62 63# defining the operator executorch.move 64executorch_lib = Library("executorch", "DEF") 65executorch_lib.define("move(Tensor self) -> Tensor") 66 67 68@impl(executorch_lib, "move", "CPU") 69def move_impl(self: torch.Tensor) -> torch.Tensor: 70 return torch.clone(self) 71 72 73def comp_types(val: KernelTypes, input_val: ValueType) -> Optional[bool]: # noqa 74 """ 75 Compares a schema type (val) with Python type (input_val) 76 Map: Int -> int, Bool -> bool, Double -> float, 77 String -> str, Tensor -> torch.Tensor, XList -> list[x] 78 79 Args: 80 `val`: value from value list with type from `schema.py` 81 82 `input_val`: value with Python type, normally from the result of executing an operation 83 """ 84 if isinstance(val, Int): 85 return isinstance(input_val, type(val.int_val)) 86 elif isinstance(val, Bool): 87 return isinstance(input_val, type(val.bool_val)) 88 elif isinstance(val, Double): 89 return isinstance(input_val, type(val.double_val)) 90 elif isinstance(val, String): 91 return isinstance(input_val, type(val.string_val)) 92 elif isinstance(val, Tensor): 93 return isinstance(input_val, torch.Tensor) 94 elif isinstance(val, IntList): 95 if not isinstance(input_val, list): 96 return False 97 return all(isinstance(x, int) for x in input_val) 98 elif isinstance(val, BoolList): 99 if not isinstance(input_val, list): 100 return False 101 return all(isinstance(x, bool) for x in input_val) 102 elif isinstance(val, DoubleList): 103 if not isinstance(input_val, list): 104 return False 105 return all(isinstance(x, float) for x in input_val) 106 elif isinstance(val, (TensorList, OptionalTensorList)): 107 if not isinstance(input_val, list): 108 return False 109 return all(isinstance(x, torch.Tensor) for x in input_val) 110 elif isinstance(val, Null): 111 raise TypeError("Setting a value where there should be a Null") 112 113 114def resolve_op(operator: Operator) -> torch._ops.OpOverload: 115 # pattern matching out the namespace and operation name 116 namespace, op_name = operator.name.split("::") 117 op = getattr(getattr(getattr(torch.ops, namespace), op_name), operator.overload) 118 return op 119 120 121def make_operators_list( 122 execution_plan: ExecutionPlan, 123) -> List[torch._ops.OpOverload]: 124 operator_list = [resolve_op(operator) for operator in execution_plan.operators] 125 return operator_list 126 127 128class Interpreter: 129 def __init__(self, program: Program) -> None: 130 # Currently there is only 1 execution plan in the list -- this assert will help 131 # catch any changes in the future 132 assert len(program.execution_plan) == 1 133 self.execution_plan: exir.schema.ExecutionPlan = program.execution_plan[0] 134 self.container_metatype: exir.schema.ContainerMetadata = program.execution_plan[ 135 0 136 ].container_meta_type 137 138 # create buffer in memory and get reference to it 139 # pyre-ignore 140 self.data_buffers: List[bindings.DataBuffer] = [ 141 # pyre-ignore 142 bindings.DataBuffer(b.storage, len(b.storage)) 143 for b in program.constant_buffer 144 ] 145 146 # generate the list of values (including tensors) and operators from the execution plan 147 self._value_list: List[ValueType] = [ 148 Uninitialized() for val in self.execution_plan.values 149 ] 150 self._operators_list: List[torch._ops.OpOverload] = make_operators_list( 151 self.execution_plan 152 ) 153 154 def get_value_list(self) -> List[ValueType]: 155 # TODO(meghajain) may need to change deepcopy to clone 156 return copy.deepcopy(self._value_list) 157 158 def get_operators_list(self) -> List[torch._ops.OpOverload]: 159 return self._operators_list 160 161 def get_constant_tensors(self) -> List[Tensor]: 162 """ 163 No side effects on Interpreter's value list. List of constant tensors returned 164 without having to run program. 165 """ 166 tensors = [] 167 for elem in self.execution_plan.values: 168 val = elem.val 169 if isinstance(val, Tensor) and val.data_buffer_idx != 0: 170 # load val into res 171 # pyre-fixme[16] 172 tensor = bindings.convert_to_tensor( 173 self.data_buffers[val.data_buffer_idx], 174 val.scalar_type, 175 val.sizes, 176 stride_from_dim_order(val.sizes, val.dim_order), 177 ) 178 tensors.append(tensor) 179 return tensors 180 181 def load_value(self, idx: int) -> None: 182 """ 183 Given an index in the value list, if value is `Uninitialized` or is a mutable object, 184 like a Tensor List, calls `load` to load and initialize value into Interpreter's value_list. 185 186 Args: 187 `idx` : index in value lists that we want to load 188 189 Returns: No returned values - value list is updated in place 190 """ 191 # if instance of any mutable object, load regardless of being initialized 192 if isinstance( 193 self.execution_plan.values[idx].val, (TensorList, OptionalTensorList) 194 ) or isinstance(self._value_list[idx], Uninitialized): 195 self.load_from_value_list(idx) 196 197 def load_from_value_list(self, idx: int) -> None: # noqa 198 """ 199 Accesses Execution Plan's value list at same index (see schema.py) to 200 load and initialize value into Interpreter's value list. Extracts the 201 necessary values depending on the type of the value. E.g. Tensor Lists 202 have indices into the value list, so they are recursively loaded. 203 If an Evalue is a Constant Tensor (denoted by allocation_info=None), it 204 converts the python obj to a torch tensor object. 205 206 Args: 207 `idx` : index in value lists that we want to load 208 209 Returns: No returned values - value list is updated in place 210 """ 211 assert idx >= 0 212 val = self.execution_plan.values[idx].val 213 214 # Case through all possible Evalue Types 215 if isinstance(val, Int): 216 self._value_list[idx] = val.int_val 217 elif isinstance(val, Bool): 218 self._value_list[idx] = val.bool_val 219 elif isinstance(val, Double): 220 self._value_list[idx] = val.double_val 221 elif isinstance(val, String): 222 self._value_list[idx] = val.string_val 223 elif isinstance(val, (IntList, BoolList, DoubleList)): 224 unboxed_list = [] 225 for item in val.items: 226 assert isinstance(item, int) 227 assert isinstance(self.execution_plan.values[item].val, Int) 228 # pyre-fixme [16] Undefined attribute [16]: Item `Bool` has no 229 # attribute `int_val`. 230 unboxed_list.append(self.execution_plan.values[item].val.int_val) 231 self._value_list[idx] = unboxed_list 232 elif isinstance(val, (TensorList, OptionalTensorList)): 233 tensor_list = [] 234 for i in val.items: 235 if i == -1: 236 tensor_list.append(None) 237 continue 238 self.load_value(i) 239 tensor_list.append(self._value_list[i]) 240 self._value_list[idx] = tensor_list 241 elif isinstance(val, Tensor): 242 if val.data_buffer_idx == 0: 243 # TODO(zhengxu) Verify that argument is actually an out variant 244 self._value_list[idx] = torch.empty( 245 val.sizes, dtype=get_scalar_type(val.scalar_type) 246 ) 247 else: 248 # Constant Tensor conversion 249 # pyre-fixme [16] 250 tensor = bindings.convert_to_tensor( 251 self.data_buffers[val.data_buffer_idx], 252 val.scalar_type, 253 val.sizes, 254 stride_from_dim_order(val.sizes, val.dim_order), 255 ) 256 self._value_list[idx] = tensor 257 elif isinstance(val, Null): 258 self._value_list[idx] = None 259 else: 260 raise TypeError( 261 f"Unexpected type, {type(val)}, with value, {val}, in Execution Plan values." 262 ) 263 264 assert not isinstance(self._value_list[idx], Uninitialized) 265 266 def set_value(self, idx: int, input_val: ValueType) -> None: 267 """ 268 Given an index in the value list, and a value, updates 269 Interpreter's value in value list at given index in place 270 If value is meant to be a TensorList at given index, 271 iterate through all the indices in the TensorList and 272 update each placeholder with the given Tensor from `input_val`. 273 274 Args: 275 `idx` : index in value lists that we want to set 276 277 `input_val` : value we want to put at `self._value_list[idx]` 278 279 Returns: No returned values - value list is updated in place 280 """ 281 evalue = self.execution_plan.values[idx] 282 val = evalue.val 283 284 if not comp_types(val, input_val): 285 raise TypeError( 286 f"Program trying to set a value of {input_val} : {type(input_val)} in memory location where {type(val)} is expected." 287 ) 288 289 # Case through all possible Evalue Types 290 if isinstance( 291 val, 292 (Int, Bool, Double, String, IntList, BoolList, DoubleList, Tensor, Null), 293 ): 294 self._value_list[idx] = input_val 295 elif isinstance(val, (TensorList, OptionalTensorList)): 296 assert isinstance(input_val, List) 297 assert len(val.items) == len(input_val) 298 tensor_list = [] 299 for i in range(len(val.items)): 300 val_idx = val.items[i] 301 self._value_list[val_idx] = input_val[i] 302 tensor_list.append(input_val[i]) 303 self._value_list[idx] = tensor_list 304 else: 305 raise TypeError( 306 f"Unexpected type, {type(val)}, with value, {val}, in Execution Plan values." 307 ) 308 309 def call_kernel(self, kernel: KernelCall) -> None: 310 """ 311 Calls operator from kernel: 312 1. Determines kernel's operation through kernel.op_index, 313 which indexes into operator list from Program's execution plan. 314 2. After identifying operation, determines number of arguments and 315 keyword arguments through operator schema. 316 3. Extracts arguments from value list and calls operator. 317 4. Sets the given output indices in value list with the values 318 returned from operation. 319 320 Args: 321 `kernel` : stores information about operator and which indices in 322 the value list contain the necessary arguments 323 324 Returns: No returned values - value list is updated with outputs 325 from operator in place 326 """ 327 328 operator = self._operators_list[kernel.op_index] 329 num_args = len( 330 [arg for arg in operator._schema.arguments if not arg.kwarg_only] 331 ) 332 kwarg_list = [kwarg for kwarg in operator._schema.arguments if kwarg.kwarg_only] 333 num_kwargs = len(kwarg_list) 334 335 # Extracting arguments and keyword arguments from value_list given indices kernel.args 336 args = [] 337 for i in kernel.args[:num_args]: 338 self.load_value(i) 339 args.append(self._value_list[i]) 340 341 kwargs = {} 342 for j in range(num_kwargs): 343 i = kernel.args[num_args + j] 344 keyword = kwarg_list[j].name 345 346 self.load_value(i) 347 kwargs[keyword] = self._value_list[i] 348 349 res = operator(*args, **kwargs) 350 output_idxs = kernel.args[num_args + num_kwargs :] 351 352 assert ( 353 len(output_idxs) == 1 354 ), "emitter is expected to pack multiple outputs into a TensorList" 355 if isinstance(res, tuple): 356 self.set_value(output_idxs[0], list(res)) 357 else: 358 self.set_value(output_idxs[0], res) 359 360 def run(self, *raw_args: torch.Tensor) -> PyTree: 361 """ 362 Loops through instructions given some inputs 363 364 Args: 365 `args` : list of inputs required for interpretation 366 367 Returns: 368 Outputs after completing all computations 369 """ 370 371 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 372 args, pytree = ex_pytree.tree_flatten((raw_args, {})) 373 374 if pytree.to_str() != self.container_metatype.encoded_inp_str: 375 raise TypeError( 376 f"Arguments provided do not match required type. \nRequired: {self.container_metatype.encoded_inp_str} \nProvided: {pytree.to_str()}" 377 ) 378 379 # Initialize user inputs in value list 380 if len(self.execution_plan.inputs) != len(args): 381 raise RuntimeError( 382 f"Incorrect number of arguments provided. Expected {len(self.execution_plan.inputs)} values, but received {len(args)}" 383 ) 384 for i in range(len(self.execution_plan.inputs)): 385 idx = self.execution_plan.inputs[i] 386 self._value_list[idx] = args[i] 387 388 assert len(self.execution_plan.chains) == 1 389 chain = self.execution_plan.chains[0] 390 391 # instruction pointer 392 ip = 0 393 394 # Kernel loop 395 while ip < len(chain.instructions): 396 instruction = chain.instructions[ip] 397 if isinstance(instruction.instr_args, KernelCall): 398 self.call_kernel(instruction.instr_args) 399 elif isinstance(instruction.instr_args, JumpFalseCall): 400 self.load_value(instruction.instr_args.cond_value_index) 401 ip = ( 402 ip + 1 403 # pyre-ignore 404 if self._value_list[instruction.instr_args.cond_val_index] 405 # pyre-ignore 406 else instruction.instr_args.destination_instruction 407 ) 408 continue 409 elif isinstance(instruction.instr_args, MoveCall): 410 move_to = instruction.instr_args.move_to 411 move_from = instruction.instr_args.move_from 412 self.load_value(move_from) 413 self.load_value(move_to) 414 self._value_list[move_to] = self._value_list[move_from] 415 else: 416 raise RuntimeError( 417 f"Received unknown instruction from program: {instruction}." 418 ) 419 ip += 1 420 421 ret = [self._value_list[i] for i in self.execution_plan.outputs] 422 # pyre-fixme[16]: Module `pytree` has no attribute `from_str`. 423 treespec = ex_pytree.from_str(self.container_metatype.encoded_out_str) 424 # pyre-fixme[16]: Module `pytree` has no attribute `tree_unflatten`. 425 return ex_pytree.tree_unflatten(ret, treespec) 426