xref: /aosp_15_r20/external/executorch/exir/passes/quantize_io_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2import logging
3from typing import Any, Dict, List, Optional, Union
4
5import numpy as np
6
7import torch
8
9from executorch.exir import EdgeProgramManager
10from executorch.exir.dialects._ops import ops as exir_ops
11
12from executorch.exir.pass_base import ExportPass
13from executorch.exir.tensor import scalar_type_enum
14from torch.fx.passes.infra.pass_base import PassResult
15
16logger = logging.getLogger(__name__)
17
18
19def quantize_input(
20    exported_program, input_index, qparams: Optional[Dict[str, Any]] = None
21):
22    """
23    Modify the program to expect quantized input at given index. The input is expected
24    to be quantizing this input as the first step. Must be called before
25    permute_input_layout. Returns the scale, zero point, qmin, qmax, and dtype of the
26    expected quantization.
27    """
28    graph = exported_program.graph_module.graph
29    name = exported_program.graph_signature.user_inputs[input_index]
30    placeholders = [n for n in graph.nodes if n.op == "placeholder" and n.name == name]
31    assert placeholders
32    target_placeholder = placeholders[0]
33
34    if len(target_placeholder.users) != 1:
35        raise ValueError(f"Input {input_index} has more than one users")
36    quantize = next(iter(target_placeholder.users))
37    if (
38        quantize.target
39        != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
40    ):
41        raise ValueError(f"Input {input_index} is not used by a quantize op")
42
43    # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op
44    need_requant = False
45    if qparams is not None:
46        assert all(
47            qparam in qparams for qparam in ["scale", "zp", "dtype"]
48        ), "dtype/scale/zp must be specified in qparam for input requantization"
49        if qparams["dtype"] != quantize.args[5]:
50            if any(
51                dtype
52                not in [torch.int8, torch.uint8, torch.bool, torch.int16, torch.uint16]
53                for dtype in [qparams["dtype"], quantize.args[5]]
54            ):
55                raise ValueError(
56                    f"Only limited data types are supported for requantization, but got {qparams['dtype']} -> {quantize.args[5]}"
57                )
58
59            need_requant = True
60        elif (
61            not np.isclose(qparams["scale"], quantize.args[1])
62            or qparams["zp"] != quantize.args[2]
63        ):
64            need_requant = True
65
66    if need_requant:
67        assert qparams is not None
68        dtype = qparams["dtype"]
69        qmin = torch.iinfo(dtype).min
70        qmax = torch.iinfo(dtype).max
71        scale = qparams["scale"]
72        zero_point = qparams["zp"]
73        quant_args = (scale, zero_point, qmin, qmax, dtype)
74        logger.info(
75            f"Modifying program to requantize quantized input at index {input_index}"
76        )
77        logger.info(f"Quantization parameters: {quant_args}")
78
79        with exported_program.graph_module.graph.inserting_before(quantize):
80            input_dequant = exported_program.graph_module.graph.call_function(
81                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
82                args=(
83                    target_placeholder,
84                    *quant_args,
85                ),
86            )
87            input_dequant.meta["input_qparams"] = [
88                {
89                    "scale": scale,
90                    "zero_point": zero_point,
91                    "qmin": qmin,
92                    "qmax": qmax,
93                    "dtype": dtype,
94                }
95            ]
96            input_dequant.meta["val"] = quantize.meta["val"].to(torch.float32)
97            target_placeholder.meta["val"] = target_placeholder.meta["val"].to(dtype)
98            quantize.replace_input_with(target_placeholder, input_dequant)
99    else:
100        quant_args = quantize.args[1:]
101        logger.info(f"Modifying program to take quantized input at index {input_index}")
102        logger.info(f"Quantization parameters: {quant_args}")
103
104        target_placeholder.meta["val"] = (
105            exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
106                target_placeholder.meta["val"], *quant_args
107            )
108        )
109        quantize.replace_all_uses_with(quantize.args[0])
110
111    exported_program.graph_module.graph.eliminate_dead_code()
112    return quant_args
113
114
115def quantize_output(exported_program, output_index):
116    """
117    Modify the program to produce quantized output at given index. The model is expected
118    to be dequantizing this output as the last step. Must be called before
119    permute_output_layout. Returns the scale, zero point, qmin, qmax, and dtype of the
120    output quantization.
121    """
122    graph = exported_program.graph_module.graph
123    outputs = [n for n in graph.nodes if n.op == "output"]
124    if len(outputs) != 1:
125        raise NotImplementedError("Only 1 output node is supported")
126
127    output_node = outputs[0]
128    output_list = list(output_node.args[0])
129    if output_index >= len(output_list):
130        raise ValueError(
131            f"{len(output_list)} outputs available, "
132            + f"output index out of bounds: {output_index}"
133        )
134
135    target_output = output_list[output_index]
136    if (
137        target_output.target
138        != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
139    ):
140        raise ValueError("Output {output_index} is not a dequantize op")
141
142    dequant = target_output
143    output_list[output_index] = dequant.args[0]
144    output_node.args = (output_list,)
145    dequant_args = dequant.args[1:]
146    graph.eliminate_dead_code()
147
148    logger.info(
149        f"Modifying program to produce quantized output at index {output_index}"
150    )
151    logger.info(f"Dequantization parameters: {dequant_args}")
152    return dequant_args
153
154
155def get_config_method_name(
156    prefix: Optional[str] = "forward",
157    arg_type: str = "input",
158    index: int = 0,
159    key: str = "scale",
160):
161    if prefix is None:
162        prefix = ""
163    else:
164        prefix = prefix + "_"
165    assert arg_type in ["input", "output"], "arg_type must be either input or output"
166    assert index >= 0, "index must be non-negative"
167    assert key in [
168        "scale",
169        "zp",
170        "quant_min",
171        "quant_max",
172        "dtype",
173    ], "key must be one of scale, zp, quant_min, quant_max, dtype"
174    return f"{prefix}{arg_type}{index}_{key}"
175
176
177class QuantizeInputs(ExportPass):
178    def __init__(
179        self,
180        edge_program_manager: EdgeProgramManager,
181        quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]],
182        method_name: Optional[str] = None,
183    ):
184        super().__init__()
185        self.edge_program_manager = edge_program_manager
186
187        self.quantized_inputs_idx_dict = {}
188        if isinstance(quantized_inputs_idx, dict):
189            self.quantized_inputs_idx_dict = quantized_inputs_idx
190        else:
191            for idx in quantized_inputs_idx:
192                self.quantized_inputs_idx_dict[idx] = None
193        self.param_prefix_name = method_name
194
195    def call(self, graph_module: torch.fx.GraphModule):
196        for i, qparams in self.quantized_inputs_idx_dict.items():
197            quant_args = quantize_input(
198                self.edge_program_manager.exported_program(), i, qparams
199            )
200
201            if not self.edge_program_manager._config_methods:
202                self.edge_program_manager._config_methods = {}
203
204            self.edge_program_manager._config_methods[
205                get_config_method_name(self.param_prefix_name, "input", i, "scale")
206            ] = quant_args[0]
207            self.edge_program_manager._config_methods[  # pyre-ignore
208                get_config_method_name(self.param_prefix_name, "input", i, "zp")
209            ] = quant_args[1]
210            self.edge_program_manager._config_methods[
211                get_config_method_name(self.param_prefix_name, "input", i, "quant_min")
212            ] = quant_args[2]
213            self.edge_program_manager._config_methods[
214                get_config_method_name(self.param_prefix_name, "input", i, "quant_max")
215            ] = quant_args[3]
216            self.edge_program_manager._config_methods[
217                get_config_method_name(self.param_prefix_name, "input", i, "dtype")
218            ] = scalar_type_enum(quant_args[4])
219        return PassResult(graph_module, True)
220
221
222class QuantizeOutputs(ExportPass):
223    def __init__(
224        self,
225        edge_program_manager: EdgeProgramManager,
226        quantized_outputs_idx_list: List[int],
227        method_name: Optional[str] = None,
228    ):
229        super().__init__()
230        self.edge_program_manager = edge_program_manager
231        self.quantized_outputs_idx_list = quantized_outputs_idx_list
232        self.param_prefix_name = method_name
233
234    def call(self, graph_module: torch.fx.GraphModule):
235        for i in self.quantized_outputs_idx_list:
236            dequant_args = quantize_output(
237                self.edge_program_manager.exported_program(), i
238            )  # noqa F841
239
240            if not self.edge_program_manager._config_methods:
241                self.edge_program_manager._config_methods = {}
242
243            self.edge_program_manager._config_methods[
244                get_config_method_name(self.param_prefix_name, "output", i, "scale")
245            ] = dequant_args[0]
246            self.edge_program_manager._config_methods[  # pyre-ignore
247                get_config_method_name(self.param_prefix_name, "output", i, "zp")
248            ] = dequant_args[1]
249            self.edge_program_manager._config_methods[
250                get_config_method_name(self.param_prefix_name, "output", i, "quant_min")
251            ] = dequant_args[2]
252            self.edge_program_manager._config_methods[
253                get_config_method_name(self.param_prefix_name, "output", i, "quant_max")
254            ] = dequant_args[3]
255            self.edge_program_manager._config_methods[
256                get_config_method_name(self.param_prefix_name, "output", i, "dtype")
257            ] = scalar_type_enum(dequant_args[4])
258
259        return PassResult(graph_module, True)
260