xref: /aosp_15_r20/external/executorch/backends/vulkan/_passes/remove_local_scalar_dense_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import torch
10from executorch.exir.dialects._ops import ops as exir_ops
11from executorch.exir.pass_base import ExportPass, PassResult
12
13from torch._subclasses.fake_tensor import FakeTensor
14
15
16def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool:
17    """
18    Converting a tensor to a scalar via tensor[0].item() creates a index_select +
19    local_scalar_dense pattern in the graph. Check if a node is the start of this pattern.
20    """
21    if (
22        node.op == "call_function"
23        and node.target == exir_ops.edge.aten.select_copy.int
24        and len(node.users) == 1
25    ):
26        user = list(node.users.keys())[0]
27        return user.target == torch.ops.aten._local_scalar_dense.default
28
29    return False
30
31
32def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None:
33    """
34    A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar
35    value instead of a Tensor object. The criteria for identifying a tensor as a scalar
36    tensor are as follows:
37
38    1. The tensor has only 1 element
39    2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which
40       creates a index_select + local_scalar_dense pattern in the graph
41
42    If any of these criteria are fulfilled, then tag the node for the tensor to mark it
43    so that it is added as a scalar value during serialization.
44    """
45    tensor_val = node.meta["val"]
46    if not isinstance(tensor_val, FakeTensor):
47        return
48
49    # Scalar tensors must have only one element
50    if tensor_val.numel() != 1:
51        return
52
53    for user in node.users:
54        if node_is_local_scalar_dense_chain(user):
55            node.meta["vkdg_is_scalar_tensor"] = True
56
57
58def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None:
59    """
60    Remove the index_select + local_scalar_dense pattern in the graph in favor of passing
61    the original scalar tensor directly.
62    """
63    replace_node = node.args[0]
64    assert isinstance(replace_node, torch.fx.Node)
65    # If the argument to the local_scalar_dense op is a select op with only
66    # one user, and the argument to the select op is a tensor with only one
67    # element (i.e. a scalar tensor), then replace the entire pattern with the
68    # scalar tensor.
69    if (
70        replace_node.op == "call_function"
71        and replace_node.target == exir_ops.edge.aten.select_copy.int
72    ):
73        # pyre-ignore
74        if replace_node.args[0].meta["val"].numel() == 1:
75            replace_node = replace_node.args[0]
76            assert isinstance(replace_node, torch.fx.Node)
77            assert replace_node.meta.get("vkdg_is_scalar_tensor", True)
78
79    with graph.inserting_after(node):
80        node.replace_all_uses_with(replace_node)
81
82
83def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
84    """
85    The purpose of this pass is twofold:
86    1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria)
87    2. Remove the index_select + local_scalar_dense pattern in the graph in favor of
88       passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`)
89
90    This makes it easier to deal with scalar tensors in the Vulkan backend. In particular,
91    it allows serializing scalar tensors as SymInt objects instead of Tensor objects.
92    Because scalar tensors are often used to inform tensor shapes, their values need to
93    be easily accessed by the CPU during resizing logic, while also being able to reflect
94    updates to their value in any GPU shaders that reference them.
95    """
96    target_op = torch.ops.aten._local_scalar_dense.default
97    for node in graph.nodes:
98        tag_node_if_scalar_tensor(node)
99
100        if node.op == "call_function" and node.target == target_op:
101            remove_local_scalar_dense_chain(graph, node)
102
103    graph.eliminate_dead_code()
104    return graph
105
106
107class RemoveLocalScalarDenseOpsTransform(ExportPass):
108    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
109        graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph)
110        return PassResult(graph_module, True)
111