1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.backends.qualcomm.builders.utils import get_parameter 9from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING 10from executorch.exir.dialects._ops import ops as exir_ops 11 12 13q_ops = { 14 exir_ops.edge.quantized_decomposed.quantize_per_channel.default, 15 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 16 exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, 17} 18 19dq_ops = { 20 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 21 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, 22 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 23} 24 25 26def get_quant_attrs( 27 edge_program: torch.export.ExportedProgram, quant_node: torch.fx.Node 28): 29 quant_attr_keys = [arg.name for arg in quant_node.target._schema.arguments][1:] 30 quant_attrs = dict.fromkeys(quant_attr_keys) 31 32 for i in range(1, len(quant_node.args)): 33 attr_n = quant_node.args[i] 34 35 value = attr_n 36 if isinstance(attr_n, torch.fx.node.Node): 37 # could be a commonly shared attribute between q & dq 38 if attr_n.target == exir_ops.edge.aten._to_copy.default: 39 value = get_parameter(attr_n.args[0], edge_program) 40 else: 41 value = get_parameter(attr_n, edge_program) 42 quant_attrs[quant_attr_keys[i - 1]] = value 43 44 quant_attrs[QCOM_ENCODING] = quant_node.target 45 return quant_attrs 46