xref: /aosp_15_r20/external/executorch/backends/transforms/fuse_dequant_linear.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import torch
10
11from executorch.exir.dialects._ops import ops as exir_ops
12from executorch.exir.pass_base import ExportPass, PassResult
13
14
15class FuseDequantLinearPass(ExportPass):
16    """
17    Fuses weight dequantize_per_channel nodes with linear nodes into
18    weight_int8pack_mm nodes, for 8-bit weight-only quantization.
19
20    Replaces dq(weight) -> linear(activation, dq)       with weight_int8pack_mm
21    Replaces dq(weight) -> linear(activation, dq, bias) with weight_int8pack_mm -> add
22    """
23
24    def fuse_dequant_with_linear(
25        self,
26        graph_module: torch.fx.GraphModule,
27        dequant_node: torch.fx.Node,
28        linear_node: torch.fx.Node,
29    ) -> None:
30        activations = linear_node.args[0]
31        bias = None
32        if len(linear_node.args) > 2:
33            bias = linear_node.args[2]
34        quant_weight = dequant_node.args[0]
35        scale = dequant_node.args[1]
36
37        with graph_module.graph.inserting_before(linear_node):
38            weight_int8pack_mm_node = graph_module.graph.create_node(
39                "call_function",
40                exir_ops.edge.aten._weight_int8pack_mm.default,
41                (activations, quant_weight, scale),
42            )
43            if bias:
44                add_node = graph_module.graph.create_node(
45                    "call_function",
46                    exir_ops.edge.aten.add.Tensor,
47                    (weight_int8pack_mm_node, bias),
48                )
49                linear_node.replace_all_uses_with(add_node)
50            else:
51                linear_node.replace_all_uses_with(weight_int8pack_mm_node)
52            graph_module.graph.erase_node(linear_node)
53            graph_module.graph.erase_node(dequant_node)
54
55    def is_node_target(
56        self, node: torch.fx.Node, target: torch._ops.OperatorBase
57    ) -> bool:
58        return node.op == "call_function" and node.target == target
59
60    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
61        for node in graph_module.graph.nodes:
62            if self.is_node_target(node, exir_ops.edge.aten.linear.default):
63                weight_node = node.args[1]
64                if self.is_node_target(
65                    weight_node,
66                    exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
67                ):
68                    # only fuse if weight tensor is int8 packed
69                    quant_weight = weight_node.args[0]
70                    if quant_weight.meta["val"].dtype != torch.int8:
71                        continue
72                    self.fuse_dequant_with_linear(graph_module, weight_node, node)
73
74        graph_module.recompile()
75        graph_module = super().call(graph_module).graph_module
76
77        return PassResult(graph_module, True)
78