xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/decompose_einsum.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.exir.pass_base import ExportPass, PassResult
9from torch.fx.experimental.proxy_tensor import make_fx
10
11
12class DecomposeEinsum(ExportPass):
13    """
14    Decompose einsum for quantization annotation to work properly.
15    """
16
17    def __init__(self) -> None:
18        super().__init__()
19
20    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
21        graph = graph_module.graph
22        for node in graph.nodes:
23            if node.target == torch.ops.aten.einsum.default:
24                decomposed_module = make_fx(
25                    node.target,
26                    tracing_mode="fake",
27                )(node.args[0], [arg.meta["val"] for arg in node.args[1]])
28
29                with graph.inserting_before(node):
30                    # remap is used to map original node values to new node values,
31                    # which ensures that reference to nodes are correclty updated in the new graph
32                    remap = {}
33                    # Different from other nodes, einsum args[0] is the einsum equation,
34                    # while input nodes are stored in args[1]
35                    for i, arg in enumerate(node.args[1]):
36                        remap[f"arg1_{i+1}"] = arg
37
38                    for decomposed_node in decomposed_module.graph.nodes:
39                        # This is the arg[0] equation string, which is not required anymore after decomposition
40                        if "arg0" in decomposed_node.name:
41                            continue
42
43                        # no need to copy existent 'output'
44                        if decomposed_node.op == "output":
45                            for user in node.users.copy():
46                                # remap
47                                user.replace_input_with(
48                                    node,
49                                    remap[decomposed_node.args[0][0]],
50                                )
51                        # no need to copy existent placeholders
52                        elif decomposed_node.op == "placeholder":
53                            # replace node map from string to graph node
54                            remap[decomposed_node] = remap.pop(decomposed_node.name)
55                        else:
56                            remap[decomposed_node] = graph.node_copy(
57                                decomposed_node,
58                                arg_transform=lambda x, remap=remap: remap[x],
59                            )
60
61                    graph.erase_node(node)
62
63        graph.eliminate_dead_code()
64        graph_module.recompile()
65        return PassResult(graph_module, True)
66