1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import torch 10from executorch.exir.dialects._ops import ops as exir_ops 11from executorch.exir.pass_base import ExportPass, PassResult 12 13 14class TagIOQuantPass(ExportPass): 15 """ 16 Pass run before partitioning to tag Q/DQ on any placeholder and output 17 to ensure we don't greedily partition them for device. Float conversion 18 has to happen outside a TOSA base inference profile. 19 """ 20 21 def is_quant_node(self, node: torch.fx.node.Node): 22 return node.target in { 23 exir_ops.edge.quantized_decomposed.quantize_per_channel.default, 24 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 25 exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, 26 } 27 28 def is_dequant_node(self, node: torch.fx.node.Node): 29 return node.target in { 30 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 31 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 32 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, 33 } 34 35 def call(self, graph_module: torch.fx.GraphModule): 36 for node in graph_module.graph.nodes: 37 # tag q of input 38 if node.op == "placeholder": 39 for user in node.users.keys(): 40 # if we have an input going into a quantize 41 if self.is_quant_node(user): 42 user.meta["arm_override_partition"] = False 43 44 # tag dq of outputs 45 if node.op == "output": 46 for quant in node.args[0]: 47 if self.is_dequant_node(quant): 48 quant.meta["arm_override_partition"] = False 49 50 graph_module.recompile() 51 return PassResult(graph_module, True) 52