xref: /aosp_15_r20/external/executorch/backends/vulkan/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum
8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Optional, Set, Tuple
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerimport torch
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.vulkan.serialization.vulkan_graph_schema import (
13*523fa7a6SAndroid Build Coastguard Worker    VkMemoryLayout,
14*523fa7a6SAndroid Build Coastguard Worker    VkStorageType,
15*523fa7a6SAndroid Build Coastguard Worker)
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import is_buffer, is_param
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker##
26*523fa7a6SAndroid Build Coastguard Worker## Node type determination
27*523fa7a6SAndroid Build Coastguard Worker##
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Workerdef is_get_attr_node(node: torch.fx.Node) -> bool:
31*523fa7a6SAndroid Build Coastguard Worker    return isinstance(node, torch.fx.Node) and node.op == "get_attr"
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Workerdef is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool:
35*523fa7a6SAndroid Build Coastguard Worker    return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker
38*523fa7a6SAndroid Build Coastguard Workerdef is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
39*523fa7a6SAndroid Build Coastguard Worker    """
40*523fa7a6SAndroid Build Coastguard Worker    Check if the given node is a parameter within the exported program
41*523fa7a6SAndroid Build Coastguard Worker    """
42*523fa7a6SAndroid Build Coastguard Worker    return (
43*523fa7a6SAndroid Build Coastguard Worker        is_get_attr_node(node)
44*523fa7a6SAndroid Build Coastguard Worker        or is_param(program, node)
45*523fa7a6SAndroid Build Coastguard Worker        or is_buffer(program, node)
46*523fa7a6SAndroid Build Coastguard Worker        or is_constant(program, node)
47*523fa7a6SAndroid Build Coastguard Worker    )
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Workerdef is_symint_node(node: torch.fx.Node) -> bool:
51*523fa7a6SAndroid Build Coastguard Worker    """
52*523fa7a6SAndroid Build Coastguard Worker    Returns true if the given node produces a SymInt value
53*523fa7a6SAndroid Build Coastguard Worker    """
54*523fa7a6SAndroid Build Coastguard Worker    if "val" not in node.meta:
55*523fa7a6SAndroid Build Coastguard Worker        return False
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker    if isinstance(node.meta["val"], torch.SymInt):
58*523fa7a6SAndroid Build Coastguard Worker        return True
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker    return False
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Workerdef is_tensor_node(node: torch.fx.Node) -> bool:
64*523fa7a6SAndroid Build Coastguard Worker    """
65*523fa7a6SAndroid Build Coastguard Worker    Returns true if the given node produces a tensor value, or a collection of tensor values
66*523fa7a6SAndroid Build Coastguard Worker    """
67*523fa7a6SAndroid Build Coastguard Worker    # All nodes with tensor values are tagged by the SpecPropPass transform
68*523fa7a6SAndroid Build Coastguard Worker    if "spec" in node.meta:
69*523fa7a6SAndroid Build Coastguard Worker        return True
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker    if "val" not in node.meta:
72*523fa7a6SAndroid Build Coastguard Worker        return False
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker    if isinstance(node.meta["val"], FakeTensor):
75*523fa7a6SAndroid Build Coastguard Worker        return True
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker    if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
78*523fa7a6SAndroid Build Coastguard Worker        return all(isinstance(x, FakeTensor) for x in node.meta["val"])
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker    return False
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker##
84*523fa7a6SAndroid Build Coastguard Worker## Memory Layout, Storage Type Determination
85*523fa7a6SAndroid Build Coastguard Worker##
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard WorkerImageExtents = Tuple[int, int, int]
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard WorkerDEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048)
90*523fa7a6SAndroid Build Coastguard WorkerDEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024)
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Workerclass PackedDim(IntEnum):
94*523fa7a6SAndroid Build Coastguard Worker    WIDTH = 0
95*523fa7a6SAndroid Build Coastguard Worker    HEIGHT = 1
96*523fa7a6SAndroid Build Coastguard Worker    CHANNELS = 2
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Workerall_packed_dims: Set[PackedDim] = {
100*523fa7a6SAndroid Build Coastguard Worker    PackedDim.WIDTH,
101*523fa7a6SAndroid Build Coastguard Worker    PackedDim.HEIGHT,
102*523fa7a6SAndroid Build Coastguard Worker    PackedDim.CHANNELS,
103*523fa7a6SAndroid Build Coastguard Worker}
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Workerall_storage_types: Set[VkStorageType] = {
106*523fa7a6SAndroid Build Coastguard Worker    VkStorageType.BUFFER,
107*523fa7a6SAndroid Build Coastguard Worker    VkStorageType.TEXTURE_3D,
108*523fa7a6SAndroid Build Coastguard Worker}
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Workerall_memory_layouts: Set[VkMemoryLayout] = {
111*523fa7a6SAndroid Build Coastguard Worker    VkMemoryLayout.TENSOR_WIDTH_PACKED,
112*523fa7a6SAndroid Build Coastguard Worker    VkMemoryLayout.TENSOR_HEIGHT_PACKED,
113*523fa7a6SAndroid Build Coastguard Worker    VkMemoryLayout.TENSOR_CHANNELS_PACKED,
114*523fa7a6SAndroid Build Coastguard Worker}
115*523fa7a6SAndroid Build Coastguard Worker
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Workerdef within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
118*523fa7a6SAndroid Build Coastguard Worker    """
119*523fa7a6SAndroid Build Coastguard Worker    Checks whether the tensors produced by the given node can fit within the device's
120*523fa7a6SAndroid Build Coastguard Worker    GPU buffer limit, which represents the maximum number of elements that can be stored
121*523fa7a6SAndroid Build Coastguard Worker    in a GPU buffer.
122*523fa7a6SAndroid Build Coastguard Worker    """
123*523fa7a6SAndroid Build Coastguard Worker    assert is_tensor_node(node)
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker    if isinstance(node.meta["val"], FakeTensor):
126*523fa7a6SAndroid Build Coastguard Worker        return node.meta["val"].numel() < buffer_limit
127*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
128*523fa7a6SAndroid Build Coastguard Worker        return all(x.numel() < buffer_limit for x in node.meta["val"])
129*523fa7a6SAndroid Build Coastguard Worker    else:
130*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")
131*523fa7a6SAndroid Build Coastguard Worker
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Workerdef required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
134*523fa7a6SAndroid Build Coastguard Worker    """
135*523fa7a6SAndroid Build Coastguard Worker    Calculate the image extents that will be used to represent a tensor with the given sizes
136*523fa7a6SAndroid Build Coastguard Worker    and memory layout in the Vulkan Delegate.
137*523fa7a6SAndroid Build Coastguard Worker    """
138*523fa7a6SAndroid Build Coastguard Worker    width = sizes[-1] if len(sizes) >= 1 else 1
139*523fa7a6SAndroid Build Coastguard Worker    height = sizes[-2] if len(sizes) >= 2 else 1
140*523fa7a6SAndroid Build Coastguard Worker    channels = sizes[-3] if len(sizes) >= 3 else 1
141*523fa7a6SAndroid Build Coastguard Worker    batch = sizes[0] if len(sizes) >= 4 else 1
142*523fa7a6SAndroid Build Coastguard Worker
143*523fa7a6SAndroid Build Coastguard Worker    if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED:
144*523fa7a6SAndroid Build Coastguard Worker        width = (width + 3) // 4
145*523fa7a6SAndroid Build Coastguard Worker    elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED:
146*523fa7a6SAndroid Build Coastguard Worker        height = (height + 3) // 4
147*523fa7a6SAndroid Build Coastguard Worker    elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED:
148*523fa7a6SAndroid Build Coastguard Worker        channels = (channels + 3) // 4
149*523fa7a6SAndroid Build Coastguard Worker    else:
150*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Unsupported memory layout {layout}")
151*523fa7a6SAndroid Build Coastguard Worker
152*523fa7a6SAndroid Build Coastguard Worker    return width, height, channels * batch
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Workerdef extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool:
156*523fa7a6SAndroid Build Coastguard Worker    return all(extents[i] <= limits[i] for i in range(len(extents)))
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Workerdef valid_texture_memory_layouts(
160*523fa7a6SAndroid Build Coastguard Worker    tensor_sizes: torch.Size, texture_limits: ImageExtents
161*523fa7a6SAndroid Build Coastguard Worker) -> Set[VkMemoryLayout]:
162*523fa7a6SAndroid Build Coastguard Worker    """
163*523fa7a6SAndroid Build Coastguard Worker    Given tensor sizes, determine the set of memory layouts which will prodice a texture
164*523fa7a6SAndroid Build Coastguard Worker    that can fit within the specified device limits.
165*523fa7a6SAndroid Build Coastguard Worker    """
166*523fa7a6SAndroid Build Coastguard Worker    valid_layouts = set()
167*523fa7a6SAndroid Build Coastguard Worker    for layout in list(all_memory_layouts):
168*523fa7a6SAndroid Build Coastguard Worker        extents = required_image_extents(tensor_sizes, layout)
169*523fa7a6SAndroid Build Coastguard Worker        if extents_are_valid(extents, texture_limits):
170*523fa7a6SAndroid Build Coastguard Worker            valid_layouts.add(layout)
171*523fa7a6SAndroid Build Coastguard Worker
172*523fa7a6SAndroid Build Coastguard Worker    return valid_layouts
173*523fa7a6SAndroid Build Coastguard Worker
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Workerdef possible_node_memory_layouts(
176*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node, texture_limits: ImageExtents
177*523fa7a6SAndroid Build Coastguard Worker) -> Set[VkMemoryLayout]:
178*523fa7a6SAndroid Build Coastguard Worker    """
179*523fa7a6SAndroid Build Coastguard Worker    Given a node, determine the set of memory layouts which can be used to represent all
180*523fa7a6SAndroid Build Coastguard Worker    tensors involved in the computation.
181*523fa7a6SAndroid Build Coastguard Worker    """
182*523fa7a6SAndroid Build Coastguard Worker    assert is_tensor_node(node)
183*523fa7a6SAndroid Build Coastguard Worker    if isinstance(node.meta["val"], FakeTensor):
184*523fa7a6SAndroid Build Coastguard Worker        return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits)
185*523fa7a6SAndroid Build Coastguard Worker    valid_layouts = set()
186*523fa7a6SAndroid Build Coastguard Worker    if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
187*523fa7a6SAndroid Build Coastguard Worker        for fake_tensor in node.meta["val"]:
188*523fa7a6SAndroid Build Coastguard Worker            valid_layouts = valid_layouts.union(
189*523fa7a6SAndroid Build Coastguard Worker                valid_texture_memory_layouts(fake_tensor.shape, texture_limits)
190*523fa7a6SAndroid Build Coastguard Worker            )
191*523fa7a6SAndroid Build Coastguard Worker
192*523fa7a6SAndroid Build Coastguard Worker    return valid_layouts
193*523fa7a6SAndroid Build Coastguard Worker
194*523fa7a6SAndroid Build Coastguard Worker
195*523fa7a6SAndroid Build Coastguard Worker##
196*523fa7a6SAndroid Build Coastguard Worker## TensorSpec Utils
197*523fa7a6SAndroid Build Coastguard Worker##
198*523fa7a6SAndroid Build Coastguard Worker
199*523fa7a6SAndroid Build Coastguard Worker
200*523fa7a6SAndroid Build Coastguard Workerdef set_node_spec_attr(node: torch.fx.Node, attr: str, value):
201*523fa7a6SAndroid Build Coastguard Worker    assert "spec" in node.meta
202*523fa7a6SAndroid Build Coastguard Worker    spec = node.meta["spec"]
203*523fa7a6SAndroid Build Coastguard Worker    if isinstance(spec, TensorSpec):
204*523fa7a6SAndroid Build Coastguard Worker        setattr(spec, attr, value)
205*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(spec, list) or isinstance(spec, tuple):
206*523fa7a6SAndroid Build Coastguard Worker        for s in spec:
207*523fa7a6SAndroid Build Coastguard Worker            assert isinstance(s, TensorSpec)
208*523fa7a6SAndroid Build Coastguard Worker            setattr(s, attr, value)
209*523fa7a6SAndroid Build Coastguard Worker    else:
210*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")
211*523fa7a6SAndroid Build Coastguard Worker
212*523fa7a6SAndroid Build Coastguard Worker
213*523fa7a6SAndroid Build Coastguard Workerdef get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True):
214*523fa7a6SAndroid Build Coastguard Worker    assert "spec" in node.meta
215*523fa7a6SAndroid Build Coastguard Worker    spec = node.meta["spec"]
216*523fa7a6SAndroid Build Coastguard Worker    if isinstance(spec, TensorSpec):
217*523fa7a6SAndroid Build Coastguard Worker        return getattr(spec, attr) if hasattr(spec, attr) else None
218*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(spec, list) or isinstance(spec, tuple):
219*523fa7a6SAndroid Build Coastguard Worker        if return_first:
220*523fa7a6SAndroid Build Coastguard Worker            return getattr(spec[0], attr) if hasattr(spec, attr) else None
221*523fa7a6SAndroid Build Coastguard Worker        else:
222*523fa7a6SAndroid Build Coastguard Worker            return [getattr(s, attr) if hasattr(s, attr) else None for s in spec]
223*523fa7a6SAndroid Build Coastguard Worker    else:
224*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}")
225*523fa7a6SAndroid Build Coastguard Worker
226*523fa7a6SAndroid Build Coastguard Worker
227*523fa7a6SAndroid Build Coastguard Workerdef get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]:
228*523fa7a6SAndroid Build Coastguard Worker    return get_node_spec_attr(node, "vk_storage_type")
229*523fa7a6SAndroid Build Coastguard Worker
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Workerdef get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]:
232*523fa7a6SAndroid Build Coastguard Worker    return get_node_spec_attr(node, "vk_memory_layout")
233