xref: /aosp_15_r20/external/executorch/exir/verification/interpreter.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport copy
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import List, Optional, Union
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[21]: Could not find module `executorch.exir.verification.bindings`.
13*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.verification.bindings as bindings  # @manual=//executorch/exir/verification:bindings
14*523fa7a6SAndroid Build Coastguard Workerimport executorch.extension.pytree as ex_pytree
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerimport torch
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch import exir
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import (
21*523fa7a6SAndroid Build Coastguard Worker    Bool,
22*523fa7a6SAndroid Build Coastguard Worker    BoolList,
23*523fa7a6SAndroid Build Coastguard Worker    Double,
24*523fa7a6SAndroid Build Coastguard Worker    DoubleList,
25*523fa7a6SAndroid Build Coastguard Worker    ExecutionPlan,
26*523fa7a6SAndroid Build Coastguard Worker    Int,
27*523fa7a6SAndroid Build Coastguard Worker    IntList,
28*523fa7a6SAndroid Build Coastguard Worker    JumpFalseCall,
29*523fa7a6SAndroid Build Coastguard Worker    KernelCall,
30*523fa7a6SAndroid Build Coastguard Worker    KernelTypes,
31*523fa7a6SAndroid Build Coastguard Worker    MoveCall,
32*523fa7a6SAndroid Build Coastguard Worker    Null,
33*523fa7a6SAndroid Build Coastguard Worker    Operator,
34*523fa7a6SAndroid Build Coastguard Worker    OptionalTensorList,
35*523fa7a6SAndroid Build Coastguard Worker    Program,
36*523fa7a6SAndroid Build Coastguard Worker    String,
37*523fa7a6SAndroid Build Coastguard Worker    Tensor,
38*523fa7a6SAndroid Build Coastguard Worker    TensorList,
39*523fa7a6SAndroid Build Coastguard Worker)
40*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import get_scalar_type, stride_from_dim_order
41*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import impl, Library
42*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import PyTree
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Workerclass Uninitialized:
46*523fa7a6SAndroid Build Coastguard Worker    pass
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard WorkerValueListType = Union[
50*523fa7a6SAndroid Build Coastguard Worker    List[torch.Tensor],
51*523fa7a6SAndroid Build Coastguard Worker    List[bool],
52*523fa7a6SAndroid Build Coastguard Worker    List[float],
53*523fa7a6SAndroid Build Coastguard Worker    List[int],
54*523fa7a6SAndroid Build Coastguard Worker]
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard WorkerValueScalarType = Union[int, str, float, torch.Tensor, Uninitialized, None]
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard WorkerValueType = Union[
59*523fa7a6SAndroid Build Coastguard Worker    ValueScalarType,
60*523fa7a6SAndroid Build Coastguard Worker    ValueListType,
61*523fa7a6SAndroid Build Coastguard Worker]
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker# defining the operator executorch.move
64*523fa7a6SAndroid Build Coastguard Workerexecutorch_lib = Library("executorch", "DEF")
65*523fa7a6SAndroid Build Coastguard Workerexecutorch_lib.define("move(Tensor self) -> Tensor")
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Worker@impl(executorch_lib, "move", "CPU")
69*523fa7a6SAndroid Build Coastguard Workerdef move_impl(self: torch.Tensor) -> torch.Tensor:
70*523fa7a6SAndroid Build Coastguard Worker    return torch.clone(self)
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Workerdef comp_types(val: KernelTypes, input_val: ValueType) -> Optional[bool]:  # noqa
74*523fa7a6SAndroid Build Coastguard Worker    """
75*523fa7a6SAndroid Build Coastguard Worker    Compares a schema type (val) with Python type (input_val)
76*523fa7a6SAndroid Build Coastguard Worker    Map: Int -> int, Bool -> bool, Double -> float,
77*523fa7a6SAndroid Build Coastguard Worker         String -> str, Tensor -> torch.Tensor, XList -> list[x]
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker    Args:
80*523fa7a6SAndroid Build Coastguard Worker    `val`: value from value list with type from `schema.py`
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker    `input_val`: value with Python type, normally from the result of executing an operation
83*523fa7a6SAndroid Build Coastguard Worker    """
84*523fa7a6SAndroid Build Coastguard Worker    if isinstance(val, Int):
85*523fa7a6SAndroid Build Coastguard Worker        return isinstance(input_val, type(val.int_val))
86*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(val, Bool):
87*523fa7a6SAndroid Build Coastguard Worker        return isinstance(input_val, type(val.bool_val))
88*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(val, Double):
89*523fa7a6SAndroid Build Coastguard Worker        return isinstance(input_val, type(val.double_val))
90*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(val, String):
91*523fa7a6SAndroid Build Coastguard Worker        return isinstance(input_val, type(val.string_val))
92*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(val, Tensor):
93*523fa7a6SAndroid Build Coastguard Worker        return isinstance(input_val, torch.Tensor)
94*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(val, IntList):
95*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(input_val, list):
96*523fa7a6SAndroid Build Coastguard Worker            return False
97*523fa7a6SAndroid Build Coastguard Worker        return all(isinstance(x, int) for x in input_val)
98*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(val, BoolList):
99*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(input_val, list):
100*523fa7a6SAndroid Build Coastguard Worker            return False
101*523fa7a6SAndroid Build Coastguard Worker        return all(isinstance(x, bool) for x in input_val)
102*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(val, DoubleList):
103*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(input_val, list):
104*523fa7a6SAndroid Build Coastguard Worker            return False
105*523fa7a6SAndroid Build Coastguard Worker        return all(isinstance(x, float) for x in input_val)
106*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(val, (TensorList, OptionalTensorList)):
107*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(input_val, list):
108*523fa7a6SAndroid Build Coastguard Worker            return False
109*523fa7a6SAndroid Build Coastguard Worker        return all(isinstance(x, torch.Tensor) for x in input_val)
110*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(val, Null):
111*523fa7a6SAndroid Build Coastguard Worker        raise TypeError("Setting a value where there should be a Null")
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Workerdef resolve_op(operator: Operator) -> torch._ops.OpOverload:
115*523fa7a6SAndroid Build Coastguard Worker    # pattern matching out the namespace and operation name
116*523fa7a6SAndroid Build Coastguard Worker    namespace, op_name = operator.name.split("::")
117*523fa7a6SAndroid Build Coastguard Worker    op = getattr(getattr(getattr(torch.ops, namespace), op_name), operator.overload)
118*523fa7a6SAndroid Build Coastguard Worker    return op
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Worker
121*523fa7a6SAndroid Build Coastguard Workerdef make_operators_list(
122*523fa7a6SAndroid Build Coastguard Worker    execution_plan: ExecutionPlan,
123*523fa7a6SAndroid Build Coastguard Worker) -> List[torch._ops.OpOverload]:
124*523fa7a6SAndroid Build Coastguard Worker    operator_list = [resolve_op(operator) for operator in execution_plan.operators]
125*523fa7a6SAndroid Build Coastguard Worker    return operator_list
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Workerclass Interpreter:
129*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, program: Program) -> None:
130*523fa7a6SAndroid Build Coastguard Worker        # Currently there is only 1 execution plan in the list -- this assert will help
131*523fa7a6SAndroid Build Coastguard Worker        # catch any changes in the future
132*523fa7a6SAndroid Build Coastguard Worker        assert len(program.execution_plan) == 1
133*523fa7a6SAndroid Build Coastguard Worker        self.execution_plan: exir.schema.ExecutionPlan = program.execution_plan[0]
134*523fa7a6SAndroid Build Coastguard Worker        self.container_metatype: exir.schema.ContainerMetadata = program.execution_plan[
135*523fa7a6SAndroid Build Coastguard Worker            0
136*523fa7a6SAndroid Build Coastguard Worker        ].container_meta_type
137*523fa7a6SAndroid Build Coastguard Worker
138*523fa7a6SAndroid Build Coastguard Worker        # create buffer in memory and get reference to it
139*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
140*523fa7a6SAndroid Build Coastguard Worker        self.data_buffers: List[bindings.DataBuffer] = [
141*523fa7a6SAndroid Build Coastguard Worker            # pyre-ignore
142*523fa7a6SAndroid Build Coastguard Worker            bindings.DataBuffer(b.storage, len(b.storage))
143*523fa7a6SAndroid Build Coastguard Worker            for b in program.constant_buffer
144*523fa7a6SAndroid Build Coastguard Worker        ]
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker        # generate the list of values (including tensors) and operators from the execution plan
147*523fa7a6SAndroid Build Coastguard Worker        self._value_list: List[ValueType] = [
148*523fa7a6SAndroid Build Coastguard Worker            Uninitialized() for val in self.execution_plan.values
149*523fa7a6SAndroid Build Coastguard Worker        ]
150*523fa7a6SAndroid Build Coastguard Worker        self._operators_list: List[torch._ops.OpOverload] = make_operators_list(
151*523fa7a6SAndroid Build Coastguard Worker            self.execution_plan
152*523fa7a6SAndroid Build Coastguard Worker        )
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Worker    def get_value_list(self) -> List[ValueType]:
155*523fa7a6SAndroid Build Coastguard Worker        # TODO(meghajain) may need to change deepcopy to clone
156*523fa7a6SAndroid Build Coastguard Worker        return copy.deepcopy(self._value_list)
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker    def get_operators_list(self) -> List[torch._ops.OpOverload]:
159*523fa7a6SAndroid Build Coastguard Worker        return self._operators_list
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker    def get_constant_tensors(self) -> List[Tensor]:
162*523fa7a6SAndroid Build Coastguard Worker        """
163*523fa7a6SAndroid Build Coastguard Worker        No side effects on Interpreter's value list. List of constant tensors returned
164*523fa7a6SAndroid Build Coastguard Worker        without having to run program.
165*523fa7a6SAndroid Build Coastguard Worker        """
166*523fa7a6SAndroid Build Coastguard Worker        tensors = []
167*523fa7a6SAndroid Build Coastguard Worker        for elem in self.execution_plan.values:
168*523fa7a6SAndroid Build Coastguard Worker            val = elem.val
169*523fa7a6SAndroid Build Coastguard Worker            if isinstance(val, Tensor) and val.data_buffer_idx != 0:
170*523fa7a6SAndroid Build Coastguard Worker                # load val into res
171*523fa7a6SAndroid Build Coastguard Worker                # pyre-fixme[16]
172*523fa7a6SAndroid Build Coastguard Worker                tensor = bindings.convert_to_tensor(
173*523fa7a6SAndroid Build Coastguard Worker                    self.data_buffers[val.data_buffer_idx],
174*523fa7a6SAndroid Build Coastguard Worker                    val.scalar_type,
175*523fa7a6SAndroid Build Coastguard Worker                    val.sizes,
176*523fa7a6SAndroid Build Coastguard Worker                    stride_from_dim_order(val.sizes, val.dim_order),
177*523fa7a6SAndroid Build Coastguard Worker                )
178*523fa7a6SAndroid Build Coastguard Worker                tensors.append(tensor)
179*523fa7a6SAndroid Build Coastguard Worker        return tensors
180*523fa7a6SAndroid Build Coastguard Worker
181*523fa7a6SAndroid Build Coastguard Worker    def load_value(self, idx: int) -> None:
182*523fa7a6SAndroid Build Coastguard Worker        """
183*523fa7a6SAndroid Build Coastguard Worker        Given an index in the value list, if value is `Uninitialized` or is a mutable object,
184*523fa7a6SAndroid Build Coastguard Worker        like a Tensor List, calls `load` to load and initialize value into Interpreter's value_list.
185*523fa7a6SAndroid Build Coastguard Worker
186*523fa7a6SAndroid Build Coastguard Worker        Args:
187*523fa7a6SAndroid Build Coastguard Worker        `idx` : index in value lists that we want to load
188*523fa7a6SAndroid Build Coastguard Worker
189*523fa7a6SAndroid Build Coastguard Worker        Returns: No returned values - value list is updated in place
190*523fa7a6SAndroid Build Coastguard Worker        """
191*523fa7a6SAndroid Build Coastguard Worker        # if instance of any mutable object, load regardless of being initialized
192*523fa7a6SAndroid Build Coastguard Worker        if isinstance(
193*523fa7a6SAndroid Build Coastguard Worker            self.execution_plan.values[idx].val, (TensorList, OptionalTensorList)
194*523fa7a6SAndroid Build Coastguard Worker        ) or isinstance(self._value_list[idx], Uninitialized):
195*523fa7a6SAndroid Build Coastguard Worker            self.load_from_value_list(idx)
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard Worker    def load_from_value_list(self, idx: int) -> None:  # noqa
198*523fa7a6SAndroid Build Coastguard Worker        """
199*523fa7a6SAndroid Build Coastguard Worker        Accesses Execution Plan's value list at same index (see schema.py) to
200*523fa7a6SAndroid Build Coastguard Worker        load and initialize value into Interpreter's value list. Extracts the
201*523fa7a6SAndroid Build Coastguard Worker        necessary values depending on the type of the value. E.g. Tensor Lists
202*523fa7a6SAndroid Build Coastguard Worker        have indices into the value list, so they are recursively loaded.
203*523fa7a6SAndroid Build Coastguard Worker        If an Evalue is a Constant Tensor (denoted by allocation_info=None), it
204*523fa7a6SAndroid Build Coastguard Worker        converts the python obj to a torch tensor object.
205*523fa7a6SAndroid Build Coastguard Worker
206*523fa7a6SAndroid Build Coastguard Worker        Args:
207*523fa7a6SAndroid Build Coastguard Worker        `idx` : index in value lists that we want to load
208*523fa7a6SAndroid Build Coastguard Worker
209*523fa7a6SAndroid Build Coastguard Worker        Returns: No returned values - value list is updated in place
210*523fa7a6SAndroid Build Coastguard Worker        """
211*523fa7a6SAndroid Build Coastguard Worker        assert idx >= 0
212*523fa7a6SAndroid Build Coastguard Worker        val = self.execution_plan.values[idx].val
213*523fa7a6SAndroid Build Coastguard Worker
214*523fa7a6SAndroid Build Coastguard Worker        # Case through all possible Evalue Types
215*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, Int):
216*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = val.int_val
217*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, Bool):
218*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = val.bool_val
219*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, Double):
220*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = val.double_val
221*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, String):
222*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = val.string_val
223*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, (IntList, BoolList, DoubleList)):
224*523fa7a6SAndroid Build Coastguard Worker            unboxed_list = []
225*523fa7a6SAndroid Build Coastguard Worker            for item in val.items:
226*523fa7a6SAndroid Build Coastguard Worker                assert isinstance(item, int)
227*523fa7a6SAndroid Build Coastguard Worker                assert isinstance(self.execution_plan.values[item].val, Int)
228*523fa7a6SAndroid Build Coastguard Worker                # pyre-fixme [16] Undefined attribute [16]: Item `Bool` has no
229*523fa7a6SAndroid Build Coastguard Worker                # attribute `int_val`.
230*523fa7a6SAndroid Build Coastguard Worker                unboxed_list.append(self.execution_plan.values[item].val.int_val)
231*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = unboxed_list
232*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, (TensorList, OptionalTensorList)):
233*523fa7a6SAndroid Build Coastguard Worker            tensor_list = []
234*523fa7a6SAndroid Build Coastguard Worker            for i in val.items:
235*523fa7a6SAndroid Build Coastguard Worker                if i == -1:
236*523fa7a6SAndroid Build Coastguard Worker                    tensor_list.append(None)
237*523fa7a6SAndroid Build Coastguard Worker                    continue
238*523fa7a6SAndroid Build Coastguard Worker                self.load_value(i)
239*523fa7a6SAndroid Build Coastguard Worker                tensor_list.append(self._value_list[i])
240*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = tensor_list
241*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, Tensor):
242*523fa7a6SAndroid Build Coastguard Worker            if val.data_buffer_idx == 0:
243*523fa7a6SAndroid Build Coastguard Worker                # TODO(zhengxu) Verify that argument is actually an out variant
244*523fa7a6SAndroid Build Coastguard Worker                self._value_list[idx] = torch.empty(
245*523fa7a6SAndroid Build Coastguard Worker                    val.sizes, dtype=get_scalar_type(val.scalar_type)
246*523fa7a6SAndroid Build Coastguard Worker                )
247*523fa7a6SAndroid Build Coastguard Worker            else:
248*523fa7a6SAndroid Build Coastguard Worker                # Constant Tensor conversion
249*523fa7a6SAndroid Build Coastguard Worker                # pyre-fixme [16]
250*523fa7a6SAndroid Build Coastguard Worker                tensor = bindings.convert_to_tensor(
251*523fa7a6SAndroid Build Coastguard Worker                    self.data_buffers[val.data_buffer_idx],
252*523fa7a6SAndroid Build Coastguard Worker                    val.scalar_type,
253*523fa7a6SAndroid Build Coastguard Worker                    val.sizes,
254*523fa7a6SAndroid Build Coastguard Worker                    stride_from_dim_order(val.sizes, val.dim_order),
255*523fa7a6SAndroid Build Coastguard Worker                )
256*523fa7a6SAndroid Build Coastguard Worker                self._value_list[idx] = tensor
257*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, Null):
258*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = None
259*523fa7a6SAndroid Build Coastguard Worker        else:
260*523fa7a6SAndroid Build Coastguard Worker            raise TypeError(
261*523fa7a6SAndroid Build Coastguard Worker                f"Unexpected type, {type(val)}, with value, {val}, in Execution Plan values."
262*523fa7a6SAndroid Build Coastguard Worker            )
263*523fa7a6SAndroid Build Coastguard Worker
264*523fa7a6SAndroid Build Coastguard Worker        assert not isinstance(self._value_list[idx], Uninitialized)
265*523fa7a6SAndroid Build Coastguard Worker
266*523fa7a6SAndroid Build Coastguard Worker    def set_value(self, idx: int, input_val: ValueType) -> None:
267*523fa7a6SAndroid Build Coastguard Worker        """
268*523fa7a6SAndroid Build Coastguard Worker        Given an index in the value list, and a value, updates
269*523fa7a6SAndroid Build Coastguard Worker        Interpreter's value in value list at given index in place
270*523fa7a6SAndroid Build Coastguard Worker        If value is meant to be a TensorList at given index,
271*523fa7a6SAndroid Build Coastguard Worker        iterate through all the indices in the TensorList and
272*523fa7a6SAndroid Build Coastguard Worker        update each placeholder with the given Tensor from `input_val`.
273*523fa7a6SAndroid Build Coastguard Worker
274*523fa7a6SAndroid Build Coastguard Worker        Args:
275*523fa7a6SAndroid Build Coastguard Worker        `idx` : index in value lists that we want to set
276*523fa7a6SAndroid Build Coastguard Worker
277*523fa7a6SAndroid Build Coastguard Worker        `input_val` : value we want to put at `self._value_list[idx]`
278*523fa7a6SAndroid Build Coastguard Worker
279*523fa7a6SAndroid Build Coastguard Worker        Returns: No returned values - value list is updated in place
280*523fa7a6SAndroid Build Coastguard Worker        """
281*523fa7a6SAndroid Build Coastguard Worker        evalue = self.execution_plan.values[idx]
282*523fa7a6SAndroid Build Coastguard Worker        val = evalue.val
283*523fa7a6SAndroid Build Coastguard Worker
284*523fa7a6SAndroid Build Coastguard Worker        if not comp_types(val, input_val):
285*523fa7a6SAndroid Build Coastguard Worker            raise TypeError(
286*523fa7a6SAndroid Build Coastguard Worker                f"Program trying to set a value of {input_val} : {type(input_val)} in memory location where {type(val)} is expected."
287*523fa7a6SAndroid Build Coastguard Worker            )
288*523fa7a6SAndroid Build Coastguard Worker
289*523fa7a6SAndroid Build Coastguard Worker        # Case through all possible Evalue Types
290*523fa7a6SAndroid Build Coastguard Worker        if isinstance(
291*523fa7a6SAndroid Build Coastguard Worker            val,
292*523fa7a6SAndroid Build Coastguard Worker            (Int, Bool, Double, String, IntList, BoolList, DoubleList, Tensor, Null),
293*523fa7a6SAndroid Build Coastguard Worker        ):
294*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = input_val
295*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, (TensorList, OptionalTensorList)):
296*523fa7a6SAndroid Build Coastguard Worker            assert isinstance(input_val, List)
297*523fa7a6SAndroid Build Coastguard Worker            assert len(val.items) == len(input_val)
298*523fa7a6SAndroid Build Coastguard Worker            tensor_list = []
299*523fa7a6SAndroid Build Coastguard Worker            for i in range(len(val.items)):
300*523fa7a6SAndroid Build Coastguard Worker                val_idx = val.items[i]
301*523fa7a6SAndroid Build Coastguard Worker                self._value_list[val_idx] = input_val[i]
302*523fa7a6SAndroid Build Coastguard Worker                tensor_list.append(input_val[i])
303*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = tensor_list
304*523fa7a6SAndroid Build Coastguard Worker        else:
305*523fa7a6SAndroid Build Coastguard Worker            raise TypeError(
306*523fa7a6SAndroid Build Coastguard Worker                f"Unexpected type, {type(val)}, with value, {val}, in Execution Plan values."
307*523fa7a6SAndroid Build Coastguard Worker            )
308*523fa7a6SAndroid Build Coastguard Worker
309*523fa7a6SAndroid Build Coastguard Worker    def call_kernel(self, kernel: KernelCall) -> None:
310*523fa7a6SAndroid Build Coastguard Worker        """
311*523fa7a6SAndroid Build Coastguard Worker        Calls operator from kernel:
312*523fa7a6SAndroid Build Coastguard Worker        1. Determines kernel's operation through kernel.op_index,
313*523fa7a6SAndroid Build Coastguard Worker           which indexes into operator list from Program's execution plan.
314*523fa7a6SAndroid Build Coastguard Worker        2. After identifying operation, determines number of arguments and
315*523fa7a6SAndroid Build Coastguard Worker        keyword arguments through operator schema.
316*523fa7a6SAndroid Build Coastguard Worker        3. Extracts arguments from value list and calls operator.
317*523fa7a6SAndroid Build Coastguard Worker        4. Sets the given output indices in value list with the values
318*523fa7a6SAndroid Build Coastguard Worker           returned from operation.
319*523fa7a6SAndroid Build Coastguard Worker
320*523fa7a6SAndroid Build Coastguard Worker        Args:
321*523fa7a6SAndroid Build Coastguard Worker        `kernel` : stores information about operator and which indices in
322*523fa7a6SAndroid Build Coastguard Worker                   the value list contain the necessary arguments
323*523fa7a6SAndroid Build Coastguard Worker
324*523fa7a6SAndroid Build Coastguard Worker        Returns: No returned values - value list is updated with outputs
325*523fa7a6SAndroid Build Coastguard Worker                from operator in place
326*523fa7a6SAndroid Build Coastguard Worker        """
327*523fa7a6SAndroid Build Coastguard Worker
328*523fa7a6SAndroid Build Coastguard Worker        operator = self._operators_list[kernel.op_index]
329*523fa7a6SAndroid Build Coastguard Worker        num_args = len(
330*523fa7a6SAndroid Build Coastguard Worker            [arg for arg in operator._schema.arguments if not arg.kwarg_only]
331*523fa7a6SAndroid Build Coastguard Worker        )
332*523fa7a6SAndroid Build Coastguard Worker        kwarg_list = [kwarg for kwarg in operator._schema.arguments if kwarg.kwarg_only]
333*523fa7a6SAndroid Build Coastguard Worker        num_kwargs = len(kwarg_list)
334*523fa7a6SAndroid Build Coastguard Worker
335*523fa7a6SAndroid Build Coastguard Worker        # Extracting arguments and keyword arguments from value_list given indices kernel.args
336*523fa7a6SAndroid Build Coastguard Worker        args = []
337*523fa7a6SAndroid Build Coastguard Worker        for i in kernel.args[:num_args]:
338*523fa7a6SAndroid Build Coastguard Worker            self.load_value(i)
339*523fa7a6SAndroid Build Coastguard Worker            args.append(self._value_list[i])
340*523fa7a6SAndroid Build Coastguard Worker
341*523fa7a6SAndroid Build Coastguard Worker        kwargs = {}
342*523fa7a6SAndroid Build Coastguard Worker        for j in range(num_kwargs):
343*523fa7a6SAndroid Build Coastguard Worker            i = kernel.args[num_args + j]
344*523fa7a6SAndroid Build Coastguard Worker            keyword = kwarg_list[j].name
345*523fa7a6SAndroid Build Coastguard Worker
346*523fa7a6SAndroid Build Coastguard Worker            self.load_value(i)
347*523fa7a6SAndroid Build Coastguard Worker            kwargs[keyword] = self._value_list[i]
348*523fa7a6SAndroid Build Coastguard Worker
349*523fa7a6SAndroid Build Coastguard Worker        res = operator(*args, **kwargs)
350*523fa7a6SAndroid Build Coastguard Worker        output_idxs = kernel.args[num_args + num_kwargs :]
351*523fa7a6SAndroid Build Coastguard Worker
352*523fa7a6SAndroid Build Coastguard Worker        assert (
353*523fa7a6SAndroid Build Coastguard Worker            len(output_idxs) == 1
354*523fa7a6SAndroid Build Coastguard Worker        ), "emitter is expected to pack multiple outputs into a TensorList"
355*523fa7a6SAndroid Build Coastguard Worker        if isinstance(res, tuple):
356*523fa7a6SAndroid Build Coastguard Worker            self.set_value(output_idxs[0], list(res))
357*523fa7a6SAndroid Build Coastguard Worker        else:
358*523fa7a6SAndroid Build Coastguard Worker            self.set_value(output_idxs[0], res)
359*523fa7a6SAndroid Build Coastguard Worker
360*523fa7a6SAndroid Build Coastguard Worker    def run(self, *raw_args: torch.Tensor) -> PyTree:
361*523fa7a6SAndroid Build Coastguard Worker        """
362*523fa7a6SAndroid Build Coastguard Worker        Loops through instructions given some inputs
363*523fa7a6SAndroid Build Coastguard Worker
364*523fa7a6SAndroid Build Coastguard Worker        Args:
365*523fa7a6SAndroid Build Coastguard Worker        `args` : list of inputs required for interpretation
366*523fa7a6SAndroid Build Coastguard Worker
367*523fa7a6SAndroid Build Coastguard Worker        Returns:
368*523fa7a6SAndroid Build Coastguard Worker        Outputs after completing all computations
369*523fa7a6SAndroid Build Coastguard Worker        """
370*523fa7a6SAndroid Build Coastguard Worker
371*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
372*523fa7a6SAndroid Build Coastguard Worker        args, pytree = ex_pytree.tree_flatten((raw_args, {}))
373*523fa7a6SAndroid Build Coastguard Worker
374*523fa7a6SAndroid Build Coastguard Worker        if pytree.to_str() != self.container_metatype.encoded_inp_str:
375*523fa7a6SAndroid Build Coastguard Worker            raise TypeError(
376*523fa7a6SAndroid Build Coastguard Worker                f"Arguments provided do not match required type. \nRequired: {self.container_metatype.encoded_inp_str} \nProvided: {pytree.to_str()}"
377*523fa7a6SAndroid Build Coastguard Worker            )
378*523fa7a6SAndroid Build Coastguard Worker
379*523fa7a6SAndroid Build Coastguard Worker        # Initialize user inputs in value list
380*523fa7a6SAndroid Build Coastguard Worker        if len(self.execution_plan.inputs) != len(args):
381*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
382*523fa7a6SAndroid Build Coastguard Worker                f"Incorrect number of arguments provided. Expected {len(self.execution_plan.inputs)} values, but received {len(args)}"
383*523fa7a6SAndroid Build Coastguard Worker            )
384*523fa7a6SAndroid Build Coastguard Worker        for i in range(len(self.execution_plan.inputs)):
385*523fa7a6SAndroid Build Coastguard Worker            idx = self.execution_plan.inputs[i]
386*523fa7a6SAndroid Build Coastguard Worker            self._value_list[idx] = args[i]
387*523fa7a6SAndroid Build Coastguard Worker
388*523fa7a6SAndroid Build Coastguard Worker        assert len(self.execution_plan.chains) == 1
389*523fa7a6SAndroid Build Coastguard Worker        chain = self.execution_plan.chains[0]
390*523fa7a6SAndroid Build Coastguard Worker
391*523fa7a6SAndroid Build Coastguard Worker        # instruction pointer
392*523fa7a6SAndroid Build Coastguard Worker        ip = 0
393*523fa7a6SAndroid Build Coastguard Worker
394*523fa7a6SAndroid Build Coastguard Worker        # Kernel loop
395*523fa7a6SAndroid Build Coastguard Worker        while ip < len(chain.instructions):
396*523fa7a6SAndroid Build Coastguard Worker            instruction = chain.instructions[ip]
397*523fa7a6SAndroid Build Coastguard Worker            if isinstance(instruction.instr_args, KernelCall):
398*523fa7a6SAndroid Build Coastguard Worker                self.call_kernel(instruction.instr_args)
399*523fa7a6SAndroid Build Coastguard Worker            elif isinstance(instruction.instr_args, JumpFalseCall):
400*523fa7a6SAndroid Build Coastguard Worker                self.load_value(instruction.instr_args.cond_value_index)
401*523fa7a6SAndroid Build Coastguard Worker                ip = (
402*523fa7a6SAndroid Build Coastguard Worker                    ip + 1
403*523fa7a6SAndroid Build Coastguard Worker                    # pyre-ignore
404*523fa7a6SAndroid Build Coastguard Worker                    if self._value_list[instruction.instr_args.cond_val_index]
405*523fa7a6SAndroid Build Coastguard Worker                    # pyre-ignore
406*523fa7a6SAndroid Build Coastguard Worker                    else instruction.instr_args.destination_instruction
407*523fa7a6SAndroid Build Coastguard Worker                )
408*523fa7a6SAndroid Build Coastguard Worker                continue
409*523fa7a6SAndroid Build Coastguard Worker            elif isinstance(instruction.instr_args, MoveCall):
410*523fa7a6SAndroid Build Coastguard Worker                move_to = instruction.instr_args.move_to
411*523fa7a6SAndroid Build Coastguard Worker                move_from = instruction.instr_args.move_from
412*523fa7a6SAndroid Build Coastguard Worker                self.load_value(move_from)
413*523fa7a6SAndroid Build Coastguard Worker                self.load_value(move_to)
414*523fa7a6SAndroid Build Coastguard Worker                self._value_list[move_to] = self._value_list[move_from]
415*523fa7a6SAndroid Build Coastguard Worker            else:
416*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(
417*523fa7a6SAndroid Build Coastguard Worker                    f"Received unknown instruction from program: {instruction}."
418*523fa7a6SAndroid Build Coastguard Worker                )
419*523fa7a6SAndroid Build Coastguard Worker            ip += 1
420*523fa7a6SAndroid Build Coastguard Worker
421*523fa7a6SAndroid Build Coastguard Worker        ret = [self._value_list[i] for i in self.execution_plan.outputs]
422*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: Module `pytree` has no attribute `from_str`.
423*523fa7a6SAndroid Build Coastguard Worker        treespec = ex_pytree.from_str(self.container_metatype.encoded_out_str)
424*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: Module `pytree` has no attribute `tree_unflatten`.
425*523fa7a6SAndroid Build Coastguard Worker        return ex_pytree.tree_unflatten(ret, treespec)
426