1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8 9from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 10from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 11 get_symmetric_quantization_config, 12 XNNPACKQuantizer, 13) 14 15 16def quantize(model, example_inputs): 17 """This is the official recommended flow for quantization in pytorch 2.0 export""" 18 logging.info(f"Original model: {model}") 19 quantizer = XNNPACKQuantizer() 20 # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel 21 operator_config = get_symmetric_quantization_config(is_per_channel=False) 22 quantizer.set_global(operator_config) 23 m = prepare_pt2e(model, quantizer) 24 # calibration 25 m(*example_inputs) 26 m = convert_pt2e(m) 27 logging.info(f"Quantized model: {m}") 28 # make sure we can export to flat buffer 29 return m 30