1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.exir.dialects._ops import ops as exir_ops 9from executorch.exir.pass_base import ExportPass 10from torch.fx import GraphModule, subgraph_rewriter 11from torch.fx.passes.infra.pass_base import PassResult 12from torch.utils import _pytree as pytree 13 14from ._quant_patterns_and_replacements import get_quant_patterns_and_replacements 15 16 17def _fuse_quantized_cat(model: GraphModule) -> None: 18 """fuse "dequantize -> cat -> quantize" pattern to cat operator, only happens if the quantization 19 parameters for dequantize for all the inputs matches, and it also matches the quantization 20 parameters for the quantize node after cat 21 """ 22 23 # get quantization parameters for the node, either for quantize or dequantize node 24 def _get_qparams(node): 25 assert node.target in ( 26 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 27 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 28 ) 29 args = list(node.args) 30 # skip input 31 qparams = args[1:] 32 return qparams 33 34 for n in model.graph.nodes: 35 if ( 36 n.op != "call_function" 37 or n.target 38 != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default 39 ): 40 41 continue 42 qnode = n 43 maybe_cat = qnode.args[0] 44 if ( 45 maybe_cat.op != "call_function" 46 or maybe_cat.target != exir_ops.edge.aten.cat.default 47 ): 48 49 continue 50 tensor_args = maybe_cat.args[0] 51 if not isinstance(tensor_args, (tuple, list)): 52 continue 53 54 matched_quantized_cat = True 55 output_qparams = _get_qparams(qnode) 56 for tensor_arg in tensor_args: 57 if ( 58 tensor_arg.op != "call_function" 59 or tensor_arg.target 60 != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default 61 ): 62 63 matched_quantized_cat = False 64 break 65 66 # make sure the input qparams for each input tensor in the concat list 67 # matches the output qparams 68 current_input_qparams = _get_qparams(tensor_arg) 69 if not current_input_qparams == output_qparams: 70 matched_quantized_cat = False 71 break 72 73 if not matched_quantized_cat: 74 continue 75 76 # now we matched a pattern for quantized cat, e.g. 77 # input1 (int8) -> dq1 -> cat -> q -> following_op 78 # input2 (int8) -> dq2 -/ 79 80 # remove dq for inputs and q for output and run cat on the int8 input directly 81 # input1 (int8) -> cat -> following_op 82 # input2 (int8) -/ 83 84 # reroute the input of dq to the cat node 85 for tensor_arg in tensor_args: 86 maybe_cat.replace_input_with(tensor_arg, tensor_arg.args[0]) 87 88 # remove q for output 89 qnode.replace_all_uses_with(maybe_cat) 90 model.graph.erase_node(qnode) 91 92 93class QuantFusionPass(ExportPass): 94 def __init__(self, _fix_node_meta_val=False): 95 super().__init__() 96 # TODO This pass violate IR spec because it produces a graph missing node.meta['val'] 97 self._fix_node_meta_val = _fix_node_meta_val 98 99 def call(self, graph_module: GraphModule) -> PassResult: 100 """Lower a quantized reference model (with reference quantized operator patterns) 101 to executorch backend, that has a canonical set of quantized operators. This pass 102 is a backend pass and should be applied on top of Edge dialect, ideally in 103 `ExecutorchBackendConfig.passes`. See `test_quant_fusion_pass.py` for an example. 104 """ 105 # linear, conv2d 106 # dynamic_linear 107 # add 108 # batchnorm2d, relu, adaptive_avg_pool2d, reshape, squeeze, permute 109 for ( 110 pattern, 111 replacement, 112 match_filters, 113 ) in get_quant_patterns_and_replacements(): 114 subgraph_rewriter.replace_pattern_with_filters( 115 graph_module, pattern, replacement, match_filters 116 ) 117 118 _fuse_quantized_cat(graph_module) 119 if self._fix_node_meta_val: 120 for n in graph_module.graph.nodes: 121 if n.op == "call_function" and "val" not in n.meta: 122 args, kwargs = pytree.tree_map_only( 123 torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs) 124 ) 125 n.meta["val"] = n.target(*args, **kwargs) 126 graph_module.graph.lint() 127 graph_module.graph.eliminate_dead_code() 128 return PassResult(graph_module, True) 129