xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.backends.qualcomm.builders.utils import get_parameter
9from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING
10from executorch.exir.dialects._ops import ops as exir_ops
11
12
13q_ops = {
14    exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
15    exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
16    exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
17}
18
19dq_ops = {
20    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
21    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
22    exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
23}
24
25
26def get_quant_attrs(
27    edge_program: torch.export.ExportedProgram, quant_node: torch.fx.Node
28):
29    quant_attr_keys = [arg.name for arg in quant_node.target._schema.arguments][1:]
30    quant_attrs = dict.fromkeys(quant_attr_keys)
31
32    for i in range(1, len(quant_node.args)):
33        attr_n = quant_node.args[i]
34
35        value = attr_n
36        if isinstance(attr_n, torch.fx.node.Node):
37            # could be a commonly shared attribute between q & dq
38            if attr_n.target == exir_ops.edge.aten._to_copy.default:
39                value = get_parameter(attr_n.args[0], edge_program)
40            else:
41                value = get_parameter(attr_n, edge_program)
42        quant_attrs[quant_attr_keys[i - 1]] = value
43
44    quant_attrs[QCOM_ENCODING] = quant_node.target
45    return quant_attrs
46