1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import executorch.exir as exir 10 11import torch 12from executorch.backends.fb.qnnpack.partition.qnnpack_partitioner import ( 13 QnnpackPartitioner, 14) 15from executorch.backends.fb.qnnpack.qnnpack_preprocess import QnnpackBackend 16from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( 17 XnnpackFloatingPointPartitioner, 18) 19 20# import the xnnpack backend implementation 21from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend 22 23from executorch.exir import CaptureConfig 24from executorch.exir.backend.backend_api import to_backend, validation_disabled 25from executorch.exir.passes.spec_prop_pass import SpecPropPass 26 27from executorch.extension.pybindings.portable_lib import ( # @manual 28 _load_for_executorch_from_buffer, 29) 30from executorch.extension.pytree import tree_flatten 31from torch.ao.quantization.backend_config.executorch import ( 32 get_executorch_backend_config, 33) 34from torch.ao.quantization.observer import ( 35 default_dynamic_quant_observer, 36 default_per_channel_weight_observer, 37) 38from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping 39from torch.ao.quantization.quantize_fx import ( 40 _convert_to_reference_decomposed_fx, 41 prepare_fx, 42) 43 44 45class TestXnnQnnBackends(unittest.TestCase): 46 def test_add_xnnpack_and_dqlinear_qnn(self): 47 qconfig_mapping = QConfigMapping().set_object_type( 48 torch.nn.Linear, 49 QConfig( 50 activation=default_dynamic_quant_observer, 51 weight=default_per_channel_weight_observer, 52 ), 53 ) 54 in_size = 1 55 in_features = 3 56 out_features = 4 57 58 class LinearAndAdd(torch.nn.Module): 59 def __init__(self): 60 super().__init__() 61 self.linear = torch.nn.Linear(in_features, out_features) 62 63 def forward(self, x, y): 64 return self.linear(x) + y 65 66 linear_and_add_mod = LinearAndAdd() 67 68 example_inputs = ( 69 torch.ones(in_size, in_features, dtype=torch.float), 70 torch.ones(in_size, out_features, dtype=torch.float), 71 ) 72 73 prepared_mod = prepare_fx( 74 linear_and_add_mod, 75 qconfig_mapping, 76 example_inputs, 77 backend_config=get_executorch_backend_config(), 78 ) 79 80 converted_mod: torch.fx.GraphModule = _convert_to_reference_decomposed_fx( 81 prepared_mod 82 ) 83 84 # Step 2: EXIR capturing 85 capture_config = CaptureConfig(enable_dynamic_shape=False) 86 captured_mod = exir.capture( 87 converted_mod, example_inputs, config=capture_config 88 ).to_edge( 89 exir.EdgeCompileConfig( 90 _check_ir_validity=False, 91 ) 92 ) 93 94 # Step 3.1: Lower dynamic quant linear to qnnpack 95 with validation_disabled(): 96 module_with_qnnpack_delegate = captured_mod 97 module_with_qnnpack_delegate.exported_program = to_backend( 98 captured_mod.exported_program, QnnpackPartitioner() 99 ) 100 101 # Step 3.2: Lower add to xnnpack 102 with validation_disabled(): 103 module_with_xnn_and_qnn = module_with_qnnpack_delegate 104 module_with_xnn_and_qnn.exported_program = to_backend( 105 module_with_qnnpack_delegate.exported_program, 106 XnnpackFloatingPointPartitioner(), 107 ) 108 109 program_with_delegates = module_with_xnn_and_qnn.to_executorch( 110 exir.ExecutorchBackendConfig(passes=[SpecPropPass()]), 111 ) 112 # The first delegate backend is Qnnpack 113 self.assertEqual( 114 program_with_delegates.program.execution_plan[0].delegates[0].id, 115 QnnpackBackend.__name__, 116 ) 117 # The second delegate backend is Xnnpack 118 self.assertEqual( 119 program_with_delegates.program.execution_plan[0].delegates[1].id, 120 XnnpackBackend.__name__, 121 ) 122 123 executorch_module = _load_for_executorch_from_buffer( 124 program_with_delegates.buffer 125 ) 126 inputs_flattened, _ = tree_flatten(example_inputs) 127 128 model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) 129 ref_output = captured_mod(*example_inputs) 130 131 # Compare the result from executor and eager mode direclty 132 self.assertTrue( 133 torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03) 134 ) 135