1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8from random import randint 9from typing import Any, List, Tuple 10 11import torch 12import torch.nn.functional as F 13from executorch import exir 14 15from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( 16 XnnpackDynamicallyQuantizedPartitioner, 17 XnnpackPartitioner, 18) 19from executorch.backends.xnnpack.utils.configs import ( 20 get_transform_passes, 21 get_xnnpack_edge_compile_config, 22 get_xnnpack_executorch_backend_config, 23) 24from executorch.backends.xnnpack.utils.utils import capture_graph_for_xnnpack 25 26# import the xnnpack backend implementation 27from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend 28from executorch.devtools import BundledProgram 29 30from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite 31from executorch.devtools.bundled_program.serialize import ( 32 serialize_from_bundled_program_to_flatbuffer, 33) 34from executorch.exir import ExecutorchProgram, ExirExportedProgram 35from executorch.exir.backend.backend_api import to_backend, validation_disabled 36 37from executorch.exir.passes.spec_prop_pass import SpecPropPass 38 39from executorch.extension.pybindings.portable_lib import ( # @manual 40 _load_for_executorch_from_buffer, 41) 42from executorch.extension.pytree import tree_flatten 43 44from torch.ao.quantization import ( # @manual 45 default_per_channel_symmetric_qnnpack_qconfig, 46 PlaceholderObserver, 47 QConfig, 48 QConfigMapping, 49) 50 51from torch.ao.quantization.backend_config.executorch import ( 52 get_executorch_backend_config, 53) 54 55from torch.ao.quantization.observer import ( 56 per_channel_weight_observer_range_neg_127_to_127, 57 # default_weight_observer, 58 weight_observer_range_neg_127_to_127, 59) 60from torch.ao.quantization.qconfig_mapping import ( 61 _get_default_qconfig_mapping_with_default_qconfig, 62 _get_symmetric_qnnpack_qconfig_mapping, 63) 64 65from torch.ao.quantization.quantize_fx import ( 66 _convert_to_reference_decomposed_fx, 67 prepare_fx, 68) 69 70from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 71from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 72 get_symmetric_quantization_config, 73 XNNPACKQuantizer, 74) 75from torch.export import export_for_training 76 77from torch.testing import FileCheck 78 79 80def randomize_bn(num_features: int, dimensionality: int = 2) -> torch.nn.Module: 81 if dimensionality == 1: 82 bn = torch.nn.BatchNorm1d(num_features) 83 input_size = (1, num_features, 5) 84 elif dimensionality == 2: 85 bn = torch.nn.BatchNorm2d(num_features) 86 input_size = (1, num_features, 5, 5) 87 else: 88 raise AssertionError( 89 f"Only dimensionality 1 or 2 supported in randomize_bn, got {dimensionality}" 90 ) 91 92 bn.weight = torch.nn.Parameter(torch.randn(num_features)) 93 bn.bias = torch.nn.Parameter(torch.randn(num_features)) 94 95 for _ in range(5): 96 bn(torch.randn(size=input_size)) 97 98 return bn 99 100 101def save_bundled_program( 102 representative_inputs, executorch_program, ref_output, output_path 103): 104 niter = 1 105 106 print("generating bundled program inputs / outputs") 107 108 method_test_cases: List[MethodTestCase] = [] 109 for _ in range(niter): 110 method_test_cases.append( 111 MethodTestCase( 112 inputs=representative_inputs, 113 expected_outputs=ref_output, 114 ) 115 ) 116 117 method_test_suites = [ 118 MethodTestSuite(method_name="forward", method_test_cases=method_test_cases) 119 ] 120 121 print("creating bundled program...") 122 bundled_program = BundledProgram(executorch_program, method_test_suites) 123 124 print("serializing bundled program...") 125 bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( 126 bundled_program 127 ) 128 output_path_with_postfix = f"{output_path}_bundled.pte" 129 print(f"saving bundled program to {output_path}...") 130 131 with open(output_path_with_postfix, "wb") as file: 132 file.write(bundled_program_buffer) 133 134 135class TestXNNPACK(unittest.TestCase): 136 def assert_outputs_equal(self, model_output, ref_output): 137 """ 138 Helper testing function that asserts that the model output and the reference output 139 are equal with some tolerance. Due to numerical differences between eager mode and 140 the XNNPACK's backend, we relax the detal such that absolute tolerance is 1e-3. and 141 relative tolerance is 1e-3. 142 """ 143 144 # Compare the result from executor and eager mode direclty 145 if isinstance(ref_output, tuple) or isinstance(ref_output, list): 146 # Multiple outputs executor always returns tuple, even if there is one output 147 self.assertTrue(len(ref_output) == len(model_output)) 148 for i in range(len(ref_output)): 149 self.assertTrue( 150 torch.allclose( 151 model_output[i], ref_output[i], atol=1e-03, rtol=1e-03 152 ) 153 ) 154 else: 155 # If one output, eager returns tensor while executor tuple of size 1 156 self.assertTrue( 157 torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03) 158 ) 159 160 def lower_module_and_test_output( 161 self, 162 module: Any, 163 sample_inputs: Tuple[torch.Tensor], 164 use_partitioner: bool = False, 165 quantized: bool = False, 166 quantized_dynamic: bool = False, 167 # TODO: remove this after we migrate to use long term flow 168 quantizer_api_test: bool = False, 169 dump_bundled_program: bool = False, # for debugging, dump the generated bundled program file 170 ) -> ExirExportedProgram: 171 """ 172 Helper testing function that takes a torch.nn.Module and lowers it to XNNPACK with 173 the given sample inputs. It then runs the lowered module and compares its 174 outputs with the outputs of the eager module. 175 """ 176 177 if quantizer_api_test: 178 assert isinstance(module, ExirExportedProgram) 179 edge_program = module 180 else: 181 182 class WrappedModule(torch.nn.Module): 183 def __init__(self): 184 super().__init__() 185 self.one_module = module 186 187 def forward(self, *args): 188 return self.one_module(*args) 189 190 edge_program = capture_graph_for_xnnpack(WrappedModule(), sample_inputs) 191 192 partitioner = None 193 if quantized: 194 if quantized_dynamic: 195 partitioner = XnnpackDynamicallyQuantizedPartitioner() 196 else: 197 partitioner = XnnpackPartitioner() 198 else: 199 partitioner = XnnpackPartitioner() 200 201 if use_partitioner: 202 with validation_disabled(): 203 delegated_program = edge_program 204 delegated_program.exported_program = to_backend( 205 edge_program.exported_program, partitioner 206 ) 207 208 executorch_program: ExecutorchProgram = delegated_program.to_executorch( 209 get_xnnpack_executorch_backend_config([SpecPropPass()]), 210 ) 211 else: 212 delegated_program = to_backend( 213 "XnnpackBackend", edge_program.exported_program, [] 214 ) 215 216 exported_program: ExirExportedProgram = capture_graph_for_xnnpack( 217 delegated_program, sample_inputs 218 ) 219 executorch_program: ExecutorchProgram = exported_program.to_executorch( 220 get_xnnpack_executorch_backend_config(), 221 ) 222 223 # print("Graph Module with delegate:") 224 # delegated_module.print_readable() 225 226 # Assert the backend name is xnnpack 227 self.assertEqual( 228 executorch_program.program.execution_plan[0].delegates[0].id, 229 XnnpackBackend.__name__, 230 ) 231 232 ref_output = delegated_program(*sample_inputs) 233 if dump_bundled_program: 234 save_bundled_program( 235 representative_inputs=sample_inputs, 236 executorch_program=executorch_program, 237 ref_output=ref_output, 238 output_path=f"/tmp/xnnpack_test_{randint(1, 99999)}", 239 ) 240 241 # Test the model with executor 242 executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) 243 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 244 inputs_flattened, _ = tree_flatten(sample_inputs) 245 246 model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) 247 248 self.assert_outputs_equal(model_output, ref_output) 249 250 return delegated_program 251 252 def lower_and_test_with_partitioner( 253 self, 254 graph_module, 255 example_inputs, 256 quantized: bool = False, 257 quantized_dynamic: bool = False, 258 ): 259 self.lower_module_and_test_output( 260 graph_module, 261 example_inputs, 262 use_partitioner=True, 263 quantized=quantized, 264 quantized_dynamic=quantized_dynamic, 265 ) 266 self.lower_module_and_test_output( 267 graph_module, 268 example_inputs, 269 use_partitioner=False, 270 quantized=quantized, 271 quantized_dynamic=quantized_dynamic, 272 ) 273 274 def quantize_and_test_model( 275 self, 276 module, 277 example_inputs, 278 per_channel_quant=False, 279 ): 280 if per_channel_quant: 281 qconfig = default_per_channel_symmetric_qnnpack_qconfig 282 qconfig_mapping = _get_default_qconfig_mapping_with_default_qconfig( 283 False, "qnnpack", qconfig 284 ) 285 else: 286 qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping() 287 module.eval() 288 prepared = prepare_fx( 289 module, 290 qconfig_mapping, 291 example_inputs, 292 backend_config=get_executorch_backend_config(), 293 ) 294 converted = _convert_to_reference_decomposed_fx( 295 prepared, 296 backend_config=get_executorch_backend_config(), 297 ) 298 299 # Let's assert quant flow did something (not care what, but anything) for this module. 300 # This is to ensure we are not just passing through an unquantized model. 301 FileCheck().check("torch.ops.quantized_decomposed").run(converted.code) 302 303 self.lower_module_and_test_output( 304 module=converted, 305 sample_inputs=example_inputs, 306 use_partitioner=True, 307 quantized=True, 308 ) 309 310 # TODO: replace quantize_and_test_model with this after 311 # QNNPACKQuantizer is more mature 312 def quantize_and_test_model_with_quantizer( 313 self, 314 module, 315 example_inputs, 316 ): 317 module.eval() 318 # program capture 319 320 m = export_for_training( 321 module, 322 example_inputs, 323 ).module() 324 325 quantizer = XNNPACKQuantizer() 326 quantization_config = get_symmetric_quantization_config() 327 quantizer.set_global(quantization_config) 328 prepared = prepare_pt2e(m, quantizer) 329 converted = convert_pt2e(prepared) 330 331 captured_program = exir.capture( 332 converted, 333 example_inputs, 334 config=exir.CaptureConfig(enable_aot=True, _unlift=True), 335 ) 336 337 edge_program = captured_program.to_edge( 338 get_xnnpack_edge_compile_config() 339 ).transform(*get_transform_passes()) 340 delegated_module = self.lower_module_and_test_output( 341 module=edge_program, 342 sample_inputs=example_inputs, 343 use_partitioner=True, 344 quantized=True, 345 quantizer_api_test=True, 346 ) 347 supported_ops = { 348 "torch.ops.aten.addmm.default", 349 "torch.ops.aten.convolution.default", 350 "torch.ops.aten.relu.default", 351 "torch.ops.aten.add.Tensor", 352 "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 353 } 354 for op in supported_ops: 355 FileCheck().check_count(op, 0, exactly=True).run( 356 delegated_module.exported_program.graph_module.code 357 ) 358 359 def _test_xnnpack_dqlinear( 360 self, 361 weight_qconfig, 362 use_bias: bool, 363 dump_bundled_program: bool = False, 364 ): 365 assert weight_qconfig in [ 366 weight_observer_range_neg_127_to_127, 367 per_channel_weight_observer_range_neg_127_to_127, 368 ] 369 in_size = 2 370 input_size = 4 371 output_size = 5 372 linear = torch.nn.Linear(input_size, output_size, bias=use_bias) 373 linear.weight = torch.nn.Parameter(torch.rand(output_size, input_size)) 374 if use_bias: 375 linear.bias = torch.nn.Parameter(torch.rand(output_size)) 376 example_inputs = (torch.rand(3, in_size, input_size, dtype=torch.float),) 377 act_affine_quant_obs = PlaceholderObserver.with_args( 378 dtype=torch.qint8, 379 qscheme=torch.per_tensor_affine, 380 quant_min=-128, 381 quant_max=127, 382 eps=2**-12, 383 is_dynamic=True, 384 ) 385 qconfig_mapping = QConfigMapping().set_object_type( 386 F.linear, 387 QConfig( 388 activation=act_affine_quant_obs, 389 weight=weight_qconfig, 390 ), 391 ) 392 393 prepared_linear = prepare_fx( 394 linear, 395 qconfig_mapping, 396 example_inputs, 397 backend_config=get_executorch_backend_config(), 398 ) 399 400 converted_linear = _convert_to_reference_decomposed_fx( 401 prepared_linear, 402 ) 403 404 captured_dqlinear = capture_graph_for_xnnpack(converted_linear, example_inputs) 405 406 captured_dqlinear.exported_program.graph_module.graph.print_tabular() 407 408 lowered_module = to_backend( 409 "XnnpackBackend", captured_dqlinear.exported_program, [] 410 ) 411 412 class CompositeModule(torch.nn.Module): 413 def __init__(self): 414 super().__init__() 415 self.lowered_module = lowered_module 416 417 def forward(self, x): 418 return self.lowered_module(x) 419 420 composite_model = CompositeModule() 421 composite_model(*example_inputs) 422 423 exported_program: ExirExportedProgram = capture_graph_for_xnnpack( 424 composite_model, example_inputs 425 ) 426 executorch_program: ExecutorchProgram = exported_program.to_executorch( 427 get_xnnpack_executorch_backend_config(), 428 ) 429 430 self.assertEqual( 431 executorch_program.program.execution_plan[0].delegates[0].id, 432 XnnpackBackend.__name__, 433 ) 434 435 ref_output = captured_dqlinear(*example_inputs) 436 ref_output = composite_model(*example_inputs) 437 print("ref_output:", ref_output) 438 439 if dump_bundled_program: 440 mm_str = "addmm" if use_bias else "mm" 441 filename = f"/tmp/dqlinear_{mm_str}" 442 if weight_qconfig == weight_observer_range_neg_127_to_127: 443 filename = f"{filename}_per_tensor" 444 else: 445 filename = f"{filename}_per_channel" 446 447 save_bundled_program( 448 representative_inputs=example_inputs, 449 executorch_program=executorch_program, 450 ref_output=ref_output, 451 output_path=filename, 452 ) 453 454 executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) 455 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 456 inputs_flattened, _ = tree_flatten(example_inputs) 457 458 model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) 459 ref_output = composite_model(*example_inputs) 460 print("ref_output (composite):", ref_output) 461 462 print("Model_output:", model_output[0]) 463 464 # Compare the result from executor and eager mode directly 465 self.assertTrue( 466 torch.allclose(model_output[0], ref_output, atol=4e-03, rtol=1e-03) 467 ) 468 469 def _get_dqlinear_graph_module(self, weight_qconfig, linear, example_inputs): 470 act_affine_quant_obs = PlaceholderObserver.with_args( 471 dtype=torch.qint8, 472 qscheme=torch.per_tensor_affine, 473 quant_min=-128, 474 quant_max=127, 475 eps=2**-12, 476 is_dynamic=True, 477 ) 478 qconfig_mapping = QConfigMapping().set_object_type( 479 F.linear, 480 QConfig( 481 activation=act_affine_quant_obs, 482 weight=weight_qconfig, 483 ), 484 ) 485 486 prepared_linear = prepare_fx( 487 linear, 488 qconfig_mapping, 489 example_inputs, 490 backend_config=get_executorch_backend_config(), 491 ) 492 493 converted_dqlinear: torch.fx.GraphModule = _convert_to_reference_decomposed_fx( 494 prepared_linear, backend_config=get_executorch_backend_config() 495 ) 496 497 return converted_dqlinear 498 499 def _test_xnnpack_dqlinear_with_partitioner(self, weight_qconfig, use_bias=True): 500 in_size = 1 501 input_size = 4 502 output_size = 5 503 linear = torch.nn.Linear(input_size, output_size, bias=use_bias) 504 linear.weight = torch.nn.Parameter(torch.rand(output_size, input_size)) 505 if use_bias: 506 linear.bias = torch.nn.Parameter(torch.rand(output_size)) 507 example_inputs = (torch.rand(in_size, input_size, dtype=torch.float),) 508 converted_dqlinear = self._get_dqlinear_graph_module( 509 weight_qconfig, linear, example_inputs 510 ) 511 512 self.lower_and_test_with_partitioner( 513 graph_module=converted_dqlinear, 514 example_inputs=example_inputs, 515 quantized=True, 516 quantized_dynamic=True, 517 ) 518 519 def _test_xnnpack_custom_dqlinear_with_partitioner_only( 520 self, LinearModule, example_inputs 521 ): 522 linear = LinearModule() 523 weight_qconfig = per_channel_weight_observer_range_neg_127_to_127 524 converted_dqlinear = self._get_dqlinear_graph_module( 525 weight_qconfig, linear, example_inputs 526 ) 527 528 # Only run test with partitioner 529 self.lower_module_and_test_output( 530 module=converted_dqlinear, 531 sample_inputs=example_inputs, 532 use_partitioner=True, 533 quantized=True, 534 quantized_dynamic=True, 535 ) 536