xref: /aosp_15_r20/external/executorch/backends/vulkan/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from enum import IntEnum
8from typing import Optional, Set, Tuple
9
10import torch
11
12from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
13    VkMemoryLayout,
14    VkStorageType,
15)
16
17from executorch.exir.tensor import TensorSpec
18
19from torch._export.utils import is_buffer, is_param
20
21from torch._subclasses.fake_tensor import FakeTensor
22
23from torch.export import ExportedProgram
24
25##
26## Node type determination
27##
28
29
30def is_get_attr_node(node: torch.fx.Node) -> bool:
31    return isinstance(node, torch.fx.Node) and node.op == "get_attr"
32
33
34def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool:
35    return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
36
37
38def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
39    """
40    Check if the given node is a parameter within the exported program
41    """
42    return (
43        is_get_attr_node(node)
44        or is_param(program, node)
45        or is_buffer(program, node)
46        or is_constant(program, node)
47    )
48
49
50def is_symint_node(node: torch.fx.Node) -> bool:
51    """
52    Returns true if the given node produces a SymInt value
53    """
54    if "val" not in node.meta:
55        return False
56
57    if isinstance(node.meta["val"], torch.SymInt):
58        return True
59
60    return False
61
62
63def is_tensor_node(node: torch.fx.Node) -> bool:
64    """
65    Returns true if the given node produces a tensor value, or a collection of tensor values
66    """
67    # All nodes with tensor values are tagged by the SpecPropPass transform
68    if "spec" in node.meta:
69        return True
70
71    if "val" not in node.meta:
72        return False
73
74    if isinstance(node.meta["val"], FakeTensor):
75        return True
76
77    if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
78        return all(isinstance(x, FakeTensor) for x in node.meta["val"])
79
80    return False
81
82
83##
84## Memory Layout, Storage Type Determination
85##
86
87ImageExtents = Tuple[int, int, int]
88
89DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048)
90DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024)
91
92
93class PackedDim(IntEnum):
94    WIDTH = 0
95    HEIGHT = 1
96    CHANNELS = 2
97
98
99all_packed_dims: Set[PackedDim] = {
100    PackedDim.WIDTH,
101    PackedDim.HEIGHT,
102    PackedDim.CHANNELS,
103}
104
105all_storage_types: Set[VkStorageType] = {
106    VkStorageType.BUFFER,
107    VkStorageType.TEXTURE_3D,
108}
109
110all_memory_layouts: Set[VkMemoryLayout] = {
111    VkMemoryLayout.TENSOR_WIDTH_PACKED,
112    VkMemoryLayout.TENSOR_HEIGHT_PACKED,
113    VkMemoryLayout.TENSOR_CHANNELS_PACKED,
114}
115
116
117def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
118    """
119    Checks whether the tensors produced by the given node can fit within the device's
120    GPU buffer limit, which represents the maximum number of elements that can be stored
121    in a GPU buffer.
122    """
123    assert is_tensor_node(node)
124
125    if isinstance(node.meta["val"], FakeTensor):
126        return node.meta["val"].numel() < buffer_limit
127    elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
128        return all(x.numel() < buffer_limit for x in node.meta["val"])
129    else:
130        raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")
131
132
133def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
134    """
135    Calculate the image extents that will be used to represent a tensor with the given sizes
136    and memory layout in the Vulkan Delegate.
137    """
138    width = sizes[-1] if len(sizes) >= 1 else 1
139    height = sizes[-2] if len(sizes) >= 2 else 1
140    channels = sizes[-3] if len(sizes) >= 3 else 1
141    batch = sizes[0] if len(sizes) >= 4 else 1
142
143    if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED:
144        width = (width + 3) // 4
145    elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED:
146        height = (height + 3) // 4
147    elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED:
148        channels = (channels + 3) // 4
149    else:
150        raise RuntimeError(f"Unsupported memory layout {layout}")
151
152    return width, height, channels * batch
153
154
155def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool:
156    return all(extents[i] <= limits[i] for i in range(len(extents)))
157
158
159def valid_texture_memory_layouts(
160    tensor_sizes: torch.Size, texture_limits: ImageExtents
161) -> Set[VkMemoryLayout]:
162    """
163    Given tensor sizes, determine the set of memory layouts which will prodice a texture
164    that can fit within the specified device limits.
165    """
166    valid_layouts = set()
167    for layout in list(all_memory_layouts):
168        extents = required_image_extents(tensor_sizes, layout)
169        if extents_are_valid(extents, texture_limits):
170            valid_layouts.add(layout)
171
172    return valid_layouts
173
174
175def possible_node_memory_layouts(
176    node: torch.fx.Node, texture_limits: ImageExtents
177) -> Set[VkMemoryLayout]:
178    """
179    Given a node, determine the set of memory layouts which can be used to represent all
180    tensors involved in the computation.
181    """
182    assert is_tensor_node(node)
183    if isinstance(node.meta["val"], FakeTensor):
184        return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits)
185    valid_layouts = set()
186    if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
187        for fake_tensor in node.meta["val"]:
188            valid_layouts = valid_layouts.union(
189                valid_texture_memory_layouts(fake_tensor.shape, texture_limits)
190            )
191
192    return valid_layouts
193
194
195##
196## TensorSpec Utils
197##
198
199
200def set_node_spec_attr(node: torch.fx.Node, attr: str, value):
201    assert "spec" in node.meta
202    spec = node.meta["spec"]
203    if isinstance(spec, TensorSpec):
204        setattr(spec, attr, value)
205    elif isinstance(spec, list) or isinstance(spec, tuple):
206        for s in spec:
207            assert isinstance(s, TensorSpec)
208            setattr(s, attr, value)
209    else:
210        raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")
211
212
213def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True):
214    assert "spec" in node.meta
215    spec = node.meta["spec"]
216    if isinstance(spec, TensorSpec):
217        return getattr(spec, attr) if hasattr(spec, attr) else None
218    elif isinstance(spec, list) or isinstance(spec, tuple):
219        if return_first:
220            return getattr(spec[0], attr) if hasattr(spec, attr) else None
221        else:
222            return [getattr(s, attr) if hasattr(s, attr) else None for s in spec]
223    else:
224        raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}")
225
226
227def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]:
228    return get_node_spec_attr(node, "vk_storage_type")
229
230
231def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]:
232    return get_node_spec_attr(node, "vk_memory_layout")
233