xref: /aosp_15_r20/external/executorch/backends/apple/mps/utils/quant_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2024 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import torch
7from executorch.exir.dialects._ops import ops as exir_ops
8
9DQ_GROUP_TARGETS = {
10    exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default,
11}
12
13Q_GROUP_TARGETS = {
14    exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default,
15}
16
17DQ_TARGETS = {
18    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
19    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
20    exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
21    exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
22}.union(DQ_GROUP_TARGETS)
23
24Q_TARGETS = {
25    exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
26    exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
27    exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
28    exir_ops.edge.quantized_decomposed.quantize_per_token.default,
29}.union(Q_GROUP_TARGETS)
30
31
32def is_quant(tensor: torch.fx.Node) -> bool:
33    return tensor.target in Q_TARGETS
34
35
36def is_dequant(tensor: torch.fx.Node) -> bool:
37    return tensor.target in DQ_TARGETS
38
39
40def is_groupwise_q_dq(tensor: torch.fx.Node) -> bool:
41    return tensor.target in [DQ_GROUP_TARGETS, Q_GROUP_TARGETS]
42