1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7"""Takes an ExportedArtifact, or a collection of ExportedArtifacts, in execution dialect, and turns 8them into a single ExecuTorch Program. 9 10The provided ExportedArtifact's graph modules are in execution dialect and the emitter parses and 11converts them into executorch instructions. The emitter walks the provided graphs and as it 12encounters concrete values such as tensors or ints, it converts them to the serialized format and 13stores them in a list for later use. The emitter walks the graph by traversing fx.nodes, these can 14come in a variety of forms and are the primitives of execution at the graph module level. The most 15common 3 we care about are 'call_function', 'place_holder', and 'output'. 'placeholder' and 'output' 16handle io for the module and 'call_function' handles everything else. Within 'call_function' we may 17encounter an operator or delegate call, in such case we parse the schema and emit all the inputs and 18outputs (unless they have already previously been emitted), and then we convert the actual function 19call into an executorch instruction such as KernelCall or DelegateCall. 20 21When control flow is present in the graphmodule it will take the form of a few different types of 22'call_function'. Today (June 14th 2023) only cond and map are supported. The actual operations of 23these, such as the true/false branches of cond, or the mapping function of map, are stored as sub 24graphmodules. When these are encountered during emission, the emitter will recursively emit them and 25their instructions. 26""" 27# TODO(jakeszwe): add information here about how weights and other parameters are handled in the 28# presence of aot autograd param lifting. 29 30# pyre-strict 31import ctypes 32import hashlib 33import operator 34import typing 35import warnings 36from dataclasses import dataclass, field 37from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Tuple, Union 38 39import executorch.exir.memory as memory 40import executorch.extension.pytree as ex_pytree 41import torch 42import torch.fx 43from executorch.exir.delegate import executorch_call_delegate, is_lowered_module 44from executorch.exir.dialects.backend._ops import BackendOpOverload 45from executorch.exir.dialects.edge._ops import EdgeOpOverload 46from executorch.exir.error import ExportError, ExportErrorType, InternalError 47from executorch.exir.operator.convert import is_out_variant 48from executorch.exir.passes.executorch_prim_ops_registry import is_sym_op 49from executorch.exir.print_program import _stacktrace_to_framelist, inspect_node 50from executorch.exir.schema import ( 51 BackendDelegate, 52 BackendDelegateDataReference, 53 BackendDelegateInlineData, 54 Bool, 55 BoolList, 56 Buffer, 57 Chain, 58 ContainerMetadata, 59 DataLocation, 60 DelegateCall, 61 Double, 62 DoubleList, 63 EValue, 64 ExecutionPlan, 65 FreeCall, 66 Instruction, 67 Int, 68 IntList, 69 JumpFalseCall, 70 KernelCall, 71 MoveCall, 72 Null, 73 Operator, 74 OptionalTensorList, 75 ScalarType, 76 String, 77 Tensor, 78 TensorList, 79 TensorShapeDynamism, 80) 81from executorch.exir.tensor import ( 82 AddressSpaceOverflowException, 83 layout_enum, 84 make_allocation_info, 85 make_tensor_value, 86 memory_format_enum, 87 scalar_type_enum, 88 TensorSpec, 89) 90from executorch.exir.types import LeafValueSpec, ValueSpec 91from torch._subclasses.fake_tensor import FakeTensor 92 93from torch.export.exported_program import ExportedProgram 94from torch.utils import _pytree as pytree 95 96from typing_extensions import TypeAlias 97 98 99@dataclass 100class _ProgramState: 101 """State shared between all methods of a program and the graph module it represents. 102 103 Initialized once within emit_program and then shared across each entry point as they are 104 emitted. 105 """ 106 107 # Parallel list of specs and the buffers that backed them, have to add + 1 to any index in here 108 # as index 0 in the constant_buffer is reserved. 109 allocated_specs: List[TensorSpec] = field(default_factory=list) 110 # Weights in any arbitrary graph_module only need to compare against weights from previously 111 # emitted graph modules, not any weights emitted from itself. This should speed up the lookup, 112 # from O(N) to O(1) 113 cached_spec_hash_values: Dict[str, int] = field(default_factory=dict) 114 cached_spec_mutable_hash_values: Dict[str, int] = field(default_factory=dict) 115 # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder. 116 constant_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")]) 117 # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder. 118 mutable_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")]) 119 # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference, 120 # and should be copied to Program.backend_delegate_data. 121 backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list) 122 123 124@dataclass 125class _EmitterState: 126 """State of a single emitter. 127 128 Local to at least the entry point, and may be local to a subgraph of an entry point originating 129 from control flow. 130 """ 131 132 values: List[EValue] 133 operators: List[Operator] 134 delegates: List[BackendDelegate] 135 operator_cache: Dict[Tuple[str, str], int] 136 delegate_cache: Dict[bytes, int] 137 emit_stacktrace: bool 138 139 spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict) 140 141 def spec2id(self, spec: TensorSpec) -> int: 142 """Map a TensorSpec to value index in the values array.""" 143 assert spec in self.spec2id_dict, f"Spec is not found: {spec.debug()}" 144 return self.spec2id_dict[spec] 145 146 147@dataclass 148class _AbstractValue: 149 """Represents an already emitted EValue""" 150 151 # Index in the values table of this EValue. 152 id: int 153 154 # Used for sanity checks for functions that expect to only receive AbstractValues. 155 tensor: Optional[Tensor] 156 157 158_EmitterValue: TypeAlias = Union[ 159 _AbstractValue, List[_AbstractValue], Tuple[_AbstractValue, ...] 160] 161 162_PythonValue: TypeAlias = Union[bool, int, float, torch.Tensor, List["_PythonValue"]] 163_SchemaType: TypeAlias = Union[ 164 torch.OptionalType, 165 torch.ListType, 166 torch.FloatType, 167 torch.BoolType, 168 torch.IntType, 169 torch.StringType, 170 torch.TensorType, 171] 172 173_Target: TypeAlias = Union[Callable[..., _PythonValue], str] 174 175_Argument: TypeAlias = Union[ 176 _EmitterValue, 177 Tuple["_Argument", ...], 178 List["_Argument"], 179 Dict[str, "_Argument"], 180 str, 181 int, 182 float, 183 bool, 184 complex, 185 torch.dtype, 186 torch.Tensor, 187 torch.memory_format, 188 torch.layout, 189 None, 190] 191 192_DelegateDebugIdentifierMap: TypeAlias = Union[ 193 Dict[int, Tuple[int]], Dict[str, Tuple[int]] 194] 195 196 197# pyre-ignore[13]: Attribute `node` is never initialized. 198class _Emitter(torch.fx.Interpreter): 199 """An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the 200 given traced torch.fx.GraphModule to the flatbuffer schema.""" 201 202 node: torch.fx.Node 203 204 def __init__( 205 self, 206 graph_module: torch.fx.GraphModule, 207 emitter_state: _EmitterState, 208 program_state: _ProgramState, 209 instruction_start_offset: int = 0, 210 binding_input_values: Optional[List[_AbstractValue]] = None, 211 binding_output_values: Optional[List[_AbstractValue]] = None, 212 ) -> None: 213 super().__init__(graph_module) 214 self.emitter_state = emitter_state 215 self.program_state = program_state 216 self.outputs: List[int] = [] 217 218 self.chain = Chain( 219 inputs=[], 220 outputs=[], 221 instructions=[], 222 stacktrace=None, 223 ) 224 225 if "non_const_buffer_sizes" not in graph_module.meta.keys(): 226 raise RuntimeError( 227 "Must set 'non_const_buffer_sizes' in graph meta in memory planning pass" 228 ) 229 self.instruction_start_offset = instruction_start_offset 230 self.binding_input_values = binding_input_values 231 self.binding_output_values = binding_output_values 232 self.graph_module: torch.fx.GraphModule = graph_module 233 self.nodes: List[torch.fx.Node] = list(self.graph_module.graph.nodes) 234 235 # Marks the placeholder node with its order so that we can match them with the corresponding 236 # Abstract Value coming from top level. 237 self.placeholder_count = 0 238 239 self.concrete_output_ids: List[_AbstractValue] = [] 240 self.debug_handle_map: Dict[int, Union[int, List[int]]] = {} 241 self.instr_id_to_delegate_debug_id_map: Dict[ 242 int, Dict[str, Union[str, _DelegateDebugIdentifierMap]] 243 ] = {} 244 245 def _emit_node_specific_error(self, node: torch.fx.Node, err_msg: str) -> str: 246 """Returns 'err_msg' with node specific information attached.""" 247 err_msg = f"Failed with error: {str(err_msg)}\n" + inspect_node( 248 self.graph_module.graph, node 249 ) 250 return err_msg 251 252 def _internal_assert_emitter( 253 self, pred: bool, node: torch.fx.Node, assert_msg: str 254 ) -> None: 255 """If pred is False, construct and raise a node specific error message.""" 256 if not pred: 257 raise InternalError(self._emit_node_specific_error(node, assert_msg)) 258 259 def _emit_int_list(self, val: List[_Argument]) -> EValue: 260 """Emits a list of integers as a collection of EValues. 261 262 For every argument in 'val': 263 - If it is a concrete value, then emit it and then place its location in the boxed list 264 - If it is already an abstract value, then just place its location in the boxed list 265 266 Int lists are boxed to handle symints whose values are determined at runtime, but could 267 still end up inside lists for ops like view_copy(Tensor self, SymInt[] size) 268 """ 269 boxed_list = [] 270 for item in val: 271 if isinstance(item, _AbstractValue): 272 boxed_list.append(item.id) 273 elif isinstance(item, int): 274 boxed_list.append( 275 self._emit_evalue(self._constant_to_evalue(item, None)).id 276 ) 277 else: 278 self._internal_assert_emitter( 279 False, self.node, "Unsupported type encountered in int list." 280 ) 281 282 return EValue(IntList(boxed_list)) 283 284 def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue: 285 """Emits a list type. 286 287 Emits the list stored in val. If the list is of Tensors, Optionals, or Ints the emitted list 288 is boxed, otherwise the values are constant at runtime and stored inline. 289 290 NOTE: When symbool and symfloat are supported bool and float lists will be stored boxed. 291 """ 292 293 if isinstance(val_type, torch.BoolType): 294 return EValue(BoolList(typing.cast(List[bool], val))) 295 296 if isinstance(val_type, torch.IntType): 297 return self._emit_int_list(val) 298 299 if isinstance(val_type, torch.FloatType): 300 return EValue(DoubleList(typing.cast(List[float], val))) 301 302 if isinstance(val_type, torch.TensorType): 303 values = [] 304 for v in val: 305 assert isinstance(v, _AbstractValue) 306 self._internal_assert_emitter( 307 v.tensor is not None, 308 self.node, 309 "AbstractValue corresponding to tensor type doesn't contain tensor value", 310 ) 311 values.append(v.id) 312 return EValue(TensorList(values)) 313 314 if isinstance(val_type, torch.OptionalType): 315 # refine further 316 actual_type = val_type.getElementType() 317 if isinstance(actual_type, torch.TensorType): 318 vals = [] 319 for v in val: 320 if v is None: 321 vals.append(-1) 322 else: 323 assert isinstance(v, _AbstractValue) 324 vals.append(v.id) 325 return EValue(OptionalTensorList(vals)) 326 327 raise ExportError( 328 ExportErrorType.NOT_SUPPORTED, f"Unknown list type: {val_type}" 329 ) 330 331 def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: 332 """Constructs an EValue from the given TensorSpec.""" 333 334 allocation_info = None 335 buffer_idx = 0 336 337 # Need to memory plan 338 # Some users set mem_id on all tensors and then rely on the 339 # default algos to set offsets, so need to check both. 340 if spec.mem_id is not None and spec.mem_offset is not None: 341 # Tensor is an activation. 342 self._internal_assert_emitter( 343 isinstance(spec.mem_id, int) and spec.mem_id >= 0, 344 self.node, 345 f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}", 346 ) 347 348 self._internal_assert_emitter( 349 isinstance(spec.mem_offset, int) and spec.mem_offset >= 0, 350 self.node, 351 f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}", 352 ) 353 try: 354 allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset) 355 except AddressSpaceOverflowException as e: 356 raise InternalError( 357 self._emit_node_specific_error( 358 self.node, 359 ( 360 f"{e}\nHint: If you are using a memory pass based on dynamic shape bounds, " 361 f"such as ConstraintBasedSymShapeEvalPass, this may be the cause of an " 362 f"unbacked SymInt with its upper bound lazily set to 2^64-1 (uint64 max) " 363 "during torch.export()." 364 ), 365 ) 366 ) 367 368 if spec.const: 369 # Tensor with a blob we need to serialize. May not actually be constant at runtime 370 # if it's a weight with an associated gradient 371 spec_array_type = ( 372 ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes() 373 ) 374 375 buffer_data = ( 376 bytes( 377 ctypes.cast( 378 typing.cast(torch.UntypedStorage, spec.storage).data_ptr(), 379 ctypes.POINTER(spec_array_type), 380 ).contents 381 ) 382 if spec.allocated_memory != 0 383 else b"" 384 ) 385 386 hashed = hashlib.sha256(buffer_data).hexdigest() 387 388 if allocation_info: 389 buffer_idx = self.program_state.cached_spec_mutable_hash_values.get( 390 hashed, -1 391 ) 392 else: 393 buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) 394 395 # Haven't seen this constant before 396 if buffer_idx == -1: 397 # Update buffer_idx to point to the end of the list where we are adding the new buffer. 398 buffer = Buffer(storage=buffer_data) 399 self.program_state.allocated_specs.append(spec) 400 # +1 because the first buffer location is reserved 401 402 if allocation_info: 403 buffer_idx = len(self.program_state.mutable_buffer) 404 self.program_state.cached_spec_mutable_hash_values[hashed] = ( 405 buffer_idx 406 ) 407 self.program_state.mutable_buffer.append(buffer) 408 else: 409 buffer_idx = len(self.program_state.constant_buffer) 410 self.program_state.cached_spec_hash_values[hashed] = buffer_idx 411 self.program_state.constant_buffer.append(buffer) 412 413 if spec.const and spec.nbytes() != len(buffer_data): 414 raise InternalError( 415 self._emit_node_specific_error( 416 self.node, 417 f"Tensor spec has buffer of size {len(buffer_data)}, but expected nbytes of {spec.nbytes()}", 418 ) 419 ) 420 421 # For constant tensors, allocation_info = None. 422 return EValue(make_tensor_value(buffer_idx, allocation_info, spec)) 423 424 def _get_list_tuple_jit_type( 425 self, val: Union[Tuple[_Argument], List[_Argument]] 426 ) -> _SchemaType: 427 """Returns the JIT type for the given python type.""" 428 assert isinstance( 429 val, (list, tuple) 430 ), f"Input to _get_list_tuple_jit_type was expected to be an instance of a list or tuple but received {type(val)}" 431 is_tensor_type = all( 432 isinstance(v, _AbstractValue) and v.tensor is not None for v in val 433 ) 434 if is_tensor_type: 435 return torch.TensorType.get() 436 elif isinstance(val[0], int): 437 return torch.IntType.get() 438 elif isinstance(val[0], bool): 439 return torch.BoolType.get() 440 elif isinstance(val[0], float): 441 return torch.FloatType.get() 442 443 raise InternalError( 444 self._emit_node_specific_error( 445 self.node, 446 "Couldn't determine JitType for list/tuple of elements. Only supports int, float, bool, and Tensor.", 447 ) 448 ) 449 450 def _constant_to_evalue( # noqa: C901 451 self, 452 val: _Argument, 453 val_type: Optional[_SchemaType], 454 ) -> EValue: 455 """Converts a constant value to an EValue. 456 457 Returns an EValue given the Python representation and JIT type. On common paths there should 458 always be a JIT type provided. Users can pass in a None to infer the JIT type but this 459 should never be the default case due to the existence of container types. 460 """ 461 if val is None: 462 return EValue(Null()) 463 464 if isinstance(val, (list, tuple)): 465 # Refine Optional[List[T]] -> List[T] This works because if the val was None it would 466 # have converted to Null before this function call. 467 if val_type is None: 468 val_type = torch.ListType( 469 self._get_list_tuple_jit_type(val) # pyre-ignore 470 ) 471 if isinstance(val_type, torch.OptionalType): 472 val_type = val_type.getElementType() 473 assert isinstance(val_type, torch.ListType) 474 return self._emit_list( 475 typing.cast(List[_Argument], val), 476 typing.cast(_SchemaType, val_type.getElementType()), 477 ) 478 479 if isinstance(val, float): 480 return EValue(Double(val)) 481 482 if isinstance(val, bool): 483 return EValue(Bool(val)) 484 485 if isinstance(val, int): 486 return EValue(Int(val)) 487 488 if isinstance(val, str): 489 return EValue(String(val)) 490 491 if isinstance(val, torch.dtype): 492 return EValue(Int(scalar_type_enum(val))) 493 494 if isinstance(val, torch.layout): 495 return EValue(Int(layout_enum(val))) 496 497 if isinstance(val, torch.memory_format): 498 try: 499 return EValue(Int(memory_format_enum(val))) 500 except KeyError: 501 raise InternalError( 502 self._emit_node_specific_error( 503 self.node, 504 f"Tensor has a memory_format that is unsupported in ExecuTorch: {val}", 505 ) 506 ) 507 508 if isinstance(val, torch.Tensor): 509 raise ExportError( 510 ExportErrorType.NOT_SUPPORTED, 511 self._emit_node_specific_error( 512 self.node, 513 "constant_to_evalue should not be encountering constant tensors, they should be emitted through other codepaths.", 514 ), 515 ) 516 517 raise ExportError( 518 ExportErrorType.NOT_SUPPORTED, 519 self._emit_node_specific_error( 520 self.node, f"Unsupported constant type: {type(val).__name__}" 521 ), 522 ) 523 524 def _emit_evalue(self, val: EValue) -> _AbstractValue: 525 """Writes an EValue to the emitter state. 526 527 Given an Evalue, adds it to the emitter_state's values table, and returns the AbstractValue 528 representing it. 529 """ 530 self.emitter_state.values.append(val) 531 tensor = val.val if isinstance(val.val, Tensor) else None 532 return _AbstractValue(len(self.emitter_state.values) - 1, tensor) 533 534 def _emit_spec(self, spec: ValueSpec) -> _EmitterValue: 535 """Given the provided spec constructs the corresponding EValue from it and then emits it.""" 536 537 def _process(spec: LeafValueSpec) -> _AbstractValue: 538 if isinstance(spec, (list, tuple)): 539 raise InternalError( 540 self.emit_node_specific_error( 541 self.node, 542 "Node spec should be either non-nested container or a scalar object", 543 ) 544 ) 545 546 # ScalarSpec can theoretically be handled fine, but it shouldn't be appearing so rather 547 # than handle it, assert that it isn't supposed to be present. In the future if it has a 548 # reason to appear we can relax this assert. 549 self._internal_assert_emitter( 550 isinstance(spec, TensorSpec), 551 self.node, 552 f"Invalid node spec expected TensorSpec received {spec}", 553 ) 554 555 ret = self._emit_evalue(self._tensor_spec_to_evalue(spec)) # pyre-ignore 556 self.emitter_state.spec2id_dict[spec] = ret.id # pyre-ignore 557 return ret 558 559 return pytree.tree_map(_process, spec) 560 561 def _merge_chain(self, chain: Chain) -> None: 562 """Merges the chain generated from subgraphs (like those originating from control flow) back 563 into the main program chain.""" 564 self.chain.instructions.extend(chain.instructions) 565 566 def _emit_cond( 567 self, 568 args: Tuple[_Argument, ...], 569 subemitter_binding_output_values: Optional[List[_AbstractValue]], 570 ) -> List[_AbstractValue]: 571 """Emits control_flow.cond. 572 573 Converts the higher order op into jumps and inlines the submodules of the true and false 574 branches. Control flow can be nested. The general emitted structure is: <Jump Instruction> - 575 decides which branch to take <True Branch> <Jump Instruction> - jumps to End Of Cond after 576 executing true branch <False Branch> <End Of Cond> 577 """ 578 pred, true_branch, false_branch, inputs = args 579 580 # Emit the true branch. 581 assert isinstance(true_branch, torch.fx.GraphModule) 582 true_branch_emitter = _Emitter( 583 true_branch, 584 self.emitter_state, 585 self.program_state, 586 instruction_start_offset=self.instruction_start_offset 587 + len(self.chain.instructions) 588 + 1, 589 binding_input_values=typing.cast(List[_AbstractValue], inputs), 590 binding_output_values=subemitter_binding_output_values, 591 ) 592 true_branch_emitter.run() 593 594 # Emit the jump. 595 assert isinstance(pred, _AbstractValue) 596 jf_instruction_to_skip_true = Instruction( 597 JumpFalseCall( 598 cond_value_index=pred.id, 599 destination_instruction=self.instruction_start_offset 600 + len(self.chain.instructions) 601 + len(true_branch_emitter.chain.instructions) 602 # This jump instruction should point at instruction that is after all instructions 603 # for the true branch. The reason we add 2 is because we need to account for this 604 # instruction we are creating right now and the jump instruction that true branch 605 # will create. 606 + 2, 607 ) 608 ) 609 610 # Insert the branch picking jump instruction to the main chain. 611 self.chain.instructions.append(jf_instruction_to_skip_true) 612 # Now that we created the true branch instructions, we move them to the main chain. 613 self._merge_chain(true_branch_emitter.chain) 614 615 # emit false branch 616 assert isinstance(false_branch, torch.fx.GraphModule) 617 false_branch_emitter = _Emitter( 618 false_branch, 619 self.emitter_state, 620 self.program_state, 621 instruction_start_offset=self.instruction_start_offset 622 + len(self.chain.instructions) 623 + 1, 624 binding_input_values=typing.cast(List[_AbstractValue], inputs), 625 binding_output_values=subemitter_binding_output_values, 626 ) 627 false_branch_emitter.run() 628 629 # We bake in constant False because this will trigger the instruction to jump over all false 630 # branch instructions and point at the start of instruction right after control flow. 631 value = self._emit_evalue(EValue(Bool(False))) 632 jf_instruction_to_skip_false = Instruction( 633 JumpFalseCall( 634 cond_value_index=value.id, 635 destination_instruction=self.instruction_start_offset 636 + len(self.chain.instructions) 637 + len(false_branch_emitter.chain.instructions) 638 + 1, 639 ) 640 ) 641 self.chain.instructions.append(jf_instruction_to_skip_false) 642 self._merge_chain(false_branch_emitter.chain) 643 return subemitter_binding_output_values 644 645 def _emit_map( 646 self, 647 args: Tuple[_Argument, ...], 648 subemitter_binding_output_values: List[_AbstractValue], 649 ) -> List[_AbstractValue]: 650 """Emits torch.map. 651 652 Converts the higher order op into a loop constructed from jump instructions and primitive 653 int operations. A concat-like custom op is also injected into the submodule code to handle 654 the construction of the map output. 655 656 Below is what the input graph that is provided to emit_map looks like. class 657 TestMapCond(torch.nn.Module): def __init__(self): 658 super().__init__() 659 660 def forward(self, x,y): 661 return control_flow.map(map_fn, x, y) 662 663 Corresponding graph: def forward(self, arg0_1, arg1_1): 664 submodule_0 = self.submodule_0 map_1 = torch.ops.higher_order.map_impl(submodule_0, arg0_1, arg1_1); 665 submodule_0 = arg0_1 = arg1_1 = None return [map_1] 666 667 submodule_0: def forward(self, arg0_1, arg1_1): 668 add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return 669 add_tensor 670 671 Post the transformations done by emit_map this is what the submodule program looks like. def 672 forward(self, arg0_1, arg1_1): 673 sym_size = torch.ops.aten.sym_size(arg0_1) # Emitter creates a variable here to track 674 iteration index select_copy_tensor = torch.ops.aten.select(arg0_1, 0, iteration_index) 675 add_tensor = torch.ops.aten.add.Tensor(select_copy_tensor, arg1_1); arg0_1 = arg1_1 = 676 None output_of_map = torch.ops.executorch.prim.et_copy_index(output_of_map, add_tensor, 677 iteration_index) iteration_index = torch.ops.executorch.prim.add.int(iteration_index, 1, 678 iteration_index) done_bool = torch.ops.executorch.prim.eq.int(iteration_index, sym_size, 679 done_bool) # Emitter inserts a instruction here, if done_bool == False jump to 680 selcect_copy op # if not continue. return add_tensor 681 """ 682 assert isinstance( 683 subemitter_binding_output_values, (list, tuple) 684 ), f"Expect a list for subemitter_binding_output_values for map. Got {subemitter_binding_output_values}." 685 686 if len(subemitter_binding_output_values) != 1: 687 raise RuntimeError( 688 f"Multiple outputs are not supported. Got {len(subemitter_binding_output_values)}." 689 ) 690 f, mapped_args, inputs = args 691 assert isinstance(mapped_args, (list, tuple)) 692 num_mapped_args: int = len(mapped_args) 693 if num_mapped_args != 1: 694 raise RuntimeError( 695 f"Emitting map with more than one mapped args is not supported. Got {num_mapped_args}." 696 ) 697 x = mapped_args[0] 698 699 assert isinstance(f, torch.fx.GraphModule) 700 701 # Generate the EValue that we will use as our iterator index to keep track of which 702 # iteration we are currently on. 703 iter_idx = self._emit_evalue(EValue(Int(0))) 704 # Generate the kernel call that will output the number of iterations we need to run for this 705 # input tensor. 706 op_index, op = self._get_operator( 707 name="aten::sym_size", 708 overload="int", 709 ) 710 sym_size = self._emit_evalue(EValue(Int(0))) 711 kernel = Instruction( 712 KernelCall( 713 op_index=op_index, 714 args=[x.id, self._emit_evalue(EValue(Int(0))).id, sym_size.id], 715 ) 716 ) 717 self.chain.instructions.append(kernel) 718 719 # This kernel call will slice the input tensor along the index specified in iter_idx to 720 # generate the input slice on which this iteration will be working on. 721 op_index, op = self._get_operator( 722 name="aten::select_copy", 723 overload="int_out", 724 ) 725 # This select copy has to output to the tensor which is the input placeholder to the map 726 # sub-graph. That placeholder isn't allocated an EValue id until the map emitter is run, so 727 # we temporarily store -1 until the map emitter is run during which the placeholder will be 728 # allocated an EValue id. After the map emitter is run we will retrieve that id and replace 729 # the -1's. 730 kernel = Instruction( 731 KernelCall( 732 op_index=op_index, 733 args=[ 734 x.id, 735 self._emit_evalue(EValue(Int(0))).id, 736 iter_idx.id, 737 -1, # input_tensor_value.id, 738 -1, # input_tensor_value.id, 739 ], 740 ) 741 ) 742 # Store the index of this instruction as it will be where we will jump back to after the end 743 # of an iteration. 744 jump_to_instruction = self.instruction_start_offset + len( 745 self.chain.instructions 746 ) 747 self.chain.instructions.append(kernel) 748 749 # Emit the map operator submodule. 750 map_emitter = _Emitter( 751 f, 752 self.emitter_state, 753 self.program_state, 754 instruction_start_offset=self.instruction_start_offset 755 + len(self.chain.instructions), 756 # Only the first input is a placeholder, rest of the inputs are args to the map fn. 757 binding_input_values=[-1, *inputs], 758 binding_output_values=subemitter_binding_output_values, 759 ) 760 map_emitter.run() 761 762 # Merge all the instructions from the map submodule. 763 self._merge_chain(map_emitter.chain) 764 # Get rid of the return instruction emitted by the map subemitter. 765 self.chain.instructions.pop() 766 # At the end of each submodule emit we insert a move call that moves the output of the 767 # submodule to a deterministic EValue, which is especially useful for if/else branches where 768 # we want the output of either branch to be in the same EValue, but we don't need a move 769 # here as our custom op executorch_prim::et_copy_index which is inserted later does that 770 # for us. 771 772 # Now that the map emitter has finished running retrieve the input placeholder EValue id and 773 # update the select_copy kernel call to output to those id's. 774 kernel.instr_args.args[-1] = map_emitter.binding_input_values[0].id 775 kernel.instr_args.args[-2] = kernel.instr_args.args[-1] 776 777 self._internal_assert_emitter( 778 len(map_emitter.concrete_output_ids) == 1, 779 self.node, 780 "Map should return only one element", 781 ) 782 783 # Here we call the custom op, specially added for the map operator. The output of this 784 # iteration will be appended to the accumulator tensor that we are maintaining. This 785 # accumulator tensor is the actual output of the map submodule. 786 op_index, op = self._get_operator( 787 name="executorch_prim::et_copy_index", 788 overload="tensor", 789 ) 790 kernel = Instruction( 791 KernelCall( 792 op_index, 793 args=[ 794 subemitter_binding_output_values[0].id, 795 map_emitter.concrete_output_ids[0].id, 796 iter_idx.id, 797 ], 798 ) 799 ) 800 self.chain.instructions.append(kernel) 801 802 # Increment iter_idx to mark that we have completed an iteration. 803 op_index, op = self._get_operator( 804 name="executorch_prim::add", 805 overload="Scalar", 806 ) 807 kernel = Instruction( 808 KernelCall( 809 op_index, 810 args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id], 811 ) 812 ) 813 self.chain.instructions.append(kernel) 814 815 jump_bool_value = self._emit_evalue(EValue(Bool(False))) 816 817 # Generate the kernel call to check whether or not we have completed all the iterations. If 818 # not jump back to the select_copy instruction that we generated at the beginning of this 819 # section. 820 op_index, op = self._get_operator( 821 name="executorch_prim::eq", 822 overload="Scalar", 823 ) 824 kernel = Instruction( 825 KernelCall( 826 op_index, 827 args=[iter_idx.id, sym_size.id, jump_bool_value.id], 828 ) 829 ) 830 self.chain.instructions.append(kernel) 831 832 jf_beginning_loop = Instruction( 833 JumpFalseCall( 834 cond_value_index=jump_bool_value.id, 835 destination_instruction=jump_to_instruction, 836 ) 837 ) 838 839 self.chain.instructions.append(jf_beginning_loop) 840 841 # Reset iter_idx in case we plan to run the model again. 842 op_index, op = self._get_operator( 843 name="executorch_prim::sub", 844 overload="Scalar", 845 ) 846 kernel = Instruction( 847 KernelCall( 848 op_index, 849 args=[iter_idx.id, sym_size.id, iter_idx.id], 850 ) 851 ) 852 self.chain.instructions.append(kernel) 853 854 return subemitter_binding_output_values 855 856 def _emit_control_flow( 857 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 858 ) -> _EmitterValue: 859 """Wraps common logic for emitting all control flow operations. 860 861 See the more specific emission functions for more details on how cond or map get emitted. 862 """ 863 subemitter_binding_output_values = pytree.tree_map( 864 lambda spec: self._emit_spec(spec), 865 self.node.meta["spec"], 866 ) 867 868 if target is torch.ops.higher_order.cond: 869 return self._emit_cond(args, subemitter_binding_output_values) 870 elif target is torch.ops.higher_order.map_impl: 871 return self._emit_map(args, subemitter_binding_output_values) 872 else: 873 raise InternalError( 874 self._emit_node_specific_error( 875 self.node, f"Unsupported control flow operator: {target}" 876 ) 877 ) 878 879 def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue: 880 assert len(args) == 2 881 882 self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6] 883 size_arg = self._emit_argument(args[1], torch.ListType.ofInts()) 884 out_arg = self._emit_argument( 885 self._emit_spec(self.node.meta["spec"]), torch.TensorType # pyre-ignore[6] 886 ) 887 888 op_idx, op = self._get_operator( 889 name="executorch_prim::et_view", 890 overload="default", 891 ) 892 kernel = Instruction( 893 KernelCall( 894 op_idx, 895 args=[ 896 self_arg.id, 897 size_arg.id, 898 out_arg.id, 899 ], 900 ) 901 ) 902 self.chain.instructions.append(kernel) 903 return out_arg 904 905 def _add_debug_handle( 906 self, 907 emitter_id: int, 908 target: _Target, 909 # pyre-ignore[11]: Annotation `LoweredBackendModule` is not defined as a type. 910 lowered_module: "Optional[LoweredBackendModule]" = None, # noqa: F821 911 ) -> None: 912 """Updates the debug handle information for the current node. 913 914 If the current node is a delegate we agregate the debug handles of the subgraph and store 915 them in the map. If the current node is any other type we store the original information in 916 the debug handle map and replace it with the executorch instruction index corresponding to 917 this node. 918 """ 919 # If it's a delegate call, collect the list of debug handles that are consumed by this 920 # delegate call and store it in the debug handle map. 921 if target == executorch_call_delegate: 922 debug_handle_list = [] 923 # Use the lowered_module to fetch the original graph and its debug 924 # handles. 925 for node in lowered_module.original_module.graph.nodes: 926 if ( 927 node.op == "call_function" 928 and node.meta.get("debug_handle") is not None 929 ): 930 debug_handle_list.append(node.meta.get("debug_handle")) 931 self.debug_handle_map[emitter_id] = debug_handle_list 932 # Debug handle for this node is the emitter_id which is essentially the index of the 933 # instruction in the chain. 934 self.node.meta["debug_handle"] = emitter_id 935 return 936 937 if self.node.meta.get("debug_handle") is not None: 938 # Store the original debug handle in the debug handle map. 939 self.debug_handle_map[emitter_id] = self.node.meta.get("debug_handle") 940 # Replace the debug handle in the metadata of the node with the emitter id which 941 # represents the instruction index in the chain. We do this because in the runtime the 942 # instruction index is what is logged during perf/debug data logging and hence we want 943 # to store this in the node so that we can map the data logged by the runtime back to 944 # the node. 945 self.node.meta["debug_handle"] = emitter_id 946 947 def _add_delegate_map( 948 self, 949 lowered_module: "LoweredBackendModule", # noqa 950 delegate_instruction_id: int, 951 ) -> None: 952 """ 953 Store the delegate map from this lowered module into the dictionary of delegate maps. It 954 will later be used for various debugging purposes such as linking back to original source 955 code, module hierarchy etc. 956 """ 957 delegate_map = {} 958 if hasattr(lowered_module, "meta"): 959 delegate_map = lowered_module.meta.get("debug_handle_map", {}) 960 961 self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = { 962 "name": lowered_module.backend_id, 963 "delegate_map": delegate_map, 964 } 965 966 def _emit_argument( 967 self, arg: _Argument, arg_type: Optional[_SchemaType] 968 ) -> _AbstractValue: 969 """Emit an argument to an operator or delegate if it had not already been emitted otherwise 970 return the previously emitted location""" 971 if isinstance(arg, _AbstractValue): 972 return arg 973 return self._emit_evalue(self._constant_to_evalue(arg, arg_type)) 974 975 def _get_sym_ret( 976 self, 977 val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]], 978 ) -> Optional[_AbstractValue]: 979 """ 980 Returns the emit ret for sym value. 981 """ 982 ret = None 983 if isinstance(val, torch.SymInt): 984 ret = self._emit_evalue(EValue(Int(0))) 985 elif isinstance(val, torch.BoolType): 986 ret = self._emit_evalue(EValue(Bool(False))) 987 elif isinstance(val, torch.FloatType): 988 ret = self._emit_evalue(EValue(Double(0))) 989 return ret 990 991 def _get_sym_and_fake_tensor_ret( 992 self, 993 val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]], 994 spec: TensorSpec, 995 ) -> Union[List[_AbstractValue], _AbstractValue, Tuple[_AbstractValue, ...]]: 996 # Try to get the ret if it's a sym value. 997 ret = self._get_sym_ret(val) 998 # If the ret is None, it means that the val is not a sym value, but a regular tensor 999 if ret is None: 1000 ret = self._emit_spec(spec) 1001 assert ret is not None, "Can't have a None ret" 1002 return ret 1003 1004 def _emit_delegate( 1005 self, 1006 lowered_module: "LoweredBackendModule", # noqa 1007 args: Tuple[_Argument, ...], 1008 kwargs: Dict[str, _Argument], 1009 ) -> _EmitterValue: 1010 """Emit the delegates inputs and outputs as specified by the schema, then emit the 1011 delegate's blob.""" 1012 processed_bytes = lowered_module.processed_bytes 1013 1014 delegate_index = self.emitter_state.delegate_cache.get(processed_bytes) 1015 delegate_ret = None 1016 1017 if isinstance(self.node.meta["spec"], list): 1018 delegate_ret = [] 1019 for index, _ in enumerate(self.node.meta["val"]): 1020 ret = self._get_sym_and_fake_tensor_ret( 1021 self.node.meta["val"][index], self.node.meta["spec"][index] 1022 ) 1023 delegate_ret.append(ret) 1024 elif isinstance(self.node.meta["spec"], tuple): 1025 if isinstance(self.node.meta["val"], FakeTensor): 1026 # There is a case when node.meta["spec"] is (TensorSpec, ) while node.meta["val"] is FakeTensor 1027 ret = self._get_sym_and_fake_tensor_ret( 1028 self.node.meta["val"], self.node.meta["spec"][0] 1029 ) 1030 delegate_ret = (ret,) 1031 else: 1032 delegate_ret = [] 1033 for index, _ in enumerate(self.node.meta["val"]): 1034 ret = self._get_sym_and_fake_tensor_ret( 1035 self.node.meta["val"][index], self.node.meta["spec"][index] 1036 ) 1037 delegate_ret.append(ret) 1038 delegate_ret = tuple(delegate_ret) 1039 elif isinstance(self.node.meta["spec"], TensorSpec): 1040 ret = self._get_sym_and_fake_tensor_ret( 1041 self.node.meta["val"], self.node.meta["spec"] 1042 ) 1043 delegate_ret = ret 1044 else: 1045 raise NotImplementedError( 1046 f"self.node.meta['spec'] {type(self.node.meta['spec'])} is not supported" 1047 ) 1048 assert delegate_ret is not None, "Can't have a None delegate_ret" 1049 if delegate_index is None: 1050 # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if 1051 # present. 1052 data_index: int = len(self.program_state.backend_delegate_data) 1053 self.program_state.backend_delegate_data.append( 1054 BackendDelegateInlineData(data=processed_bytes) 1055 ) 1056 1057 backend_delegate = BackendDelegate( 1058 id=lowered_module.backend_id, 1059 processed=BackendDelegateDataReference( 1060 location=DataLocation.INLINE, index=data_index 1061 ), 1062 compile_specs=lowered_module.compile_specs, 1063 ) 1064 delegate_index = len(self.emitter_state.delegate_cache) 1065 self.emitter_state.delegates.append(backend_delegate) 1066 self.emitter_state.delegate_cache[processed_bytes] = delegate_index 1067 1068 # TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the 1069 # function's spec and with default arguments. This requires us to store the function's spec 1070 # in to_backend() 1071 delegate_args = [ 1072 self._emit_argument(arg, None).id 1073 for arg in typing.cast(List[_Argument], args) 1074 ] 1075 1076 for elem in pytree.tree_flatten(delegate_ret)[0]: 1077 delegate_args.append(elem.id) 1078 1079 self.chain.instructions.append( 1080 Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args)) 1081 ) 1082 1083 return delegate_ret 1084 1085 def _get_operator(self, name: str, overload: str) -> Tuple[int, Operator]: 1086 """Given a fully qualified name, lookups the operator in the ExecuTorch Program, or adds it 1087 if it is not already present""" 1088 key = (name, overload) 1089 op_index = self.emitter_state.operator_cache.get(key) 1090 if op_index is not None: 1091 return op_index, self.emitter_state.operators[op_index] 1092 1093 op_index, operator = len(self.emitter_state.operators), Operator( 1094 name=name, overload=overload 1095 ) 1096 self.emitter_state.operators.append(operator) 1097 self.emitter_state.operator_cache[key] = op_index 1098 return op_index, operator 1099 1100 def _emit_operator( # noqa: C901 1101 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1102 ) -> _EmitterValue: 1103 """Emits an operator (aten or custom), directly translates to a call_kernel instruction.""" 1104 assert isinstance( 1105 target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload) 1106 ), f"target is {target}" 1107 1108 # grab the name 1109 op_name = target._overloadpacket._qualified_op_name 1110 op_overload = "" 1111 if target._overloadname != "default": 1112 op_overload = target._overloadname 1113 1114 def _get_empty_tensor_evalue() -> EValue: 1115 """Constructs an EValue for an empty tensor.""" 1116 return EValue( 1117 Tensor( 1118 scalar_type=ScalarType.BYTE, 1119 # The runtime currently only supports tensors with offset 0. 1120 storage_offset=0, 1121 sizes=[0], 1122 dim_order=[], 1123 requires_grad=False, 1124 layout=0, 1125 data_buffer_idx=0, 1126 allocation_info=None, 1127 shape_dynamism=TensorShapeDynamism.STATIC, 1128 ) 1129 ) 1130 1131 op_index, operator = self._get_operator(name=op_name, overload=op_overload) 1132 1133 # Emit the args and kwargs in the order according to the function schema. 1134 kernel_args = [] 1135 out_args = [] 1136 for i, schema_arg in enumerate(target._schema.arguments): 1137 if schema_arg.name in kwargs: 1138 kernel_arg = kwargs[schema_arg.name] 1139 elif not schema_arg.kwarg_only and i < len(args): 1140 kernel_arg = args[i] 1141 else: 1142 # Emit default values 1143 kernel_arg = schema_arg.default_value 1144 1145 if kernel_arg is None and isinstance(schema_arg.type, torch.TensorType): 1146 kernel_arg = self._emit_evalue(_get_empty_tensor_evalue()) 1147 1148 kernel_args.append(self._emit_argument(kernel_arg, schema_arg.type).id) 1149 1150 if schema_arg.is_out: 1151 out_args.append((schema_arg.name, kernel_arg)) 1152 1153 if is_out_variant(op_name, op_overload): 1154 ret = [val for _, val in out_args] 1155 ret = ret[0] if len(ret) == 1 else ret 1156 elif is_sym_op(target): 1157 assert ( 1158 len(target._schema.returns) == 1 1159 ), "Only returning a single Sym from symbolic ops is supported currently." 1160 assert type(target._schema.returns[0].type) in ( 1161 torch.IntType, 1162 torch.FloatType, 1163 torch.BoolType, 1164 torch.NumberType, 1165 ), f"Only symbolic ops that return a Int Bool Float are supported currently got {type(target._schema.returns[0].type)}." 1166 ret = self._get_sym_ret(target._schema.returns[0]) 1167 if ret is None: # type(target._schema.returns[0].type) == torch.NumberType: 1168 # Cant definitively say what type this is, the runtime operator just overrides the EValue completely 1169 # though so we can just serialize whatever as a placeholder. 1170 ret = self._emit_evalue(EValue(Int(0))) 1171 else: 1172 ret = self._emit_spec(self.node.meta["spec"]) 1173 1174 out_args = ( 1175 self._emit_evalue( 1176 EValue(TensorList([cast(_AbstractValue, val).id for val in ret])) 1177 ) 1178 if isinstance(ret, list) 1179 else ret 1180 ) 1181 1182 for elem in pytree.tree_flatten(out_args)[0]: 1183 kernel_args.append(cast(_AbstractValue, elem).id) 1184 1185 self.chain.instructions.append( 1186 Instruction(KernelCall(op_index=op_index, args=kernel_args)) 1187 ) 1188 self._add_debug_handle(len(self.chain.instructions) - 1, target) 1189 1190 # Get the stacktrace if it exists for each instruction. 1191 if self.emitter_state.emit_stacktrace: 1192 stack_trace = self.node.meta["stack_trace"] 1193 chain_stacktrace = self.chain.stacktrace or [] 1194 1195 chain_stacktrace.append(_stacktrace_to_framelist(stack_trace)) 1196 self._internal_assert_emitter( 1197 len(chain_stacktrace) == len(self.chain.instructions), 1198 self.node, 1199 f"Each instruction should have corresponding stacktrace received {len(self.chain.instructions)} \ 1200 instructions and {len(chain_stacktrace)} stacktraces", 1201 ) 1202 self.chain.stacktrace = chain_stacktrace 1203 1204 return cast(_EmitterValue, ret) 1205 1206 def _emit_free(self, spec: TensorSpec) -> _AbstractValue: 1207 """Emits a FreeCall instruction to release a given Unbounded Tensor's memory.""" 1208 self.chain.instructions.append( 1209 Instruction(FreeCall(value_index=self.emitter_state.spec2id(spec))) 1210 ) 1211 # The value is not used but the caller expects an AbstractValue returned. 1212 return _AbstractValue(None, None) # pyre-ignore 1213 1214 def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan]: 1215 """ 1216 Given a mapping of function names to return values, emit simple execution 1217 plans that just return these constant values. 1218 1219 Precondition: All the values are primitives (bool, float, int, str, enum) 1220 or structures (list, dict) of them. 1221 """ 1222 plans = [] 1223 # flatten any structures 1224 for method, vals in prim_getters.items(): 1225 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 1226 flattened_output, spec = ex_pytree.tree_flatten(vals) 1227 spec = spec.to_str() 1228 chain = Chain( 1229 inputs=[], 1230 outputs=[], 1231 instructions=[], 1232 stacktrace=None, 1233 ) 1234 1235 # switch on type of prim 1236 values = [] 1237 for val in flattened_output: 1238 if isinstance(val, float): 1239 values.append(EValue(Double(val))) 1240 1241 elif isinstance(val, bool): 1242 values.append(EValue(Bool(val))) 1243 1244 elif isinstance(val, int): 1245 values.append(EValue(Int(val))) 1246 1247 elif isinstance(val, str): 1248 values.append(EValue(String(val))) 1249 1250 elif isinstance(val, torch.dtype): 1251 values.append(EValue(Int(scalar_type_enum(val)))) 1252 1253 elif isinstance(val, torch.layout): 1254 values.append(EValue(Int(layout_enum(val)))) 1255 1256 elif isinstance(val, torch.Tensor): 1257 values.append( 1258 self._tensor_spec_to_evalue( 1259 TensorSpec.from_tensor(val, const=True) 1260 ) 1261 ) 1262 1263 else: 1264 raise ExportError( 1265 ExportErrorType.NOT_SUPPORTED, 1266 f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive", 1267 ) 1268 1269 # add to plans 1270 plans.append( 1271 ExecutionPlan( 1272 name=method, 1273 values=values, 1274 inputs=[], 1275 outputs=list(range(0, len(values))), 1276 chains=[chain], 1277 operators=[], 1278 delegates=[], 1279 non_const_buffer_sizes=[0], 1280 container_meta_type=ContainerMetadata("", spec), 1281 ) 1282 ) 1283 return plans 1284 1285 def fetch_attr(self, target: _Target) -> _AbstractValue: 1286 """Fetch weights and other module parameters. If the attribute is a tensor, emit it.""" 1287 attr = super().fetch_attr(target) # pyre-fixme[6] 1288 1289 if isinstance(attr, torch.Tensor): 1290 return self._emit_evalue( 1291 self._tensor_spec_to_evalue(TensorSpec.from_tensor(attr, const=True)) 1292 ) 1293 1294 elif isinstance(attr, torch._C.ScriptObject): 1295 raise ExportError( 1296 ExportErrorType.NOT_SUPPORTED, 1297 f"Custom class {attr} is not supported in EXIR", 1298 ) 1299 1300 else: 1301 return attr 1302 1303 def call_module( # pyre-fixme[14] 1304 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1305 ) -> None: 1306 """Unsupported in execution IR, so unhandled by the emitter.""" 1307 raise InternalError( 1308 self._emit_node_specific_error(self.node, "call_module is not supported") 1309 ) 1310 1311 def call_method( # pyre-fixme[14] 1312 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1313 ) -> _EmitterValue: 1314 """Unsupported in execution IR, so unhandled by the emitter.""" 1315 raise InternalError( 1316 self._emit_node_specific_error(self.node, "call_method is not supported") 1317 ) 1318 1319 def placeholder( # pyre-fixme[14] 1320 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1321 ) -> _AbstractValue: 1322 """Performs actions for the placeholder node of a graph module. 1323 1324 The placeholder node of the top level entry point is handled by TopLevelEmitter. This 1325 function only executes on control flow subgraphs. Takes the inputs of the subgraph that had 1326 not previously been emitted and emits them. 1327 """ 1328 # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. 1329 value = self.binding_input_values[self.placeholder_count] 1330 # This indicates that the placeholder wasn't allocated an EValue id before this sub-emitter 1331 # was run, so we generate one now. 1332 if value == -1: 1333 value = self._emit_evalue( 1334 self._tensor_spec_to_evalue(self.node.meta["spec"]) 1335 ) 1336 # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. 1337 self.binding_input_values[self.placeholder_count] = value 1338 self.placeholder_count += 1 1339 return value 1340 1341 def output( # pyre-fixme[14] 1342 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1343 ) -> None: 1344 """Performs actions for the output node of a graph module. 1345 1346 The output node of the top level entry point is handled by TopLevelEmitter. This function 1347 only executes on control flow subgraphs. Takes the outputs of the subgraph (if any) and 1348 inserts instructions to move them to the common output location between control flow 1349 branches. 1350 """ 1351 self.concrete_output_ids = list(pytree.tree_flatten(args[0])[0]) 1352 binding_output_values = self.binding_output_values 1353 if binding_output_values is not None: 1354 binding_output_list, _ = pytree.tree_flatten(binding_output_values) 1355 1356 self._internal_assert_emitter( 1357 len(binding_output_list) == len(self.concrete_output_ids), 1358 self.node, 1359 "The number of binding output values should match the args to output", 1360 ) 1361 1362 for move_from, move_to in zip( 1363 self.concrete_output_ids, binding_output_list 1364 ): 1365 if move_from != move_to: 1366 instruction = Instruction( 1367 MoveCall(move_from=move_from.id, move_to=move_to.id) 1368 ) 1369 self.chain.instructions.append(instruction) 1370 1371 def call_function( # pyre-fixme[14] 1372 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1373 ) -> _EmitterValue: 1374 """Performs actions for the call_function node of a graph module. 1375 1376 Dispatches based on 'target' and emits the corresponding function. 'call_function' is a 1377 powerful node that contains many operations ranging from control_flow, to memory management, 1378 to delegate and operator calls. 1379 """ 1380 1381 # Delegate and operator calls are the only functions that should have a debug handle 1382 # associated with them. All the others such as memory.alloc, getitem should be ignored. 1383 # Default to none and let delegates and ops override. 1384 if target == operator.getitem: 1385 assert len(args) == 2 1386 head = typing.cast(Mapping[int, _EmitterValue], args[0]) 1387 index = typing.cast(int, args[1]) 1388 return head[index] 1389 1390 elif target == memory.alloc: 1391 assert len(args) == 1 1392 return self._emit_spec(self.node.meta["spec"]) 1393 1394 elif target == memory.view: 1395 return self._emit_view(args) 1396 1397 elif target == memory.free: 1398 assert len(args) == 1 1399 # pyre-ignore 1400 return self._emit_free(args[0]) 1401 1402 elif target is torch.ops.higher_order.cond: 1403 return self._emit_control_flow(target, args, kwargs) 1404 1405 elif target is torch.ops.higher_order.map_impl: 1406 return self._emit_control_flow(target, args, kwargs) 1407 1408 elif target == executorch_call_delegate: 1409 lowered_module = args[0] 1410 assert is_lowered_module(lowered_module) 1411 v = self._emit_delegate(lowered_module, args[1:], kwargs) 1412 delegate_instruction_id = len(self.chain.instructions) - 1 1413 self._add_debug_handle(delegate_instruction_id, target, lowered_module) 1414 self._add_delegate_map(lowered_module, delegate_instruction_id) 1415 return v 1416 1417 elif isinstance( 1418 target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload) 1419 ): 1420 return self._emit_operator(target, args, kwargs) 1421 1422 else: 1423 raise InternalError( 1424 self._emit_node_specific_error( 1425 self.node, f"invalid target for call_function {target}" 1426 ) 1427 ) 1428 1429 def run( # pyre-fixme[14] 1430 self, 1431 *args: _Argument, 1432 initial_env: Optional[Dict[torch.fx.Node, _Argument]] = None, 1433 ) -> None: 1434 """Traverses all nodes in the graph module and emits each one appropriately.""" 1435 super().run(*args, initial_env, enable_io_processing=False) 1436 1437 def run_node(self, n: torch.fx.Node) -> None: 1438 """Executes and emits the specified node. 1439 1440 For more context on what a node is and what execution means see 1441 https://pytorch.org/docs/stable/fx.html#torch.fx.Node 1442 """ 1443 self.node = n 1444 try: 1445 ret = super().run_node(n) 1446 except Exception as e: 1447 if isinstance(e, (InternalError, ExportError)): 1448 raise e 1449 else: 1450 raise InternalError( 1451 self._emit_node_specific_error(self.node, str(e)) 1452 ) from e 1453 return ret 1454 1455 1456class _TopLevelEmitter(_Emitter): 1457 """An emitter that manages the root level operations within a graph module. 1458 1459 Exists as a separate class so that 'Emitter' can handle the special behavior of 'placeholder' 1460 and 'output' nodes in control flow submodules. 1461 """ 1462 1463 def __init__( 1464 self, 1465 name: str, 1466 exported_program: ExportedProgram, 1467 graph_module: torch.fx.GraphModule, 1468 program_state: _ProgramState, 1469 emitter_state: _EmitterState, 1470 ) -> None: 1471 super().__init__(graph_module, emitter_state, program_state) 1472 self.name = name 1473 self.exported_program = exported_program 1474 1475 self.inputs: List[int] = [] 1476 self.outputs: List[int] = [] 1477 self.given_mutable_buffer_warning = False 1478 1479 def create_container_str(spec: Optional[pytree.TreeSpec]) -> str: 1480 if spec is None: 1481 return "" 1482 assert isinstance(spec, pytree.TreeSpec), type(spec) 1483 dummy_leaves = [0] * spec.num_leaves 1484 tree = torch.utils._pytree.tree_unflatten(dummy_leaves, spec) 1485 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 1486 _, tree = ex_pytree.tree_flatten(tree) 1487 return tree.to_str() 1488 1489 inp_container_str = create_container_str(exported_program.call_spec.in_spec) 1490 out_container_str = create_container_str(exported_program.call_spec.out_spec) 1491 1492 self.container_meta_type = ContainerMetadata( 1493 inp_container_str, out_container_str 1494 ) 1495 1496 def _find_fqn_for_placeholder( 1497 self, target: _Target, spec: Any # pyre-ignore[2] 1498 ) -> Tuple[Optional[str], bool]: 1499 # Find the fully qualified name 1500 fqn = None 1501 is_mutable_buffer = False 1502 if target in self.exported_program.graph_signature.inputs_to_parameters: 1503 fqn = self.exported_program.graph_signature.inputs_to_parameters[target] 1504 1505 elif target in self.exported_program.graph_signature.inputs_to_buffers: 1506 fqn = self.exported_program.graph_signature.inputs_to_buffers[target] 1507 1508 # if the buffer is mutated then record that 1509 if fqn in self.exported_program.graph_signature.buffers_to_mutate.values(): 1510 is_mutable_buffer = True 1511 if not self.given_mutable_buffer_warning: 1512 warnings.warn( 1513 "Mutation on a buffer in the model is detected. ExecuTorch assumes " 1514 "buffers that are mutated in the graph have a meaningless initial state, " 1515 "only the shape and dtype will be serialized.", 1516 UserWarning, 1517 stacklevel=1, 1518 ) 1519 self.given_mutable_buffer_warning = True 1520 1521 elif ( 1522 target 1523 in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants 1524 ): 1525 fqn = ( 1526 self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[ 1527 target 1528 ] 1529 ) 1530 return fqn, is_mutable_buffer 1531 1532 def placeholder( 1533 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1534 ) -> _AbstractValue: 1535 """Emits the value within the placeholder node. 1536 1537 For more information on placeholder nodes see 1538 https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder 1539 """ 1540 spec = self.node.meta["spec"] 1541 is_user_input = True 1542 1543 if isinstance(target, str) and isinstance(spec, TensorSpec): 1544 fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec) 1545 1546 # From the fqn find the corresponding tensor 1547 real_tensor = None 1548 if fqn in self.exported_program.state_dict: 1549 real_tensor = self.exported_program.state_dict[fqn] 1550 is_user_input = False 1551 1552 elif fqn in self.exported_program.constants: 1553 real_tensor = self.exported_program.constants[fqn] 1554 is_user_input = False 1555 elif fqn is not None: 1556 buffers = self.exported_program.named_buffers() 1557 buf = next((x[1] for x in buffers if x[0] == fqn), None) 1558 if buf is not None: 1559 real_tensor = buf 1560 is_user_input = False 1561 else: 1562 raise InternalError( 1563 self._emit_node_specific_error( 1564 self.node, 1565 f"Could not find buffer with fqn {fqn} in state_dict or named_buffers", 1566 ) 1567 ) 1568 1569 # assign the storage of the placeholder spec to the storage of the real tensor if there is one 1570 if real_tensor is not None: 1571 # for non-contigous tensors, convert to a contiguous one 1572 real_tensor = real_tensor.contiguous() 1573 # Weights cannot be views during emission or serialization 1574 if real_tensor.nbytes != real_tensor.untyped_storage().nbytes(): 1575 real_tensor = real_tensor.clone() 1576 1577 spec.storage = real_tensor.untyped_storage() 1578 1579 # User inputs and mutable buffers are not constants, other buffers or parameters are. 1580 spec.const = not (is_user_input or is_mutable_buffer) 1581 1582 evalue = ( 1583 self._tensor_spec_to_evalue(spec) 1584 if isinstance(spec, TensorSpec) 1585 else self._constant_to_evalue(spec, None) 1586 ) 1587 value = self._emit_evalue(evalue) 1588 1589 # Only user inputs should remain as inputs. 1590 if is_user_input: 1591 self.inputs.append(value.id) 1592 1593 return value 1594 1595 def output( 1596 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1597 ) -> None: 1598 """Records the ExecutionPlan's outputs based on the output node in the graph.""" 1599 if isinstance(args[0], dict): 1600 args_tuple, _ = pytree.tree_flatten(args[0]) 1601 else: 1602 args_tuple = typing.cast(Tuple[_AbstractValue, ...], args[0]) 1603 if isinstance(args_tuple, _AbstractValue): 1604 self.outputs.append(args_tuple.id) 1605 else: 1606 for arg in args_tuple: 1607 if isinstance(arg, (int, float, bool, type(None))): 1608 arg = self._emit_evalue(self._constant_to_evalue(arg, None)) 1609 elif isinstance(arg, str): 1610 # TODO(jackkhuu): T181599879 Add support for string outputs IFF compiler supports 1611 raise InternalError( 1612 self._emit_node_specific_error( 1613 self.node, 1614 f"Returning {arg} is not yet supported in the emitter.", 1615 ) 1616 ) 1617 else: 1618 # Every other output should already have its value emitted. 1619 # They should only be abstract IDs at this point. 1620 assert isinstance(arg, _AbstractValue) 1621 self.outputs.append(arg.id) 1622 1623 def plan(self) -> ExecutionPlan: 1624 """Returns the execution plan emitted from this entry point.""" 1625 return ExecutionPlan( 1626 name=self.name, 1627 values=self.emitter_state.values, 1628 inputs=self.inputs, 1629 outputs=self.outputs, 1630 chains=[self.chain], 1631 operators=self.emitter_state.operators, 1632 delegates=self.emitter_state.delegates, 1633 # non_const_buffer_sizes field is set by the memory_planning_pass. In case the field is 1634 # missing in scenarios like unit test that does not enable memory planning, assume an 1635 # empty list. 1636 non_const_buffer_sizes=typing.cast( 1637 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorB... 1638 List[int], self.module.meta["non_const_buffer_sizes"] 1639 ), 1640 container_meta_type=self.container_meta_type, 1641 ) 1642