1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10from executorch.exir.pass_base import ExportPass, PassResult 11from torch._decomp import get_decompositions 12from torch.fx.experimental.proxy_tensor import make_fx 13 14 15class DecomposeScaledDotProductAttention(ExportPass): 16 """ 17 Decompose from scaled_dot_product_attention to multiple nodes. 18 """ 19 20 def __init__(self, allow_non_fake_inputs: bool = True) -> None: 21 super().__init__() 22 # With allow_non_fake_inputs=False, we don't get _unsafe_view ops 23 # in the graph, we allow disabling it here. 24 self._allow_non_fake_inputs = allow_non_fake_inputs 25 26 def call( 27 self, graph_module: torch.fx.GraphModule, allow_non_fake_inputs: bool = True 28 ) -> PassResult: 29 graph = graph_module.graph 30 for node in graph.nodes: 31 if node.target == torch.ops.aten.scaled_dot_product_attention.default: 32 input_tensors = (arg.meta["val"] for arg in node.args) 33 34 # refer to pytorch/test/test_decomp.py 35 decomposed_module = make_fx( 36 node.target, 37 decomposition_table=get_decompositions( # pyre-fixme[6] 38 [ 39 torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, 40 ] 41 ), 42 tracing_mode="fake", 43 _allow_non_fake_inputs=allow_non_fake_inputs, 44 )(*input_tensors) 45 with graph.inserting_before(node): 46 name_to_input_tensor_map = {} 47 for i, arg in enumerate(node.args): 48 name_to_input_tensor_map[f"arg{i}_1"] = arg 49 50 decomposed_node_to_subgraph_node = {} 51 last_decomposed_node = None 52 # Create a mapping from input nodes in decomposed module to original nodes. 53 # In decomposed module, there are only input tensors for placeholder op. 54 for decomposed_node in decomposed_module.graph.nodes: 55 if decomposed_node.op == "placeholder": 56 decomposed_node_to_subgraph_node[decomposed_node] = ( 57 name_to_input_tensor_map[decomposed_node.name] 58 ) 59 60 if decomposed_node.op == "output": 61 last_decomposed_node = decomposed_node.args[0] 62 63 # Copy node from decompose graph module 64 for decomposed_node in decomposed_module.graph.nodes: 65 if decomposed_node.op == "placeholder": 66 continue 67 68 if ( 69 decomposed_node.op == "output" 70 and last_decomposed_node is not None 71 ): 72 for user in node.users.copy(): 73 user.replace_input_with( 74 node, 75 decomposed_node_to_subgraph_node[ 76 last_decomposed_node 77 ], 78 ) 79 continue 80 81 subgraph_node = graph.node_copy( 82 decomposed_node, 83 arg_transform=lambda x: decomposed_node_to_subgraph_node[ # noqa: B023 84 x 85 ], 86 ) 87 subgraph_node.meta["source_fn_stack"] = [ 88 (subgraph_node, subgraph_node.target) 89 ] 90 decomposed_node_to_subgraph_node[decomposed_node] = ( 91 subgraph_node 92 ) 93 94 graph.erase_node(node) 95 96 graph.eliminate_dead_code() 97 graph_module.recompile() 98 return PassResult(graph_module, True) 99