xref: /aosp_15_r20/external/executorch/devtools/bundled_program/util/test_util.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8import random
9import string
10from typing import List, Tuple
11
12import torch
13from executorch.devtools.bundled_program.config import (
14    MethodInputType,
15    MethodOutputType,
16    MethodTestCase,
17    MethodTestSuite,
18)
19
20from executorch.exir import ExecutorchProgramManager, to_edge
21from torch.export import export
22from torch.export.unflatten import _assign_attr, _AttrKind
23
24# A hacky integer to deal with a mismatch between execution plan and complier.
25#
26# Execution plan supports multiple types of inputs, like Tensor, Int, etc,
27# rather than only Tensor. However, compiler only supports Tensor as input type.
28# All other inputs will remain the same as default value in the model, which
29# means during model execution, each function will use the preset default value
30# for non-tensor inputs, rather than the one we manually set. However, eager
31# model supports multiple types of inputs.
32#
33# In order to show that bundled program can support multiple input types while
34# executorch model can generate the same output as eager model, we hackily set
35# all Int inputs in Bundled Program and default int inputs in model as a same
36# value, called DEFAULT_INT_INPUT.
37#
38# TODO(gasoonjia): track the situation. Stop supporting multiple input types in
39# bundled program if execution plan stops supporting it, or remove this hacky
40# method if compiler can support multiple input types
41DEFAULT_INT_INPUT = 2
42
43
44class SampleModel(torch.nn.Module):
45    """An example model with multi-methods. Each method has multiple input and single output"""
46
47    def __init__(self) -> None:
48        super().__init__()
49        self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.int32))
50        self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.int32))
51        self.method_names = ["encode", "decode"]
52
53    def encode(
54        self, x: torch.Tensor, q: torch.Tensor, a: int = DEFAULT_INT_INPUT
55    ) -> torch.Tensor:
56        z = x.clone()
57        torch.mul(self.a, x, out=z)
58        y = x.clone()
59        torch.add(z, self.b, alpha=a, out=y)
60        torch.add(y, q, out=y)
61        return y
62
63    def decode(
64        self, x: torch.Tensor, q: torch.Tensor, a: int = DEFAULT_INT_INPUT
65    ) -> torch.Tensor:
66        y = x * q
67        torch.add(y, self.b, alpha=a, out=y)
68        return y
69
70
71def get_rand_input_values(
72    n_tensors: int,
73    sizes: List[List[int]],
74    n_int: int,
75    dtype: torch.dtype,
76    n_sets_per_plan_test: int,
77    n_method_test_suites: int,
78) -> List[List[MethodInputType]]:
79    # pyre-ignore[7]: expected `List[List[List[Union[bool, float, int, Tensor]]]]` but got `List[List[List[Union[int, Tensor]]]]`
80    return [
81        [
82            [(torch.rand(*sizes[i]) - 0.5).to(dtype) for i in range(n_tensors)]
83            + [DEFAULT_INT_INPUT for _ in range(n_int)]
84            for _ in range(n_sets_per_plan_test)
85        ]
86        for _ in range(n_method_test_suites)
87    ]
88
89
90def get_rand_output_values(
91    n_tensors: int,
92    sizes: List[List[int]],
93    dtype: torch.dtype,
94    n_sets_per_plan_test: int,
95    n_method_test_suites: int,
96) -> List[List[MethodOutputType]]:
97    # pyre-ignore [7]: Expected `List[List[Sequence[Tensor]]]` but got `List[List[List[Tensor]]]`.
98    return [
99        [
100            [(torch.rand(*sizes[i]) - 0.5).to(dtype) for i in range(n_tensors)]
101            for _ in range(n_sets_per_plan_test)
102        ]
103        for _ in range(n_method_test_suites)
104    ]
105
106
107def get_rand_method_names(n_method_test_suites: int) -> List[str]:
108    unique_strings = set()
109    while len(unique_strings) < n_method_test_suites:
110        rand_str = "".join(random.choices(string.ascii_letters, k=5))
111        if rand_str not in unique_strings:
112            unique_strings.add(rand_str)
113    return list(unique_strings)
114
115
116def get_random_test_suites(
117    n_model_inputs: int,
118    model_input_sizes: List[List[int]],
119    n_model_outputs: int,
120    model_output_sizes: List[List[int]],
121    dtype: torch.dtype,
122    n_sets_per_plan_test: int,
123    n_method_test_suites: int,
124) -> Tuple[
125    List[str],
126    List[List[MethodInputType]],
127    List[List[MethodOutputType]],
128    List[MethodTestSuite],
129]:
130    """Helper function to generate config filled with random inputs and expected outputs.
131
132    The return type of rand inputs is a List[List[InputValues]]. The inner list of
133    InputValues represents all test sets for single execution plan, while the outer list
134    is for multiple execution plans.
135
136    Same for rand_expected_outputs.
137
138    """
139
140    rand_method_names = get_rand_method_names(n_method_test_suites)
141
142    rand_inputs_per_program = get_rand_input_values(
143        n_tensors=n_model_inputs,
144        sizes=model_input_sizes,
145        n_int=1,
146        dtype=dtype,
147        n_sets_per_plan_test=n_sets_per_plan_test,
148        n_method_test_suites=n_method_test_suites,
149    )
150
151    rand_expected_output_per_program = get_rand_output_values(
152        n_tensors=n_model_outputs,
153        sizes=model_output_sizes,
154        dtype=dtype,
155        n_sets_per_plan_test=n_sets_per_plan_test,
156        n_method_test_suites=n_method_test_suites,
157    )
158
159    rand_method_test_suites: List[MethodTestSuite] = []
160
161    for (
162        rand_method_name,
163        rand_inputs_per_method,
164        rand_expected_output_per_method,
165    ) in zip(
166        rand_method_names, rand_inputs_per_program, rand_expected_output_per_program
167    ):
168        rand_method_test_cases: List[MethodTestCase] = []
169        for rand_inputs, rand_expected_outputs in zip(
170            rand_inputs_per_method, rand_expected_output_per_method
171        ):
172            rand_method_test_cases.append(
173                MethodTestCase(
174                    inputs=rand_inputs, expected_outputs=rand_expected_outputs
175                )
176            )
177
178        rand_method_test_suites.append(
179            MethodTestSuite(
180                method_name=rand_method_name, test_cases=rand_method_test_cases
181            )
182        )
183
184    return (
185        rand_method_names,
186        rand_inputs_per_program,
187        rand_expected_output_per_program,
188        rand_method_test_suites,
189    )
190
191
192def get_random_test_suites_with_eager_model(
193    eager_model: torch.nn.Module,
194    method_names: List[str],
195    n_model_inputs: int,
196    model_input_sizes: List[List[int]],
197    dtype: torch.dtype,
198    n_sets_per_plan_test: int,
199) -> Tuple[List[List[MethodInputType]], List[MethodTestSuite]]:
200    """Generate config filled with random inputs for each inference method given eager model
201
202    The details of return type is the same as get_random_test_suites_with_rand_io_lists.
203    """
204    inputs_per_program = get_rand_input_values(
205        n_tensors=n_model_inputs,
206        sizes=model_input_sizes,
207        n_int=1,
208        dtype=dtype,
209        n_sets_per_plan_test=n_sets_per_plan_test,
210        n_method_test_suites=len(method_names),
211    )
212
213    method_test_suites: List[MethodTestSuite] = []
214
215    for method_name, inputs_per_method in zip(method_names, inputs_per_program):
216        method_test_cases: List[MethodTestCase] = []
217        for inputs in inputs_per_method:
218            method_test_cases.append(
219                MethodTestCase(
220                    inputs=inputs,
221                    expected_outputs=getattr(eager_model, method_name)(*inputs),
222                )
223            )
224
225        method_test_suites.append(
226            MethodTestSuite(method_name=method_name, test_cases=method_test_cases)
227        )
228
229    return inputs_per_program, method_test_suites
230
231
232class StatefulWrapperModule(torch.nn.Module):
233    """A version of wrapper module that preserves parameters/buffers.
234
235    Use this if you are planning to wrap a non-forward method on an existing
236    module.
237    """
238
239    def __init__(self, base_mod, method) -> None:  # pyre-ignore
240        super().__init__()
241        state_dict = base_mod.state_dict()
242        for name, value in base_mod.named_parameters():
243            _assign_attr(value, self, name, _AttrKind.PARAMETER)
244        for name, value in base_mod.named_buffers():
245            _assign_attr(
246                value, self, name, _AttrKind.BUFFER, persistent=name in state_dict
247            )
248        self.fn = method  # pyre-ignore
249
250    def forward(self, *args, **kwargs):  # pyre-ignore
251        return self.fn(*args, **kwargs)
252
253
254def get_common_executorch_program() -> (
255    Tuple[ExecutorchProgramManager, List[MethodTestSuite]]
256):
257    """Helper function to generate a sample BundledProgram with its config."""
258    eager_model = SampleModel()
259    # Trace to FX Graph.
260    capture_inputs = {
261        m_name: (
262            (torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
263            (torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
264            DEFAULT_INT_INPUT,
265        )
266        for m_name in eager_model.method_names
267    }
268
269    # Trace to FX Graph and emit the program
270    method_graphs = {
271        m_name: export(
272            StatefulWrapperModule(eager_model, getattr(eager_model, m_name)),
273            capture_inputs[m_name],
274        )
275        for m_name in eager_model.method_names
276    }
277
278    executorch_program = to_edge(method_graphs).to_executorch()
279
280    _, method_test_suites = get_random_test_suites_with_eager_model(
281        eager_model=eager_model,
282        method_names=eager_model.method_names,
283        n_model_inputs=2,
284        model_input_sizes=[[2, 2], [2, 2]],
285        dtype=torch.int32,
286        n_sets_per_plan_test=10,
287    )
288    return executorch_program, method_test_suites
289