xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/_equalize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3import warnings
4from collections import namedtuple
5from typing import Any, Dict, List, Optional, Tuple
6
7import torch
8import torch.ao.nn.intrinsic as nni
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
12from torch.ao.quantization.observer import (
13    _with_args,
14    ObserverBase,
15    PerChannelMinMaxObserver,
16)
17from torch.ao.quantization.utils import _parent_name, check_min_max_valid
18from torch.fx import GraphModule
19from torch.fx.graph import Node
20
21from .utils import (
22    get_new_attr_name_with_prefix,
23    maybe_get_next_module,
24    node_arg_is_weight,
25)
26
27
28CUSTOM_MODULE_SUPP_LIST: List[Any] = []
29
30
31def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
32    """Reshapes the scale so that we can multiply it to the input by the given axis."""
33    new_shape = [1] * input.ndim
34    new_shape[axis] = input.size(axis)
35    return scale.view(new_shape)
36
37
38qsheme_mapping_per_tensor_to_per_channel = {
39    torch.per_tensor_affine: torch.per_channel_affine,
40    torch.per_tensor_symmetric: torch.per_channel_symmetric,
41}
42
43
44class _InputEqualizationObserver(nn.Module):
45    r"""Observer for tracking the running min/max values of input columns, and
46    computing the quantization parameters for the overall min/max input values.
47
48    Args:
49        dtype: Quantized data type
50        qscheme: Quantization scheme
51        quant_min: Minimum quantization value. If unspecified, it will
52            follow the 8-bit setup.
53        quant_max: Maximum quantization value. If unspecified, it will
54            follow the 8-bit setup.
55
56    The running minimum/maximum :math:`x_\text{min/max}` are computed in the
57    same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`,
58    with the difference that the running min/max values are stored per column.
59    This observer is intended to be used along with a WeightEqualizationObserver
60    to calculate the equalization scale.
61    """
62
63    def __init__(
64        self,
65        dtype=torch.quint8,
66        qscheme=torch.per_tensor_affine,
67        quant_min=None,
68        quant_max=None,
69        factory_kwargs=None,
70    ) -> None:
71        super().__init__()
72
73        if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
74            raise TypeError("Input qscheme must be per-tensor")
75
76        self.dtype = dtype
77        self.qscheme = qscheme
78
79        per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
80        self.input_obs = PerChannelMinMaxObserver(
81            ch_axis=1,
82            dtype=dtype,
83            qscheme=per_channel_qscheme,
84            quant_min=quant_min,
85            quant_max=quant_max,
86            factory_kwargs=factory_kwargs,
87        )
88
89        self.equalization_scale = torch.tensor(1)
90        self.equalization_shape: List[int] = []
91
92    def forward(self, x_orig):
93        if not (x_orig.ndim >= 2 and x_orig.ndim <= 5):
94            raise ValueError(
95                "InputEqualizationObserver only supports Linear and Conv layers"
96            )
97
98        # Calculate the shape needed to reshape the equalization scale later (needed for Conv layers)
99        self.equalization_shape = [1] * x_orig.ndim
100        self.equalization_shape[1] = x_orig.size(1)
101
102        return self.input_obs(x_orig)
103
104    def get_input_minmax(self):
105        return (self.input_obs.min_val, self.input_obs.max_val)
106
107    def set_equalization_scale(self, equalization_scale):
108        # Reshape the equalization scale along axis=1 so that it can be
109        # multiplied with the input along axis=1
110        if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
111            return
112        self.equalization_scale = torch.reshape(
113            equalization_scale, self.equalization_shape
114        )
115
116    def calculate_scaled_minmax(self):
117        r"""Returns the scaled min/max inputs"""
118        if (
119            self.equalization_scale.nelement() == 1
120            and self.equalization_scale == torch.tensor(1)
121        ):
122            warnings.warn(
123                "Must call calculate_equalization_scale before calling calculate_scaled_minmax. "
124                + "Will not scale the next quantization observer."
125            )
126            return None, None
127
128        # Calculate qparams for the scaled min/max inputs
129        # Scale the input by the equalization scale located at the same column
130        # index
131        (min_inputs, max_inputs) = self.get_input_minmax()
132        equalization_scale_reshaped = reshape_scale(
133            self.equalization_scale, 0, min_inputs
134        )
135        min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped))
136        max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped))
137
138        return min_input_scaled, max_input_scaled
139
140    with_args = classmethod(_with_args)
141
142
143class _WeightEqualizationObserver(nn.Module):
144    r"""Observer for tracking the running min/max values of weight columns and
145    rows, and computing the quantization parameters for the weight rows.
146
147    Args:
148        dtype: Quantized data type
149        qscheme: Quantization scheme
150        quant_min: Minimum quantization value. If unspecified, it will
151            follow the 8-bit setup.
152        quant_max: Maximum quantization value. If unspecified, it will
153            follow the 8-bit setup.
154
155    This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
156    to record the running minimum and maximum of columns of incoming weight
157    tensors. This observer is intended to be used along with an
158    InputEqualizationObserver to calculate the equalization scale.
159
160    The running minimum/maximum :math:`w_\text{min/max}` are computed in the
161    same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`.
162    """
163
164    def __init__(
165        self,
166        dtype=torch.qint8,
167        qscheme=torch.per_tensor_affine,
168        quant_min=None,
169        quant_max=None,
170        factory_kwargs=None,
171    ) -> None:
172        super().__init__()
173
174        self.dtype = dtype
175        self.qscheme = qscheme
176        self.ch_axis = 1
177
178        per_channel_qscheme = qscheme
179        if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
180            per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
181        self.weight_col_obs = PerChannelMinMaxObserver(
182            ch_axis=1,
183            dtype=dtype,
184            qscheme=per_channel_qscheme,
185            quant_min=quant_min,
186            quant_max=quant_max,
187            factory_kwargs=factory_kwargs,
188        )
189
190        self.equalization_scale = torch.tensor(1)
191
192    def forward(self, w_orig):
193        if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
194            raise ValueError(
195                "InputEqualizationObserver only supports Linear and Conv layers"
196            )
197
198        return self.weight_col_obs(w_orig)
199
200    def get_weight_col_minmax(self):
201        return (self.weight_col_obs.min_val, self.weight_col_obs.max_val)
202
203    def set_equalization_scale(self, equalization_scale):
204        self.equalization_scale = equalization_scale
205
206    with_args = classmethod(_with_args)
207
208
209def calculate_equalization_scale(
210    input_obs: _InputEqualizationObserver, weight_obs: _WeightEqualizationObserver
211) -> torch.Tensor:
212    r"""Calculates the equalization scale and sets the equalization_scale value
213    in the observers.
214
215    Args:
216        input_obs: Observer that tracks the ranges for the input columns
217        weight_obs: Observer that tracks the ranges for the weight columns
218    """
219
220    (min_inputs, max_inputs) = input_obs.get_input_minmax()
221    (min_weights, max_weights) = weight_obs.get_weight_col_minmax()
222
223    if not (
224        check_min_max_valid(min_inputs, max_inputs)
225        and check_min_max_valid(min_weights, max_weights)
226    ):
227        warnings.warn(
228            "Must run observer before calling calculate_equalization_scale. "
229            + "Returning default equalization scale torch.tensor(1)."
230        )
231        return torch.tensor(1)
232
233    if not (min_inputs.shape == min_weights.shape):
234        raise ValueError(
235            "Input and Weight must have the same column dimension. "
236            + f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
237        )
238
239    equalization_scale = torch.sqrt(
240        (max_weights - min_weights) / (max_inputs - min_inputs)
241    )
242    # Replace all 'inf', 'nan', 0's with 1s to prevent errors
243    equalization_scale[equalization_scale == 0.0] = 1
244    equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1)
245    return equalization_scale
246
247
248class EqualizationQConfig(
249    namedtuple("EqualizationQConfig", ["input_activation", "weight"])
250):
251    """
252    Describes how to quantize a layer or a part of the network specifically for
253    input-weight equalization by providing settings (observer classes) for
254    inputs, outputs, and weights.
255
256    Note that EqualizationQConfig needs to contain observer **classes** (like
257    MinMaxObserver) or a callable that returns instances on invocation, not the
258    concrete observer instances themselves.
259    Quantization function will instantiate observers multiple times for each of
260    the layers.
261
262    Observer classes have usually reasonable default arguments, but they can be
263    overwritten with `with_args` method (that behaves like functools.partial):
264
265    my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8),
266                                    weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
267    """
268
269    def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
270        if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
271            raise ValueError(
272                "EqualizationQConfig received observer instance, please pass observer class instead. "
273                + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed"
274            )
275        self = super().__new__(cls, input_activation, weight)
276        return self
277
278
279input_equalization_observer = _InputEqualizationObserver.with_args(
280    dtype=torch.quint8, qscheme=torch.per_tensor_symmetric
281)
282weight_equalization_observer = _WeightEqualizationObserver.with_args(
283    dtype=torch.qint8, qscheme=torch.per_channel_symmetric
284)
285default_equalization_qconfig = EqualizationQConfig(
286    input_activation=input_equalization_observer, weight=weight_equalization_observer
287)
288
289
290def fused_module_supports_equalization(module) -> bool:
291    """Checks if the fused node supports equalization."""
292    return type(module) in [
293        nni.LinearReLU,
294        nni.ConvReLU1d,
295        nni.ConvReLU2d,
296        nni.ConvReLU3d,
297    ]
298
299
300def nn_module_supports_equalization(module) -> bool:
301    """Checks if the torch.nn node supports equalization."""
302    return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
303
304
305def custom_module_supports_equalization(module) -> bool:
306    """Checks if the custom node supports equalization."""
307    return type(module) in CUSTOM_MODULE_SUPP_LIST
308
309
310def node_supports_equalization(node: Node, modules) -> bool:
311    """Checks if the current node supports equalization
312    Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
313    """
314    if node.op == "call_module":
315        return (
316            nn_module_supports_equalization(modules[str(node.target)])
317            or fused_module_supports_equalization(modules[str(node.target)])
318            or custom_module_supports_equalization(modules[str(node.target)])
319        )
320    elif node.op == "call_function":
321        return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
322    return False
323
324
325def is_equalization_observer(observer: nn.Module) -> bool:
326    return isinstance(
327        observer, (_InputEqualizationObserver, _WeightEqualizationObserver)
328    )
329
330
331###############################################################################
332# Functions for equalization during convert                                   #
333###############################################################################
334
335
336def get_op_node_and_weight_eq_obs(
337    input_eq_obs_node: Node, model: GraphModule, modules: Dict[str, nn.Module]
338) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
339    """Gets the following weight equalization observer. There should always
340    exist a weight equalization observer after an input equalization observer.
341
342    Returns the operation node that follows the input equalization observer node
343    and the weight equalization observer
344    """
345
346    # Find the op node that comes directly after the input equalization observer
347    op_node = None
348    for user in input_eq_obs_node.users.keys():
349        if node_supports_equalization(user, modules):
350            op_node = user
351            break
352
353    assert op_node is not None
354    if op_node.op == "call_module":
355        # If the op_node is a nn.Linear layer, then it must have a
356        # WeightEqualizationObserver configuration
357        maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(
358            model, "equalization_node_name_to_qconfig"
359        )
360        assert maybe_equalization_node_name_to_config is not None
361        equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config  # type: ignore[assignment]
362        assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
363        weight_eq_obs = equalization_node_name_to_qconfig.get(
364            op_node.name, None
365        ).weight()
366
367        assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
368        return op_node, weight_eq_obs
369
370    elif op_node.op == "call_function":
371        weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
372        if weight_node is not None:
373            weight_eq_obs = modules[str(weight_node.target)]
374            assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
375            return op_node, weight_eq_obs
376
377    return None, None
378
379
380def maybe_get_weight_eq_obs_node(
381    op_node: Node, modules: Dict[str, nn.Module]
382) -> Optional[Node]:
383    """Gets the weight equalization observer node if it exists."""
384    assert op_node.op == "call_function"
385    for node_arg in op_node.args:
386        if node_arg_is_weight(op_node, node_arg):
387            assert (
388                isinstance(node_arg, Node)
389                and node_arg.op == "call_module"
390                and isinstance(
391                    modules[str(node_arg.target)], _WeightEqualizationObserver
392                )
393            )
394            return node_arg
395    return None
396
397
398def maybe_get_next_input_eq_obs(
399    node: Node, modules: Dict[str, nn.Module]
400) -> Optional[_InputEqualizationObserver]:
401    """Gets the following input equalization observer if it exists.
402
403    For example, in the case of connecting linear layers:
404        x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
405    If the node being passed in is the linear1 node, then we want to return eq_obs2,
406    the following equalization observer for linear2.
407
408    However, if there are no connecting layers:
409        x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
410    Then we want to return None.
411
412    In the case of an unfused linear-relu layer with a connecting linear layer:
413        linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
414    Since it is unfused, we want to skip over the relu layer and return eq_obs2,
415    the following equalization observer for linear2.
416    """
417
418    assert node_supports_equalization(node, modules)
419
420    # Locate the following nn.ReLU or F.relu node if it exists
421    maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
422    if maybe_relu_node is None:
423        maybe_relu_node = maybe_get_next_module(
424            node, modules, target_functional_type=F.relu
425        )
426
427    # Locate the following output observer if it exists.
428    # We will skip the relu node if it exists.
429    maybe_obs_node = (
430        maybe_get_next_module(node, modules, ObserverBase)
431        if maybe_relu_node is None
432        else maybe_get_next_module(maybe_relu_node, modules, ObserverBase)
433    )
434    if maybe_obs_node is None:
435        return None
436
437    maybe_eq_obs_node = maybe_get_next_module(
438        maybe_obs_node, modules, _InputEqualizationObserver
439    )
440    if maybe_eq_obs_node is None:
441        return None
442
443    maybe_eq_obs = modules[str(maybe_eq_obs_node)]
444    assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
445    return maybe_eq_obs
446
447
448def maybe_get_next_equalization_scale(
449    node: Node, modules: Dict[str, nn.Module]
450) -> Optional[torch.Tensor]:
451    """If the next next node is an InputEqualizationObserver then we want to
452    return its equalization scale, else we return 1
453
454    This is used in the case where there are two connecting linear layers:
455        linear1 -> LinearOutObs -> InputEqObs -> linear2
456    In this case, the node given is linear1 and we want to locate the InputEqObs.
457    """
458    next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
459    if next_inp_eq_obs:
460        if (
461            next_inp_eq_obs.equalization_scale.nelement() == 1
462            and next_inp_eq_obs.equalization_scale == torch.tensor(1)
463        ):
464            return None
465        return next_inp_eq_obs.equalization_scale
466    return None
467
468
469def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
470    """Scales the following input quantization observer's min/max values by
471    updating the values with the scaled min/max values calculated by the input
472    equalization observer
473    """
474    input_eq_obs = modules[str(node.target)]
475    assert isinstance(input_eq_obs, _InputEqualizationObserver)
476
477    input_quant_obs_node = node.args[0]
478    assert isinstance(input_quant_obs_node, Node)
479
480    input_quant_obs = modules[str(input_quant_obs_node.target)]
481    if not isinstance(input_quant_obs, ObserverBase):
482        return
483
484    min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
485    if min_input_scaled is None and max_input_scaled is None:
486        return
487    input_quant_obs.min_val = min_input_scaled
488    input_quant_obs.max_val = max_input_scaled
489
490
491def scale_weight_node(
492    node: Node,
493    modules: Dict[str, nn.Module],
494    equalization_scale: torch.Tensor,
495    next_equalization_scale: Optional[torch.Tensor],
496) -> None:
497    """Scale the weights for input-weight equalization by multiplying the
498    weight by 1/equalization_scale and next_equalization_scale
499
500    Args:
501        node: Current node whose weights we want to scale
502        equalization_scale: Current node's calculated equalization scale
503        next_equalization_scale: Next node's calculated equalization scale if
504           the following node needs to be equalized, 1 otherwise
505    """
506    if equalization_scale is None:
507        return
508
509    if fused_module_supports_equalization(modules[str(node.target)]):
510        op_module = modules[str(node.target)][0]  # type: ignore[index]
511    else:
512        op_module = modules[str(node.target)]
513    assert nn_module_supports_equalization(
514        op_module
515    ) or custom_module_supports_equalization(op_module)
516
517    # Scale the weights for input-weight equalization
518    # If the following layer needs to be equalized then we will multiply its scale
519    weight = op_module.weight
520    assert isinstance(weight, torch.Tensor)
521
522    # Scale the weights by the reciprocal of the equalization scale
523    # Reshape the equalization scale so that we can multiply it to the weight along axis=1
524    equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
525    scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
526
527    if next_equalization_scale is None:
528        op_module.weight = nn.Parameter(scaled_weight)
529        return
530
531    # Multiply the weights row wise by the next equalization scale
532    # Reshape the equalization scale so that we can multiply it to the weight along axis=0
533    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight)
534    scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
535
536    op_module.weight = nn.Parameter(scaled_weight)
537
538    # Multiply the bias element wise by the next equalization scale
539    bias = op_module.bias
540    if bias is None:
541        return
542    assert isinstance(bias, torch.Tensor)
543
544    # Reshape the equalization scale so that we can multiply it element-wise to the bias
545    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
546    scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
547    op_module.bias = nn.Parameter(scaled_bias)
548
549
550def scale_weight_functional(
551    op_node: Node,
552    model: GraphModule,
553    modules: Dict[str, nn.Module],
554    equalization_scale: torch.Tensor,
555    next_equalization_scale: Optional[torch.Tensor],
556) -> None:
557    """Scales the weight value for functional layers"""
558    if equalization_scale is None:
559        return
560
561    # From the given op_node, the path looks like:
562    #   get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
563    # So we want to trace back from the op_node to get the equalization observer
564    # node, then the quantization observer node, and then finally the weight
565    # node which contains the weight values.
566
567    # Get the equalization observer node
568    weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
569    if weight_eq_obs_node is None:
570        return
571
572    # Get the quantization observer node
573    weight_quant_obs_node = weight_eq_obs_node.args[0]
574    if weight_quant_obs_node is None:
575        return
576    assert isinstance(weight_quant_obs_node, Node) and isinstance(
577        modules[str(weight_quant_obs_node.target)], ObserverBase
578    )
579
580    # Get the get_attr(weight) node
581    weight_node = weight_quant_obs_node.args[0]
582    if weight_node is None:
583        return
584    assert isinstance(weight_node, Node) and weight_node.op == "get_attr"
585
586    weight_parent_name, weight_name = _parent_name(weight_node.target)
587    weight = getattr(modules[weight_parent_name], weight_name)
588
589    # Scale the weights for input-weight equalization
590    # If the following layer needs to be equalized then we will multiply its scale
591    # Reshape the equalization scale so that we can multiply it to the weight along axis=1
592    equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
593    scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
594
595    if next_equalization_scale is None:
596        setattr(modules[weight_parent_name], weight_name, scaled_weight)
597        return
598
599    # Multiply the weights row wise by the next equalization scale
600    # Reshape the equalization scale so that we can multiply it to the weight along axis=1
601    next_equalization_scale_reshaped = reshape_scale(
602        next_equalization_scale, 0, scaled_weight
603    )
604    scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
605
606    setattr(modules[weight_parent_name], weight_name, scaled_weight)
607    assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
608
609    # Multiply the bias element wise by the next equalization scale
610    bias_node = None
611    for node in op_node.args:
612        # Find the node containing the weight values
613        if isinstance(node, Node) and node.op == "get_attr" and "bias" in node.name:
614            bias_node = node
615            break
616    if bias_node is None:
617        return
618
619    bias_parent_name, bias_name = _parent_name(bias_node.target)
620    bias = getattr(modules[bias_parent_name], bias_name)
621
622    # Reshape the equalization scale so that we can multiply it element-wise to the bias
623    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
624    scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
625    setattr(modules[bias_parent_name], bias_name, scaled_bias)
626
627
628def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> None:
629    """Given the operation node, we want find the corresponding quantization
630    observer and reset its min/max values
631    """
632    weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
633    if weight_eq_obs_node is None:
634        return
635
636    weight_quant_obs_node = weight_eq_obs_node.args[0]
637    if weight_quant_obs_node is None:
638        return
639    assert isinstance(weight_quant_obs_node, Node)
640
641    weight_quant_obs = modules[str(weight_quant_obs_node.target)]
642    assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
643    weight_quant_obs.reset_min_max_vals()  # type: ignore[operator]
644
645
646def remove_node(model: GraphModule, node: Node, prev_node: Node):
647    """Removes the given node from the model by replacing all of its users with
648    the given previous node
649    """
650    # For all of the current node's users, replace the current node with
651    # the input quantization observer node
652    orig_users = list(node.users.keys())
653    for user_node in orig_users:
654        user_node.replace_input_with(node, prev_node)
655
656    # Erase the InputEqualizationObserver node
657    model.graph.erase_node(node)
658
659
660def update_obs_for_equalization(
661    model: GraphModule, modules: Dict[str, nn.Module]
662) -> Dict[str, _WeightEqualizationObserver]:
663    """Update all of the observer's equalization scale. For each
664    InputEqualizationObserver, we will find the location of the next
665    WeightEqualizationObserver, create it, and calculate the equalization scale
666    based on the two observers.
667
668    We will then return a dictionary mapping operation node names to
669    the corresponding WeightEqualizationObservers for that operation.
670    """
671    weight_eq_obs_dict = {}
672    for node in model.graph.nodes:
673        if node.op == "call_module" and isinstance(
674            modules[node.target], _InputEqualizationObserver
675        ):
676            input_eq_obs = modules[node.target]
677            assert isinstance(input_eq_obs, _InputEqualizationObserver)
678            op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
679
680            if op_node is None or weight_eq_obs is None:
681                continue
682
683            if op_node.op == "call_module":
684                # Calibrate the weight equalization observer since it has just
685                # been created
686                if fused_module_supports_equalization(modules[str(op_node.target)]):
687                    module = modules[str(op_node.target)][0]  # type: ignore[index]
688                    assert nn_module_supports_equalization(module)
689                    weight_eq_obs(module.weight)
690                else:
691                    weight_eq_obs(modules[str(op_node.target)].weight)
692
693            # Calculate and set the equalization scale values
694            equalization_scale = calculate_equalization_scale(
695                input_eq_obs, weight_eq_obs
696            )
697            input_eq_obs.set_equalization_scale(equalization_scale)
698            weight_eq_obs.set_equalization_scale(equalization_scale)
699
700            weight_eq_obs_dict[op_node.name] = weight_eq_obs
701
702    return weight_eq_obs_dict
703
704
705def convert_eq_obs(
706    model: GraphModule,
707    modules: Dict[str, nn.Module],
708    weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
709) -> None:
710    """Converts the equalization operations and updates the other nodes in the
711    following way:
712        - Removes the input equalization observers and inserts a mul operator
713          along with an equalization scale node wherever applicable (we do not
714          want to insert a mul operator between connecting linear layers).
715        - Updates the input quantization observers with the scaled input min/max
716          values.
717        - Scales the weights by the current and next equalization scales.
718        - Removes the weight equalization observer node if it exists.
719
720    Before (after prepare):
721                                    weight values
722                                          |
723                                    WeightQuantObs
724                                          |
725                                      WeightEqObs
726                                          |
727        x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
728
729    After this function:
730                                              scaled weight values
731                                                      |
732       equalization scale                       WeightQuantObs
733              |                                       |
734        x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
735
736    After convert:
737       equalization scale                 scaled weight values
738              |                                    |
739        x -> mul -> quantize_per_tensor -> quantized::linear
740
741    Note that although the equalization observer appeared after the quantization
742    observer after prepare_fx, the mul node appears before the quantization node
743    after convert_fx. This is because placing the equalization observer after
744    the quantization observer in prepare_fx would allow us to keep the invariant
745    that the graph before the current node inserts its observers is not
746    modified.
747
748    Having the equalization observer before the quantization observer would also
749    cause some inconsistences between the ordering of the quantization and
750    equalization observers.
751    For example, a single linear layer would look like:
752        x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
753    But between two connected linear layers, it would look like:
754        linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
755    """
756    for node in model.graph.nodes:
757        if node.op == "call_module" and isinstance(
758            modules[node.target], _InputEqualizationObserver
759        ):
760            inp_quant_obs_node = node.args[0]
761            prev_node = inp_quant_obs_node.args[0]
762
763            # If the previous node is a layer that needs to be equalized, then
764            # we will remove the current node because we do not need to add any
765            # equalization nodes between two layers that need to be equalized
766
767            # Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2
768            # After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2
769            if (
770                node_supports_equalization(prev_node, modules)
771                or "relu" in prev_node.name
772            ):
773                remove_node(model, node, inp_quant_obs_node)
774                continue
775
776            # Update the following input quantization observer's min/max values
777            scale_input_observer(node, modules)
778
779            # Remove the InputEqualization node and add a mul operator before
780            # the quantization observer node that appears before the equalization node
781            # Before: x -> input_quant_obs -> input_eq_obs -> linear
782            # After: x -> mul -> input_quant_obs -> linear
783
784            # Create a node containing the equalization scale
785            with model.graph.inserting_before(inp_quant_obs_node):
786                get_new_eq_scale_name = get_new_attr_name_with_prefix(
787                    prev_node.name + "_equalization_scale"
788                )
789                name = get_new_eq_scale_name(modules)
790                setattr(model, name, modules[node.target].equalization_scale)
791                eq_scale_node = model.graph.create_node("get_attr", name)
792
793            # Create a node multiplying the input with the equalization scale
794            with model.graph.inserting_after(eq_scale_node):
795                inputs = (prev_node, eq_scale_node)
796                mul_node = model.graph.create_node("call_function", torch.mul, inputs)
797
798            # Set the mul nod to be the input_quant_obs_node's input instead of
799            # the previous node
800            inp_quant_obs_node.replace_input_with(prev_node, mul_node)
801            remove_node(model, node, inp_quant_obs_node)
802
803        elif weight_eq_obs_dict.get(node.name, None) is not None:
804            weight_eq_obs = weight_eq_obs_dict.get(node.name)
805            assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
806            equalization_scale = weight_eq_obs.equalization_scale
807
808            if (
809                equalization_scale.nelement() == 1
810                and equalization_scale == torch.tensor(1)
811            ):
812                equalization_scale = None  # type: ignore[assignment]
813            maybe_next_equalization_scale = maybe_get_next_equalization_scale(
814                node, modules
815            )
816
817            # Scale the weight nodes
818            if node.op == "call_module":
819                scale_weight_node(
820                    node, modules, equalization_scale, maybe_next_equalization_scale
821                )
822            elif node.op == "call_function":
823                scale_weight_functional(
824                    node,
825                    model,
826                    modules,
827                    equalization_scale,
828                    maybe_next_equalization_scale,
829                )
830
831                weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
832                if weight_eq_obs_node is None:
833                    return
834                assert isinstance(
835                    modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver
836                )
837
838                # Clear the quantization observer's min/max values so that they
839                # can get updated later based on the new scale values
840                clear_weight_quant_obs_node(node, modules)
841
842                # Erase the weight equalization observer node
843                prev_node = weight_eq_obs_node.args[0]
844                remove_node(model, weight_eq_obs_node, prev_node)
845            else:
846                raise ValueError(
847                    "Expected operation node to be 'call_module' or 'call_function"
848                    + f"Instead got node {node.name} as '{node.op}'."
849                )
850
851
852def _convert_equalization_ref(model: GraphModule):
853    """Reference function which applies changes needed for equalization, but
854    does not quantize the nodes
855    """
856    modules = dict(model.named_modules(remove_duplicate=False))
857
858    # Calculate the equalization scale, update the observers with the scaled
859    # inputs, and scale the weight
860    weight_eq_obs_dict = update_obs_for_equalization(model, modules)
861    convert_eq_obs(model, modules, weight_eq_obs_dict)
862
863    return GraphModule(model, model.graph)
864
865
866###############################################################################
867# Functions for running the equalized model on the Numeric Suite              #
868###############################################################################
869
870
871def get_layer_sqnr_dict(
872    model_a: nn.Module, model_b: nn.Module, x: torch.Tensor
873) -> Dict[str, float]:
874    """Runs the Numeric Suite on model_a and model_b and returns a dictionary
875    containing the SQNR between layers in model_a and model_b.
876
877    Note: In order to support equalized models, this function has a hacky fix in
878    which we do not match any torch.mul operators. This is because equalized
879    models contain extra mul operators to scale the input by the equalization
880    scale, but this edge case has not been resolved yet within the numeric suite code.
881
882    Args:
883        model_a: A float model
884        model_b: A quantized model
885        x: Inputs to use during calibration
886    """
887    import torch.ao.ns._numeric_suite_fx as ns
888    from torch.ao.ns.fx.mappings import get_unmatchable_types_map
889
890    unmatchable_types_map = get_unmatchable_types_map()
891    unmatchable_types_map["funs_unmatchable"].add(torch.mul)
892
893    model_a_ns, model_b_ns = ns.add_loggers(
894        "fp32",
895        model_a,
896        "int8",
897        model_b,
898        ns.OutputLogger,
899        unmatchable_types_map=unmatchable_types_map,
900    )
901
902    model_a_ns(x)
903    model_b_ns(x)
904
905    activation_comparison_dict = ns.extract_logger_info(
906        model_a_ns, model_b_ns, ns.OutputLogger, "int8"
907    )
908    ns.extend_logger_results_with_comparison(
909        activation_comparison_dict,
910        "fp32",
911        "int8",
912        torch.ao.ns.fx.utils.compute_sqnr,
913        "sqnr",
914    )
915
916    # Construct a dictionary mapping layer names to the SQNR values
917    layer_sqnr_dict = {}
918    for key in activation_comparison_dict:
919        layer = activation_comparison_dict[key]["node_output"]["int8"][0]["fqn"]
920        sqnr = activation_comparison_dict[key]["node_output"]["int8"][0]["sqnr"][0]
921        layer_sqnr_dict[layer] = sqnr
922
923    return layer_sqnr_dict
924
925
926def get_equalization_qconfig_dict(
927    layer_sqnr_dict: Dict[str, float], num_layers_to_equalize: int
928) -> Any:
929    """Given the layer to SQNR dictionary, find the layers with the highest
930    quantization errors, and return an equalization_qconfig_dict
931    specifying to only equalize those top layers.
932
933    Args:
934        layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found
935            when comparing an equalized model against a float model)
936        num_layers_to_equalize: Number of layers with the highest quantization
937           errors to equalize
938    """
939
940    # Sort the layer_sqnr_dictionary values and get the layers with the lowest
941    # SQNR values (aka highest quantization errors)
942    layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1))
943    layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize]
944
945    # Constructs an equalization_qconfig_dict that specifies to only equalize
946    # the layers with the highest quantization errors
947    module_to_qconfig_list = [
948        (item[0], default_equalization_qconfig) for item in layers_to_equalize
949    ]
950    equalization_qconfig_dict = {"module_name": module_to_qconfig_list}
951    return equalization_qconfig_dict
952