1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10from executorch.exir.dialects._ops import ops as exir_ops 11from executorch.exir.pass_base import ExportPass, PassResult 12 13from torch._subclasses.fake_tensor import FakeTensor 14 15 16def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool: 17 """ 18 Converting a tensor to a scalar via tensor[0].item() creates a index_select + 19 local_scalar_dense pattern in the graph. Check if a node is the start of this pattern. 20 """ 21 if ( 22 node.op == "call_function" 23 and node.target == exir_ops.edge.aten.select_copy.int 24 and len(node.users) == 1 25 ): 26 user = list(node.users.keys())[0] 27 return user.target == torch.ops.aten._local_scalar_dense.default 28 29 return False 30 31 32def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None: 33 """ 34 A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar 35 value instead of a Tensor object. The criteria for identifying a tensor as a scalar 36 tensor are as follows: 37 38 1. The tensor has only 1 element 39 2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which 40 creates a index_select + local_scalar_dense pattern in the graph 41 42 If any of these criteria are fulfilled, then tag the node for the tensor to mark it 43 so that it is added as a scalar value during serialization. 44 """ 45 tensor_val = node.meta["val"] 46 if not isinstance(tensor_val, FakeTensor): 47 return 48 49 # Scalar tensors must have only one element 50 if tensor_val.numel() != 1: 51 return 52 53 for user in node.users: 54 if node_is_local_scalar_dense_chain(user): 55 node.meta["vkdg_is_scalar_tensor"] = True 56 57 58def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None: 59 """ 60 Remove the index_select + local_scalar_dense pattern in the graph in favor of passing 61 the original scalar tensor directly. 62 """ 63 replace_node = node.args[0] 64 assert isinstance(replace_node, torch.fx.Node) 65 # If the argument to the local_scalar_dense op is a select op with only 66 # one user, and the argument to the select op is a tensor with only one 67 # element (i.e. a scalar tensor), then replace the entire pattern with the 68 # scalar tensor. 69 if ( 70 replace_node.op == "call_function" 71 and replace_node.target == exir_ops.edge.aten.select_copy.int 72 ): 73 # pyre-ignore 74 if replace_node.args[0].meta["val"].numel() == 1: 75 replace_node = replace_node.args[0] 76 assert isinstance(replace_node, torch.fx.Node) 77 assert replace_node.meta.get("vkdg_is_scalar_tensor", True) 78 79 with graph.inserting_after(node): 80 node.replace_all_uses_with(replace_node) 81 82 83def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph: 84 """ 85 The purpose of this pass is twofold: 86 1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria) 87 2. Remove the index_select + local_scalar_dense pattern in the graph in favor of 88 passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`) 89 90 This makes it easier to deal with scalar tensors in the Vulkan backend. In particular, 91 it allows serializing scalar tensors as SymInt objects instead of Tensor objects. 92 Because scalar tensors are often used to inform tensor shapes, their values need to 93 be easily accessed by the CPU during resizing logic, while also being able to reflect 94 updates to their value in any GPU shaders that reference them. 95 """ 96 target_op = torch.ops.aten._local_scalar_dense.default 97 for node in graph.nodes: 98 tag_node_if_scalar_tensor(node) 99 100 if node.op == "call_function" and node.target == target_op: 101 remove_local_scalar_dense_chain(graph, node) 102 103 graph.eliminate_dead_code() 104 return graph 105 106 107class RemoveLocalScalarDenseOpsTransform(ExportPass): 108 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 109 graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph) 110 return PassResult(graph_module, True) 111