xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/lstm_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import copy
2import operator
3from typing import Any, Callable, Optional, Tuple
4
5import torch
6from torch.ao.quantization import (
7    default_weight_fake_quant,
8    default_weight_observer,
9    FakeQuantizeBase,
10    QConfig,
11    QConfigMapping,
12)
13from torch.ao.quantization.backend_config import BackendConfig
14from torch.ao.quantization.observer import _PartialWrapper
15from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
16
17
18# TODO: move all LSTM util functions from fx/utils.py to this file
19def _get_lstm_with_individually_observed_parts(
20    float_lstm: torch.nn.LSTM,
21    example_inputs: Tuple[Any, ...],
22    backend_config: Optional[BackendConfig] = None,
23    linear_output_obs_ctr: Optional[_PartialWrapper] = None,
24    sigmoid_obs_ctr: Optional[_PartialWrapper] = None,
25    tanh_obs_ctr: Optional[_PartialWrapper] = None,
26    cell_state_obs_ctr: Optional[_PartialWrapper] = None,
27    hidden_state_obs_ctr: Optional[_PartialWrapper] = None,
28) -> torch.ao.nn.quantizable.LSTM:
29    """
30    Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
31    with specific observers or fake quantizes assigned to the inner ops or submodules.
32
33    In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is
34    used as an observed custom module, which is responsible for inserting its own
35    observers. By default, all inner ops inherit the parent custom module's QConfig.
36    Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM`
37    and use this helper function to customize the observer insertion logic.
38
39    This is meant to be used to convert a float module to an observed module in the
40    custom module flow.
41
42    Args:
43        `float_lstm`: The float LSTM module
44        `example_inputs`: example inputs for the forward function of the LSTM module
45        `backend_config`: BackendConfig to use to observe the LSTM module
46        `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b,
47            where W is the weight matrix, b is the bias, and x is either the inputs
48            or the hidden state from the previous layer (if any)
49        `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations
50        `tanh_obs_ctr`: observer or fake quantize for tanh activations
51        `cell_state_obs_ctr`: observer or fake quantize for the cell state
52        `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and
53            the output
54
55    Return:
56        A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes
57        assigned to the inner ops.
58    """
59
60    def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig:
61        """
62        Make a QConfig with fixed qparams observers or fake quantizes.
63        """
64        if isinstance(obs_ctr(), FakeQuantizeBase):
65            weight = default_weight_fake_quant
66        else:
67            weight = default_weight_observer
68        return QConfig(activation=obs_ctr, weight=weight)
69
70    quantizable_lstm = torch.ao.nn.quantizable.LSTM(
71        float_lstm.input_size,
72        float_lstm.hidden_size,
73        float_lstm.num_layers,
74        float_lstm.bias,
75        float_lstm.batch_first,
76        float_lstm.dropout,
77        float_lstm.bidirectional,
78    )
79    quantizable_lstm.qconfig = float_lstm.qconfig
80
81    for idx in range(float_lstm.num_layers):
82        quantizable_lstm.layers[
83            idx
84        ] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(
85            float_lstm, idx, float_lstm.qconfig, batch_first=False
86        )
87
88    # Build QConfigMapping for the LSTM cell
89    # Note: FloatFunctional qconfigs will be configured separately below
90    cell_qm = QConfigMapping().set_global(float_lstm.qconfig)  # type: ignore[arg-type]
91    if sigmoid_obs_ctr is not None:
92        cell_qm.set_module_name("input_gate", make_qconfig(sigmoid_obs_ctr))
93        cell_qm.set_module_name("forget_gate", make_qconfig(sigmoid_obs_ctr))
94        cell_qm.set_module_name("output_gate", make_qconfig(sigmoid_obs_ctr))
95    if tanh_obs_ctr is not None:
96        cell_qm.set_module_name("cell_gate", make_qconfig(tanh_obs_ctr))
97
98    # Insert observers into each LSTM cell
99    # TODO: maybe make this work for layer_bw as well
100    for layer in quantizable_lstm.layers:
101        cell = layer.layer_fw.cell
102        cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config)
103        # HACK: Manually replace the activation_post_process following these ops.
104        # This is needed for FloatFunctional ops because there is currently no way
105        # to configure these ops in FX graph mode quantization today. This is because
106        # the FloatFunctional modules simply disappear from the graph after tracing.
107        # In the future, we should rewrite quantizable LSTM without FloatFunctionals.
108        op_index_to_activation_post_process_ctr = {
109            (torch.add, 0): linear_output_obs_ctr,  # gates.add
110            (torch.mul, 0): cell_state_obs_ctr,  # fgate_cx.mul
111            (torch.mul, 1): cell_state_obs_ctr,  # igate_cgate.mul
112            (torch.add, 1): cell_state_obs_ctr,  # fgate_cx_igate_cgate.add
113            (torch.mul, 2): hidden_state_obs_ctr,  # ogate_cy.mul
114        }
115        add_count = 0
116        mul_count = 0
117        for node in cell.graph.nodes:
118            op_index: Optional[Tuple[Callable, int]] = None  # e.g. (torch.add, 1)
119            if node.target == torch.add:
120                op_index = (torch.add, add_count)
121                add_count += 1
122            elif node.target == torch.mul:
123                op_index = (torch.mul, mul_count)
124                mul_count += 1
125            else:
126                # Neither torch.add nor torch.mul
127                continue
128            if op_index not in op_index_to_activation_post_process_ctr:
129                continue
130            assert len(node.users) == 1
131            activation_post_process_name = next(iter(node.users.keys())).name
132            activation_post_process_ctr = op_index_to_activation_post_process_ctr[
133                op_index
134            ]
135            if activation_post_process_ctr is not None:
136                setattr(
137                    cell, activation_post_process_name, activation_post_process_ctr()
138                )
139        layer.layer_fw.cell = cell
140    return quantizable_lstm
141
142
143def _get_reference_quantized_lstm_module(
144    observed_lstm: torch.ao.nn.quantizable.LSTM,
145    backend_config: Optional[BackendConfig] = None,
146) -> torch.ao.nn.quantized.LSTM:
147    """
148    Return a `torch.ao.nn.quantized.LSTM` created from a `torch.ao.nn.quantizable.LSTM`
149    with observers or fake quantizes inserted through `prepare_fx`, e.g. from
150    `_get_lstm_with_individually_observed_parts`.
151
152    This is meant to be used to convert an observed module to a quantized module in the
153    custom module flow.
154
155    Args:
156        `observed_lstm`: a `torch.ao.nn.quantizable.LSTM` observed through `prepare_fx`
157        `backend_config`: BackendConfig to use to produce the reference quantized model
158
159    Return:
160        A reference `torch.ao.nn.quantized.LSTM` module.
161    """
162    quantized_lstm = torch.ao.nn.quantized.LSTM(
163        observed_lstm.input_size,
164        observed_lstm.hidden_size,
165        observed_lstm.num_layers,
166        observed_lstm.bias,
167        observed_lstm.batch_first,
168        observed_lstm.dropout,
169        observed_lstm.bidirectional,
170    )
171
172    for i, layer in enumerate(quantized_lstm.layers):
173        cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell)  # type: ignore[union-attr]
174        cell = convert_to_reference_fx(cell, backend_config=backend_config)  # type: ignore[arg-type]
175        assert isinstance(cell, torch.fx.GraphModule)
176        # HACK: Manually remove input quantize nodes and output dequantize nodes,
177        # since custom modules expect quint8 inputs and outputs for now. Note that
178        # this functionality is supposedly handled through PrepareCustomConfig's
179        # `set_input_quantized_indexes` and `set_output_quantized_indexes`, but that
180        # API doesn't currently handle tuple inputs and outputs, so we have to do
181        # this manually for now. In the future we should (1) relax the restriction
182        # on custom module input/output dtypes, and (2) expand support for complex
183        # input/output structures.
184        for node in cell.graph.nodes:
185            if node.target == torch.quantize_per_tensor:
186                arg = node.args[0]
187                # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1])
188                if arg.target == "x" or (
189                    arg.target == operator.getitem and arg.args[0].target == "hidden"
190                ):
191                    with cell.graph.inserting_before(node):
192                        node.replace_all_uses_with(arg)
193                        cell.graph.erase_node(node)
194            if node.target == "output":
195                # Remove all dequantize nodes in the output tuple
196                for arg in node.args[0]:
197                    with cell.graph.inserting_before(node):
198                        node.replace_input_with(arg, arg.args[0])
199        cell.graph.eliminate_dead_code()
200        cell.recompile()
201        layer.layer_fw.cell = cell
202    return quantized_lstm
203