1# 2# Copyright (c) 2024 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import torch 7from executorch.exir.dialects._ops import ops as exir_ops 8 9DQ_GROUP_TARGETS = { 10 exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default, 11} 12 13Q_GROUP_TARGETS = { 14 exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default, 15} 16 17DQ_TARGETS = { 18 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 19 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, 20 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 21 exir_ops.edge.quantized_decomposed.dequantize_per_token.default, 22}.union(DQ_GROUP_TARGETS) 23 24Q_TARGETS = { 25 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 26 exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, 27 exir_ops.edge.quantized_decomposed.quantize_per_channel.default, 28 exir_ops.edge.quantized_decomposed.quantize_per_token.default, 29}.union(Q_GROUP_TARGETS) 30 31 32def is_quant(tensor: torch.fx.Node) -> bool: 33 return tensor.target in Q_TARGETS 34 35 36def is_dequant(tensor: torch.fx.Node) -> bool: 37 return tensor.target in DQ_TARGETS 38 39 40def is_groupwise_q_dq(tensor: torch.fx.Node) -> bool: 41 return tensor.target in [DQ_GROUP_TARGETS, Q_GROUP_TARGETS] 42