xref: /aosp_15_r20/external/executorch/exir/emit/_emitter.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""Takes an ExportedArtifact, or a collection of ExportedArtifacts, in execution dialect, and turns
8them into a single ExecuTorch Program.
9
10The provided ExportedArtifact's graph modules are in execution dialect and the emitter parses and
11converts them into executorch instructions. The emitter walks the provided graphs and as it
12encounters concrete values such as tensors or ints, it converts them to the serialized format and
13stores them in a list for later use. The emitter walks the graph by traversing fx.nodes, these can
14come in a variety of forms and are the primitives of execution at the graph module level. The most
15common 3 we care about are 'call_function', 'place_holder', and 'output'. 'placeholder' and 'output'
16handle io for the module and 'call_function' handles everything else. Within 'call_function' we may
17encounter an operator or delegate call, in such case we parse the schema and emit all the inputs and
18outputs (unless they have already previously been emitted), and then we convert the actual function
19call into an executorch instruction such as KernelCall or DelegateCall.
20
21When control flow is present in the graphmodule it will take the form of a few different types of
22'call_function'. Today (June 14th 2023) only cond and map are supported. The actual operations of
23these, such as the true/false branches of cond, or the mapping function of map, are stored as sub
24graphmodules. When these are encountered during emission, the emitter will recursively emit them and
25their instructions.
26"""
27# TODO(jakeszwe): add information here about how weights and other parameters are handled in the
28# presence of aot autograd param lifting.
29
30# pyre-strict
31import ctypes
32import hashlib
33import operator
34import typing
35import warnings
36from dataclasses import dataclass, field
37from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Tuple, Union
38
39import executorch.exir.memory as memory
40import executorch.extension.pytree as ex_pytree
41import torch
42import torch.fx
43from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
44from executorch.exir.dialects.backend._ops import BackendOpOverload
45from executorch.exir.dialects.edge._ops import EdgeOpOverload
46from executorch.exir.error import ExportError, ExportErrorType, InternalError
47from executorch.exir.operator.convert import is_out_variant
48from executorch.exir.passes.executorch_prim_ops_registry import is_sym_op
49from executorch.exir.print_program import _stacktrace_to_framelist, inspect_node
50from executorch.exir.schema import (
51    BackendDelegate,
52    BackendDelegateDataReference,
53    BackendDelegateInlineData,
54    Bool,
55    BoolList,
56    Buffer,
57    Chain,
58    ContainerMetadata,
59    DataLocation,
60    DelegateCall,
61    Double,
62    DoubleList,
63    EValue,
64    ExecutionPlan,
65    FreeCall,
66    Instruction,
67    Int,
68    IntList,
69    JumpFalseCall,
70    KernelCall,
71    MoveCall,
72    Null,
73    Operator,
74    OptionalTensorList,
75    ScalarType,
76    String,
77    Tensor,
78    TensorList,
79    TensorShapeDynamism,
80)
81from executorch.exir.tensor import (
82    AddressSpaceOverflowException,
83    layout_enum,
84    make_allocation_info,
85    make_tensor_value,
86    memory_format_enum,
87    scalar_type_enum,
88    TensorSpec,
89)
90from executorch.exir.types import LeafValueSpec, ValueSpec
91from torch._subclasses.fake_tensor import FakeTensor
92
93from torch.export.exported_program import ExportedProgram
94from torch.utils import _pytree as pytree
95
96from typing_extensions import TypeAlias
97
98
99@dataclass
100class _ProgramState:
101    """State shared between all methods of a program and the graph module it represents.
102
103    Initialized once within emit_program and then shared across each entry point as they are
104    emitted.
105    """
106
107    # Parallel list of specs and the buffers that backed them, have to add + 1 to any index in here
108    # as index 0 in the constant_buffer is reserved.
109    allocated_specs: List[TensorSpec] = field(default_factory=list)
110    # Weights in any arbitrary graph_module only need to compare against weights from previously
111    # emitted graph modules, not any weights emitted from itself. This should speed up the lookup,
112    # from O(N) to O(1)
113    cached_spec_hash_values: Dict[str, int] = field(default_factory=dict)
114    cached_spec_mutable_hash_values: Dict[str, int] = field(default_factory=dict)
115    # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder.
116    constant_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")])
117    # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder.
118    mutable_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")])
119    # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
120    # and should be copied to Program.backend_delegate_data.
121    backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
122
123
124@dataclass
125class _EmitterState:
126    """State of a single emitter.
127
128    Local to at least the entry point, and may be local to a subgraph of an entry point originating
129    from control flow.
130    """
131
132    values: List[EValue]
133    operators: List[Operator]
134    delegates: List[BackendDelegate]
135    operator_cache: Dict[Tuple[str, str], int]
136    delegate_cache: Dict[bytes, int]
137    emit_stacktrace: bool
138
139    spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict)
140
141    def spec2id(self, spec: TensorSpec) -> int:
142        """Map a TensorSpec to value index in the values array."""
143        assert spec in self.spec2id_dict, f"Spec is not found: {spec.debug()}"
144        return self.spec2id_dict[spec]
145
146
147@dataclass
148class _AbstractValue:
149    """Represents an already emitted EValue"""
150
151    # Index in the values table of this EValue.
152    id: int
153
154    # Used for sanity checks for functions that expect to only receive AbstractValues.
155    tensor: Optional[Tensor]
156
157
158_EmitterValue: TypeAlias = Union[
159    _AbstractValue, List[_AbstractValue], Tuple[_AbstractValue, ...]
160]
161
162_PythonValue: TypeAlias = Union[bool, int, float, torch.Tensor, List["_PythonValue"]]
163_SchemaType: TypeAlias = Union[
164    torch.OptionalType,
165    torch.ListType,
166    torch.FloatType,
167    torch.BoolType,
168    torch.IntType,
169    torch.StringType,
170    torch.TensorType,
171]
172
173_Target: TypeAlias = Union[Callable[..., _PythonValue], str]
174
175_Argument: TypeAlias = Union[
176    _EmitterValue,
177    Tuple["_Argument", ...],
178    List["_Argument"],
179    Dict[str, "_Argument"],
180    str,
181    int,
182    float,
183    bool,
184    complex,
185    torch.dtype,
186    torch.Tensor,
187    torch.memory_format,
188    torch.layout,
189    None,
190]
191
192_DelegateDebugIdentifierMap: TypeAlias = Union[
193    Dict[int, Tuple[int]], Dict[str, Tuple[int]]
194]
195
196
197# pyre-ignore[13]: Attribute `node` is never initialized.
198class _Emitter(torch.fx.Interpreter):
199    """An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the
200    given traced torch.fx.GraphModule to the flatbuffer schema."""
201
202    node: torch.fx.Node
203
204    def __init__(
205        self,
206        graph_module: torch.fx.GraphModule,
207        emitter_state: _EmitterState,
208        program_state: _ProgramState,
209        instruction_start_offset: int = 0,
210        binding_input_values: Optional[List[_AbstractValue]] = None,
211        binding_output_values: Optional[List[_AbstractValue]] = None,
212    ) -> None:
213        super().__init__(graph_module)
214        self.emitter_state = emitter_state
215        self.program_state = program_state
216        self.outputs: List[int] = []
217
218        self.chain = Chain(
219            inputs=[],
220            outputs=[],
221            instructions=[],
222            stacktrace=None,
223        )
224
225        if "non_const_buffer_sizes" not in graph_module.meta.keys():
226            raise RuntimeError(
227                "Must set 'non_const_buffer_sizes' in graph meta in memory planning pass"
228            )
229        self.instruction_start_offset = instruction_start_offset
230        self.binding_input_values = binding_input_values
231        self.binding_output_values = binding_output_values
232        self.graph_module: torch.fx.GraphModule = graph_module
233        self.nodes: List[torch.fx.Node] = list(self.graph_module.graph.nodes)
234
235        # Marks the placeholder node with its order so that we can match them with the corresponding
236        # Abstract Value coming from top level.
237        self.placeholder_count = 0
238
239        self.concrete_output_ids: List[_AbstractValue] = []
240        self.debug_handle_map: Dict[int, Union[int, List[int]]] = {}
241        self.instr_id_to_delegate_debug_id_map: Dict[
242            int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]
243        ] = {}
244
245    def _emit_node_specific_error(self, node: torch.fx.Node, err_msg: str) -> str:
246        """Returns 'err_msg' with node specific information attached."""
247        err_msg = f"Failed with error: {str(err_msg)}\n" + inspect_node(
248            self.graph_module.graph, node
249        )
250        return err_msg
251
252    def _internal_assert_emitter(
253        self, pred: bool, node: torch.fx.Node, assert_msg: str
254    ) -> None:
255        """If pred is False, construct and raise a node specific error message."""
256        if not pred:
257            raise InternalError(self._emit_node_specific_error(node, assert_msg))
258
259    def _emit_int_list(self, val: List[_Argument]) -> EValue:
260        """Emits a list of integers as a collection of EValues.
261
262        For every argument in 'val':
263            - If it is a concrete value, then emit it and then place its location in the boxed list
264            - If it is already an abstract value, then just place its location in the boxed list
265
266        Int lists are boxed to handle symints whose values are determined at runtime, but could
267        still end up inside lists for ops like view_copy(Tensor self, SymInt[] size)
268        """
269        boxed_list = []
270        for item in val:
271            if isinstance(item, _AbstractValue):
272                boxed_list.append(item.id)
273            elif isinstance(item, int):
274                boxed_list.append(
275                    self._emit_evalue(self._constant_to_evalue(item, None)).id
276                )
277            else:
278                self._internal_assert_emitter(
279                    False, self.node, "Unsupported type encountered in int list."
280                )
281
282        return EValue(IntList(boxed_list))
283
284    def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
285        """Emits a list type.
286
287        Emits the list stored in val. If the list is of Tensors, Optionals, or Ints the emitted list
288        is boxed, otherwise the values are constant at runtime and stored inline.
289
290        NOTE: When symbool and symfloat are supported bool and float lists will be stored boxed.
291        """
292
293        if isinstance(val_type, torch.BoolType):
294            return EValue(BoolList(typing.cast(List[bool], val)))
295
296        if isinstance(val_type, torch.IntType):
297            return self._emit_int_list(val)
298
299        if isinstance(val_type, torch.FloatType):
300            return EValue(DoubleList(typing.cast(List[float], val)))
301
302        if isinstance(val_type, torch.TensorType):
303            values = []
304            for v in val:
305                assert isinstance(v, _AbstractValue)
306                self._internal_assert_emitter(
307                    v.tensor is not None,
308                    self.node,
309                    "AbstractValue corresponding to tensor type doesn't contain tensor value",
310                )
311                values.append(v.id)
312            return EValue(TensorList(values))
313
314        if isinstance(val_type, torch.OptionalType):
315            # refine further
316            actual_type = val_type.getElementType()
317            if isinstance(actual_type, torch.TensorType):
318                vals = []
319                for v in val:
320                    if v is None:
321                        vals.append(-1)
322                    else:
323                        assert isinstance(v, _AbstractValue)
324                        vals.append(v.id)
325                return EValue(OptionalTensorList(vals))
326
327        raise ExportError(
328            ExportErrorType.NOT_SUPPORTED, f"Unknown list type: {val_type}"
329        )
330
331    def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
332        """Constructs an EValue from the given TensorSpec."""
333
334        allocation_info = None
335        buffer_idx = 0
336
337        # Need to memory plan
338        # Some users set mem_id on all tensors and then rely on the
339        # default algos to set offsets, so need to check both.
340        if spec.mem_id is not None and spec.mem_offset is not None:
341            # Tensor is an activation.
342            self._internal_assert_emitter(
343                isinstance(spec.mem_id, int) and spec.mem_id >= 0,
344                self.node,
345                f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}",
346            )
347
348            self._internal_assert_emitter(
349                isinstance(spec.mem_offset, int) and spec.mem_offset >= 0,
350                self.node,
351                f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}",
352            )
353            try:
354                allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset)
355            except AddressSpaceOverflowException as e:
356                raise InternalError(
357                    self._emit_node_specific_error(
358                        self.node,
359                        (
360                            f"{e}\nHint: If you are using a memory pass based on dynamic shape bounds, "
361                            f"such as ConstraintBasedSymShapeEvalPass, this may be the cause of an "
362                            f"unbacked SymInt with its upper bound lazily set to 2^64-1 (uint64 max) "
363                            "during torch.export()."
364                        ),
365                    )
366                )
367
368        if spec.const:
369            # Tensor with a blob we need to serialize. May not actually be constant at runtime
370            # if it's a weight with an associated gradient
371            spec_array_type = (
372                ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes()
373            )
374
375            buffer_data = (
376                bytes(
377                    ctypes.cast(
378                        typing.cast(torch.UntypedStorage, spec.storage).data_ptr(),
379                        ctypes.POINTER(spec_array_type),
380                    ).contents
381                )
382                if spec.allocated_memory != 0
383                else b""
384            )
385
386            hashed = hashlib.sha256(buffer_data).hexdigest()
387
388            if allocation_info:
389                buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
390                    hashed, -1
391                )
392            else:
393                buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1)
394
395            # Haven't seen this constant before
396            if buffer_idx == -1:
397                # Update buffer_idx to point to the end of the list where we are adding the new buffer.
398                buffer = Buffer(storage=buffer_data)
399                self.program_state.allocated_specs.append(spec)
400                # +1 because the first buffer location is reserved
401
402                if allocation_info:
403                    buffer_idx = len(self.program_state.mutable_buffer)
404                    self.program_state.cached_spec_mutable_hash_values[hashed] = (
405                        buffer_idx
406                    )
407                    self.program_state.mutable_buffer.append(buffer)
408                else:
409                    buffer_idx = len(self.program_state.constant_buffer)
410                    self.program_state.cached_spec_hash_values[hashed] = buffer_idx
411                    self.program_state.constant_buffer.append(buffer)
412
413            if spec.const and spec.nbytes() != len(buffer_data):
414                raise InternalError(
415                    self._emit_node_specific_error(
416                        self.node,
417                        f"Tensor spec has buffer of size {len(buffer_data)}, but expected nbytes of {spec.nbytes()}",
418                    )
419                )
420
421        # For constant tensors, allocation_info = None.
422        return EValue(make_tensor_value(buffer_idx, allocation_info, spec))
423
424    def _get_list_tuple_jit_type(
425        self, val: Union[Tuple[_Argument], List[_Argument]]
426    ) -> _SchemaType:
427        """Returns the JIT type for the given python type."""
428        assert isinstance(
429            val, (list, tuple)
430        ), f"Input to _get_list_tuple_jit_type was expected to be an instance of a list or tuple but received {type(val)}"
431        is_tensor_type = all(
432            isinstance(v, _AbstractValue) and v.tensor is not None for v in val
433        )
434        if is_tensor_type:
435            return torch.TensorType.get()
436        elif isinstance(val[0], int):
437            return torch.IntType.get()
438        elif isinstance(val[0], bool):
439            return torch.BoolType.get()
440        elif isinstance(val[0], float):
441            return torch.FloatType.get()
442
443        raise InternalError(
444            self._emit_node_specific_error(
445                self.node,
446                "Couldn't determine JitType for list/tuple of elements. Only supports int, float, bool, and Tensor.",
447            )
448        )
449
450    def _constant_to_evalue(  # noqa: C901
451        self,
452        val: _Argument,
453        val_type: Optional[_SchemaType],
454    ) -> EValue:
455        """Converts a constant value to an EValue.
456
457        Returns an EValue given the Python representation and JIT type. On common paths there should
458        always be a JIT type provided. Users can pass in a None to infer the JIT type but this
459        should never be the default case due to the existence of container types.
460        """
461        if val is None:
462            return EValue(Null())
463
464        if isinstance(val, (list, tuple)):
465            # Refine Optional[List[T]] -> List[T] This works because if the val was None it would
466            # have converted to Null before this function call.
467            if val_type is None:
468                val_type = torch.ListType(
469                    self._get_list_tuple_jit_type(val)  # pyre-ignore
470                )
471            if isinstance(val_type, torch.OptionalType):
472                val_type = val_type.getElementType()
473            assert isinstance(val_type, torch.ListType)
474            return self._emit_list(
475                typing.cast(List[_Argument], val),
476                typing.cast(_SchemaType, val_type.getElementType()),
477            )
478
479        if isinstance(val, float):
480            return EValue(Double(val))
481
482        if isinstance(val, bool):
483            return EValue(Bool(val))
484
485        if isinstance(val, int):
486            return EValue(Int(val))
487
488        if isinstance(val, str):
489            return EValue(String(val))
490
491        if isinstance(val, torch.dtype):
492            return EValue(Int(scalar_type_enum(val)))
493
494        if isinstance(val, torch.layout):
495            return EValue(Int(layout_enum(val)))
496
497        if isinstance(val, torch.memory_format):
498            try:
499                return EValue(Int(memory_format_enum(val)))
500            except KeyError:
501                raise InternalError(
502                    self._emit_node_specific_error(
503                        self.node,
504                        f"Tensor has a memory_format that is unsupported in ExecuTorch: {val}",
505                    )
506                )
507
508        if isinstance(val, torch.Tensor):
509            raise ExportError(
510                ExportErrorType.NOT_SUPPORTED,
511                self._emit_node_specific_error(
512                    self.node,
513                    "constant_to_evalue should not be encountering constant tensors, they should be emitted through other codepaths.",
514                ),
515            )
516
517        raise ExportError(
518            ExportErrorType.NOT_SUPPORTED,
519            self._emit_node_specific_error(
520                self.node, f"Unsupported constant type: {type(val).__name__}"
521            ),
522        )
523
524    def _emit_evalue(self, val: EValue) -> _AbstractValue:
525        """Writes an EValue to the emitter state.
526
527        Given an Evalue, adds it to the emitter_state's values table, and returns the AbstractValue
528        representing it.
529        """
530        self.emitter_state.values.append(val)
531        tensor = val.val if isinstance(val.val, Tensor) else None
532        return _AbstractValue(len(self.emitter_state.values) - 1, tensor)
533
534    def _emit_spec(self, spec: ValueSpec) -> _EmitterValue:
535        """Given the provided spec constructs the corresponding EValue from it and then emits it."""
536
537        def _process(spec: LeafValueSpec) -> _AbstractValue:
538            if isinstance(spec, (list, tuple)):
539                raise InternalError(
540                    self.emit_node_specific_error(
541                        self.node,
542                        "Node spec should be either non-nested container or a scalar object",
543                    )
544                )
545
546            # ScalarSpec can theoretically be handled fine, but it shouldn't be appearing so rather
547            # than handle it, assert that it isn't supposed to be present. In the future if it has a
548            # reason to appear we can relax this assert.
549            self._internal_assert_emitter(
550                isinstance(spec, TensorSpec),
551                self.node,
552                f"Invalid node spec expected TensorSpec received {spec}",
553            )
554
555            ret = self._emit_evalue(self._tensor_spec_to_evalue(spec))  # pyre-ignore
556            self.emitter_state.spec2id_dict[spec] = ret.id  # pyre-ignore
557            return ret
558
559        return pytree.tree_map(_process, spec)
560
561    def _merge_chain(self, chain: Chain) -> None:
562        """Merges the chain generated from subgraphs (like those originating from control flow) back
563        into the main program chain."""
564        self.chain.instructions.extend(chain.instructions)
565
566    def _emit_cond(
567        self,
568        args: Tuple[_Argument, ...],
569        subemitter_binding_output_values: Optional[List[_AbstractValue]],
570    ) -> List[_AbstractValue]:
571        """Emits control_flow.cond.
572
573        Converts the higher order op into jumps and inlines the submodules of the true and false
574        branches. Control flow can be nested. The general emitted structure is: <Jump Instruction> -
575        decides which branch to take <True Branch> <Jump Instruction> - jumps to End Of Cond after
576        executing true branch <False Branch> <End Of Cond>
577        """
578        pred, true_branch, false_branch, inputs = args
579
580        # Emit the true branch.
581        assert isinstance(true_branch, torch.fx.GraphModule)
582        true_branch_emitter = _Emitter(
583            true_branch,
584            self.emitter_state,
585            self.program_state,
586            instruction_start_offset=self.instruction_start_offset
587            + len(self.chain.instructions)
588            + 1,
589            binding_input_values=typing.cast(List[_AbstractValue], inputs),
590            binding_output_values=subemitter_binding_output_values,
591        )
592        true_branch_emitter.run()
593
594        # Emit the jump.
595        assert isinstance(pred, _AbstractValue)
596        jf_instruction_to_skip_true = Instruction(
597            JumpFalseCall(
598                cond_value_index=pred.id,
599                destination_instruction=self.instruction_start_offset
600                + len(self.chain.instructions)
601                + len(true_branch_emitter.chain.instructions)
602                # This jump instruction should point at instruction that is after all instructions
603                # for the true branch. The reason we add 2 is because we need to account for this
604                # instruction we are creating right now and the jump instruction that true branch
605                # will create.
606                + 2,
607            )
608        )
609
610        # Insert the branch picking jump instruction to the main chain.
611        self.chain.instructions.append(jf_instruction_to_skip_true)
612        # Now that we created the true branch instructions, we move them to the main chain.
613        self._merge_chain(true_branch_emitter.chain)
614
615        # emit false branch
616        assert isinstance(false_branch, torch.fx.GraphModule)
617        false_branch_emitter = _Emitter(
618            false_branch,
619            self.emitter_state,
620            self.program_state,
621            instruction_start_offset=self.instruction_start_offset
622            + len(self.chain.instructions)
623            + 1,
624            binding_input_values=typing.cast(List[_AbstractValue], inputs),
625            binding_output_values=subemitter_binding_output_values,
626        )
627        false_branch_emitter.run()
628
629        # We bake in constant False because this will trigger the instruction to jump over all false
630        # branch instructions and point at the start of instruction right after control flow.
631        value = self._emit_evalue(EValue(Bool(False)))
632        jf_instruction_to_skip_false = Instruction(
633            JumpFalseCall(
634                cond_value_index=value.id,
635                destination_instruction=self.instruction_start_offset
636                + len(self.chain.instructions)
637                + len(false_branch_emitter.chain.instructions)
638                + 1,
639            )
640        )
641        self.chain.instructions.append(jf_instruction_to_skip_false)
642        self._merge_chain(false_branch_emitter.chain)
643        return subemitter_binding_output_values
644
645    def _emit_map(
646        self,
647        args: Tuple[_Argument, ...],
648        subemitter_binding_output_values: List[_AbstractValue],
649    ) -> List[_AbstractValue]:
650        """Emits torch.map.
651
652        Converts the higher order op into a loop constructed from jump instructions and primitive
653        int operations. A concat-like custom op is also injected into the submodule code to handle
654        the construction of the map output.
655
656        Below is what the input graph that is provided to emit_map looks like. class
657        TestMapCond(torch.nn.Module): def __init__(self):
658            super().__init__()
659
660        def forward(self, x,y):
661            return control_flow.map(map_fn, x, y)
662
663        Corresponding graph: def forward(self, arg0_1, arg1_1):
664            submodule_0 = self.submodule_0 map_1 = torch.ops.higher_order.map_impl(submodule_0, arg0_1, arg1_1);
665            submodule_0 = arg0_1 = arg1_1 = None return [map_1]
666
667        submodule_0: def forward(self, arg0_1, arg1_1):
668            add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None return
669            add_tensor
670
671        Post the transformations done by emit_map this is what the submodule program looks like. def
672        forward(self, arg0_1, arg1_1):
673            sym_size = torch.ops.aten.sym_size(arg0_1) # Emitter creates a variable here to track
674            iteration index select_copy_tensor = torch.ops.aten.select(arg0_1, 0, iteration_index)
675            add_tensor = torch.ops.aten.add.Tensor(select_copy_tensor, arg1_1);  arg0_1 = arg1_1 =
676            None output_of_map = torch.ops.executorch.prim.et_copy_index(output_of_map, add_tensor,
677            iteration_index) iteration_index = torch.ops.executorch.prim.add.int(iteration_index, 1,
678            iteration_index) done_bool = torch.ops.executorch.prim.eq.int(iteration_index, sym_size,
679            done_bool) # Emitter inserts a instruction here, if done_bool == False jump to
680            selcect_copy op # if not continue. return add_tensor
681        """
682        assert isinstance(
683            subemitter_binding_output_values, (list, tuple)
684        ), f"Expect a list for subemitter_binding_output_values for map. Got {subemitter_binding_output_values}."
685
686        if len(subemitter_binding_output_values) != 1:
687            raise RuntimeError(
688                f"Multiple outputs are not supported. Got {len(subemitter_binding_output_values)}."
689            )
690        f, mapped_args, inputs = args
691        assert isinstance(mapped_args, (list, tuple))
692        num_mapped_args: int = len(mapped_args)
693        if num_mapped_args != 1:
694            raise RuntimeError(
695                f"Emitting map with more than one mapped args is not supported. Got {num_mapped_args}."
696            )
697        x = mapped_args[0]
698
699        assert isinstance(f, torch.fx.GraphModule)
700
701        # Generate the EValue that we will use as our iterator index to keep track of which
702        # iteration we are currently on.
703        iter_idx = self._emit_evalue(EValue(Int(0)))
704        # Generate the kernel call that will output the number of iterations we need to run for this
705        # input tensor.
706        op_index, op = self._get_operator(
707            name="aten::sym_size",
708            overload="int",
709        )
710        sym_size = self._emit_evalue(EValue(Int(0)))
711        kernel = Instruction(
712            KernelCall(
713                op_index=op_index,
714                args=[x.id, self._emit_evalue(EValue(Int(0))).id, sym_size.id],
715            )
716        )
717        self.chain.instructions.append(kernel)
718
719        # This kernel call will slice the input tensor along the index specified in iter_idx to
720        # generate the input slice on which this iteration will be working on.
721        op_index, op = self._get_operator(
722            name="aten::select_copy",
723            overload="int_out",
724        )
725        # This select copy has to output to the tensor which is the input placeholder to the map
726        # sub-graph. That placeholder isn't allocated an EValue id until the map emitter is run, so
727        # we temporarily store -1 until the map emitter is run during which the placeholder will be
728        # allocated an EValue id. After the map emitter is run we will retrieve that id and replace
729        # the -1's.
730        kernel = Instruction(
731            KernelCall(
732                op_index=op_index,
733                args=[
734                    x.id,
735                    self._emit_evalue(EValue(Int(0))).id,
736                    iter_idx.id,
737                    -1,  # input_tensor_value.id,
738                    -1,  # input_tensor_value.id,
739                ],
740            )
741        )
742        # Store the index of this instruction as it will be where we will jump back to after the end
743        # of an iteration.
744        jump_to_instruction = self.instruction_start_offset + len(
745            self.chain.instructions
746        )
747        self.chain.instructions.append(kernel)
748
749        # Emit the map operator submodule.
750        map_emitter = _Emitter(
751            f,
752            self.emitter_state,
753            self.program_state,
754            instruction_start_offset=self.instruction_start_offset
755            + len(self.chain.instructions),
756            # Only the first input is a placeholder, rest of the inputs are args to the map fn.
757            binding_input_values=[-1, *inputs],
758            binding_output_values=subemitter_binding_output_values,
759        )
760        map_emitter.run()
761
762        # Merge all the instructions from the map submodule.
763        self._merge_chain(map_emitter.chain)
764        # Get rid of the return instruction emitted by the map subemitter.
765        self.chain.instructions.pop()
766        # At the end of each submodule emit we insert a move call that moves the output of the
767        # submodule to a deterministic EValue, which is especially useful for if/else branches where
768        # we want the output of either branch to be in the same EValue, but we don't need a move
769        # here as our custom op executorch_prim::et_copy_index which is inserted later does that
770        # for us.
771
772        # Now that the map emitter has finished running retrieve the input placeholder EValue id and
773        # update the select_copy kernel call to output to those id's.
774        kernel.instr_args.args[-1] = map_emitter.binding_input_values[0].id
775        kernel.instr_args.args[-2] = kernel.instr_args.args[-1]
776
777        self._internal_assert_emitter(
778            len(map_emitter.concrete_output_ids) == 1,
779            self.node,
780            "Map should return only one element",
781        )
782
783        # Here we call the custom op, specially added for the map operator. The output of this
784        # iteration will be appended to the accumulator tensor that we are maintaining. This
785        # accumulator tensor is the actual output of the map submodule.
786        op_index, op = self._get_operator(
787            name="executorch_prim::et_copy_index",
788            overload="tensor",
789        )
790        kernel = Instruction(
791            KernelCall(
792                op_index,
793                args=[
794                    subemitter_binding_output_values[0].id,
795                    map_emitter.concrete_output_ids[0].id,
796                    iter_idx.id,
797                ],
798            )
799        )
800        self.chain.instructions.append(kernel)
801
802        # Increment iter_idx to mark that we have completed an iteration.
803        op_index, op = self._get_operator(
804            name="executorch_prim::add",
805            overload="Scalar",
806        )
807        kernel = Instruction(
808            KernelCall(
809                op_index,
810                args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id],
811            )
812        )
813        self.chain.instructions.append(kernel)
814
815        jump_bool_value = self._emit_evalue(EValue(Bool(False)))
816
817        # Generate the kernel call to check whether or not we have completed all the iterations. If
818        # not jump back to the select_copy instruction that we generated at the beginning of this
819        # section.
820        op_index, op = self._get_operator(
821            name="executorch_prim::eq",
822            overload="Scalar",
823        )
824        kernel = Instruction(
825            KernelCall(
826                op_index,
827                args=[iter_idx.id, sym_size.id, jump_bool_value.id],
828            )
829        )
830        self.chain.instructions.append(kernel)
831
832        jf_beginning_loop = Instruction(
833            JumpFalseCall(
834                cond_value_index=jump_bool_value.id,
835                destination_instruction=jump_to_instruction,
836            )
837        )
838
839        self.chain.instructions.append(jf_beginning_loop)
840
841        # Reset iter_idx in case we plan to run the model again.
842        op_index, op = self._get_operator(
843            name="executorch_prim::sub",
844            overload="Scalar",
845        )
846        kernel = Instruction(
847            KernelCall(
848                op_index,
849                args=[iter_idx.id, sym_size.id, iter_idx.id],
850            )
851        )
852        self.chain.instructions.append(kernel)
853
854        return subemitter_binding_output_values
855
856    def _emit_control_flow(
857        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
858    ) -> _EmitterValue:
859        """Wraps common logic for emitting all control flow operations.
860
861        See the more specific emission functions for more details on how cond or map get emitted.
862        """
863        subemitter_binding_output_values = pytree.tree_map(
864            lambda spec: self._emit_spec(spec),
865            self.node.meta["spec"],
866        )
867
868        if target is torch.ops.higher_order.cond:
869            return self._emit_cond(args, subemitter_binding_output_values)
870        elif target is torch.ops.higher_order.map_impl:
871            return self._emit_map(args, subemitter_binding_output_values)
872        else:
873            raise InternalError(
874                self._emit_node_specific_error(
875                    self.node, f"Unsupported control flow operator: {target}"
876                )
877            )
878
879    def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
880        assert len(args) == 2
881
882        self_arg = self._emit_argument(args[0], torch.TensorType)  # pyre-ignore[6]
883        size_arg = self._emit_argument(args[1], torch.ListType.ofInts())
884        out_arg = self._emit_argument(
885            self._emit_spec(self.node.meta["spec"]), torch.TensorType  # pyre-ignore[6]
886        )
887
888        op_idx, op = self._get_operator(
889            name="executorch_prim::et_view",
890            overload="default",
891        )
892        kernel = Instruction(
893            KernelCall(
894                op_idx,
895                args=[
896                    self_arg.id,
897                    size_arg.id,
898                    out_arg.id,
899                ],
900            )
901        )
902        self.chain.instructions.append(kernel)
903        return out_arg
904
905    def _add_debug_handle(
906        self,
907        emitter_id: int,
908        target: _Target,
909        # pyre-ignore[11]: Annotation `LoweredBackendModule` is not defined as a type.
910        lowered_module: "Optional[LoweredBackendModule]" = None,  # noqa: F821
911    ) -> None:
912        """Updates the debug handle information for the current node.
913
914        If the current node is a delegate we agregate the debug handles of the subgraph and store
915        them in the map. If the current node is any other type we store the original information in
916        the debug handle map and replace it with the executorch instruction index corresponding to
917        this node.
918        """
919        # If it's a delegate call, collect the list of debug handles that are consumed by this
920        # delegate call and store it in the debug handle map.
921        if target == executorch_call_delegate:
922            debug_handle_list = []
923            # Use the lowered_module to fetch the original graph and its debug
924            # handles.
925            for node in lowered_module.original_module.graph.nodes:
926                if (
927                    node.op == "call_function"
928                    and node.meta.get("debug_handle") is not None
929                ):
930                    debug_handle_list.append(node.meta.get("debug_handle"))
931            self.debug_handle_map[emitter_id] = debug_handle_list
932            # Debug handle for this node is the emitter_id which is essentially the index of the
933            # instruction in the chain.
934            self.node.meta["debug_handle"] = emitter_id
935            return
936
937        if self.node.meta.get("debug_handle") is not None:
938            # Store the original debug handle in the debug handle map.
939            self.debug_handle_map[emitter_id] = self.node.meta.get("debug_handle")
940            # Replace the debug handle in the metadata of the node with the emitter id which
941            # represents the instruction index in the chain. We do this because in the runtime the
942            # instruction index is what is logged during perf/debug data logging and hence we want
943            # to store this in the node so that we can map the data logged by the runtime back to
944            # the node.
945            self.node.meta["debug_handle"] = emitter_id
946
947    def _add_delegate_map(
948        self,
949        lowered_module: "LoweredBackendModule",  # noqa
950        delegate_instruction_id: int,
951    ) -> None:
952        """
953        Store the delegate map from this lowered module into the dictionary of delegate maps. It
954        will later be used for various debugging purposes such as linking back to original source
955        code, module hierarchy etc.
956        """
957        delegate_map = {}
958        if hasattr(lowered_module, "meta"):
959            delegate_map = lowered_module.meta.get("debug_handle_map", {})
960
961        self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = {
962            "name": lowered_module.backend_id,
963            "delegate_map": delegate_map,
964        }
965
966    def _emit_argument(
967        self, arg: _Argument, arg_type: Optional[_SchemaType]
968    ) -> _AbstractValue:
969        """Emit an argument to an operator or delegate if it had not already been emitted otherwise
970        return the previously emitted location"""
971        if isinstance(arg, _AbstractValue):
972            return arg
973        return self._emit_evalue(self._constant_to_evalue(arg, arg_type))
974
975    def _get_sym_ret(
976        self,
977        val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]],
978    ) -> Optional[_AbstractValue]:
979        """
980        Returns the emit ret for sym value.
981        """
982        ret = None
983        if isinstance(val, torch.SymInt):
984            ret = self._emit_evalue(EValue(Int(0)))
985        elif isinstance(val, torch.BoolType):
986            ret = self._emit_evalue(EValue(Bool(False)))
987        elif isinstance(val, torch.FloatType):
988            ret = self._emit_evalue(EValue(Double(0)))
989        return ret
990
991    def _get_sym_and_fake_tensor_ret(
992        self,
993        val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]],
994        spec: TensorSpec,
995    ) -> Union[List[_AbstractValue], _AbstractValue, Tuple[_AbstractValue, ...]]:
996        # Try to get the ret if it's a sym value.
997        ret = self._get_sym_ret(val)
998        # If the ret is None, it means that the val is not a sym value, but a regular tensor
999        if ret is None:
1000            ret = self._emit_spec(spec)
1001        assert ret is not None, "Can't have a None ret"
1002        return ret
1003
1004    def _emit_delegate(
1005        self,
1006        lowered_module: "LoweredBackendModule",  # noqa
1007        args: Tuple[_Argument, ...],
1008        kwargs: Dict[str, _Argument],
1009    ) -> _EmitterValue:
1010        """Emit the delegates inputs and outputs as specified by the schema, then emit the
1011        delegate's blob."""
1012        processed_bytes = lowered_module.processed_bytes
1013
1014        delegate_index = self.emitter_state.delegate_cache.get(processed_bytes)
1015        delegate_ret = None
1016
1017        if isinstance(self.node.meta["spec"], list):
1018            delegate_ret = []
1019            for index, _ in enumerate(self.node.meta["val"]):
1020                ret = self._get_sym_and_fake_tensor_ret(
1021                    self.node.meta["val"][index], self.node.meta["spec"][index]
1022                )
1023                delegate_ret.append(ret)
1024        elif isinstance(self.node.meta["spec"], tuple):
1025            if isinstance(self.node.meta["val"], FakeTensor):
1026                # There is a case when node.meta["spec"] is (TensorSpec, ) while node.meta["val"] is FakeTensor
1027                ret = self._get_sym_and_fake_tensor_ret(
1028                    self.node.meta["val"], self.node.meta["spec"][0]
1029                )
1030                delegate_ret = (ret,)
1031            else:
1032                delegate_ret = []
1033                for index, _ in enumerate(self.node.meta["val"]):
1034                    ret = self._get_sym_and_fake_tensor_ret(
1035                        self.node.meta["val"][index], self.node.meta["spec"][index]
1036                    )
1037                    delegate_ret.append(ret)
1038                delegate_ret = tuple(delegate_ret)
1039        elif isinstance(self.node.meta["spec"], TensorSpec):
1040            ret = self._get_sym_and_fake_tensor_ret(
1041                self.node.meta["val"], self.node.meta["spec"]
1042            )
1043            delegate_ret = ret
1044        else:
1045            raise NotImplementedError(
1046                f"self.node.meta['spec'] {type(self.node.meta['spec'])} is not supported"
1047            )
1048        assert delegate_ret is not None, "Can't have a None delegate_ret"
1049        if delegate_index is None:
1050            # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
1051            # present.
1052            data_index: int = len(self.program_state.backend_delegate_data)
1053            self.program_state.backend_delegate_data.append(
1054                BackendDelegateInlineData(data=processed_bytes)
1055            )
1056
1057            backend_delegate = BackendDelegate(
1058                id=lowered_module.backend_id,
1059                processed=BackendDelegateDataReference(
1060                    location=DataLocation.INLINE, index=data_index
1061                ),
1062                compile_specs=lowered_module.compile_specs,
1063            )
1064            delegate_index = len(self.emitter_state.delegate_cache)
1065            self.emitter_state.delegates.append(backend_delegate)
1066            self.emitter_state.delegate_cache[processed_bytes] = delegate_index
1067
1068        # TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
1069        # function's spec and with default arguments. This requires us to store the function's spec
1070        # in to_backend()
1071        delegate_args = [
1072            self._emit_argument(arg, None).id
1073            for arg in typing.cast(List[_Argument], args)
1074        ]
1075
1076        for elem in pytree.tree_flatten(delegate_ret)[0]:
1077            delegate_args.append(elem.id)
1078
1079        self.chain.instructions.append(
1080            Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args))
1081        )
1082
1083        return delegate_ret
1084
1085    def _get_operator(self, name: str, overload: str) -> Tuple[int, Operator]:
1086        """Given a fully qualified name, lookups the operator in the ExecuTorch Program, or adds it
1087        if it is not already present"""
1088        key = (name, overload)
1089        op_index = self.emitter_state.operator_cache.get(key)
1090        if op_index is not None:
1091            return op_index, self.emitter_state.operators[op_index]
1092
1093        op_index, operator = len(self.emitter_state.operators), Operator(
1094            name=name, overload=overload
1095        )
1096        self.emitter_state.operators.append(operator)
1097        self.emitter_state.operator_cache[key] = op_index
1098        return op_index, operator
1099
1100    def _emit_operator(  # noqa: C901
1101        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1102    ) -> _EmitterValue:
1103        """Emits an operator (aten or custom), directly translates to a call_kernel instruction."""
1104        assert isinstance(
1105            target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload)
1106        ), f"target is {target}"
1107
1108        # grab the name
1109        op_name = target._overloadpacket._qualified_op_name
1110        op_overload = ""
1111        if target._overloadname != "default":
1112            op_overload = target._overloadname
1113
1114        def _get_empty_tensor_evalue() -> EValue:
1115            """Constructs an EValue for an empty tensor."""
1116            return EValue(
1117                Tensor(
1118                    scalar_type=ScalarType.BYTE,
1119                    # The runtime currently only supports tensors with offset 0.
1120                    storage_offset=0,
1121                    sizes=[0],
1122                    dim_order=[],
1123                    requires_grad=False,
1124                    layout=0,
1125                    data_buffer_idx=0,
1126                    allocation_info=None,
1127                    shape_dynamism=TensorShapeDynamism.STATIC,
1128                )
1129            )
1130
1131        op_index, operator = self._get_operator(name=op_name, overload=op_overload)
1132
1133        # Emit the args and kwargs in the order according to the function schema.
1134        kernel_args = []
1135        out_args = []
1136        for i, schema_arg in enumerate(target._schema.arguments):
1137            if schema_arg.name in kwargs:
1138                kernel_arg = kwargs[schema_arg.name]
1139            elif not schema_arg.kwarg_only and i < len(args):
1140                kernel_arg = args[i]
1141            else:
1142                # Emit default values
1143                kernel_arg = schema_arg.default_value
1144
1145            if kernel_arg is None and isinstance(schema_arg.type, torch.TensorType):
1146                kernel_arg = self._emit_evalue(_get_empty_tensor_evalue())
1147
1148            kernel_args.append(self._emit_argument(kernel_arg, schema_arg.type).id)
1149
1150            if schema_arg.is_out:
1151                out_args.append((schema_arg.name, kernel_arg))
1152
1153        if is_out_variant(op_name, op_overload):
1154            ret = [val for _, val in out_args]
1155            ret = ret[0] if len(ret) == 1 else ret
1156        elif is_sym_op(target):
1157            assert (
1158                len(target._schema.returns) == 1
1159            ), "Only returning a single Sym from symbolic ops is supported currently."
1160            assert type(target._schema.returns[0].type) in (
1161                torch.IntType,
1162                torch.FloatType,
1163                torch.BoolType,
1164                torch.NumberType,
1165            ), f"Only symbolic ops that return a Int Bool Float are supported currently got {type(target._schema.returns[0].type)}."
1166            ret = self._get_sym_ret(target._schema.returns[0])
1167            if ret is None:  # type(target._schema.returns[0].type) == torch.NumberType:
1168                # Cant definitively say what type this is, the runtime operator just overrides the EValue completely
1169                # though so we can just serialize whatever as a placeholder.
1170                ret = self._emit_evalue(EValue(Int(0)))
1171        else:
1172            ret = self._emit_spec(self.node.meta["spec"])
1173
1174        out_args = (
1175            self._emit_evalue(
1176                EValue(TensorList([cast(_AbstractValue, val).id for val in ret]))
1177            )
1178            if isinstance(ret, list)
1179            else ret
1180        )
1181
1182        for elem in pytree.tree_flatten(out_args)[0]:
1183            kernel_args.append(cast(_AbstractValue, elem).id)
1184
1185        self.chain.instructions.append(
1186            Instruction(KernelCall(op_index=op_index, args=kernel_args))
1187        )
1188        self._add_debug_handle(len(self.chain.instructions) - 1, target)
1189
1190        # Get the stacktrace if it exists for each instruction.
1191        if self.emitter_state.emit_stacktrace:
1192            stack_trace = self.node.meta["stack_trace"]
1193            chain_stacktrace = self.chain.stacktrace or []
1194
1195            chain_stacktrace.append(_stacktrace_to_framelist(stack_trace))
1196            self._internal_assert_emitter(
1197                len(chain_stacktrace) == len(self.chain.instructions),
1198                self.node,
1199                f"Each instruction should have corresponding stacktrace received {len(self.chain.instructions)} \
1200                instructions and {len(chain_stacktrace)} stacktraces",
1201            )
1202            self.chain.stacktrace = chain_stacktrace
1203
1204        return cast(_EmitterValue, ret)
1205
1206    def _emit_free(self, spec: TensorSpec) -> _AbstractValue:
1207        """Emits a FreeCall instruction to release a given Unbounded Tensor's memory."""
1208        self.chain.instructions.append(
1209            Instruction(FreeCall(value_index=self.emitter_state.spec2id(spec)))
1210        )
1211        # The value is not used but the caller expects an AbstractValue returned.
1212        return _AbstractValue(None, None)  # pyre-ignore
1213
1214    def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
1215        """
1216        Given a mapping of function names to return values, emit simple execution
1217        plans that just return these constant values.
1218
1219        Precondition: All the values are primitives (bool, float, int, str, enum)
1220        or structures (list, dict) of them.
1221        """
1222        plans = []
1223        # flatten any structures
1224        for method, vals in prim_getters.items():
1225            # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
1226            flattened_output, spec = ex_pytree.tree_flatten(vals)
1227            spec = spec.to_str()
1228            chain = Chain(
1229                inputs=[],
1230                outputs=[],
1231                instructions=[],
1232                stacktrace=None,
1233            )
1234
1235            # switch on type of prim
1236            values = []
1237            for val in flattened_output:
1238                if isinstance(val, float):
1239                    values.append(EValue(Double(val)))
1240
1241                elif isinstance(val, bool):
1242                    values.append(EValue(Bool(val)))
1243
1244                elif isinstance(val, int):
1245                    values.append(EValue(Int(val)))
1246
1247                elif isinstance(val, str):
1248                    values.append(EValue(String(val)))
1249
1250                elif isinstance(val, torch.dtype):
1251                    values.append(EValue(Int(scalar_type_enum(val))))
1252
1253                elif isinstance(val, torch.layout):
1254                    values.append(EValue(Int(layout_enum(val))))
1255
1256                elif isinstance(val, torch.Tensor):
1257                    values.append(
1258                        self._tensor_spec_to_evalue(
1259                            TensorSpec.from_tensor(val, const=True)
1260                        )
1261                    )
1262
1263                else:
1264                    raise ExportError(
1265                        ExportErrorType.NOT_SUPPORTED,
1266                        f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive",
1267                    )
1268
1269            # add to plans
1270            plans.append(
1271                ExecutionPlan(
1272                    name=method,
1273                    values=values,
1274                    inputs=[],
1275                    outputs=list(range(0, len(values))),
1276                    chains=[chain],
1277                    operators=[],
1278                    delegates=[],
1279                    non_const_buffer_sizes=[0],
1280                    container_meta_type=ContainerMetadata("", spec),
1281                )
1282            )
1283        return plans
1284
1285    def fetch_attr(self, target: _Target) -> _AbstractValue:
1286        """Fetch weights and other module parameters. If the attribute is a tensor, emit it."""
1287        attr = super().fetch_attr(target)  # pyre-fixme[6]
1288
1289        if isinstance(attr, torch.Tensor):
1290            return self._emit_evalue(
1291                self._tensor_spec_to_evalue(TensorSpec.from_tensor(attr, const=True))
1292            )
1293
1294        elif isinstance(attr, torch._C.ScriptObject):
1295            raise ExportError(
1296                ExportErrorType.NOT_SUPPORTED,
1297                f"Custom class {attr} is not supported in EXIR",
1298            )
1299
1300        else:
1301            return attr
1302
1303    def call_module(  # pyre-fixme[14]
1304        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1305    ) -> None:
1306        """Unsupported in execution IR, so unhandled by the emitter."""
1307        raise InternalError(
1308            self._emit_node_specific_error(self.node, "call_module is not supported")
1309        )
1310
1311    def call_method(  # pyre-fixme[14]
1312        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1313    ) -> _EmitterValue:
1314        """Unsupported in execution IR, so unhandled by the emitter."""
1315        raise InternalError(
1316            self._emit_node_specific_error(self.node, "call_method is not supported")
1317        )
1318
1319    def placeholder(  # pyre-fixme[14]
1320        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1321    ) -> _AbstractValue:
1322        """Performs actions for the placeholder node of a graph module.
1323
1324        The placeholder node of the top level entry point is handled by TopLevelEmitter. This
1325        function only executes on control flow subgraphs. Takes the inputs of the subgraph that had
1326        not previously been emitted and emits them.
1327        """
1328        # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
1329        value = self.binding_input_values[self.placeholder_count]
1330        # This indicates that the placeholder wasn't allocated an EValue id before this sub-emitter
1331        # was run, so we generate one now.
1332        if value == -1:
1333            value = self._emit_evalue(
1334                self._tensor_spec_to_evalue(self.node.meta["spec"])
1335            )
1336            # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
1337            self.binding_input_values[self.placeholder_count] = value
1338        self.placeholder_count += 1
1339        return value
1340
1341    def output(  # pyre-fixme[14]
1342        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1343    ) -> None:
1344        """Performs actions for the output node of a graph module.
1345
1346        The output node of the top level entry point is handled by TopLevelEmitter. This function
1347        only executes on control flow subgraphs. Takes the outputs of the subgraph (if any) and
1348        inserts instructions to move them to the common output location between control flow
1349        branches.
1350        """
1351        self.concrete_output_ids = list(pytree.tree_flatten(args[0])[0])
1352        binding_output_values = self.binding_output_values
1353        if binding_output_values is not None:
1354            binding_output_list, _ = pytree.tree_flatten(binding_output_values)
1355
1356            self._internal_assert_emitter(
1357                len(binding_output_list) == len(self.concrete_output_ids),
1358                self.node,
1359                "The number of binding output values should match the args to output",
1360            )
1361
1362            for move_from, move_to in zip(
1363                self.concrete_output_ids, binding_output_list
1364            ):
1365                if move_from != move_to:
1366                    instruction = Instruction(
1367                        MoveCall(move_from=move_from.id, move_to=move_to.id)
1368                    )
1369                    self.chain.instructions.append(instruction)
1370
1371    def call_function(  # pyre-fixme[14]
1372        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1373    ) -> _EmitterValue:
1374        """Performs actions for the call_function node of a graph module.
1375
1376        Dispatches based on 'target' and emits the corresponding function. 'call_function' is a
1377        powerful node that contains many operations ranging from control_flow, to memory management,
1378        to delegate and operator calls.
1379        """
1380
1381        # Delegate and operator calls are the only functions that should have a debug handle
1382        # associated with them. All the others such as memory.alloc, getitem should be ignored.
1383        # Default to none and let delegates and ops override.
1384        if target == operator.getitem:
1385            assert len(args) == 2
1386            head = typing.cast(Mapping[int, _EmitterValue], args[0])
1387            index = typing.cast(int, args[1])
1388            return head[index]
1389
1390        elif target == memory.alloc:
1391            assert len(args) == 1
1392            return self._emit_spec(self.node.meta["spec"])
1393
1394        elif target == memory.view:
1395            return self._emit_view(args)
1396
1397        elif target == memory.free:
1398            assert len(args) == 1
1399            # pyre-ignore
1400            return self._emit_free(args[0])
1401
1402        elif target is torch.ops.higher_order.cond:
1403            return self._emit_control_flow(target, args, kwargs)
1404
1405        elif target is torch.ops.higher_order.map_impl:
1406            return self._emit_control_flow(target, args, kwargs)
1407
1408        elif target == executorch_call_delegate:
1409            lowered_module = args[0]
1410            assert is_lowered_module(lowered_module)
1411            v = self._emit_delegate(lowered_module, args[1:], kwargs)
1412            delegate_instruction_id = len(self.chain.instructions) - 1
1413            self._add_debug_handle(delegate_instruction_id, target, lowered_module)
1414            self._add_delegate_map(lowered_module, delegate_instruction_id)
1415            return v
1416
1417        elif isinstance(
1418            target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload)
1419        ):
1420            return self._emit_operator(target, args, kwargs)
1421
1422        else:
1423            raise InternalError(
1424                self._emit_node_specific_error(
1425                    self.node, f"invalid target for call_function {target}"
1426                )
1427            )
1428
1429    def run(  # pyre-fixme[14]
1430        self,
1431        *args: _Argument,
1432        initial_env: Optional[Dict[torch.fx.Node, _Argument]] = None,
1433    ) -> None:
1434        """Traverses all nodes in the graph module and emits each one appropriately."""
1435        super().run(*args, initial_env, enable_io_processing=False)
1436
1437    def run_node(self, n: torch.fx.Node) -> None:
1438        """Executes and emits the specified node.
1439
1440        For more context on what a node is and what execution means see
1441        https://pytorch.org/docs/stable/fx.html#torch.fx.Node
1442        """
1443        self.node = n
1444        try:
1445            ret = super().run_node(n)
1446        except Exception as e:
1447            if isinstance(e, (InternalError, ExportError)):
1448                raise e
1449            else:
1450                raise InternalError(
1451                    self._emit_node_specific_error(self.node, str(e))
1452                ) from e
1453        return ret
1454
1455
1456class _TopLevelEmitter(_Emitter):
1457    """An emitter that manages the root level operations within a graph module.
1458
1459    Exists as a separate class so that 'Emitter' can handle the special behavior of 'placeholder'
1460    and 'output' nodes in control flow submodules.
1461    """
1462
1463    def __init__(
1464        self,
1465        name: str,
1466        exported_program: ExportedProgram,
1467        graph_module: torch.fx.GraphModule,
1468        program_state: _ProgramState,
1469        emitter_state: _EmitterState,
1470    ) -> None:
1471        super().__init__(graph_module, emitter_state, program_state)
1472        self.name = name
1473        self.exported_program = exported_program
1474
1475        self.inputs: List[int] = []
1476        self.outputs: List[int] = []
1477        self.given_mutable_buffer_warning = False
1478
1479        def create_container_str(spec: Optional[pytree.TreeSpec]) -> str:
1480            if spec is None:
1481                return ""
1482            assert isinstance(spec, pytree.TreeSpec), type(spec)
1483            dummy_leaves = [0] * spec.num_leaves
1484            tree = torch.utils._pytree.tree_unflatten(dummy_leaves, spec)
1485            # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
1486            _, tree = ex_pytree.tree_flatten(tree)
1487            return tree.to_str()
1488
1489        inp_container_str = create_container_str(exported_program.call_spec.in_spec)
1490        out_container_str = create_container_str(exported_program.call_spec.out_spec)
1491
1492        self.container_meta_type = ContainerMetadata(
1493            inp_container_str, out_container_str
1494        )
1495
1496    def _find_fqn_for_placeholder(
1497        self, target: _Target, spec: Any  # pyre-ignore[2]
1498    ) -> Tuple[Optional[str], bool]:
1499        # Find the fully qualified name
1500        fqn = None
1501        is_mutable_buffer = False
1502        if target in self.exported_program.graph_signature.inputs_to_parameters:
1503            fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
1504
1505        elif target in self.exported_program.graph_signature.inputs_to_buffers:
1506            fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1507
1508            # if the buffer is mutated then record that
1509            if fqn in self.exported_program.graph_signature.buffers_to_mutate.values():
1510                is_mutable_buffer = True
1511                if not self.given_mutable_buffer_warning:
1512                    warnings.warn(
1513                        "Mutation on a buffer in the model is detected. ExecuTorch assumes "
1514                        "buffers that are mutated in the graph have a meaningless initial state, "
1515                        "only the shape and dtype will be serialized.",
1516                        UserWarning,
1517                        stacklevel=1,
1518                    )
1519                    self.given_mutable_buffer_warning = True
1520
1521        elif (
1522            target
1523            in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
1524        ):
1525            fqn = (
1526                self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
1527                    target
1528                ]
1529            )
1530        return fqn, is_mutable_buffer
1531
1532    def placeholder(
1533        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1534    ) -> _AbstractValue:
1535        """Emits the value within the placeholder node.
1536
1537        For more information on placeholder nodes see
1538        https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
1539        """
1540        spec = self.node.meta["spec"]
1541        is_user_input = True
1542
1543        if isinstance(target, str) and isinstance(spec, TensorSpec):
1544            fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
1545
1546            # From the fqn find the corresponding tensor
1547            real_tensor = None
1548            if fqn in self.exported_program.state_dict:
1549                real_tensor = self.exported_program.state_dict[fqn]
1550                is_user_input = False
1551
1552            elif fqn in self.exported_program.constants:
1553                real_tensor = self.exported_program.constants[fqn]
1554                is_user_input = False
1555            elif fqn is not None:
1556                buffers = self.exported_program.named_buffers()
1557                buf = next((x[1] for x in buffers if x[0] == fqn), None)
1558                if buf is not None:
1559                    real_tensor = buf
1560                    is_user_input = False
1561                else:
1562                    raise InternalError(
1563                        self._emit_node_specific_error(
1564                            self.node,
1565                            f"Could not find buffer with fqn {fqn} in state_dict or named_buffers",
1566                        )
1567                    )
1568
1569            # assign the storage of the placeholder spec to the storage of the real tensor if there is one
1570            if real_tensor is not None:
1571                # for non-contigous tensors, convert to a contiguous one
1572                real_tensor = real_tensor.contiguous()
1573                # Weights cannot be views during emission or serialization
1574                if real_tensor.nbytes != real_tensor.untyped_storage().nbytes():
1575                    real_tensor = real_tensor.clone()
1576
1577                spec.storage = real_tensor.untyped_storage()
1578
1579            # User inputs and mutable buffers are not constants, other buffers or parameters are.
1580            spec.const = not (is_user_input or is_mutable_buffer)
1581
1582        evalue = (
1583            self._tensor_spec_to_evalue(spec)
1584            if isinstance(spec, TensorSpec)
1585            else self._constant_to_evalue(spec, None)
1586        )
1587        value = self._emit_evalue(evalue)
1588
1589        # Only user inputs should remain as inputs.
1590        if is_user_input:
1591            self.inputs.append(value.id)
1592
1593        return value
1594
1595    def output(
1596        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1597    ) -> None:
1598        """Records the ExecutionPlan's outputs based on the output node in the graph."""
1599        if isinstance(args[0], dict):
1600            args_tuple, _ = pytree.tree_flatten(args[0])
1601        else:
1602            args_tuple = typing.cast(Tuple[_AbstractValue, ...], args[0])
1603        if isinstance(args_tuple, _AbstractValue):
1604            self.outputs.append(args_tuple.id)
1605        else:
1606            for arg in args_tuple:
1607                if isinstance(arg, (int, float, bool, type(None))):
1608                    arg = self._emit_evalue(self._constant_to_evalue(arg, None))
1609                elif isinstance(arg, str):
1610                    # TODO(jackkhuu): T181599879 Add support for string outputs IFF compiler supports
1611                    raise InternalError(
1612                        self._emit_node_specific_error(
1613                            self.node,
1614                            f"Returning {arg} is not yet supported in the emitter.",
1615                        )
1616                    )
1617                else:
1618                    # Every other output should already have its value emitted.
1619                    # They should only be abstract IDs at this point.
1620                    assert isinstance(arg, _AbstractValue)
1621                self.outputs.append(arg.id)
1622
1623    def plan(self) -> ExecutionPlan:
1624        """Returns the execution plan emitted from this entry point."""
1625        return ExecutionPlan(
1626            name=self.name,
1627            values=self.emitter_state.values,
1628            inputs=self.inputs,
1629            outputs=self.outputs,
1630            chains=[self.chain],
1631            operators=self.emitter_state.operators,
1632            delegates=self.emitter_state.delegates,
1633            # non_const_buffer_sizes field is set by the memory_planning_pass. In case the field is
1634            # missing in scenarios like unit test that does not enable memory planning, assume an
1635            # empty list.
1636            non_const_buffer_sizes=typing.cast(
1637                # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorB...
1638                List[int], self.module.meta["non_const_buffer_sizes"]
1639            ),
1640            container_meta_type=self.container_meta_type,
1641        )
1642