xref: /aosp_15_r20/external/executorch/backends/transforms/decompose_sdpa.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import torch
10from executorch.exir.pass_base import ExportPass, PassResult
11from torch._decomp import get_decompositions
12from torch.fx.experimental.proxy_tensor import make_fx
13
14
15class DecomposeScaledDotProductAttention(ExportPass):
16    """
17    Decompose from scaled_dot_product_attention to multiple nodes.
18    """
19
20    def __init__(self, allow_non_fake_inputs: bool = True) -> None:
21        super().__init__()
22        # With allow_non_fake_inputs=False, we don't get _unsafe_view ops
23        # in the graph, we allow disabling it here.
24        self._allow_non_fake_inputs = allow_non_fake_inputs
25
26    def call(
27        self, graph_module: torch.fx.GraphModule, allow_non_fake_inputs: bool = True
28    ) -> PassResult:
29        graph = graph_module.graph
30        for node in graph.nodes:
31            if node.target == torch.ops.aten.scaled_dot_product_attention.default:
32                input_tensors = (arg.meta["val"] for arg in node.args)
33
34                # refer to pytorch/test/test_decomp.py
35                decomposed_module = make_fx(
36                    node.target,
37                    decomposition_table=get_decompositions(  # pyre-fixme[6]
38                        [
39                            torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
40                        ]
41                    ),
42                    tracing_mode="fake",
43                    _allow_non_fake_inputs=allow_non_fake_inputs,
44                )(*input_tensors)
45                with graph.inserting_before(node):
46                    name_to_input_tensor_map = {}
47                    for i, arg in enumerate(node.args):
48                        name_to_input_tensor_map[f"arg{i}_1"] = arg
49
50                    decomposed_node_to_subgraph_node = {}
51                    last_decomposed_node = None
52                    # Create a mapping from input nodes in decomposed module to original nodes.
53                    # In decomposed module, there are only input tensors for placeholder op.
54                    for decomposed_node in decomposed_module.graph.nodes:
55                        if decomposed_node.op == "placeholder":
56                            decomposed_node_to_subgraph_node[decomposed_node] = (
57                                name_to_input_tensor_map[decomposed_node.name]
58                            )
59
60                        if decomposed_node.op == "output":
61                            last_decomposed_node = decomposed_node.args[0]
62
63                    # Copy node from decompose graph module
64                    for decomposed_node in decomposed_module.graph.nodes:
65                        if decomposed_node.op == "placeholder":
66                            continue
67
68                        if (
69                            decomposed_node.op == "output"
70                            and last_decomposed_node is not None
71                        ):
72                            for user in node.users.copy():
73                                user.replace_input_with(
74                                    node,
75                                    decomposed_node_to_subgraph_node[
76                                        last_decomposed_node
77                                    ],
78                                )
79                            continue
80
81                        subgraph_node = graph.node_copy(
82                            decomposed_node,
83                            arg_transform=lambda x: decomposed_node_to_subgraph_node[  # noqa: B023
84                                x
85                            ],
86                        )
87                        subgraph_node.meta["source_fn_stack"] = [
88                            (subgraph_node, subgraph_node.target)
89                        ]
90                        decomposed_node_to_subgraph_node[decomposed_node] = (
91                            subgraph_node
92                        )
93
94                    graph.erase_node(node)
95
96        graph.eliminate_dead_code()
97        graph_module.recompile()
98        return PassResult(graph_module, True)
99