xref: /aosp_15_r20/external/executorch/exir/verification/interpreter.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import copy
10from typing import List, Optional, Union
11
12# pyre-fixme[21]: Could not find module `executorch.exir.verification.bindings`.
13import executorch.exir.verification.bindings as bindings  # @manual=//executorch/exir/verification:bindings
14import executorch.extension.pytree as ex_pytree
15
16import torch
17
18from executorch import exir
19
20from executorch.exir.schema import (
21    Bool,
22    BoolList,
23    Double,
24    DoubleList,
25    ExecutionPlan,
26    Int,
27    IntList,
28    JumpFalseCall,
29    KernelCall,
30    KernelTypes,
31    MoveCall,
32    Null,
33    Operator,
34    OptionalTensorList,
35    Program,
36    String,
37    Tensor,
38    TensorList,
39)
40from executorch.exir.tensor import get_scalar_type, stride_from_dim_order
41from torch.library import impl, Library
42from torch.utils._pytree import PyTree
43
44
45class Uninitialized:
46    pass
47
48
49ValueListType = Union[
50    List[torch.Tensor],
51    List[bool],
52    List[float],
53    List[int],
54]
55
56ValueScalarType = Union[int, str, float, torch.Tensor, Uninitialized, None]
57
58ValueType = Union[
59    ValueScalarType,
60    ValueListType,
61]
62
63# defining the operator executorch.move
64executorch_lib = Library("executorch", "DEF")
65executorch_lib.define("move(Tensor self) -> Tensor")
66
67
68@impl(executorch_lib, "move", "CPU")
69def move_impl(self: torch.Tensor) -> torch.Tensor:
70    return torch.clone(self)
71
72
73def comp_types(val: KernelTypes, input_val: ValueType) -> Optional[bool]:  # noqa
74    """
75    Compares a schema type (val) with Python type (input_val)
76    Map: Int -> int, Bool -> bool, Double -> float,
77         String -> str, Tensor -> torch.Tensor, XList -> list[x]
78
79    Args:
80    `val`: value from value list with type from `schema.py`
81
82    `input_val`: value with Python type, normally from the result of executing an operation
83    """
84    if isinstance(val, Int):
85        return isinstance(input_val, type(val.int_val))
86    elif isinstance(val, Bool):
87        return isinstance(input_val, type(val.bool_val))
88    elif isinstance(val, Double):
89        return isinstance(input_val, type(val.double_val))
90    elif isinstance(val, String):
91        return isinstance(input_val, type(val.string_val))
92    elif isinstance(val, Tensor):
93        return isinstance(input_val, torch.Tensor)
94    elif isinstance(val, IntList):
95        if not isinstance(input_val, list):
96            return False
97        return all(isinstance(x, int) for x in input_val)
98    elif isinstance(val, BoolList):
99        if not isinstance(input_val, list):
100            return False
101        return all(isinstance(x, bool) for x in input_val)
102    elif isinstance(val, DoubleList):
103        if not isinstance(input_val, list):
104            return False
105        return all(isinstance(x, float) for x in input_val)
106    elif isinstance(val, (TensorList, OptionalTensorList)):
107        if not isinstance(input_val, list):
108            return False
109        return all(isinstance(x, torch.Tensor) for x in input_val)
110    elif isinstance(val, Null):
111        raise TypeError("Setting a value where there should be a Null")
112
113
114def resolve_op(operator: Operator) -> torch._ops.OpOverload:
115    # pattern matching out the namespace and operation name
116    namespace, op_name = operator.name.split("::")
117    op = getattr(getattr(getattr(torch.ops, namespace), op_name), operator.overload)
118    return op
119
120
121def make_operators_list(
122    execution_plan: ExecutionPlan,
123) -> List[torch._ops.OpOverload]:
124    operator_list = [resolve_op(operator) for operator in execution_plan.operators]
125    return operator_list
126
127
128class Interpreter:
129    def __init__(self, program: Program) -> None:
130        # Currently there is only 1 execution plan in the list -- this assert will help
131        # catch any changes in the future
132        assert len(program.execution_plan) == 1
133        self.execution_plan: exir.schema.ExecutionPlan = program.execution_plan[0]
134        self.container_metatype: exir.schema.ContainerMetadata = program.execution_plan[
135            0
136        ].container_meta_type
137
138        # create buffer in memory and get reference to it
139        # pyre-ignore
140        self.data_buffers: List[bindings.DataBuffer] = [
141            # pyre-ignore
142            bindings.DataBuffer(b.storage, len(b.storage))
143            for b in program.constant_buffer
144        ]
145
146        # generate the list of values (including tensors) and operators from the execution plan
147        self._value_list: List[ValueType] = [
148            Uninitialized() for val in self.execution_plan.values
149        ]
150        self._operators_list: List[torch._ops.OpOverload] = make_operators_list(
151            self.execution_plan
152        )
153
154    def get_value_list(self) -> List[ValueType]:
155        # TODO(meghajain) may need to change deepcopy to clone
156        return copy.deepcopy(self._value_list)
157
158    def get_operators_list(self) -> List[torch._ops.OpOverload]:
159        return self._operators_list
160
161    def get_constant_tensors(self) -> List[Tensor]:
162        """
163        No side effects on Interpreter's value list. List of constant tensors returned
164        without having to run program.
165        """
166        tensors = []
167        for elem in self.execution_plan.values:
168            val = elem.val
169            if isinstance(val, Tensor) and val.data_buffer_idx != 0:
170                # load val into res
171                # pyre-fixme[16]
172                tensor = bindings.convert_to_tensor(
173                    self.data_buffers[val.data_buffer_idx],
174                    val.scalar_type,
175                    val.sizes,
176                    stride_from_dim_order(val.sizes, val.dim_order),
177                )
178                tensors.append(tensor)
179        return tensors
180
181    def load_value(self, idx: int) -> None:
182        """
183        Given an index in the value list, if value is `Uninitialized` or is a mutable object,
184        like a Tensor List, calls `load` to load and initialize value into Interpreter's value_list.
185
186        Args:
187        `idx` : index in value lists that we want to load
188
189        Returns: No returned values - value list is updated in place
190        """
191        # if instance of any mutable object, load regardless of being initialized
192        if isinstance(
193            self.execution_plan.values[idx].val, (TensorList, OptionalTensorList)
194        ) or isinstance(self._value_list[idx], Uninitialized):
195            self.load_from_value_list(idx)
196
197    def load_from_value_list(self, idx: int) -> None:  # noqa
198        """
199        Accesses Execution Plan's value list at same index (see schema.py) to
200        load and initialize value into Interpreter's value list. Extracts the
201        necessary values depending on the type of the value. E.g. Tensor Lists
202        have indices into the value list, so they are recursively loaded.
203        If an Evalue is a Constant Tensor (denoted by allocation_info=None), it
204        converts the python obj to a torch tensor object.
205
206        Args:
207        `idx` : index in value lists that we want to load
208
209        Returns: No returned values - value list is updated in place
210        """
211        assert idx >= 0
212        val = self.execution_plan.values[idx].val
213
214        # Case through all possible Evalue Types
215        if isinstance(val, Int):
216            self._value_list[idx] = val.int_val
217        elif isinstance(val, Bool):
218            self._value_list[idx] = val.bool_val
219        elif isinstance(val, Double):
220            self._value_list[idx] = val.double_val
221        elif isinstance(val, String):
222            self._value_list[idx] = val.string_val
223        elif isinstance(val, (IntList, BoolList, DoubleList)):
224            unboxed_list = []
225            for item in val.items:
226                assert isinstance(item, int)
227                assert isinstance(self.execution_plan.values[item].val, Int)
228                # pyre-fixme [16] Undefined attribute [16]: Item `Bool` has no
229                # attribute `int_val`.
230                unboxed_list.append(self.execution_plan.values[item].val.int_val)
231            self._value_list[idx] = unboxed_list
232        elif isinstance(val, (TensorList, OptionalTensorList)):
233            tensor_list = []
234            for i in val.items:
235                if i == -1:
236                    tensor_list.append(None)
237                    continue
238                self.load_value(i)
239                tensor_list.append(self._value_list[i])
240            self._value_list[idx] = tensor_list
241        elif isinstance(val, Tensor):
242            if val.data_buffer_idx == 0:
243                # TODO(zhengxu) Verify that argument is actually an out variant
244                self._value_list[idx] = torch.empty(
245                    val.sizes, dtype=get_scalar_type(val.scalar_type)
246                )
247            else:
248                # Constant Tensor conversion
249                # pyre-fixme [16]
250                tensor = bindings.convert_to_tensor(
251                    self.data_buffers[val.data_buffer_idx],
252                    val.scalar_type,
253                    val.sizes,
254                    stride_from_dim_order(val.sizes, val.dim_order),
255                )
256                self._value_list[idx] = tensor
257        elif isinstance(val, Null):
258            self._value_list[idx] = None
259        else:
260            raise TypeError(
261                f"Unexpected type, {type(val)}, with value, {val}, in Execution Plan values."
262            )
263
264        assert not isinstance(self._value_list[idx], Uninitialized)
265
266    def set_value(self, idx: int, input_val: ValueType) -> None:
267        """
268        Given an index in the value list, and a value, updates
269        Interpreter's value in value list at given index in place
270        If value is meant to be a TensorList at given index,
271        iterate through all the indices in the TensorList and
272        update each placeholder with the given Tensor from `input_val`.
273
274        Args:
275        `idx` : index in value lists that we want to set
276
277        `input_val` : value we want to put at `self._value_list[idx]`
278
279        Returns: No returned values - value list is updated in place
280        """
281        evalue = self.execution_plan.values[idx]
282        val = evalue.val
283
284        if not comp_types(val, input_val):
285            raise TypeError(
286                f"Program trying to set a value of {input_val} : {type(input_val)} in memory location where {type(val)} is expected."
287            )
288
289        # Case through all possible Evalue Types
290        if isinstance(
291            val,
292            (Int, Bool, Double, String, IntList, BoolList, DoubleList, Tensor, Null),
293        ):
294            self._value_list[idx] = input_val
295        elif isinstance(val, (TensorList, OptionalTensorList)):
296            assert isinstance(input_val, List)
297            assert len(val.items) == len(input_val)
298            tensor_list = []
299            for i in range(len(val.items)):
300                val_idx = val.items[i]
301                self._value_list[val_idx] = input_val[i]
302                tensor_list.append(input_val[i])
303            self._value_list[idx] = tensor_list
304        else:
305            raise TypeError(
306                f"Unexpected type, {type(val)}, with value, {val}, in Execution Plan values."
307            )
308
309    def call_kernel(self, kernel: KernelCall) -> None:
310        """
311        Calls operator from kernel:
312        1. Determines kernel's operation through kernel.op_index,
313           which indexes into operator list from Program's execution plan.
314        2. After identifying operation, determines number of arguments and
315        keyword arguments through operator schema.
316        3. Extracts arguments from value list and calls operator.
317        4. Sets the given output indices in value list with the values
318           returned from operation.
319
320        Args:
321        `kernel` : stores information about operator and which indices in
322                   the value list contain the necessary arguments
323
324        Returns: No returned values - value list is updated with outputs
325                from operator in place
326        """
327
328        operator = self._operators_list[kernel.op_index]
329        num_args = len(
330            [arg for arg in operator._schema.arguments if not arg.kwarg_only]
331        )
332        kwarg_list = [kwarg for kwarg in operator._schema.arguments if kwarg.kwarg_only]
333        num_kwargs = len(kwarg_list)
334
335        # Extracting arguments and keyword arguments from value_list given indices kernel.args
336        args = []
337        for i in kernel.args[:num_args]:
338            self.load_value(i)
339            args.append(self._value_list[i])
340
341        kwargs = {}
342        for j in range(num_kwargs):
343            i = kernel.args[num_args + j]
344            keyword = kwarg_list[j].name
345
346            self.load_value(i)
347            kwargs[keyword] = self._value_list[i]
348
349        res = operator(*args, **kwargs)
350        output_idxs = kernel.args[num_args + num_kwargs :]
351
352        assert (
353            len(output_idxs) == 1
354        ), "emitter is expected to pack multiple outputs into a TensorList"
355        if isinstance(res, tuple):
356            self.set_value(output_idxs[0], list(res))
357        else:
358            self.set_value(output_idxs[0], res)
359
360    def run(self, *raw_args: torch.Tensor) -> PyTree:
361        """
362        Loops through instructions given some inputs
363
364        Args:
365        `args` : list of inputs required for interpretation
366
367        Returns:
368        Outputs after completing all computations
369        """
370
371        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
372        args, pytree = ex_pytree.tree_flatten((raw_args, {}))
373
374        if pytree.to_str() != self.container_metatype.encoded_inp_str:
375            raise TypeError(
376                f"Arguments provided do not match required type. \nRequired: {self.container_metatype.encoded_inp_str} \nProvided: {pytree.to_str()}"
377            )
378
379        # Initialize user inputs in value list
380        if len(self.execution_plan.inputs) != len(args):
381            raise RuntimeError(
382                f"Incorrect number of arguments provided. Expected {len(self.execution_plan.inputs)} values, but received {len(args)}"
383            )
384        for i in range(len(self.execution_plan.inputs)):
385            idx = self.execution_plan.inputs[i]
386            self._value_list[idx] = args[i]
387
388        assert len(self.execution_plan.chains) == 1
389        chain = self.execution_plan.chains[0]
390
391        # instruction pointer
392        ip = 0
393
394        # Kernel loop
395        while ip < len(chain.instructions):
396            instruction = chain.instructions[ip]
397            if isinstance(instruction.instr_args, KernelCall):
398                self.call_kernel(instruction.instr_args)
399            elif isinstance(instruction.instr_args, JumpFalseCall):
400                self.load_value(instruction.instr_args.cond_value_index)
401                ip = (
402                    ip + 1
403                    # pyre-ignore
404                    if self._value_list[instruction.instr_args.cond_val_index]
405                    # pyre-ignore
406                    else instruction.instr_args.destination_instruction
407                )
408                continue
409            elif isinstance(instruction.instr_args, MoveCall):
410                move_to = instruction.instr_args.move_to
411                move_from = instruction.instr_args.move_from
412                self.load_value(move_from)
413                self.load_value(move_to)
414                self._value_list[move_to] = self._value_list[move_from]
415            else:
416                raise RuntimeError(
417                    f"Received unknown instruction from program: {instruction}."
418                )
419            ip += 1
420
421        ret = [self._value_list[i] for i in self.execution_plan.outputs]
422        # pyre-fixme[16]: Module `pytree` has no attribute `from_str`.
423        treespec = ex_pytree.from_str(self.container_metatype.encoded_out_str)
424        # pyre-fixme[16]: Module `pytree` has no attribute `tree_unflatten`.
425        return ex_pytree.tree_unflatten(ret, treespec)
426