1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10 11from executorch.exir.dialects._ops import ops as exir_ops 12from executorch.exir.pass_base import ExportPass, PassResult 13 14 15class FuseDequantLinearPass(ExportPass): 16 """ 17 Fuses weight dequantize_per_channel nodes with linear nodes into 18 weight_int8pack_mm nodes, for 8-bit weight-only quantization. 19 20 Replaces dq(weight) -> linear(activation, dq) with weight_int8pack_mm 21 Replaces dq(weight) -> linear(activation, dq, bias) with weight_int8pack_mm -> add 22 """ 23 24 def fuse_dequant_with_linear( 25 self, 26 graph_module: torch.fx.GraphModule, 27 dequant_node: torch.fx.Node, 28 linear_node: torch.fx.Node, 29 ) -> None: 30 activations = linear_node.args[0] 31 bias = None 32 if len(linear_node.args) > 2: 33 bias = linear_node.args[2] 34 quant_weight = dequant_node.args[0] 35 scale = dequant_node.args[1] 36 37 with graph_module.graph.inserting_before(linear_node): 38 weight_int8pack_mm_node = graph_module.graph.create_node( 39 "call_function", 40 exir_ops.edge.aten._weight_int8pack_mm.default, 41 (activations, quant_weight, scale), 42 ) 43 if bias: 44 add_node = graph_module.graph.create_node( 45 "call_function", 46 exir_ops.edge.aten.add.Tensor, 47 (weight_int8pack_mm_node, bias), 48 ) 49 linear_node.replace_all_uses_with(add_node) 50 else: 51 linear_node.replace_all_uses_with(weight_int8pack_mm_node) 52 graph_module.graph.erase_node(linear_node) 53 graph_module.graph.erase_node(dequant_node) 54 55 def is_node_target( 56 self, node: torch.fx.Node, target: torch._ops.OperatorBase 57 ) -> bool: 58 return node.op == "call_function" and node.target == target 59 60 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 61 for node in graph_module.graph.nodes: 62 if self.is_node_target(node, exir_ops.edge.aten.linear.default): 63 weight_node = node.args[1] 64 if self.is_node_target( 65 weight_node, 66 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 67 ): 68 # only fuse if weight tensor is int8 packed 69 quant_weight = weight_node.args[0] 70 if quant_weight.meta["val"].dtype != torch.int8: 71 continue 72 self.fuse_dequant_with_linear(graph_module, weight_node, node) 73 74 graph_module.recompile() 75 graph_module = super().call(graph_module).graph_module 76 77 return PassResult(graph_module, True) 78