1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8import unittest 9 10import torch 11from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower 12from executorch.exir.passes.quantize_io_pass import ( 13 get_config_method_name, 14 QuantizeInputs, 15 QuantizeOutputs, 16) 17from executorch.exir.tensor import get_scalar_type 18from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 19 20from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 21 get_symmetric_quantization_config, 22 XNNPACKQuantizer, 23) 24from torch.testing import FileCheck 25 26op_str = { 27 "q": "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", 28 "dq": "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default", 29} 30 31 32class TestQuantIOPass(unittest.TestCase): 33 class Add(torch.nn.Module): 34 def forward(self, x, y): 35 return x + y 36 37 def _quantize(self, mod, example_inputs): 38 quantizer = XNNPACKQuantizer() 39 operator_config = get_symmetric_quantization_config() 40 quantizer.set_global(operator_config) 41 m = torch.export.export_for_training( 42 mod, copy.deepcopy(example_inputs) 43 ).module() 44 m = prepare_pt2e(m, quantizer) 45 _ = m(*example_inputs) 46 m = convert_pt2e(m) 47 exported_program = torch.export.export_for_training(m, example_inputs) 48 return exported_program 49 50 def _check_count(self, op, count, epm): 51 code = epm.exported_program().graph_module.code 52 FileCheck().check_count(op, count, exactly=True).run(code) 53 54 def _get_edge_prog_manager(self, mod, example_inputs): 55 exported_program = self._quantize(mod, example_inputs) 56 edge_program_manager = to_edge_transform_and_lower( 57 exported_program, 58 transform_passes=[], 59 partitioner=None, 60 compile_config=EdgeCompileConfig(_check_ir_validity=False), 61 ) 62 63 self._check_count(op_str["dq"], 3, edge_program_manager) 64 self._check_count(op_str["q"], 3, edge_program_manager) 65 return edge_program_manager 66 67 def test_add_drop_q_inputs(self) -> None: 68 example_inputs = (torch.randn(1, 5), torch.randn(1, 5)) 69 mod = self.Add().eval() 70 edge_program_manager = self._get_edge_prog_manager(mod, example_inputs) 71 reference_outputs = edge_program_manager.exported_program().module()( 72 *example_inputs 73 ) 74 75 edge_program_manager_qin = edge_program_manager.transform( 76 [ 77 QuantizeInputs( 78 edge_program_manager=edge_program_manager, 79 quantized_inputs_idx=[0, 1], 80 method_name="forward", 81 ) 82 ] 83 ) 84 self._check_count(op_str["dq"], 3, edge_program_manager) 85 self._check_count(op_str["q"], 1, edge_program_manager) 86 87 quantized_example_inputs = [] 88 for i in range(len(example_inputs)): 89 d = edge_program_manager_qin._config_methods 90 scale = d[get_config_method_name("forward", "input", i, "scale")] 91 zp = d[get_config_method_name("forward", "input", i, "zp")] 92 quant_min = d[get_config_method_name("forward", "input", i, "quant_min")] 93 quant_max = d[get_config_method_name("forward", "input", i, "quant_max")] 94 dtype = get_scalar_type( 95 d[get_config_method_name("forward", "input", i, "dtype")] 96 ) 97 98 quantized_example_inputs.append( 99 torch.ops.quantized_decomposed.quantize_per_tensor.default( 100 example_inputs[i], scale, zp, quant_min, quant_max, dtype 101 ), 102 ) 103 quantized_example_inputs = tuple(quantized_example_inputs) 104 output = edge_program_manager_qin.exported_program().module()( 105 *quantized_example_inputs 106 ) 107 torch.testing.assert_close( 108 reference_outputs[0], 109 output[0], 110 ) 111 112 def test_add_drop_dq_output(self) -> None: 113 example_inputs = (torch.randn(1, 5), torch.randn(1, 5)) 114 mod = self.Add().eval() 115 edge_program_manager = self._get_edge_prog_manager(mod, example_inputs) 116 reference_outputs = edge_program_manager.exported_program().module()( 117 *example_inputs 118 ) 119 120 edge_program_manager_dqout = edge_program_manager.transform( 121 [ 122 QuantizeOutputs( 123 edge_program_manager=edge_program_manager, 124 quantized_outputs_idx_list=[0], 125 method_name="forward", 126 ) 127 ] 128 ) 129 self._check_count(op_str["dq"], 2, edge_program_manager) 130 self._check_count(op_str["q"], 3, edge_program_manager) 131 132 quantized_outputs = edge_program_manager_dqout.exported_program().module()( 133 *example_inputs 134 ) 135 136 dequantized_outputs = [] 137 for i in range(len(quantized_outputs)): 138 d = edge_program_manager_dqout._config_methods 139 scale = d[get_config_method_name("forward", "output", i, "scale")] 140 zp = d[get_config_method_name("forward", "output", i, "zp")] 141 q_min = d[get_config_method_name("forward", "output", i, "quant_min")] 142 q_max = d[get_config_method_name("forward", "output", i, "quant_max")] 143 dtype = get_scalar_type( 144 d[get_config_method_name("forward", "output", i, "dtype")] 145 ) 146 dequantized_outputs.append( 147 torch.ops.quantized_decomposed.dequantize_per_tensor.default( 148 quantized_outputs[i], scale, zp, q_min, q_max, dtype 149 ) 150 ) 151 dequantized_outputs = tuple(dequantized_outputs) 152 153 torch.testing.assert_close( 154 reference_outputs[0], 155 dequantized_outputs[0], 156 ) 157