# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import unittest import torch import torchvision from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch.ao.ns.fx.utils import compute_sqnr from torch.ao.quantization import QConfigMapping # @manual from torch.ao.quantization.backend_config import get_executorch_backend_config from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig from torch.ao.quantization.quantize_fx import prepare_fx from torch.ao.quantization.quantize_pt2e import ( _convert_to_reference_decomposed_fx, convert_pt2e, prepare_pt2e, ) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) from torch.export import export from torch.testing import FileCheck from torch.testing._internal.common_quantized import override_quantized_engine # load executorch out variant ops torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") class TestQuantization(unittest.TestCase): """prepare_pt2e and convert_pt2e are OSS APIs, the rest are all meta-only APIs for now, but we plan to open source them in the future """ def test_resnet(self) -> None: import copy with override_quantized_engine("qnnpack"): torch.backends.quantized.engine = "qnnpack" example_inputs = (torch.randn(1, 3, 224, 224),) m = torchvision.models.resnet18().eval() m_copy = copy.deepcopy(m) # program capture m = torch.export.export_for_training( m, copy.deepcopy(example_inputs) ).module() quantizer = XNNPACKQuantizer() operator_config = get_symmetric_quantization_config(is_per_channel=True) quantizer.set_global(operator_config) m = prepare_pt2e(m, quantizer) # pyre-fixme[6] self.assertEqual( id(m.activation_post_process_3), id(m.activation_post_process_2) ) after_prepare_result = m(*example_inputs)[0] m = convert_pt2e(m) # TODO: conv, conv_relu, linear delegation # quantized ops to implement: add_relu compile_config = EdgeCompileConfig( _check_ir_validity=False, ) m = to_edge( export(m, example_inputs), compile_config=compile_config ).transform([QuantFusionPass(), SpecPropPass()]) after_quant_result = m.exported_program().module()(*example_inputs)[0] FileCheck().check( "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor" ).check( "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor" ).run( m.exported_program().graph_module.code ) # after_quant_fusion_result = m(*example_inputs)[0] # TODO: implement torch.ops.quantized_decomposed.add_relu.out # m = m.to_executorch().dump_graph_module() # after_to_executorch = m(*example_inputs)[0] # test the result before and after to_executorch matches # TODO: debug why this is a mismatch # self.assertTrue(torch.equal(after_quant_fusion_result, after_to_executorch)) # self.assertEqual(compute_sqnr(after_quant_fusion_result, after_to_executorch), torch.tensor(float("inf"))) # comparing with existing fx graph mode quantization reference flow qconfig = default_per_channel_symmetric_qnnpack_qconfig qconfig_mapping = QConfigMapping().set_global(qconfig) backend_config = get_executorch_backend_config() m_fx = prepare_fx( m_copy, qconfig_mapping, example_inputs, backend_config=backend_config ) after_prepare_result_fx = m_fx(*example_inputs) m_fx = _convert_to_reference_decomposed_fx( m_fx, backend_config=backend_config ) after_quant_result_fx = m_fx(*example_inputs) # the result matches exactly after prepare self.assertTrue( torch.allclose(after_prepare_result, after_prepare_result_fx, atol=1e-6) ) # there are slight differences after convert due to different implementations # of quant/dequant self.assertTrue( torch.max(after_quant_result - after_quant_result_fx) < 1e-1 ) self.assertTrue( compute_sqnr(after_quant_result, after_quant_result_fx) > 35 )