1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict 8 9import torch 10from executorch.backends.transforms import get_shape 11from executorch.backends.xnnpack.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 16 XNNGraph, 17 XNNScaledDotProductAttention, 18 XNode, 19) 20from executorch.backends.xnnpack.utils.utils import get_input_node 21 22 23@register_node_visitor 24class SDPAVisitor(NodeVisitor): 25 target = "aten.scaled_dot_product_attention.default" 26 27 def __init__(self, *args) -> None: 28 super().__init__(*args) 29 30 @staticmethod 31 def get_fake_attr(name: str, value: torch.Tensor) -> torch.fx.Node: 32 g = torch.fx.Graph() 33 gm = torch.fx.GraphModule({}, g) 34 fake_node = torch.fx.Node(g, name, "get_attr", target=name, args=(), kwargs={}) 35 g._owning_module = gm 36 setattr(g._owning_module, name, value) 37 fake_node.meta["val"] = value 38 return fake_node 39 40 def define_node( 41 self, 42 node: torch.fx.Node, 43 xnn_graph: XNNGraph, 44 vals_to_ids: Dict[torch.fx.Node, int], 45 debug_handle: int, 46 ) -> None: 47 # inputs 48 for i in range(0, 4): 49 inp = get_input_node(node, i) 50 self.define_tensor( 51 inp, 52 xnn_graph, 53 vals_to_ids, 54 ) 55 56 # Make sure mask is not bool 57 mask_node = get_input_node(node, 3) 58 mask_dtype = mask_node.meta["val"].dtype 59 assert mask_dtype in [ 60 torch.float, 61 torch.float16, 62 ], "SDPA Mask must be a float (or half) tensor" 63 64 # Make sure mask is not >2D 65 assert len(get_shape(mask_node)) == 2, "SDPA Mask must be 2D" 66 67 # Hack to broadcast the scale 68 q_shape = get_shape(get_input_node(node, 0)) 69 embedding_dim = q_shape[-1] 70 scale = 1 / (embedding_dim**0.5) 71 if "scale" in node.kwargs and node.kwargs["scale"]: 72 scale = cast(float, node.kwargs["scale"]) 73 74 t = torch.full((embedding_dim,), scale, dtype=mask_dtype) 75 scale_node = self.get_fake_attr("scale", t) 76 self.define_tensor( 77 scale_node, 78 xnn_graph, 79 vals_to_ids, 80 ) 81 82 # outputs 83 outp = node 84 self.define_tensor( 85 outp, 86 xnn_graph, 87 vals_to_ids, 88 ) 89 90 # ids 91 q_id = vals_to_ids[get_input_node(node, 0)] 92 k_id = vals_to_ids[get_input_node(node, 1)] 93 v_id = vals_to_ids[get_input_node(node, 2)] 94 mask_id = vals_to_ids[mask_node] 95 scale_id = vals_to_ids[scale_node] 96 output_id = vals_to_ids[outp] 97 98 # Create a new node 99 sdpa_node = XNode( 100 xnode_union=XNNScaledDotProductAttention( 101 query_id=q_id, 102 key_id=k_id, 103 value_id=v_id, 104 scale_id=scale_id, 105 mask_id=mask_id, 106 output_id=output_id, 107 flags=0, 108 ), 109 debug_handle=debug_handle, 110 ) 111 xnn_graph.xnodes.append(sdpa_node) 112