1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8import random 9import string 10from typing import List, Tuple 11 12import torch 13from executorch.devtools.bundled_program.config import ( 14 MethodInputType, 15 MethodOutputType, 16 MethodTestCase, 17 MethodTestSuite, 18) 19 20from executorch.exir import ExecutorchProgramManager, to_edge 21from torch.export import export 22from torch.export.unflatten import _assign_attr, _AttrKind 23 24# A hacky integer to deal with a mismatch between execution plan and complier. 25# 26# Execution plan supports multiple types of inputs, like Tensor, Int, etc, 27# rather than only Tensor. However, compiler only supports Tensor as input type. 28# All other inputs will remain the same as default value in the model, which 29# means during model execution, each function will use the preset default value 30# for non-tensor inputs, rather than the one we manually set. However, eager 31# model supports multiple types of inputs. 32# 33# In order to show that bundled program can support multiple input types while 34# executorch model can generate the same output as eager model, we hackily set 35# all Int inputs in Bundled Program and default int inputs in model as a same 36# value, called DEFAULT_INT_INPUT. 37# 38# TODO(gasoonjia): track the situation. Stop supporting multiple input types in 39# bundled program if execution plan stops supporting it, or remove this hacky 40# method if compiler can support multiple input types 41DEFAULT_INT_INPUT = 2 42 43 44class SampleModel(torch.nn.Module): 45 """An example model with multi-methods. Each method has multiple input and single output""" 46 47 def __init__(self) -> None: 48 super().__init__() 49 self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.int32)) 50 self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.int32)) 51 self.method_names = ["encode", "decode"] 52 53 def encode( 54 self, x: torch.Tensor, q: torch.Tensor, a: int = DEFAULT_INT_INPUT 55 ) -> torch.Tensor: 56 z = x.clone() 57 torch.mul(self.a, x, out=z) 58 y = x.clone() 59 torch.add(z, self.b, alpha=a, out=y) 60 torch.add(y, q, out=y) 61 return y 62 63 def decode( 64 self, x: torch.Tensor, q: torch.Tensor, a: int = DEFAULT_INT_INPUT 65 ) -> torch.Tensor: 66 y = x * q 67 torch.add(y, self.b, alpha=a, out=y) 68 return y 69 70 71def get_rand_input_values( 72 n_tensors: int, 73 sizes: List[List[int]], 74 n_int: int, 75 dtype: torch.dtype, 76 n_sets_per_plan_test: int, 77 n_method_test_suites: int, 78) -> List[List[MethodInputType]]: 79 # pyre-ignore[7]: expected `List[List[List[Union[bool, float, int, Tensor]]]]` but got `List[List[List[Union[int, Tensor]]]]` 80 return [ 81 [ 82 [(torch.rand(*sizes[i]) - 0.5).to(dtype) for i in range(n_tensors)] 83 + [DEFAULT_INT_INPUT for _ in range(n_int)] 84 for _ in range(n_sets_per_plan_test) 85 ] 86 for _ in range(n_method_test_suites) 87 ] 88 89 90def get_rand_output_values( 91 n_tensors: int, 92 sizes: List[List[int]], 93 dtype: torch.dtype, 94 n_sets_per_plan_test: int, 95 n_method_test_suites: int, 96) -> List[List[MethodOutputType]]: 97 # pyre-ignore [7]: Expected `List[List[Sequence[Tensor]]]` but got `List[List[List[Tensor]]]`. 98 return [ 99 [ 100 [(torch.rand(*sizes[i]) - 0.5).to(dtype) for i in range(n_tensors)] 101 for _ in range(n_sets_per_plan_test) 102 ] 103 for _ in range(n_method_test_suites) 104 ] 105 106 107def get_rand_method_names(n_method_test_suites: int) -> List[str]: 108 unique_strings = set() 109 while len(unique_strings) < n_method_test_suites: 110 rand_str = "".join(random.choices(string.ascii_letters, k=5)) 111 if rand_str not in unique_strings: 112 unique_strings.add(rand_str) 113 return list(unique_strings) 114 115 116def get_random_test_suites( 117 n_model_inputs: int, 118 model_input_sizes: List[List[int]], 119 n_model_outputs: int, 120 model_output_sizes: List[List[int]], 121 dtype: torch.dtype, 122 n_sets_per_plan_test: int, 123 n_method_test_suites: int, 124) -> Tuple[ 125 List[str], 126 List[List[MethodInputType]], 127 List[List[MethodOutputType]], 128 List[MethodTestSuite], 129]: 130 """Helper function to generate config filled with random inputs and expected outputs. 131 132 The return type of rand inputs is a List[List[InputValues]]. The inner list of 133 InputValues represents all test sets for single execution plan, while the outer list 134 is for multiple execution plans. 135 136 Same for rand_expected_outputs. 137 138 """ 139 140 rand_method_names = get_rand_method_names(n_method_test_suites) 141 142 rand_inputs_per_program = get_rand_input_values( 143 n_tensors=n_model_inputs, 144 sizes=model_input_sizes, 145 n_int=1, 146 dtype=dtype, 147 n_sets_per_plan_test=n_sets_per_plan_test, 148 n_method_test_suites=n_method_test_suites, 149 ) 150 151 rand_expected_output_per_program = get_rand_output_values( 152 n_tensors=n_model_outputs, 153 sizes=model_output_sizes, 154 dtype=dtype, 155 n_sets_per_plan_test=n_sets_per_plan_test, 156 n_method_test_suites=n_method_test_suites, 157 ) 158 159 rand_method_test_suites: List[MethodTestSuite] = [] 160 161 for ( 162 rand_method_name, 163 rand_inputs_per_method, 164 rand_expected_output_per_method, 165 ) in zip( 166 rand_method_names, rand_inputs_per_program, rand_expected_output_per_program 167 ): 168 rand_method_test_cases: List[MethodTestCase] = [] 169 for rand_inputs, rand_expected_outputs in zip( 170 rand_inputs_per_method, rand_expected_output_per_method 171 ): 172 rand_method_test_cases.append( 173 MethodTestCase( 174 inputs=rand_inputs, expected_outputs=rand_expected_outputs 175 ) 176 ) 177 178 rand_method_test_suites.append( 179 MethodTestSuite( 180 method_name=rand_method_name, test_cases=rand_method_test_cases 181 ) 182 ) 183 184 return ( 185 rand_method_names, 186 rand_inputs_per_program, 187 rand_expected_output_per_program, 188 rand_method_test_suites, 189 ) 190 191 192def get_random_test_suites_with_eager_model( 193 eager_model: torch.nn.Module, 194 method_names: List[str], 195 n_model_inputs: int, 196 model_input_sizes: List[List[int]], 197 dtype: torch.dtype, 198 n_sets_per_plan_test: int, 199) -> Tuple[List[List[MethodInputType]], List[MethodTestSuite]]: 200 """Generate config filled with random inputs for each inference method given eager model 201 202 The details of return type is the same as get_random_test_suites_with_rand_io_lists. 203 """ 204 inputs_per_program = get_rand_input_values( 205 n_tensors=n_model_inputs, 206 sizes=model_input_sizes, 207 n_int=1, 208 dtype=dtype, 209 n_sets_per_plan_test=n_sets_per_plan_test, 210 n_method_test_suites=len(method_names), 211 ) 212 213 method_test_suites: List[MethodTestSuite] = [] 214 215 for method_name, inputs_per_method in zip(method_names, inputs_per_program): 216 method_test_cases: List[MethodTestCase] = [] 217 for inputs in inputs_per_method: 218 method_test_cases.append( 219 MethodTestCase( 220 inputs=inputs, 221 expected_outputs=getattr(eager_model, method_name)(*inputs), 222 ) 223 ) 224 225 method_test_suites.append( 226 MethodTestSuite(method_name=method_name, test_cases=method_test_cases) 227 ) 228 229 return inputs_per_program, method_test_suites 230 231 232class StatefulWrapperModule(torch.nn.Module): 233 """A version of wrapper module that preserves parameters/buffers. 234 235 Use this if you are planning to wrap a non-forward method on an existing 236 module. 237 """ 238 239 def __init__(self, base_mod, method) -> None: # pyre-ignore 240 super().__init__() 241 state_dict = base_mod.state_dict() 242 for name, value in base_mod.named_parameters(): 243 _assign_attr(value, self, name, _AttrKind.PARAMETER) 244 for name, value in base_mod.named_buffers(): 245 _assign_attr( 246 value, self, name, _AttrKind.BUFFER, persistent=name in state_dict 247 ) 248 self.fn = method # pyre-ignore 249 250 def forward(self, *args, **kwargs): # pyre-ignore 251 return self.fn(*args, **kwargs) 252 253 254def get_common_executorch_program() -> ( 255 Tuple[ExecutorchProgramManager, List[MethodTestSuite]] 256): 257 """Helper function to generate a sample BundledProgram with its config.""" 258 eager_model = SampleModel() 259 # Trace to FX Graph. 260 capture_inputs = { 261 m_name: ( 262 (torch.rand(2, 2) - 0.5).to(dtype=torch.int32), 263 (torch.rand(2, 2) - 0.5).to(dtype=torch.int32), 264 DEFAULT_INT_INPUT, 265 ) 266 for m_name in eager_model.method_names 267 } 268 269 # Trace to FX Graph and emit the program 270 method_graphs = { 271 m_name: export( 272 StatefulWrapperModule(eager_model, getattr(eager_model, m_name)), 273 capture_inputs[m_name], 274 ) 275 for m_name in eager_model.method_names 276 } 277 278 executorch_program = to_edge(method_graphs).to_executorch() 279 280 _, method_test_suites = get_random_test_suites_with_eager_model( 281 eager_model=eager_model, 282 method_names=eager_model.method_names, 283 n_model_inputs=2, 284 model_input_sizes=[[2, 2], [2, 2]], 285 dtype=torch.int32, 286 n_sets_per_plan_test=10, 287 ) 288 return executorch_program, method_test_suites 289