xref: /aosp_15_r20/external/mesa3d/src/vulkan/util/vk_synchronization_helpers_gen.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1COPYRIGHT=u"""
2/* Copyright © 2023 Collabora, Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23"""
24
25import argparse
26import os
27import textwrap
28import xml.etree.ElementTree as et
29
30from mako.template import Template
31from vk_extensions import get_api_list
32
33TEMPLATE_C = Template(COPYRIGHT + """\
34#include "vk_synchronization.h"
35
36VkPipelineStageFlags2
37vk_expand_pipeline_stage_flags2(VkPipelineStageFlags2 stages)
38{
39% for (group_stage, stages) in group_stages.items():
40    if (stages & ${group_stage})
41        stages |= ${' |\\n                  '.join(stages)};
42
43% endfor
44    if (stages & VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT) {
45% for (guard, stage) in all_commands_stages:
46% if guard is not None:
47#ifdef ${guard}
48% endif
49        stages |= ${stage};
50% if guard is not None:
51#endif
52% endif
53% endfor
54    }
55
56    return stages;
57}
58
59VkAccessFlags2
60vk_read_access2_for_pipeline_stage_flags2(VkPipelineStageFlags2 stages)
61{
62    VkAccessFlags2 access = 0;
63
64% for ((guard, stages), access) in stages_read_access.items():
65% if guard is not None:
66#ifdef ${guard}
67% endif
68    if (stages & (${' |\\n                  '.join(stages)}))
69        access |= ${' |\\n                  '.join(access)};
70% if guard is not None:
71#endif
72% endif
73
74% endfor
75    return access;
76}
77
78VkAccessFlags2
79vk_write_access2_for_pipeline_stage_flags2(VkPipelineStageFlags2 stages)
80{
81    VkAccessFlags2 access = 0;
82
83% for ((guard, stages), access) in stages_write_access.items():
84% if guard is not None:
85#ifdef ${guard}
86% endif
87    if (stages & (${' |\\n                  '.join(stages)}))
88        access |= ${' |\\n                  '.join(access)};
89% if guard is not None:
90#endif
91% endif
92
93% endfor
94    return access;
95}
96""")
97
98def get_guards(xml, api):
99    guards = {}
100    for ext_elem in xml.findall('./extensions/extension'):
101        supported = get_api_list(ext_elem.attrib['supported'])
102        if api not in supported:
103            continue
104
105        for enum in ext_elem.findall('./require/enum[@extends]'):
106            if enum.attrib['extends'] not in ('VkPipelineStageFlagBits2',
107                                              'VkAccessFlagBits2'):
108                continue
109
110            if 'protect' not in enum.attrib:
111                continue
112
113            name = enum.attrib['name']
114            guard = enum.attrib['protect']
115            guards[name] = guard
116
117    return guards
118
119def get_all_commands_stages(xml, guards):
120    stages = []
121    for stage in xml.findall('./sync/syncstage'):
122        stage_name = stage.attrib['name']
123
124        exclude = [
125            # This isn't a real stage
126            'VK_PIPELINE_STAGE_2_NONE',
127
128            # These are real stages but they're a bit weird to include in
129            # ALL_COMMANDS because they're context-dependent, depending on
130            # whether they're part of srcStagesMask or dstStagesMask.
131            #
132            # We could avoid all grouped stages but then if someone adds
133            # another group later, the behavior of this function may change in
134            # a backwards-compatible way.  Also, the other ones aren't really
135            # hurting anything if we add them in.
136            'VK_PIPELINE_STAGE_2_TOP_OF_PIPE_BIT',
137            'VK_PIPELINE_STAGE_2_BOTTOM_OF_PIPE_BIT',
138
139            # This is all COMMANDS, not host.
140            'VK_PIPELINE_STAGE_2_HOST_BIT',
141        ]
142        if stage_name in exclude:
143            continue
144
145        guard = guards.get(stage_name, None)
146        stages.append((guard, stage_name))
147
148    return stages
149
150def get_group_stages(xml):
151    group_stages = {}
152    for stage in xml.findall('./sync/syncstage'):
153        name = stage.attrib['name']
154        equiv = stage.find('./syncequivalent')
155        if equiv is not None:
156            stages = equiv.attrib['stage'].split(',')
157            group_stages[name] = stages
158
159    return group_stages
160
161def access_is_read(name):
162    if 'READ' in name:
163        assert 'WRITE' not in name
164        return True
165    elif 'WRITE' in name:
166        return False
167    else:
168        print(name)
169        assert False, "Invalid access bit name"
170
171def get_stages_access(xml, read, guards, all_commands_stages, group_stages):
172    stages_access = {}
173    for access in xml.findall('./sync/syncaccess'):
174        access_name = access.attrib['name']
175        if access_name == 'VK_ACCESS_2_NONE':
176            continue
177
178        if access_is_read(access_name) != read:
179            continue
180
181        guard = guards.get(access_name, None)
182        support = access.find('./syncsupport')
183        if support is not None:
184            stages = support.attrib['stage'].split(',')
185
186            for stage in stages:
187                if (guard, stage) in all_commands_stages:
188                    stages.append('VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT')
189                    stages.append('VK_PIPELINE_STAGE_2_TOP_OF_PIPE_BIT' if read else 'VK_PIPELINE_STAGE_2_BOTTOM_OF_PIPE_BIT')
190                    break
191
192            for (group, equiv) in group_stages.items():
193                for stage in stages:
194                    if stage in equiv:
195                        stages.append(group)
196                        break
197
198
199            stages.sort()
200            key = (guard, tuple(stages))
201            if key in stages_access:
202                stages_access[key].append(access_name)
203            else:
204                stages_access[key] = [access_name]
205
206    return stages_access
207
208def main():
209    parser = argparse.ArgumentParser()
210    parser.add_argument('--beta', required=True, help='Enable beta extensions.')
211    parser.add_argument('--xml', required=True, help='Vulkan API XML file')
212    parser.add_argument('--out-c', required=True, help='Output C file.')
213    args = parser.parse_args()
214
215    xml = et.parse(args.xml);
216
217    guards = get_guards(xml, 'vulkan')
218    all_commands_stages = get_all_commands_stages(xml, guards)
219    group_stages = get_group_stages(xml)
220
221    environment = {
222        'all_commands_stages': all_commands_stages,
223        'group_stages': group_stages,
224        'stages_read_access': get_stages_access(xml, True, guards, all_commands_stages, group_stages),
225        'stages_write_access': get_stages_access(xml, False, guards, all_commands_stages, group_stages),
226    }
227
228    try:
229        with open(args.out_c, 'w', encoding='utf-8') as f:
230            f.write(TEMPLATE_C.render(**environment))
231    except Exception:
232        # In the event there's an error, this imports some helpers from mako
233        # to print a useful stack trace and prints it, then exits with
234        # status 1, if python is run with debug; otherwise it just raises
235        # the exception
236        import sys
237        from mako import exceptions
238        print(exceptions.text_error_template().render(), file=sys.stderr)
239        sys.exit(1)
240
241if __name__ == '__main__':
242    main()
243