xref: /aosp_15_r20/external/executorch/devtools/bundled_program/core.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import ctypes
8import typing
9from typing import Dict, List, Optional, Sequence, Type, Union
10
11import executorch.devtools.bundled_program.schema as bp_schema
12
13import executorch.exir.schema as core_schema
14
15import torch
16import torch.fx
17from executorch.devtools.bundled_program.config import ConfigValue, MethodTestSuite
18
19from executorch.devtools.bundled_program.version import BUNDLED_PROGRAM_SCHEMA_VERSION
20
21from executorch.exir import ExecutorchProgram, ExecutorchProgramManager
22from executorch.exir._serialize import _serialize_pte_binary
23from executorch.exir.tensor import get_scalar_type, scalar_type_enum, TensorSpec
24
25# pyre-ignore
26supported_program_type_table: Dict[Type[core_schema.KernelTypes], ConfigValue] = {
27    core_schema.Tensor: torch.Tensor,
28    core_schema.Int: int,
29    core_schema.Double: float,
30    core_schema.Bool: bool,
31}
32
33
34class BundledProgram:
35    """
36    Bundled program contains all information needed to execute and verify the program on device.
37
38    Public Attributes:
39        method_test_suites: All test suites for verifying methods.
40        executorch_program: ExecutorchProgram-like variable, containing the Program to be verified by method_test_suites, including
41                            ExecutorchProgram, MultiMethodExecutorchProgram or ExecutorchProgramManager.
42    """
43
44    def __init__(
45        self,
46        executorch_program: Union[
47            ExecutorchProgram,
48            ExecutorchProgramManager,
49        ],
50        method_test_suites: Sequence[MethodTestSuite],
51    ):
52        """Create BundledProgram by bundling the given program and method_test_suites together.
53
54        Args:
55            executorch_program: The program to be bundled.
56            method_test_suites: The testcases for certain methods to be bundled.
57        """
58
59        method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name)
60        self._assert_valid_bundle(executorch_program, method_test_suites)
61
62        self.executorch_program = executorch_program
63        self.method_test_suites = method_test_suites
64
65        # This is the cache for bundled program in schema type.
66        # User should not access this field directly. Please Use `serialize_to_schema` function instead.
67        self._bundled_program_in_schema: Optional[bp_schema.BundledProgram] = None
68
69    def serialize_to_schema(self) -> bp_schema.BundledProgram:
70        """Serialize the current Bundled Program into its schema format for further serialization.."""
71        # Return cached value if exists
72        if self._bundled_program_in_schema is not None:
73            return self._bundled_program_in_schema
74
75        program = self._extract_program(self.executorch_program)
76        bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = []
77
78        # Emit data and metadata of bundled tensor
79        for method_test_suite in self.method_test_suites:
80            bundled_test_cases: List[bp_schema.BundledMethodTestCase] = []
81
82            # emit I/O sets for each method test case
83            for i in range(len(method_test_suite.test_cases)):
84                inputs: List[bp_schema.Value] = []
85                expected_outputs: List[bp_schema.Value] = []
86
87                cur_plan_test_inputs = method_test_suite.test_cases[i].inputs
88                cur_plan_test_expected_outputs = method_test_suite.test_cases[
89                    i
90                ].expected_outputs
91
92                for input_val in cur_plan_test_inputs:
93                    if type(input_val) is torch.Tensor:
94                        self._emit_bundled_tensor(
95                            TensorSpec.from_tensor(input_val, const=True),
96                            inputs,
97                        )
98                    else:
99                        self._emit_prim(
100                            input_val,
101                            inputs,
102                        )
103                for expected_output_tensor in cur_plan_test_expected_outputs:
104                    assert (
105                        type(expected_output_tensor) is torch.Tensor
106                    ), "Only tensor outputs are currently supported."
107                    self._emit_bundled_tensor(
108                        TensorSpec.from_tensor(expected_output_tensor, const=True),
109                        expected_outputs,
110                    )
111                bundled_test_cases.append(
112                    bp_schema.BundledMethodTestCase(
113                        inputs=inputs, expected_outputs=expected_outputs
114                    )
115                )
116
117            # emit the whole execution plan test
118            bundled_method_test_suites.append(
119                bp_schema.BundledMethodTestSuite(
120                    method_name=method_test_suite.method_name,
121                    test_cases=bundled_test_cases,
122                )
123            )
124
125        # TODO(T181463742): avoid calling bytes(..) which may incur large copies.
126        program_bytes: bytes = bytes(_serialize_pte_binary(program))
127        self._bundled_program_in_schema = bp_schema.BundledProgram(
128            version=BUNDLED_PROGRAM_SCHEMA_VERSION,
129            method_test_suites=bundled_method_test_suites,
130            program=program_bytes,
131        )
132        return self._bundled_program_in_schema
133
134    def _emit_bundled_tensor(
135        self, spec: TensorSpec, bundled_values: List[bp_schema.Value]
136    ) -> None:
137        # QuantizedSchema in tensor has deprecated and may not be used anymore.
138        # So here we don't emit it.
139
140        if spec.allocated_memory == 0:
141            tensor_data: bytes = b""
142        else:
143            array_type = (
144                ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes()
145            )
146            spec_array = ctypes.cast(
147                typing.cast(torch.UntypedStorage, spec.storage).data_ptr(),
148                ctypes.POINTER(array_type),
149            ).contents
150            tensor_data: bytes = bytes(spec_array)
151
152        bundled_values.append(
153            bp_schema.Value(
154                val=bp_schema.Tensor(
155                    scalar_type=scalar_type_enum(spec.dtype),
156                    sizes=spec.shape,
157                    data=tensor_data,
158                    dim_order=list(spec.dim_order),
159                ),
160            )
161        )
162
163    def _emit_prim(self, val: ConfigValue, bundled_values: List[bp_schema.Value]):
164        if type(val) is int:
165            bundled_values.append(bp_schema.Value(val=bp_schema.Int(int_val=val)))
166        elif type(val) is bool:
167            bundled_values.append(bp_schema.Value(val=bp_schema.Bool(bool_val=val)))
168        elif type(val) is float:
169            bundled_values.append(bp_schema.Value(val=bp_schema.Double(double_val=val)))
170        else:
171            assert 0, "Unsupported primitive type received."
172
173    def _get_program_input(
174        self, program: core_schema.Program, plan_idx: int, input_idx: int
175    ) -> core_schema.KernelTypes:
176        return (
177            program.execution_plan[plan_idx]
178            .values[program.execution_plan[plan_idx].inputs[input_idx]]
179            .val
180        )
181
182    def _get_program_output(
183        self, program: core_schema.Program, plan_idx: int, output_idx: int
184    ) -> core_schema.KernelTypes:
185        return (
186            program.execution_plan[plan_idx]
187            .values[program.execution_plan[plan_idx].outputs[output_idx]]
188            .val
189        )
190
191    def _get_input_dtype(
192        self, program: core_schema.Program, plan_idx: int, input_idx: int
193    ) -> torch.dtype:
194        return get_scalar_type(
195            # pyre-fixme[16]: now assert all input and outputs is in tenor type. Support multuple datatypes in the future.
196            self._get_program_input(program, plan_idx, input_idx).scalar_type
197        )
198
199    def _get_input_type(
200        self, program: core_schema.Program, plan_idx: int, input_idx: int
201    ) -> type:
202        type_lookup = {
203            core_schema.Int: int,
204            core_schema.Bool: bool,
205            core_schema.Double: float,
206        }
207        # pyre-fixme[6]: Incompatible parameter type [6]: In call `dict.__getitem__`, for 1st positional only parameter
208        # expected `Type[Union[core_schema.Bool, core_schema.Double, core_schema.Int]]` but got `Type[Union[core_schema.Bool, core_schema.Double, core_schema.Int, core_schema.Tensor, BoolList, DoubleList,
209        # IntList, Null, OptionalTensorList, String, TensorList]]`.
210        return type_lookup[type(self._get_program_input(program, plan_idx, input_idx))]
211
212    def _get_output_dtype(
213        self, program: core_schema.Program, plan_idx: int, output_idx: int
214    ) -> torch.dtype:
215        return get_scalar_type(
216            # pyre-ignore[16]: now assert all outputs is in tensor type.
217            self._get_program_output(program, plan_idx, output_idx).scalar_type
218        )
219
220    def _assert_valid_bundle(
221        self,
222        executorch_program: Union[
223            ExecutorchProgram,
224            ExecutorchProgramManager,
225        ],
226        method_test_suites: Sequence[MethodTestSuite],
227    ) -> None:
228        """Check if the program and method_test_suites matches each other.
229
230        Other checks not related to correspondence are done in config.py
231
232        Args:
233            executorch_program: The program to be bundled.
234            method_test_suites: The testcases for specific methods to be bundled.
235        """
236
237        program = self._extract_program(executorch_program)
238
239        method_name_of_program = {e.name for e in program.execution_plan}
240        method_name_of_test_suites = {t.method_name for t in method_test_suites}
241
242        assert method_name_of_test_suites.issubset(
243            method_name_of_program
244        ), f"All method names in bundled config should be found in program.execution_plan, \
245            but {str(method_name_of_test_suites - method_name_of_program)} does not include."
246
247        # check if method_test_suites has been sorted in ascending alphabetical order of method name.
248        for test_suite_id in range(1, len(method_test_suites)):
249            assert (
250                method_test_suites[test_suite_id - 1].method_name
251                <= method_test_suites[test_suite_id].method_name
252            ), f"The method name of test suite should be sorted in ascending alphabetical \
253                order of method name, but {test_suite_id-1}-th and {test_suite_id}-th method_test_suite aren't."
254
255        # Check if the inputs' type meet Program's requirement
256        for method_test_suite in method_test_suites:
257
258            # Get the method with same method name as method_test_suite
259            program_plan_id = -1
260            for plan in program.execution_plan:
261                if plan.name == method_test_suite.method_name:
262                    program_plan_id = program.execution_plan.index(plan)
263                    break
264
265            # Raise Assertion Error if can not find the method with same method_name as method_test_suite in program.
266            assert (
267                program_plan_id != -1
268            ), f"method_test_suites has testcases for method {method_test_suite.method_name}, but can not find it in the given program. All method names in the program are {', '.join([p.name for p in program.execution_plan])}."
269
270            plan = program.execution_plan[program_plan_id]
271
272            # Check if the type of Program's input is supported
273            for index in range(len(plan.inputs)):
274                assert (
275                    type(self._get_program_input(program, program_plan_id, index))
276                    in supported_program_type_table
277                ), "The type of program's input isn't supported."
278
279            # Check if the type of Program's output is supported
280            for index in range(len(plan.outputs)):
281                assert (
282                    type(self._get_program_output(program, program_plan_id, index))
283                    == core_schema.Tensor
284                ), "Only supports program with output in Tensor type."
285
286            # Check if the I/O sets of each execution plan test match program's requirement.
287            for i in range(len(method_test_suite.test_cases)):
288                cur_plan_test_inputs = method_test_suite.test_cases[i].inputs
289                cur_plan_test_expected_outputs = method_test_suite.test_cases[
290                    i
291                ].expected_outputs
292
293                assert len(plan.inputs) == len(
294                    cur_plan_test_inputs
295                ), "The number of input in each bundled set and Program shall equal, but get {} and {}".format(
296                    len(plan.inputs),
297                    len(cur_plan_test_inputs),
298                )
299
300                # Check if bundled input in the current exeution plan test share same type as input in Program
301                for j in range(len(cur_plan_test_inputs)):
302                    assert (
303                        type(cur_plan_test_inputs[j])
304                        is supported_program_type_table[
305                            type(self._get_program_input(program, program_plan_id, j))
306                        ]
307                    ), "The type {}-th input in {}-th test set of {}-th execution plan does not meet Program's requirement: expected {} but get {}".format(
308                        j,
309                        i,
310                        program_plan_id,
311                        supported_program_type_table[
312                            type(self._get_program_input(program, program_plan_id, j))
313                        ],
314                        type(cur_plan_test_inputs[j]),
315                    )
316
317                    # type of tensor input should match execution plan
318                    if type(cur_plan_test_inputs[j]) is torch.Tensor:
319                        # pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
320                        # has no attribute `dtype`.
321                        assert cur_plan_test_inputs[j].dtype == self._get_input_dtype(
322                            program, program_plan_id, j
323                        ), "The input tensor {} dtype shall be {}, but now is {}".format(
324                            cur_plan_test_inputs[j],
325                            self._get_input_dtype(program, program_plan_id, j),
326                            cur_plan_test_inputs[j].dtype,
327                        )
328                    elif type(cur_plan_test_inputs[j]) in (
329                        int,
330                        bool,
331                        float,
332                    ):
333                        assert type(cur_plan_test_inputs[j]) is self._get_input_type(
334                            program, program_plan_id, j
335                        ), "The input primitive dtype shall be {}, but now is {}".format(
336                            self._get_input_type(program, program_plan_id, j),
337                            type(cur_plan_test_inputs[j]),
338                        )
339
340                # Check if bundled expected output in the current exeution plan test share same type as output in Program
341                for j in range(len(cur_plan_test_expected_outputs)):
342                    assert (
343                        type(cur_plan_test_expected_outputs[j]) is torch.Tensor
344                    ), "The {}-th expected output shall be a tensor, but now is {}".format(
345                        j, type(cur_plan_test_expected_outputs[j])
346                    )
347
348                    # pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
349                    # has no attribute `dtype`.
350                    assert cur_plan_test_expected_outputs[
351                        j
352                    ].dtype == self._get_output_dtype(
353                        program, program_plan_id, j
354                    ), "The label tensor {} dtype shall be {}, but now is {}".format(
355                        cur_plan_test_expected_outputs[j],
356                        self._get_output_dtype(program, program_plan_id, j),
357                        cur_plan_test_expected_outputs[j].dtype,
358                    )
359
360    def _extract_program(
361        self,
362        executorch_program: Union[
363            ExecutorchProgram,
364            ExecutorchProgramManager,
365        ],
366    ):
367        if isinstance(executorch_program, ExecutorchProgramManager):
368            program = executorch_program.executorch_program
369        else:
370            assert isinstance(executorch_program, ExecutorchProgram)
371            program = executorch_program.program
372        return program
373