xref: /aosp_15_r20/external/pytorch/tools/gen_vulkan_spv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3from __future__ import annotations
4
5import argparse
6import array
7import codecs
8import copy
9import glob
10import io
11import os
12import re
13import sys
14from itertools import product
15
16sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
17import subprocess
18import textwrap
19from dataclasses import dataclass
20from typing import Any
21
22import yaml
23from yaml.constructor import ConstructorError
24from yaml.nodes import MappingNode
25
26try:
27    from yaml import CLoader as Loader
28except ImportError:
29    from yaml import Loader  # type: ignore[assignment, misc]
30
31CPP_H_NAME = "spv.h"
32CPP_SRC_NAME = "spv.cpp"
33
34DEFAULT_ENV: dict[str, Any] = {
35    "PRECISION": "highp",
36    "FLOAT_IMAGE_FORMAT": "rgba16f",
37    "INT_IMAGE_FORMAT": "rgba32i",
38    "UINT_IMAGE_FORMAT": "rgba32ui",
39}
40
41TYPES_ENV: dict[str, Any] = {
42    "IMAGE_FORMAT": {
43        "float": "rgba32f",
44        "half": "rgba16f",
45        "int": "rgba32i",
46        "uint": "rgba32ui",
47        "int8": "rgba8i",
48        "uint8": "rgba8ui",
49    },
50    "IMAGE_T": {
51        3: {
52            "float": "image3D",
53            "half": "image3D",
54            "int": "iimage3D",
55            "uint": "uimage3D",
56        },
57        2: {
58            "float": "image2D",
59            "half": "image2D",
60            "int": "iimage2D",
61            "uint": "uimage2D",
62        },
63    },
64    "SAMPLER_T": {
65        3: {
66            "float": "sampler3D",
67            "half": "sampler3D",
68            "int": "isampler3D",
69            "uint": "usampler3D",
70        },
71        2: {
72            "float": "sampler2D",
73            "half": "sampler2D",
74            "int": "isampler2D",
75            "uint": "usampler2D",
76        },
77    },
78    "VEC4_T": {
79        "float": "vec4",
80        "half": "vec4",
81        "int": "ivec4",
82        "uint": "uvec4",
83        "int8": "vec4",
84        "uint8": "uvec4",
85    },
86    "T": {
87        "float": "float",
88        "half": "float",
89        "int": "int",
90        "uint": "uint",
91        "int8": "int",
92        "uint8": "uint8",
93    },
94}
95
96FUNCS_ENV: dict[str, Any] = {
97    "GET_POS": {
98        3: lambda pos: pos,
99        2: lambda pos: f"{pos}.xy",
100    }
101}
102
103
104def extract_filename(path: str, keep_ext: bool = True) -> Any:
105    if keep_ext:
106        return os.path.basename(path)
107    else:
108        return os.path.basename(path).split(".")[0]
109
110
111############################
112#  SPIR-V Code Generation  #
113############################
114
115
116# https://gist.github.com/pypt/94d747fe5180851196eb
117class UniqueKeyLoader(Loader):
118    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
119        if not isinstance(node, MappingNode):
120            raise ConstructorError(
121                None,
122                None,
123                f"expected a mapping node, but found {node.id}",
124                node.start_mark,
125            )
126        mapping = {}
127        for key_node, value_node in node.value:
128            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
129            try:
130                hash(key)
131            except TypeError as e:
132                raise ConstructorError(
133                    "while constructing a mapping",
134                    node.start_mark,
135                    "found unacceptable key ",
136                    key_node.start_mark,
137                ) from e
138            # check for duplicate keys
139            if key in mapping:
140                raise ConstructorError(
141                    "while constructing a mapping",
142                    node.start_mark,
143                    "found duplicate key",
144                    key_node.start_mark,
145                )
146            value = self.construct_object(value_node, deep=deep)  # type: ignore[no-untyped-call]
147            mapping[key] = value
148        return mapping
149
150
151# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
152def extract_leading_whitespace(line: str) -> str:
153    match = re.match(r"\s*", line)
154    return match.group(0) if match else ""
155
156
157# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
158def escape(line: str) -> str:
159    output_parts = []
160    while "${" in line:
161        start_pos = line.index("${")
162        end_pos = line.index("}", start_pos + 2)
163        if start_pos != 0:
164            output_parts.append('"' + line[:start_pos].replace('"', '\\"') + '"')
165        output_parts.append("str(" + line[start_pos + 2 : end_pos] + ")")
166        line = line[end_pos + 1 :]
167    if line:
168        output_parts.append('"' + line.replace('"', '\\"') + '"')
169    return " + ".join(output_parts)
170
171
172# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
173def preprocess(
174    input_text: str, variables: dict[str, Any], input_path: str = "codegen"
175) -> str:
176    input_lines = input_text.splitlines()
177    python_lines = []
178
179    blank_lines = 0
180
181    last_indent = ""
182
183    # List of tuples (total_index, python_indent)
184    indent_stack = [("", "")]
185
186    # Indicates whether this is the first line inside Python
187    # code block (i.e. for, while, if, elif, else)
188    python_block_start = True
189    for i, input_line in enumerate(input_lines):
190        if input_line == "":
191            blank_lines += 1
192            continue
193        # Skip lint markers.
194        if "LINT" in input_line:
195            continue
196
197        input_indent = extract_leading_whitespace(input_line)
198        if python_block_start:
199            assert input_indent.startswith(last_indent)
200            extra_python_indent = input_indent[len(last_indent) :]
201            python_indent = indent_stack[-1][1] + extra_python_indent
202            indent_stack.append((input_indent, python_indent))
203            assert input_indent.startswith(indent_stack[-1][0])
204        else:
205            while not input_indent.startswith(indent_stack[-1][0]):
206                del indent_stack[-1]
207        python_block_start = False
208
209        python_indent = indent_stack[-1][1]
210        stripped_input_line = input_line.strip()
211        if stripped_input_line.startswith("$") and not stripped_input_line.startswith(
212            "${"
213        ):
214            if stripped_input_line.endswith(":"):
215                python_block_start = True
216            while blank_lines != 0:
217                python_lines.append(python_indent + "print(file=OUT_STREAM)")
218                blank_lines -= 1
219            python_lines.append(python_indent + stripped_input_line.replace("$", ""))
220        else:
221            assert input_line.startswith(python_indent)
222            while blank_lines != 0:
223                python_lines.append(python_indent + "print(file=OUT_STREAM)")
224                blank_lines -= 1
225            python_lines.append(
226                python_indent
227                + f"print({escape(input_line[len(python_indent) :])}, file=OUT_STREAM)"
228            )
229        last_indent = input_indent
230
231    while blank_lines != 0:
232        python_lines.append(python_indent + "print(file=OUT_STREAM)")
233        blank_lines -= 1
234
235    exec_globals = dict(variables)
236    output_stream = io.StringIO()
237    exec_globals["OUT_STREAM"] = output_stream
238
239    python_bytecode = compile("\n".join(python_lines), input_path, "exec")
240    exec(python_bytecode, exec_globals)
241
242    return output_stream.getvalue()
243
244
245class SPVGenerator:
246    def __init__(
247        self,
248        src_dir_paths: str | list[str],
249        env: dict[Any, Any],
250        glslc_path: str | None,
251    ) -> None:
252        if isinstance(src_dir_paths, str):
253            self.src_dir_paths = [src_dir_paths]
254        else:
255            self.src_dir_paths = src_dir_paths
256
257        self.env = env
258        self.glslc_path = glslc_path
259
260        self.glsl_src_files: dict[str, str] = {}
261        self.template_yaml_files: list[str] = []
262
263        self.addSrcAndYamlFiles(self.src_dir_paths)
264        self.shader_template_params: dict[Any, Any] = {}
265        for yaml_file in self.template_yaml_files:
266            self.parseTemplateYaml(yaml_file)
267
268        self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {}
269        self.constructOutputMap()
270
271    def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None:
272        for src_path in src_dir_paths:
273            # Collect glsl source files
274            glsl_files = glob.glob(
275                os.path.join(src_path, "**", "*.glsl*"), recursive=True
276            )
277            for file in glsl_files:
278                if len(file) > 1:
279                    self.glsl_src_files[extract_filename(file, keep_ext=False)] = file
280            # Collect template yaml files
281            yaml_files = glob.glob(
282                os.path.join(src_path, "**", "*.yaml"), recursive=True
283            )
284            for file in yaml_files:
285                if len(file) > 1:
286                    self.template_yaml_files.append(file)
287
288    def generateVariantCombinations(
289        self,
290        iterated_params: dict[str, Any],
291        exclude_params: set[str] | None = None,
292    ) -> list[Any]:
293        if exclude_params is None:
294            exclude_params = set()
295        all_iterated_params = []
296        for param_name, value_list in iterated_params.items():
297            if param_name not in exclude_params:
298                param_values = []
299                for value in value_list:
300                    suffix = value.get("SUFFIX", value["VALUE"])
301                    param_values.append((param_name, suffix, value["VALUE"]))
302                all_iterated_params.append(param_values)
303
304        return list(product(*all_iterated_params))
305
306    def parseTemplateYaml(self, yaml_file: str) -> None:
307        with open(yaml_file) as f:
308            contents = yaml.load(f, Loader=UniqueKeyLoader)
309            for template_name, params_dict in contents.items():
310                if template_name in self.shader_template_params:
311                    raise KeyError(f"{template_name} params file is defined twice")
312
313                default_params = params_dict["parameter_names_with_default_values"]
314                params_names = set(default_params.keys()).union({"NAME"})
315
316                self.shader_template_params[template_name] = []
317
318                default_iterated_params = params_dict.get(
319                    "generate_variant_forall", None
320                )
321
322                for variant in params_dict["shader_variants"]:
323                    variant_params_names = set(variant.keys())
324                    invalid_keys = (
325                        variant_params_names
326                        - params_names
327                        - {"generate_variant_forall"}
328                    )
329                    assert len(invalid_keys) == 0
330
331                    iterated_params = variant.get(
332                        "generate_variant_forall", default_iterated_params
333                    )
334
335                    if iterated_params is not None:
336                        variant_combinations = self.generateVariantCombinations(
337                            iterated_params, variant_params_names
338                        )
339
340                        for combination in variant_combinations:
341                            default_params_copy = copy.deepcopy(default_params)
342                            for key in variant:
343                                if key != "generate_variant_forall":
344                                    default_params_copy[key] = variant[key]
345
346                            variant_name = variant["NAME"]
347                            for param_value in combination:
348                                default_params_copy[param_value[0]] = param_value[2]
349                                if len(param_value[1]) > 0:
350                                    variant_name = f"{variant_name}_{param_value[1]}"
351
352                            default_params_copy["NAME"] = variant_name
353
354                            self.shader_template_params[template_name].append(
355                                default_params_copy
356                            )
357                    else:
358                        default_params_copy = copy.deepcopy(default_params)
359                        for key in variant:
360                            default_params_copy[key] = variant[key]
361
362                        self.shader_template_params[template_name].append(
363                            default_params_copy
364                        )
365
366    def create_shader_params(
367        self, variant_params: dict[str, Any] | None = None
368    ) -> dict[str, str]:
369        if variant_params is None:
370            variant_params = {}
371        shader_params = copy.deepcopy(self.env)
372        for key, value in variant_params.items():
373            shader_params[key] = value
374
375        shader_dtype = shader_params.get("DTYPE", "float")
376
377        if shader_dtype == "int":
378            shader_params["FORMAT"] = self.env["INT_IMAGE_FORMAT"]
379        elif shader_dtype == "uint":
380            shader_params["FORMAT"] = self.env["UINT_IMAGE_FORMAT"]
381        elif shader_dtype == "int32":
382            shader_params["FORMAT"] = "rgba32i"
383        elif shader_dtype == "uint32":
384            shader_params["FORMAT"] = "rgba32ui"
385        elif shader_dtype == "int8":
386            shader_params["FORMAT"] = "rgba8i"
387        elif shader_dtype == "uint8":
388            shader_params["FORMAT"] = "rgba8ui"
389        elif shader_dtype == "float32":
390            shader_params["FORMAT"] = "rgba32f"
391        # Assume float by default
392        else:
393            shader_params["FORMAT"] = self.env["FLOAT_IMAGE_FORMAT"]
394
395        return shader_params
396
397    def constructOutputMap(self) -> None:
398        for shader_name, params in self.shader_template_params.items():
399            for variant in params:
400                source_glsl = self.glsl_src_files[shader_name]
401
402                self.output_shader_map[variant["NAME"]] = (
403                    source_glsl,
404                    self.create_shader_params(variant),
405                )
406
407        for shader_name, source_glsl in self.glsl_src_files.items():
408            if shader_name not in self.shader_template_params:
409                self.output_shader_map[shader_name] = (
410                    source_glsl,
411                    self.create_shader_params(),
412                )
413
414    def generateSPV(self, output_dir: str) -> dict[str, str]:
415        output_file_map = {}
416        for shader_name in self.output_shader_map:
417            source_glsl = self.output_shader_map[shader_name][0]
418            shader_params = self.output_shader_map[shader_name][1]
419
420            with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
421                input_text = input_file.read()
422                output_text = preprocess(input_text, shader_params)
423
424            glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
425            with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file:
426                output_file.write(output_text)
427
428            # If no GLSL compiler is specified, then only write out the generated GLSL shaders.
429            # This is mainly for testing purposes.
430            if self.glslc_path is not None:
431                spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")
432
433                cmd = [
434                    self.glslc_path,
435                    "-fshader-stage=compute",
436                    glsl_out_path,
437                    "-o",
438                    spv_out_path,
439                    "--target-env=vulkan1.0",
440                    "-Werror",
441                ] + [
442                    arg
443                    for src_dir_path in self.src_dir_paths
444                    for arg in ["-I", src_dir_path]
445                ]
446
447                print("glslc cmd:", cmd)
448                subprocess.check_call(cmd)
449
450                output_file_map[spv_out_path] = glsl_out_path
451
452        return output_file_map
453
454
455##############################################
456#  Shader Info and Shader Registry Handling  #
457##############################################
458
459
460@dataclass
461class ShaderInfo:
462    tile_size: list[int]
463    layouts: list[str]
464    weight_storage_type: str = ""
465    bias_storage_type: str = ""
466    register_for: tuple[str, list[str]] | None = None
467
468
469def getName(filePath: str) -> str:
470    return os.path.basename(filePath).replace("/", "_").replace(".", "_")
471
472
473def isDescriptorLine(lineStr: str) -> bool:
474    descriptorLineId = r"^layout\(set"
475    return re.search(descriptorLineId, lineStr) is not None
476
477
478def isTileSizeLine(lineStr: str) -> bool:
479    tile_size_id = r"^ \* TILE_SIZE = \("
480    return re.search(tile_size_id, lineStr) is not None
481
482
483def findTileSizes(lineStr: str) -> list[int]:
484    tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
485    matches = re.search(tile_size_id, lineStr)
486    if matches is None:
487        raise AssertionError("matches is None in findTileSizes")
488    return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
489
490
491def isWeightStorageTypeLine(lineStr: str) -> bool:
492    weight_storage_id = r"^ \* WEIGHT_STORAGE = "
493    return re.search(weight_storage_id, lineStr) is not None
494
495
496def getWeightStorageType(lineStr: str) -> str:
497    weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
498    matches = re.search(weight_storage_id, lineStr)
499    if matches is None:
500        raise AssertionError("matches is None in getWeightStorageType")
501    return matches.group(1)
502
503
504def isBiasStorageTypeLine(lineStr: str) -> bool:
505    weight_storage_id = r"^ \* BIAS_STORAGE = "
506    return re.search(weight_storage_id, lineStr) is not None
507
508
509def getBiasStorageType(lineStr: str) -> str:
510    weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
511    matches = re.search(weight_storage_id, lineStr)
512    if matches is None:
513        raise AssertionError("matches is None in getBiasStorageType")
514    return matches.group(1)
515
516
517def isRegisterForLine(lineStr: str) -> bool:
518    # Check for Shader Name and a list of at least one Registry Key
519    register_for_id = (
520        r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
521    )
522    return re.search(register_for_id, lineStr) is not None
523
524
525def findRegisterFor(lineStr: str) -> tuple[str, list[str]]:
526    register_for_pattern = r"'([A-Za-z0-9_]+)'"
527    matches = re.findall(register_for_pattern, lineStr)
528    if matches is None:
529        raise AssertionError("matches is None in getBiasStorageType")
530    matches_list = list(matches)
531    return (matches_list[0], matches_list[1:])
532
533
534typeIdMapping = {
535    r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
536    r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
537    r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
538    r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER",
539}
540
541storageTypeToEnum = {
542    "TEXTURE_2D": "api::StorageType::TEXTURE_2D",
543    "TEXTURE_3D": "api::StorageType::TEXTURE_3D",
544    "BUFFER": "api::StorageType::BUFFER",
545    "": "api::StorageType::UNKNOWN",
546}
547
548
549def determineDescriptorType(lineStr: str) -> str:
550    for identifier, typeNum in typeIdMapping.items():
551        if re.search(identifier, lineStr):
552            return typeNum
553    raise AssertionError(
554        "No matching descriptor type for " + lineStr + " in determineDescriptorType"
555    )
556
557
558def getShaderInfo(srcFilePath: str) -> ShaderInfo:
559    shader_info = ShaderInfo([], [], "")
560    with open(srcFilePath) as srcFile:
561        for line in srcFile:
562            if isDescriptorLine(line):
563                shader_info.layouts.append(determineDescriptorType(line))
564            if isTileSizeLine(line):
565                shader_info.tile_size = findTileSizes(line)
566            if isWeightStorageTypeLine(line):
567                shader_info.weight_storage_type = getWeightStorageType(line)
568            if isBiasStorageTypeLine(line):
569                shader_info.bias_storage_type = getBiasStorageType(line)
570            if isRegisterForLine(line):
571                shader_info.register_for = findRegisterFor(line)
572
573    return shader_info
574
575
576##########################
577#  C++ File Generation  #
578#########################
579
580cpp_template = """
581#include <ATen/native/vulkan/api/ShaderRegistry.h>
582#include <stdint.h>
583#include <vector>
584
585using namespace at::native::vulkan;
586
587namespace at {{
588namespace native {{
589namespace vulkan {{
590
591namespace {{
592
593{spv_bin_arrays}
594
595}}
596
597static void register_fn() {{
598
599{register_shader_infos}
600
601{shader_info_registry}
602
603}}
604
605static const api::ShaderRegisterInit register_shaders(&register_fn);
606
607}}
608}}
609}}
610
611"""
612
613
614def generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]:
615    with open(spvPath, "rb") as fr:
616        next_bin = array.array("I", fr.read())
617        sizeBytes = 4 * len(next_bin)
618        spv_bin_str = "const uint32_t {}_bin[] = {{\n{}\n}};".format(
619            name,
620            textwrap.indent(",\n".join(str(x) for x in next_bin), "  "),
621        )
622
623    return sizeBytes, spv_bin_str
624
625
626def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> str:
627    tile_size = (
628        f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
629        if (len(shader_info.tile_size) > 0)
630        else "std::vector<uint32_t>()"
631    )
632
633    shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
634
635    shader_info_args = [
636        f'"{name}"',
637        f"{name}_bin",
638        str(sizeBytes),
639        shader_info_layouts,
640        tile_size,
641        storageTypeToEnum[shader_info.weight_storage_type],
642        storageTypeToEnum[shader_info.bias_storage_type],
643    ]
644
645    shader_info_str = textwrap.indent(
646        "api::shader_registry().register_shader(\n  api::ShaderInfo(\n{args}));\n".format(
647            args=textwrap.indent(",\n".join(shader_info_args), "     "),
648        ),
649        "    ",
650    )
651
652    return shader_info_str
653
654
655def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
656    if shader_info.register_for is None:
657        return ""
658
659    (op_name, registry_keys) = shader_info.register_for
660    for registry_key in registry_keys:
661        shader_dispatch_str = textwrap.indent(
662            f'api::shader_registry().register_op_dispatch("{op_name}", api::DispatchKey::{registry_key.upper()}, "{name}");',
663            "    ",
664        )
665
666    return shader_dispatch_str
667
668
669def genCppFiles(
670    spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str
671) -> None:
672    spv_bin_strs = []
673    register_shader_info_strs = []
674    shader_registry_strs = []
675
676    for spvPath, srcPath in spv_files.items():
677        name = getName(spvPath).replace("_spv", "")
678
679        sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
680        spv_bin_strs.append(spv_bin_str)
681
682        shader_info = getShaderInfo(srcPath)
683
684        register_shader_info_strs.append(
685            generateShaderInfoStr(shader_info, name, sizeBytes)
686        )
687
688        if shader_info.register_for is not None:
689            shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
690
691    spv_bin_arrays = "\n".join(spv_bin_strs)
692    register_shader_infos = "\n".join(register_shader_info_strs)
693    shader_info_registry = "\n".join(shader_registry_strs)
694
695    cpp = cpp_template.format(
696        spv_bin_arrays=spv_bin_arrays,
697        register_shader_infos=register_shader_infos,
698        shader_info_registry=shader_info_registry,
699    )
700
701    with open(cpp_src_file_path, "w") as fw:
702        fw.write(cpp)
703
704
705##########
706#  Main  #
707##########
708
709
710def parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]:
711    d = {}
712    if items:
713        for item in items:
714            tokens = item.split("=")
715            key = tokens[0].strip()
716            value = tokens[1].strip()
717            d[key] = value
718    return d
719
720
721def main(argv: list[str]) -> int:
722    parser = argparse.ArgumentParser(description="")
723    parser.add_argument(
724        "-i",
725        "--glsl-paths",
726        nargs="+",
727        help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
728        default=["."],
729    )
730    parser.add_argument("-c", "--glslc-path", required=True, help="")
731    parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
732    parser.add_argument("-o", "--output-path", required=True, help="")
733    parser.add_argument(
734        "--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
735    )
736    options = parser.parse_args()
737
738    DEFAULT_ENV.update(TYPES_ENV)
739    DEFAULT_ENV.update(FUNCS_ENV)
740    env = DEFAULT_ENV
741
742    for key, value in parse_arg_env(options.env).items():
743        env[key] = value
744
745    if not os.path.exists(options.output_path):
746        os.makedirs(options.output_path)
747
748    if not os.path.exists(options.tmp_dir_path):
749        os.makedirs(options.tmp_dir_path)
750
751    shader_generator = SPVGenerator(options.glsl_paths, env, options.glslc_path)
752    output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
753
754    genCppFiles(
755        output_spv_files,
756        f"{options.output_path}/{CPP_H_NAME}",
757        f"{options.output_path}/{CPP_SRC_NAME}",
758    )
759
760    return 0
761
762
763def invoke_main() -> None:
764    sys.exit(main(sys.argv))
765
766
767if __name__ == "__main__":
768    invoke_main()  # pragma: no cover
769