xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3Utils shared by different modes of quantization (eager/graph)
4"""
5import functools
6import warnings
7from collections import OrderedDict
8from inspect import getfullargspec, signature
9from typing import Any, Callable, Dict, Optional, Tuple, Union
10
11import torch
12from torch.ao.quantization.quant_type import QuantType
13from torch.fx import Node
14from torch.nn.utils.parametrize import is_parametrized
15
16
17NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
18NodePattern.__module__ = "torch.ao.quantization.utils"
19
20# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
21# Define separately to prevent circular imports.
22# TODO(future PR): improve this.
23# make this public once fixed (can't be public as is because setting the module directly
24# doesn't work)
25QuantizerCls = Any
26
27# Type for fusion patterns, it can be more complicated than the following actually,
28# see pattern.md for docs
29# TODO: not sure if typing supports recursive data types
30Pattern = Union[
31    Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any
32]
33Pattern.__module__ = "torch.ao.quantization.utils"
34
35
36# TODO: maybe rename this to MatchInputNode
37class MatchAllNode:
38    """A node pattern that matches all nodes, used in defining
39    fusion patterns in FX Graph Mode Quantization
40    """
41
42
43module_type_list = {
44    torch.nn.ReLU,
45    torch.nn.ReLU6,
46    torch.nn.AdaptiveAvgPool1d,
47    torch.nn.AdaptiveAvgPool2d,
48    torch.nn.AdaptiveAvgPool3d,
49    torch.nn.AvgPool1d,
50    torch.nn.AvgPool2d,
51    torch.nn.AvgPool3d,
52    torch.nn.MaxPool1d,
53    torch.nn.MaxPool2d,
54    torch.nn.MaxPool3d,
55    torch.nn.Identity,
56    torch.nn.Hardsigmoid,
57    torch.nn.Sigmoid,
58    torch.nn.Tanh,
59}
60func_list = {
61    torch.nn.functional.adaptive_avg_pool1d,
62    torch.nn.functional.adaptive_avg_pool2d,
63    torch.nn.functional.adaptive_avg_pool3d,
64    torch.nn.functional.elu,
65    torch.nn.functional.hardswish,
66    torch.nn.functional.instance_norm,
67    torch.nn.functional.layer_norm,
68    torch.nn.functional.leaky_relu,
69    torch.nn.functional.silu,
70    torch.nn.functional.mish,
71    torch.nn.functional.dropout,
72    torch.nn.functional.max_pool1d,
73    torch.nn.functional.max_pool2d,
74    torch.nn.functional.max_pool3d,
75    torch.nn.functional.relu,
76    torch.nn.functional.hardtanh,
77    torch.nn.functional.hardtanh_,
78    torch.nn.functional.hardsigmoid,
79    torch.nn.functional.sigmoid,
80    torch.transpose,
81    torch.repeat_interleave,
82    torch.sigmoid,
83    torch.squeeze,
84    torch.stack,
85    torch.sum,
86    torch.tanh,
87    torch.unsqueeze,
88    torch.cat,
89}
90method_list = {
91    torch.mean,
92    "relu",
93    "relu_",
94    "contiguous",
95    "detach",
96    "detach_",
97    "hardsigmoid",
98    "hardsigmoid_",
99    "permute",
100    "repeat",
101    "repeat_interleave",
102    "reshape",
103    "resize_",
104    "shape",
105    "sigmoid",
106    "sigmoid_",
107    "size",
108    "squeeze",
109    "squeeze_",
110    "tanh",
111    "tanh_",
112    "transpose",
113    "unsqueeze",
114    "unsqueeze_",
115    "view",
116}
117
118
119# TODO: not used now, remove
120def check_node(node, modules):
121    # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
122    is_call_function = node.op == "call_function" and node.target in func_list
123    is_call_method = node.op == "call_method" and node.target in method_list
124    is_call_module = (
125        node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
126    )
127    return is_call_function, is_call_method, is_call_module
128
129
130def get_combined_dict(default_dict, additional_dict):
131    """
132    Combines two dictionaries.
133
134    This function takes two dictionaries as input and returns a new dictionary
135    that contains all the key-value pairs from both input dictionaries.
136    If there are any duplicate keys in the `additional_dict`, the values
137    from the `additional_dict` will overwrite those in the `default_dict`.
138    Args:
139        default_dict (dict): The main dictionary that will be used as the base
140        additional_dict (dict): The dictionary used to update `default_dict`
141
142    Returns:
143        dict: The resulting dictionary
144    Example:
145        >>> x = dict(a=1, b=1)
146        >>> y = dict(b=2, c=3)
147        >>> get_combined_dict(x, y)
148        {'a': 1, 'b': 2, 'c': 3}
149    """
150    d = default_dict.copy()
151    d.update(additional_dict)
152    return d
153
154
155def is_per_tensor(qscheme):
156    return qscheme == torch.per_tensor_affine or qscheme == torch.per_tensor_symmetric
157
158
159def is_per_channel(qscheme):
160    return qscheme in [
161        torch.per_channel_affine,
162        torch.per_channel_affine_float_qparams,
163        torch.per_channel_symmetric,
164    ]
165
166
167def getattr_from_fqn(obj: Any, fqn: str) -> Any:
168    """
169    Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
170    """
171    return functools.reduce(getattr, fqn.split("."), obj)
172
173
174def to_underlying_dtype(qdtype):
175    DTYPE_MAPPING = {
176        torch.quint8: torch.uint8,
177        torch.qint8: torch.int8,
178        torch.qint32: torch.int32,
179        torch.quint4x2: torch.uint8,
180        torch.quint2x4: torch.uint8,
181        torch.uint8: torch.uint8,
182        torch.int8: torch.int8,
183        torch.int16: torch.int16,
184        torch.int32: torch.int32,
185        torch.float8_e5m2: torch.float8_e5m2,
186        torch.float8_e4m3fn: torch.float8_e4m3fn,
187    }
188    assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
189    return DTYPE_MAPPING[qdtype]
190
191
192def get_qparam_dict(observer_or_fake_quant):
193    from torch.ao.quantization.observer import PlaceholderObserver
194
195    qscheme = getattr(observer_or_fake_quant, "qscheme", None)
196    dtype = observer_or_fake_quant.dtype
197    qparams = {"qscheme": qscheme, "dtype": dtype}
198
199    if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver):
200        return {"qscheme": None, "dtype": dtype}
201
202    if is_per_tensor(qscheme):
203        qscheme = torch.per_tensor_affine
204    elif is_per_channel(qscheme):
205        # change symmetric to affine since we do not have symmetric
206        # quantized Tensor
207        if qscheme == torch.per_channel_symmetric:
208            qscheme = torch.per_channel_affine
209        qparams["axis"] = observer_or_fake_quant.ch_axis
210    else:
211        raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
212    # update qscheme, since we don't have symmetric quant qscheme
213    # in quantized Tensor
214    qparams["qscheme"] = qscheme
215
216    scale, zero_point = observer_or_fake_quant.calculate_qparams()
217    qparams["scale"] = scale
218    qparams["zero_point"] = zero_point
219
220    if hasattr(observer_or_fake_quant, "quant_min"):
221        qparams["quant_min"] = observer_or_fake_quant.quant_min
222    if hasattr(observer_or_fake_quant, "quant_max"):
223        qparams["quant_max"] = observer_or_fake_quant.quant_max
224
225    return qparams
226
227
228def get_swapped_custom_module_class(
229    custom_module, custom_module_class_mapping, qconfig
230):
231    """Get the observed/quantized custom module class that we need
232    to swap `custom_module` to
233    Input:
234        custom_module: input, can be an instance of either a float or observed custom module
235        custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
236        qconfig: qconfig configured for the custom module
237
238    Output:
239        corresponding observed/quantized custom module class for input custom module instance
240    """
241    quant_type = get_quant_type(qconfig)
242    class_mapping = custom_module_class_mapping.get(quant_type, {})
243    assert type(custom_module) in class_mapping, (
244        "did not find corresponding observed "
245        f"module class for {type(custom_module)} in mapping: {class_mapping}"
246    )
247    return class_mapping[type(custom_module)]
248
249
250def activation_dtype(qconfig):
251    assert qconfig is not None
252    activation = qconfig.activation()
253    return activation.dtype
254
255
256def weight_dtype(qconfig):
257    assert qconfig is not None
258    weight = qconfig.weight()
259    return weight.dtype
260
261
262def activation_is_statically_quantized(qconfig):
263    """Given a qconfig, decide if the activation needs to be
264    quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16
265    """
266    return activation_dtype(qconfig) in [
267        torch.quint8,
268        torch.qint8,
269        torch.qint32,
270        torch.float16,
271        torch.uint8,
272        torch.int8,
273        torch.int16,
274        torch.int32,
275        torch.float8_e5m2,
276        torch.float8_e4m3fn,
277    ] and (not activation_is_dynamically_quantized(qconfig))
278
279
280def activation_is_dynamically_quantized(qconfig):
281    """Given a qconfig, decide if the activation needs to be
282    dynamically quantized or not, this includes dynamically quantizing to
283    quint8, qint8 and float16
284    """
285    activation_dtype, _, activation_is_dynamic = get_qconfig_dtypes(qconfig)
286    return activation_is_dynamic
287
288
289def activation_is_int8_quantized(qconfig):
290    """Given a qconfig, decide if the activation needs to be
291    quantized to int8 or not, this includes quantizing to quint8, qint8
292    """
293    return activation_dtype(qconfig) in [
294        torch.quint8,
295        torch.qint8,
296        torch.uint8,
297        torch.int8,
298    ]
299
300
301def activation_is_int32_quantized(qconfig):
302    """Given a qconfig, decide if the activation needs to be
303    quantized to int32 or not
304    """
305    return activation_dtype(qconfig) in [torch.qint32, torch.int32]
306
307
308def weight_is_quantized(qconfig):
309    """Given a qconfig, decide if the weight needs to be
310    quantized or not
311    """
312    return weight_dtype(qconfig) in [
313        torch.quint8,
314        torch.qint8,
315        torch.float16,
316        torch.quint4x2,
317        torch.uint8,
318        torch.int8,
319        torch.int16,
320        torch.int32,
321        torch.float8_e5m2,
322        torch.float8_e4m3fn,
323    ]
324
325
326def weight_is_statically_quantized(qconfig):
327    """Given a qconfig, decide if the weight needs to be statically
328    quantized or not
329    """
330    return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8]
331
332
333def op_is_int8_dynamically_quantized(qconfig) -> bool:
334    """Given a qconfig, returns True if this op is using int8 dynamic
335    quantization
336    """
337    activation_dtype, weight_dtype, activation_is_dynamic = get_qconfig_dtypes(qconfig)
338    return (
339        activation_dtype in [torch.quint8, torch.uint8]
340        and
341        # for now, the lines below assume fbgemm or qnnpack
342        weight_dtype in [torch.qint8, torch.int8]
343        and activation_is_dynamic
344    )
345
346
347def get_qconfig_dtypes(qconfig):
348    r"""returns the qconfig tuple for qconfig:
349    (activation_dtype, weight_dtype, activation_is_dynamic)
350    """
351    assert qconfig is not None
352    activation = qconfig.activation()
353    weight = qconfig.weight()
354    act_is_dynamic = getattr(activation, "is_dynamic", False)
355    return (activation.dtype, weight.dtype, act_is_dynamic)
356
357
358def get_quant_type(qconfig):
359    assert qconfig is not None
360    activation = qconfig.activation()
361    weight = qconfig.weight()
362    static_dtypes = [
363        torch.quint8,
364        torch.qint8,
365        torch.quint4x2,
366        torch.qint32,
367        torch.uint8,
368        torch.int8,
369        torch.int16,
370        torch.int32,
371        torch.float8_e5m2,
372        torch.float8_e4m3fn,
373    ]
374    if weight.dtype in static_dtypes:
375        if hasattr(activation, "is_dynamic") and activation.is_dynamic:
376            return QuantType.DYNAMIC
377        elif activation.dtype in static_dtypes:
378            return QuantType.STATIC
379        else:
380            return QuantType.WEIGHT_ONLY
381
382    if weight.dtype == torch.float16:
383        if hasattr(activation, "is_dynamic") and activation.is_dynamic:
384            return QuantType.DYNAMIC
385        elif activation.dtype == torch.float16:
386            return QuantType.STATIC
387
388    raise Exception(  # noqa: TRY002
389        f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype}),"
390        f"weight({weight.dtype})"
391    )
392
393
394def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
395    """Checks if the given minimum and maximum values are valid, meaning that
396    they exist and the min value is less than the max value.
397    """
398    if min_val.numel() == 0 or max_val.numel() == 0:
399        warnings.warn(
400            "must run observer before calling calculate_qparams. "
401            + "Returning default values."
402        )
403        return False
404
405    if min_val.dim() == 0 or max_val.dim() == 0:
406        if min_val == float("inf") and max_val == float("-inf"):
407            warnings.warn(
408                "must run observer before calling calculate_qparams. "
409                + "Returning default values."
410            )
411
412            return False
413
414        assert min_val <= max_val, f"min {min_val} should be less than max {max_val}"
415    else:
416        assert torch.all(
417            min_val <= max_val
418        ), f"min {min_val} should be less than max {max_val}"
419
420    return True
421
422
423def calculate_qmin_qmax(
424    quant_min: int,
425    quant_max: int,
426    has_customized_qrange: bool,
427    dtype: torch.dtype,
428    reduce_range: bool,
429) -> Tuple[int, int]:
430    r"""Calculates actual qmin and qmax based on the quantization range,
431    observer datatype and if range is reduced.
432    """
433    # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted.
434    if has_customized_qrange:
435        # This initialization here is to be resolve TorchScript compilation issues and allow
436        # using of refinement to decouple initial_qmin and initial_qmax from quantization range.
437        # The actual values of initial_qmin and initial_qmax will be reset below.
438        if dtype in [torch.qint32, torch.int32]:
439            initial_quant_min, initial_quant_max = 0, 2**32 - 1
440        else:
441            initial_quant_min, initial_quant_max = 0, 255
442        # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
443        # attribute from Optional valid integers for use, based on TorchScript's requirements.
444        custom_quant_min, custom_quant_max = quant_min, quant_max
445        if custom_quant_min is not None and custom_quant_max is not None:
446            initial_quant_min, initial_quant_max = (
447                custom_quant_min,
448                custom_quant_max,
449            )
450
451        qrange_len = initial_quant_max - initial_quant_min + 1
452        if dtype in [torch.qint8, torch.int8]:
453            assert (
454                0 < qrange_len <= 256
455            ), "quantization range should be positive and not exceed the maximum bit range (=256)."
456        elif dtype in [torch.qint32, torch.int32]:
457            assert (
458                0 < qrange_len <= 2**32
459            ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
460        if reduce_range:
461            quant_min, quant_max = quant_min // 2, quant_max // 2
462    else:
463        # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
464        if dtype in [torch.qint8, torch.int8]:
465            if reduce_range:
466                quant_min, quant_max = -64, 63
467            else:
468                quant_min, quant_max = -128, 127
469        elif dtype in [torch.quint8, torch.uint8]:
470            if reduce_range:
471                quant_min, quant_max = 0, 127
472            else:
473                quant_min, quant_max = 0, 255
474        elif dtype in [torch.qint32, torch.int32]:
475            quant_min, quant_max = -1 * (2**31), (2**31) - 1
476        else:
477            quant_min, quant_max = 0, 15
478    return quant_min, quant_max
479
480
481def _parent_name(target):
482    """
483    Turn 'foo.bar' into ['foo', 'bar']
484    """
485    r = target.rsplit(".", 1)
486    if len(r) == 1:
487        return "", r[0]
488    else:
489        return r[0], r[1]
490
491
492def has_no_children_ignoring_parametrizations(module):
493    """
494    Checks if module._modules is empty or
495    if module is a parametrization, checks that module._modules only has
496    the 'parametrizations' module
497    """
498    if len(module._modules) == 0:
499        return True
500    elif is_parametrized(module):
501        return len(module._modules) == 1 and "parametrizations" in module._modules
502    else:
503        return False
504
505
506def _get_path_of_module(
507    root: torch.nn.Module, submodule: torch.nn.Module
508) -> Optional[str]:
509    """Get the path (fully qualified name) of a submodule
510
511    Example::
512
513    >> class M(torch.nn.Module):
514           def __init__(self) -> None:
515               self.linear = torch.nn.Linear(5, 5)
516           def forward(self, x):
517               return self.linear(x)
518
519    >> m = M()
520    >> l = m.linear
521    >> _get_path_of_module(m, l)
522    "linear"
523    """
524    for n, p in root.named_modules():
525        if submodule is p:
526            return n
527    return None
528
529
530def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]:
531    """Get local keyword arguments
532
533    Example::
534
535    >> def f(self, a, b=9):
536           pass
537    >> loc = {"a": 6, "c": 7}
538    >> _get_signature_locals(f, loc)
539    {"a": 6}
540    """
541    return {k: v for k, v in loc.items() if k in signature(f).parameters}
542
543
544def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]":
545    """Get all default keyword arguments from function signature
546
547    Example::
548
549    >> def f(self, a, b=9):
550           pass
551    >> _get_default_kwargs(f)
552    {"b": 9}
553    """
554    kwargs = {}
555    for name, param in signature(f).parameters.items():
556        if param.default is not param.empty:
557            kwargs[name] = param.default
558        elif param.kind is param.VAR_POSITIONAL:
559            kwargs[name] = ()
560        elif param.kind is param.VAR_KEYWORD:
561            kwargs[name] = {}
562    return OrderedDict(kwargs)
563
564
565def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> "OrderedDict[str, Any]":
566    """Given a function and local function arguments, normalize the keyword
567    arguments by filling in default arguments from function signature
568
569    Example::
570
571    >> def f(self, key1=3, key2=3):
572           pass
573    >> loc = {"key2": 6}
574    >> _normalize_kwargs(f, loc)
575    {"key1": 3, "key2": 6}
576    """
577    default_kwargs = _get_default_kwargs(func)
578    local_kwargs = _get_signature_locals(func, loc)
579    normalized_kwargs = default_kwargs.copy()
580    for attr, val in local_kwargs.items():
581        if attr in normalized_kwargs:
582            # override the default keyword arguments
583            normalized_kwargs[attr] = val
584    return normalized_kwargs
585
586
587def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
588    r"""Validates that the user-specified quantization range is properly initialized
589    and within the given bound supported by the observer dtype.
590
591    To accommodate lower-bit quantization with respect to the existing torch.qint8 and
592    torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
593    in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
594    values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
595    fake quantization. These estimates are compared against parameters learned through backpropagation.
596    The related literatures for scale and zero point via backpropagation are as follows:
597
598    Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
599    Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
600    """
601    # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
602    # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
603    assert (
604        quant_min <= 0 <= quant_max
605    ), "Used-specified quantization range must include 0."
606    assert (
607        quant_min < quant_max
608    ), "qmin must be strictly less than qmax for user-specified quantization range."
609
610
611# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme
612# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
613# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change
614# (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168)
615def determine_qparams(
616    min_val: torch.Tensor,
617    max_val: torch.Tensor,
618    quant_min: int,
619    quant_max: int,
620    dtype: torch.dtype,
621    eps: torch.Tensor,
622    has_customized_qrange: bool,
623    qscheme: torch.qscheme = torch.per_tensor_affine,
624) -> Tuple[torch.Tensor, torch.Tensor]:
625    r"""Calculates the quantization parameters, given min and max
626    value tensors. Works for both per tensor and per channel cases
627
628    Args:
629        min_val: Minimum values per channel
630        max_val: Maximum values per channel
631
632    Returns:
633        scales: Scales tensor of shape (#channels,)
634        zero_points: Zero points tensor of shape (#channels,)
635    """
636    if not check_min_max_valid(min_val, max_val):
637        return torch.tensor([1.0], device=min_val.device.type), torch.tensor(
638            [0], device=min_val.device.type
639        )
640
641    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
642    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
643
644    device = min_val_neg.device
645    scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device)
646    zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
647
648    if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric:
649        max_val_pos = torch.max(-min_val_neg, max_val_pos)
650        scale = max_val_pos / (float(quant_max - quant_min) / 2)
651        scale = torch.max(scale, eps)
652        if dtype in [torch.uint8, torch.quint8]:
653            if has_customized_qrange:
654                # When customized quantization range is used, down-rounded midpoint of the range is chosen.
655                zero_point = zero_point.new_full(
656                    zero_point.size(), (quant_min + quant_max) // 2
657                )
658            else:
659                zero_point = zero_point.new_full(zero_point.size(), 128)
660    elif qscheme == torch.per_channel_affine_float_qparams:
661        scale = (max_val - min_val) / float(quant_max - quant_min)
662        scale = torch.where(scale > eps, scale, torch.ones_like(scale))
663        # We use the quantize function
664        # xq = Round(Xf * inv_scale + zero_point),
665        # setting zero_point to (-1 * min *inv_scale) we get
666        # Xq = Round((Xf - min) * inv_scale)
667        zero_point = -1 * min_val / scale
668    else:
669        scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
670        scale = torch.max(scale, eps)
671        zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
672        zero_point = torch.clamp(zero_point, quant_min, quant_max)
673
674    # For scalar values, cast them to Tensors of size 1 to keep the shape
675    # consistent with default values in FakeQuantize.
676    if len(scale.shape) == 0:
677        # TODO: switch to scale.item() after adding JIT support
678        scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
679    if len(zero_point.shape) == 0:
680        # TODO: switch to zero_point.item() after adding JIT support
681        zero_point = torch.tensor(
682            [int(zero_point)], dtype=zero_point.dtype, device=device
683        )
684        if qscheme == torch.per_channel_affine_float_qparams:
685            zero_point = torch.tensor(
686                [float(zero_point)], dtype=zero_point.dtype, device=device
687            )
688
689    return scale.to(torch.double), zero_point.to(torch.int64)
690
691
692def _get_num_pos_args(f: Callable) -> int:
693    """Get number of positional args for a function
694
695    Example::
696
697    >> def f(self, key1=3, key2=3):
698           pass
699    >> _get_num_pos_args(f)
700    3
701    """
702    return len(getfullargspec(f).args)
703
704
705def get_fqn_to_example_inputs(
706    model: torch.nn.Module, example_inputs: Tuple[Any, ...]
707) -> Dict[str, Tuple[Any, ...]]:
708    """Given a model and its example inputs, return a dictionary from
709    fully qualified name of submodules to example_inputs for that submodule,
710    e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,),
711          "sub.linear1": (tensor4,), ...}
712
713    Used to make quantizing submodules easier now that FX Graph Mode Quantization requires
714    example inputs.
715
716    Also works for keyword arguments with default values, we would flatten keyword
717    arguments as positional arguments and fill in the missing keyword args with default
718    values, e.g. if we have a forward function:
719    def forward(self, x, key1=3, key2=3):
720        ...
721
722    and we call it with self.submodule(x, key2=6)
723    we'll get example_inputs: (x, 3, 6)
724
725    user can also override `key1` with positional arguments as well:
726    for self.submodule(x, 5, key2=6)
727    we'll get: (x, 5, 6)
728
729    variable positional arguments and variable positional keyword arguments in forward
730    function are not supported currently, so please make sure no submodules is using
731    them.
732    """
733    root = model
734    fqn_to_example_inputs = {}
735
736    def _patched_module_call(self, *args, **kwargs):
737        submodule_example_inputs = list(args).copy()
738        normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
739        # minus 1 to skipping counting `self`
740        num_args = _get_num_pos_args(self.forward) - 1
741        num_to_pop = num_args - len(submodule_example_inputs)
742        while num_to_pop and normalized_kwargs:
743            normalized_kwargs.popitem(last=False)
744            num_to_pop -= 1
745        submodule_example_inputs.extend(normalized_kwargs.values())
746        submodule_example_inputs_tuple = tuple(submodule_example_inputs)
747        fqn = _get_path_of_module(root, self)
748        if fqn is not None:
749            fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
750        return orig_module_call(self, *args, **kwargs)
751
752    orig_module_call = torch.nn.Module.__call__
753    torch.nn.Module.__call__ = _patched_module_call  # type: ignore[method-assign]
754    try:
755        model(*example_inputs)
756    finally:
757        # restore the module call even if there is an exception
758        torch.nn.Module.__call__ = orig_module_call  # type: ignore[method-assign]
759    return fqn_to_example_inputs
760
761
762def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
763    """
764    Returns the unique device for a module, or None if no device is found.
765    Throws an error if multiple devices are detected.
766    """
767    devices = {p.device for p in module.parameters()} | {
768        p.device for p in module.buffers()
769    }
770    """
771    As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
772    """
773    if {torch.device("cpu"), torch.device("meta")} == devices:
774        warnings.warn(
775            "Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'."
776        )
777        devices = {torch.device("cpu")}
778    ""
779    assert len(devices) <= 1, (
780        "prepare only works with cpu or single-device CUDA modules, "
781        f"but got devices {devices}"
782    )
783    device = next(iter(devices)) if len(devices) > 0 else None
784    return device
785
786
787__all__ = [
788    "NodePattern",
789    "Pattern",
790    "MatchAllNode",
791    "check_node",
792    "get_combined_dict",
793    "is_per_tensor",
794    "is_per_channel",
795    "getattr_from_fqn",
796    "get_qparam_dict",
797    "get_swapped_custom_module_class",
798    "activation_dtype",
799    "weight_dtype",
800    "activation_is_statically_quantized",
801    "activation_is_dynamically_quantized",
802    "activation_is_int8_quantized",
803    "activation_is_int32_quantized",
804    "weight_is_quantized",
805    "weight_is_statically_quantized",
806    "op_is_int8_dynamically_quantized",
807    "get_qconfig_dtypes",
808    "get_quant_type",
809    "check_min_max_valid",
810    "calculate_qmin_qmax",
811    "has_no_children_ignoring_parametrizations",
812    "get_fqn_to_example_inputs",
813    "to_underlying_dtype",
814    "determine_qparams",
815    "validate_qmin_qmax",
816]
817