xref: /aosp_15_r20/external/mesa3d/src/gfxstream/codegen/scripts/cereal/common/codegen.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1# Copyright 2023 Google LLC
2# SPDX-License-Identifier: MIT
3from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI
4from collections import OrderedDict
5from copy import copy
6from pathlib import Path, PurePosixPath
7
8import os
9import sys
10import shutil
11import subprocess
12import re
13
14# Class capturing a single file
15
16
17class SingleFileModule(object):
18    def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False):
19        self.directory = directory
20        self.basename = basename
21        self.customAbsDir = customAbsDir
22        self.suffix = suffix
23        self.file = None
24
25        self.preamble = ""
26        self.postamble = ""
27
28        self.suppress = suppress
29
30    def begin(self, globalDir):
31        if self.suppress:
32            return
33
34        # Create subdirectory, if needed
35        if self.customAbsDir:
36            absDir = self.customAbsDir
37        else:
38            absDir = os.path.join(globalDir, self.directory)
39
40        filename = os.path.join(absDir, self.basename)
41
42        self.file = open(filename + self.suffix, "w", encoding="utf-8")
43        self.file.write(self.preamble)
44
45    def append(self, toAppend):
46        if self.suppress:
47            return
48
49        self.file.write(toAppend)
50
51    def end(self):
52        if self.suppress:
53            return
54
55        self.file.write(self.postamble)
56        self.file.close()
57
58    def getMakefileSrcEntry(self):
59        return ""
60
61    def getCMakeSrcEntry(self):
62        return ""
63
64# Class capturing a .cpp file and a .h file (a "C++ module")
65
66
67class Module(object):
68
69    def __init__(
70            self, directory, basename, customAbsDir=None, suppress=False, implOnly=False,
71            headerOnly=False, suppressFeatureGuards=False):
72        self._headerFileModule = SingleFileModule(
73            ".h", directory, basename, customAbsDir, suppress or implOnly)
74        self._implFileModule = SingleFileModule(
75            ".cpp", directory, basename, customAbsDir, suppress or headerOnly)
76
77        self._headerOnly = headerOnly
78        self._implOnly = implOnly
79
80        self.directory = directory
81        self.basename = basename
82        self._customAbsDir = customAbsDir
83
84        self.suppressFeatureGuards = suppressFeatureGuards
85
86    @property
87    def suppress(self):
88        raise AttributeError("suppress is write only")
89
90    @suppress.setter
91    def suppress(self, value: bool):
92        self._headerFileModule.suppress = self._implOnly or value
93        self._implFileModule.suppress = self._headerOnly or value
94
95    @property
96    def headerPreamble(self) -> str:
97        return self._headerFileModule.preamble
98
99    @headerPreamble.setter
100    def headerPreamble(self, value: str):
101        self._headerFileModule.preamble = value
102
103    @property
104    def headerPostamble(self) -> str:
105        return self._headerFileModule.postamble
106
107    @headerPostamble.setter
108    def headerPostamble(self, value: str):
109        self._headerFileModule.postamble = value
110
111    @property
112    def implPreamble(self) -> str:
113        return self._implFileModule.preamble
114
115    @implPreamble.setter
116    def implPreamble(self, value: str):
117        self._implFileModule.preamble = value
118
119    @property
120    def implPostamble(self) -> str:
121        return self._implFileModule.postamble
122
123    @implPostamble.setter
124    def implPostamble(self, value: str):
125        self._implFileModule.postamble = value
126
127    def getMakefileSrcEntry(self):
128        if self._customAbsDir:
129            return self.basename + ".cpp \\\n"
130        dirName = self.directory
131        baseName = self.basename
132        joined = os.path.join(dirName, baseName)
133        return "    " + joined + ".cpp \\\n"
134
135    def getCMakeSrcEntry(self):
136        if self._customAbsDir:
137            return "\n" + self.basename + ".cpp "
138        dirName = Path(self.directory)
139        baseName = Path(self.basename)
140        joined = PurePosixPath(dirName / baseName)
141        return "\n    " + str(joined) + ".cpp "
142
143    def begin(self, globalDir):
144        self._headerFileModule.begin(globalDir)
145        self._implFileModule.begin(globalDir)
146
147    def appendHeader(self, toAppend):
148        self._headerFileModule.append(toAppend)
149
150    def appendImpl(self, toAppend):
151        self._implFileModule.append(toAppend)
152
153    def end(self):
154        self._headerFileModule.end()
155        self._implFileModule.end()
156
157        # Removes empty ifdef blocks with a regex query over the file
158        # which are mainly introduced by extensions with no functions or variables
159        def remove_empty_ifdefs(filename: Path):
160            """Removes empty #ifdef blocks from a C++ file."""
161
162            # Load file contents
163            with open(filename, 'r') as file:
164                content = file.read()
165
166            # Regular Expression Pattern
167            pattern = r"#ifdef\s+(\w+)\s*(?://.*)?\s*\n\s*#endif\s*(?://.*)?\s*"
168
169            # Replace Empty Blocks
170            modified_content = re.sub(pattern, "", content)
171
172            # Save file back
173            with open(filename, 'w') as file:
174                file.write(modified_content)
175
176        clang_format_command = shutil.which('clang-format')
177
178        def formatFile(filename: Path):
179            if "GFXSTREAM_NO_CLANG_FMT" in os.environ:
180                return
181            assert (clang_format_command is not None)
182            assert (subprocess.call([clang_format_command, "-i",
183                    "--style=file", str(filename.resolve())]) == 0)
184
185        if not self._headerFileModule.suppress:
186            filename = Path(self._headerFileModule.file.name)
187            remove_empty_ifdefs(filename)
188            formatFile(filename)
189
190        if not self._implFileModule.suppress:
191            filename = Path(self._implFileModule.file.name)
192            remove_empty_ifdefs(filename)
193            formatFile(filename)
194
195
196class PyScript(SingleFileModule):
197    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
198        super().__init__(".py", directory, basename, customAbsDir, suppress)
199
200
201# Class capturing a .proto protobuf definition file
202class Proto(SingleFileModule):
203
204    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
205        super().__init__(".proto", directory, basename, customAbsDir, suppress)
206
207    def getMakefileSrcEntry(self):
208        super().getMakefileSrcEntry()
209        if self.customAbsDir:
210            return self.basename + ".proto \\\n"
211        dirName = self.directory
212        baseName = self.basename
213        joined = os.path.join(dirName, baseName)
214        return "    " + joined + ".proto \\\n"
215
216    def getCMakeSrcEntry(self):
217        super().getCMakeSrcEntry()
218        if self.customAbsDir:
219            return "\n" + self.basename + ".proto "
220
221        dirName = self.directory
222        baseName = self.basename
223        joined = os.path.join(dirName, baseName)
224        return "\n    " + joined + ".proto "
225
226class CodeGen(object):
227
228    def __init__(self,):
229        self.code = ""
230        self.indentLevel = 0
231        self.gensymCounter = [-1]
232
233    def var(self, prefix="cgen_var"):
234        self.gensymCounter[-1] += 1
235        res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0))
236        return res
237
238    def swapCode(self,):
239        res = "%s" % self.code
240        self.code = ""
241        return res
242
243    def indent(self,extra=0):
244        return "".join("    " * (self.indentLevel + extra))
245
246    def incrIndent(self,):
247        self.indentLevel += 1
248
249    def decrIndent(self,):
250        if self.indentLevel > 0:
251            self.indentLevel -= 1
252
253    def beginBlock(self, bracketPrint=True):
254        if bracketPrint:
255            self.code += self.indent() + "{\n"
256        self.indentLevel += 1
257        self.gensymCounter.append(-1)
258
259    def endBlock(self,bracketPrint=True):
260        self.indentLevel -= 1
261        if bracketPrint:
262            self.code += self.indent() + "}\n"
263        del self.gensymCounter[-1]
264
265    def beginIf(self, cond):
266        self.code += self.indent() + "if (" + cond + ")\n"
267        self.beginBlock()
268
269    def beginElse(self, cond = None):
270        if cond is not None:
271            self.code += \
272                self.indent() + \
273                "else if (" + cond + ")\n"
274        else:
275            self.code += self.indent() + "else\n"
276        self.beginBlock()
277
278    def endElse(self):
279        self.endBlock()
280
281    def endIf(self):
282        self.endBlock()
283
284    def beginSwitch(self, switchvar):
285        self.code += self.indent() + "switch (" + switchvar + ")\n"
286        self.beginBlock()
287
288    def switchCase(self, switchval, blocked = False):
289        self.code += self.indent() + "case %s:" % switchval
290        self.beginBlock(bracketPrint = blocked)
291
292    def switchCaseBreak(self, switchval, blocked = False):
293        self.code += self.indent() + "case %s:" % switchval
294        self.endBlock(bracketPrint = blocked)
295
296    def switchCaseDefault(self, blocked = False):
297        self.code += self.indent() + "default:" % switchval
298        self.beginBlock(bracketPrint = blocked)
299
300    def endSwitch(self):
301        self.endBlock()
302
303    def beginWhile(self, cond):
304        self.code += self.indent() + "while (" + cond + ")\n"
305        self.beginBlock()
306
307    def endWhile(self):
308        self.endBlock()
309
310    def beginFor(self, initial, condition, increment):
311        self.code += \
312            self.indent() + "for (" + \
313            "; ".join([initial, condition, increment]) + \
314            ")\n"
315        self.beginBlock()
316
317    def endFor(self):
318        self.endBlock()
319
320    def beginLoop(self, loopVarType, loopVar, loopInit, loopBound):
321        self.beginFor(
322            "%s %s = %s" % (loopVarType, loopVar, loopInit),
323            "%s < %s" % (loopVar, loopBound),
324            "++%s" % (loopVar))
325
326    def endLoop(self):
327        self.endBlock()
328
329    def stmt(self, code):
330        self.code += self.indent() + code + ";\n"
331
332    def line(self, code):
333        self.code += self.indent() + code + "\n"
334
335    def leftline(self, code):
336        self.code += code + "\n"
337
338    def makeCallExpr(self, funcName, parameters):
339        return funcName + "(%s)" % (", ".join(parameters))
340
341    def funcCall(self, lhs, funcName, parameters):
342        res = self.indent()
343
344        if lhs is not None:
345            res += lhs + " = "
346
347        res += self.makeCallExpr(funcName, parameters) + ";\n"
348        self.code += res
349
350    def funcCallRet(self, _lhs, funcName, parameters):
351        res = self.indent()
352        res += "return " + self.makeCallExpr(funcName, parameters) + ";\n"
353        self.code += res
354
355    # Given a VulkanType object, generate a C type declaration
356    # with optional parameter name:
357    # [const] [typename][*][const*] [paramName]
358    def makeCTypeDecl(self, vulkanType, useParamName=True):
359        constness = "const " if vulkanType.isConst else ""
360        typeName = vulkanType.typeName
361
362        if vulkanType.pointerIndirectionLevels == 0:
363            ptrSpec = ""
364        elif vulkanType.isPointerToConstPointer:
365            ptrSpec = "* const*" if vulkanType.isConst else "**"
366            if vulkanType.pointerIndirectionLevels > 2:
367                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
368        else:
369            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
370
371        if useParamName and (vulkanType.paramName is not None):
372            paramStr = (" " + vulkanType.paramName)
373        else:
374            paramStr = ""
375
376        return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr)
377
378    def makeRichCTypeDecl(self, vulkanType, useParamName=True):
379        constness = "const " if vulkanType.isConst else ""
380        typeName = vulkanType.typeName
381
382        if vulkanType.pointerIndirectionLevels == 0:
383            ptrSpec = ""
384        elif vulkanType.isPointerToConstPointer:
385            ptrSpec = "* const*" if vulkanType.isConst else "**"
386            if vulkanType.pointerIndirectionLevels > 2:
387                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
388        else:
389            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
390
391        if useParamName and (vulkanType.paramName is not None):
392            paramStr = (" " + vulkanType.paramName)
393        else:
394            paramStr = ""
395
396        if vulkanType.staticArrExpr:
397            staticArrInfo = "[%s]" % vulkanType.staticArrExpr
398        else:
399            staticArrInfo = ""
400
401        return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo)
402
403    # Given a VulkanAPI object, generate the C function protype:
404    # <returntype> <funcname>(<parameters>)
405    def makeFuncProto(self, vulkanApi, useParamName=True):
406
407        protoBegin = "%s %s" % (self.makeCTypeDecl(
408            vulkanApi.retType, useParamName=False), vulkanApi.name)
409
410        def getFuncArgDecl(param):
411            if param.staticArrExpr:
412                return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr)
413            else:
414                return self.makeCTypeDecl(param, useParamName=useParamName)
415
416        protoParams = "(\n    %s)" % ((",\n%s" % self.indent(1)).join(
417            list(map(
418                getFuncArgDecl,
419                vulkanApi.parameters))))
420
421        return protoBegin + protoParams
422
423    def makeFuncAlias(self, nameDst, nameSrc):
424        return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst)
425
426    def makeFuncDecl(self, vulkanApi):
427        return self.makeFuncProto(vulkanApi) + ";\n\n"
428
429    def makeFuncImpl(self, vulkanApi, codegenFunc):
430        self.swapCode()
431
432        self.line(self.makeFuncProto(vulkanApi))
433        self.beginBlock()
434        codegenFunc(self)
435        self.endBlock()
436
437        return self.swapCode() + "\n"
438
439    def emitFuncImpl(self, vulkanApi, codegenFunc):
440        self.line(self.makeFuncProto(vulkanApi))
441        self.beginBlock()
442        codegenFunc(self)
443        self.endBlock()
444
445    def makeStructAccess(self,
446                         vulkanType,
447                         structVarName,
448                         asPtr=True,
449                         structAsPtr=True,
450                         accessIndex=None):
451
452        deref = "->" if structAsPtr else "."
453
454        indexExpr = (" + %s" % accessIndex) if accessIndex else ""
455
456        addrOfExpr = "" if vulkanType.accessibleAsPointer() or (
457            not asPtr) else "&"
458
459        return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref,
460                               vulkanType.paramName, indexExpr)
461
462    def makeRawLengthAccess(self, vulkanType):
463        lenExpr = vulkanType.getLengthExpression()
464
465        if not lenExpr:
466            return None, None
467
468        if lenExpr == "null-terminated":
469            return "strlen(%s)" % vulkanType.paramName, None
470
471        return lenExpr, None
472
473    def makeLengthAccessFromStruct(self,
474                                   structInfo,
475                                   vulkanType,
476                                   structVarName,
477                                   asPtr=True):
478        # Handle special cases first
479        # Mostly when latexmath is involved
480        def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr):
481            cases = [
482                {
483                    "structName": "VkShaderModuleCreateInfo",
484                    "field": "pCode",
485                    "lenExprMember": "codeSize",
486                    "postprocess": lambda expr: "(%s / 4)" % expr
487                },
488                {
489                    "structName": "VkPipelineMultisampleStateCreateInfo",
490                    "field": "pSampleMask",
491                    "lenExprMember": "rasterizationSamples",
492                    "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
493                },
494                {
495                    "structName": "VkAccelerationStructureVersionInfoKHR",
496                    "field": "pVersionData",
497                    "lenExprMember": "",
498                    "postprocess": lambda _: "2*VK_UUID_SIZE"
499                },
500            ]
501
502            for c in cases:
503                if (structInfo.name, vulkanType.paramName) == (c["structName"],
504                                                               c["field"]):
505                    deref = "->" if asPtr else "."
506                    expr = "%s%s%s" % (structVarName, deref,
507                                       c["lenExprMember"])
508                    lenAccessGuardExpr = "%s" % structVarName
509                    return c["postprocess"](expr), lenAccessGuardExpr
510
511            return None, None
512
513        specialCaseAccess = \
514            handleSpecialCases(
515                structInfo, vulkanType, structVarName, asPtr)
516
517        if specialCaseAccess != (None, None):
518            return specialCaseAccess
519
520        lenExpr = vulkanType.getLengthExpression()
521
522        if not lenExpr:
523            return None, None
524
525        deref = "->" if asPtr else "."
526        lenAccessGuardExpr = "%s" % (
527
528            structVarName) if deref else None
529        if lenExpr == "null-terminated":
530            return "strlen(%s%s%s)" % (structVarName, deref,
531                                       vulkanType.paramName), lenAccessGuardExpr
532
533        if not structInfo.getMember(lenExpr):
534            return self.makeRawLengthAccess(vulkanType)
535
536        return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr
537
538    def makeLengthAccessFromApi(self, api, vulkanType):
539        # Handle special cases first
540        # Mostly when :: is involved
541        def handleSpecialCases(vulkanType):
542            lenExpr = vulkanType.getLengthExpression()
543
544            if lenExpr is None:
545                return None, None
546
547            if "::" in lenExpr:
548                structVarName, memberVarName = lenExpr.split("::")
549                lenAccessGuardExpr = "%s" % (structVarName)
550                return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr
551            return None, None
552
553        specialCaseAccess = handleSpecialCases(vulkanType)
554
555        if specialCaseAccess != (None, None):
556            return specialCaseAccess
557
558        lenExpr = vulkanType.getLengthExpression()
559
560        if not lenExpr:
561            return None, None
562
563        lenExprInfo = api.getParameter(lenExpr)
564
565        if not lenExprInfo:
566            return self.makeRawLengthAccess(vulkanType)
567
568        if lenExpr == "null-terminated":
569            return "strlen(%s)" % vulkanType.paramName(), None
570        else:
571            deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else ""
572            lenAccessGuardExpr = "%s" % lenExpr if deref else None
573            return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr
574
575    def accessParameter(self, param, asPtr=True):
576        if asPtr:
577            if param.pointerIndirectionLevels > 0:
578                return param.paramName
579            else:
580                return "&%s" % param.paramName
581        else:
582            return param.paramName
583
584    def sizeofExpr(self, vulkanType):
585        return "sizeof(%s)" % (
586            self.makeCTypeDecl(vulkanType, useParamName=False))
587
588    def generalAccess(self,
589                      vulkanType,
590                      parentVarName=None,
591                      asPtr=True,
592                      structAsPtr=True):
593        if vulkanType.parent is None:
594            if parentVarName is None:
595                return self.accessParameter(vulkanType, asPtr=asPtr)
596            else:
597                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
598
599        if isinstance(vulkanType.parent, VulkanCompoundType):
600            return self.makeStructAccess(
601                vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr)
602
603        if isinstance(vulkanType.parent, VulkanAPI):
604            if parentVarName is None:
605                return self.accessParameter(vulkanType, asPtr=asPtr)
606            else:
607                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
608
609        os.abort("Could not find a way to access Vulkan type %s" %
610                 vulkanType.name)
611
612    def makeLengthAccess(self, vulkanType, parentVarName="parent"):
613        if vulkanType.parent is None:
614            return self.makeRawLengthAccess(vulkanType)
615
616        if isinstance(vulkanType.parent, VulkanCompoundType):
617            return self.makeLengthAccessFromStruct(
618                vulkanType.parent, vulkanType, parentVarName, asPtr=True)
619
620        if isinstance(vulkanType.parent, VulkanAPI):
621            return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType)
622
623        os.abort("Could not find a way to access length of Vulkan type %s" %
624                 vulkanType.name)
625
626    def generalLengthAccess(self, vulkanType, parentVarName="parent"):
627        return self.makeLengthAccess(vulkanType, parentVarName)[0]
628
629    def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"):
630        return self.makeLengthAccess(vulkanType, parentVarName)[1]
631
632    def vkApiCall(self, api, customPrefix="", globalStatePrefix="", customParameters=None, checkForDeviceLost=False, checkForOutOfMemory=False):
633        callLhs = None
634
635        retTypeName = api.getRetTypeExpr()
636        retVar = None
637
638        if retTypeName != "void":
639            retVar = api.getRetVarExpr()
640            self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName))
641            callLhs = retVar
642
643        if customParameters is None:
644            self.funcCall(
645            callLhs, customPrefix + api.name, [p.paramName for p in api.parameters])
646        else:
647            self.funcCall(
648                callLhs, customPrefix + api.name, customParameters)
649
650        if retTypeName == "VkResult" and checkForDeviceLost:
651            self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix))
652
653        if retTypeName == "VkResult" and checkForOutOfMemory:
654            if api.name == "vkAllocateMemory":
655                self.stmt(
656                    "%sCheckOutOfMemory(%s, opcode, context, std::make_optional<uint64_t>(pAllocateInfo->allocationSize))"
657                    % (globalStatePrefix, callLhs))
658            else:
659                self.stmt(
660                    "%sCheckOutOfMemory(%s, opcode, context)"
661                    % (globalStatePrefix, callLhs))
662
663        return (retTypeName, retVar)
664
665    def makeCheckVkSuccess(self, expr):
666        return "((%s) == VK_SUCCESS)" % expr
667
668    def makeReinterpretCast(self, varName, typeName, const=True):
669        return "reinterpret_cast<%s%s*>(%s)" % \
670               ("const " if const else "", typeName, varName)
671
672    def validPrimitive(self, typeInfo, typeName):
673        size = typeInfo.getPrimitiveEncodingSize(typeName)
674        return size != None
675
676    def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"):
677        if not self.validPrimitive(typeInfo, typeName):
678            return None
679
680        size = typeInfo.getPrimitiveEncodingSize(typeName)
681        prefix = "put" if direction == "write" else "get"
682        suffix = None
683        if size == 1:
684            suffix = "Byte"
685        elif size == 2:
686            suffix = "Be16"
687        elif size == 4:
688            suffix = "Be32"
689        elif size == 8:
690            suffix = "Be64"
691
692        if suffix:
693            return prefix + suffix
694
695        return None
696
697    def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"):
698        if not self.validPrimitive(typeInfo, typeName):
699            return None
700
701        size = typeInfo.getPrimitiveEncodingSize(typeName)
702        prefix = "to" if direction == "write" else "from"
703        suffix = None
704        if size == 1:
705            suffix = "Byte"
706        elif size == 2:
707            suffix = "Be16"
708        elif size == 4:
709            suffix = "Be32"
710        elif size == 8:
711            suffix = "Be64"
712
713        if suffix:
714            return prefix + suffix
715
716        return None
717
718    def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
719        accessTypeName = accessType.typeName
720
721        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
722            print("Tried to stream a non-primitive type: %s" % accessTypeName)
723            os.abort()
724
725        needPtrCast = False
726
727        if accessType.pointerIndirectionLevels > 0:
728            streamSize = 8
729            streamStorageVarType = "uint64_t"
730            needPtrCast = True
731            streamMethod = "putBe64" if direction == "write" else "getBe64"
732        else:
733            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
734            if streamSize == 1:
735                streamStorageVarType = "uint8_t"
736            elif streamSize == 2:
737                streamStorageVarType = "uint16_t"
738            elif streamSize == 4:
739                streamStorageVarType = "uint32_t"
740            elif streamSize == 8:
741                streamStorageVarType = "uint64_t"
742            streamMethod = self.makePrimitiveStreamMethod(
743                typeInfo, accessTypeName, direction=direction)
744
745        streamStorageVar = self.var()
746
747        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
748
749        ptrCast = "(uintptr_t)" if needPtrCast else ""
750
751        if direction == "read":
752            self.stmt("%s = (%s)%s%s->%s()" %
753                      (accessExpr,
754                       accessCast,
755                       ptrCast,
756                       streamVar,
757                       streamMethod))
758        else:
759            self.stmt("%s %s = (%s)%s%s" %
760                      (streamStorageVarType, streamStorageVar,
761                       streamStorageVarType, ptrCast, accessExpr))
762            self.stmt("%s->%s(%s)" %
763                      (streamVar, streamMethod, streamStorageVar))
764
765    def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, variant, direction="write"):
766        accessTypeName = accessType.typeName
767
768        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
769            print("Tried to stream a non-primitive type: %s" % accessTypeName)
770            os.abort()
771
772        needPtrCast = False
773
774        streamSize = 8
775
776        if accessType.pointerIndirectionLevels > 0:
777            streamSize = 8
778            streamStorageVarType = "uint64_t"
779            needPtrCast = True
780            streamMethod = "toBe64" if direction == "write" else "fromBe64"
781        else:
782            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
783            if streamSize == 1:
784                streamStorageVarType = "uint8_t"
785            elif streamSize == 2:
786                streamStorageVarType = "uint16_t"
787            elif streamSize == 4:
788                streamStorageVarType = "uint32_t"
789            elif streamSize == 8:
790                streamStorageVarType = "uint64_t"
791            streamMethod = self.makePrimitiveStreamMethodInPlace(
792                typeInfo, accessTypeName, direction=direction)
793
794        streamStorageVar = self.var()
795
796        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
797
798        if direction == "read":
799            accessCast = self.makeRichCTypeDecl(
800                accessType.getForNonConstAccess(), useParamName=False)
801
802        ptrCast = "(uintptr_t)" if needPtrCast else ""
803        if variant == "guest":
804            streamNamespace = "gfxstream::aemu"
805        else:
806            streamNamespace = "android::base"
807
808        if direction == "read":
809            self.stmt("memcpy((%s*)&%s, %s, %s)" %
810                      (accessCast,
811                       accessExpr,
812                       streamVar,
813                       str(streamSize)))
814            self.stmt("%s::Stream::%s((uint8_t*)&%s)" % (
815                streamNamespace,
816                streamMethod,
817                accessExpr))
818        else:
819            self.stmt("%s %s = (%s)%s%s" %
820                      (streamStorageVarType, streamStorageVar,
821                       streamStorageVarType, ptrCast, accessExpr))
822            self.stmt("memcpy(%s, &%s, %s)" %
823                      (streamVar, streamStorageVar, str(streamSize)))
824            self.stmt("%s::Stream::%s((uint8_t*)%s)" % (
825                streamNamespace,
826                streamMethod,
827                streamVar))
828
829    def countPrimitive(self, typeInfo, accessType):
830        accessTypeName = accessType.typeName
831
832        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
833            print("Tried to count a non-primitive type: %s" % accessTypeName)
834            os.abort()
835
836        needPtrCast = False
837
838        if accessType.pointerIndirectionLevels > 0:
839            streamSize = 8
840        else:
841            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
842
843        return streamSize
844
845# Class to wrap a Vulkan API call.
846#
847# The user gives a generic callback, |codegenDef|,
848# that takes a CodeGen object and a VulkanAPI object as arguments.
849# codegenDef uses CodeGen along with the VulkanAPI object
850# to generate the function body.
851class VulkanAPIWrapper(object):
852
853    def __init__(self,
854                 customApiPrefix,
855                 extraParameters=None,
856                 returnTypeOverride=None,
857                 codegenDef=None):
858        self.customApiPrefix = customApiPrefix
859        self.extraParameters = extraParameters
860        self.returnTypeOverride = returnTypeOverride
861
862        self.codegen = CodeGen()
863
864        self.definitionFunc = codegenDef
865
866        # Private function
867
868        def makeApiFunc(self, typeInfo, apiName):
869            customApi = copy(typeInfo.apis[apiName])
870            customApi.name = self.customApiPrefix + customApi.name
871            if self.extraParameters is not None:
872                if isinstance(self.extraParameters, list):
873                    customApi.parameters = \
874                        self.extraParameters + customApi.parameters
875                else:
876                    os.abort(
877                        "Type of extra parameters to custom API not valid. Expected list, got %s" % type(
878                            self.extraParameters))
879
880            if self.returnTypeOverride is not None:
881                customApi.retType = self.returnTypeOverride
882            return customApi
883
884        self.makeApi = makeApiFunc
885
886    def setCodegenDef(self, codegenDefFunc):
887        self.definitionFunc = codegenDefFunc
888
889    def makeDecl(self, typeInfo, apiName):
890        return self.codegen.makeFuncProto(
891            self.makeApi(self, typeInfo, apiName)) + ";\n\n"
892
893    def makeDefinition(self, typeInfo, apiName, isStatic=False):
894        vulkanApi = self.makeApi(self, typeInfo, apiName)
895
896        self.codegen.swapCode()
897        self.codegen.beginBlock()
898
899        if self.definitionFunc is None:
900            print("ERROR: No definition found for (%s, %s)" %
901                  (vulkanApi.name, self.customApiPrefix))
902            sys.exit(1)
903
904        self.definitionFunc(self.codegen, vulkanApi)
905
906        self.codegen.endBlock()
907
908        return ("static " if isStatic else "") + self.codegen.makeFuncProto(
909            vulkanApi) + "\n" + self.codegen.swapCode() + "\n"
910
911# Base class for wrapping all Vulkan API objects.  These work with Vulkan
912# Registry generators and have gen* triggers.  They tend to contain
913# VulkanAPIWrapper objects to make it easier to generate the code.
914class VulkanWrapperGenerator(object):
915
916    def __init__(self, module: Module, typeInfo: VulkanTypeInfo):
917        self.module: Module = module
918        self.typeInfo: VulkanTypeInfo = typeInfo
919        self.extensionStructTypes = OrderedDict()
920
921    def onBegin(self):
922        pass
923
924    def onEnd(self):
925        pass
926
927    def onBeginFeature(self, featureName, featureType):
928        pass
929
930    def onFeatureNewCmd(self, cmdName):
931        pass
932
933    def onEndFeature(self):
934        pass
935
936    def onGenType(self, typeInfo, name, alias):
937        category = self.typeInfo.categoryOf(name)
938        if category in ["struct", "union"] and not alias:
939            structInfo = self.typeInfo.structs[name]
940            if structInfo.structExtendsExpr:
941                self.extensionStructTypes[name] = structInfo
942        pass
943
944    def onGenStruct(self, typeInfo, name, alias):
945        pass
946
947    def onGenGroup(self, groupinfo, groupName, alias=None):
948        pass
949
950    def onGenEnum(self, enuminfo, name, alias):
951        pass
952
953    def onGenCmd(self, cmdinfo, name, alias):
954        pass
955
956    # Below Vulkan structure types may correspond to multiple Vulkan structs
957    # due to a conflict between different Vulkan registries. In order to get
958    # the correct Vulkan struct type, we need to check the type of its "root"
959    # struct as well.
960    ROOT_TYPE_MAPPING = {
961        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": {
962            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
963            "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
964            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE",
965            "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
966        },
967        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": {
968            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
969            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkCreateBlobGOOGLE",
970            "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
971        },
972        "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": {
973            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT",
974            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT",
975            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE",
976            "default": "VkRenderPassFragmentDensityMapCreateInfoEXT",
977        },
978    }
979
980    def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None):
981        def readStructType(structTypeName, structVarName, cgen):
982            cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \
983                (structTypeName, "goldfish_vk_struct_type", structVarName))
984
985        def castAsStruct(varName, typeName, const=True):
986            return "reinterpret_cast<%s%s*>(%s)" % \
987                   ("const " if const else "", typeName, varName)
988
989        def doDefaultReturn(cgen):
990            if retType.typeName == "void":
991                cgen.stmt("return")
992            else:
993                cgen.stmt("return (%s)0" % retType.typeName)
994
995        cgen.beginIf("!%s" % triggerVar.paramName)
996        if nullEmit is None:
997            doDefaultReturn(cgen)
998        else:
999            nullEmit(cgen)
1000        cgen.endIf()
1001
1002        readStructType("structType", triggerVar.paramName, cgen)
1003
1004        cgen.line("switch(structType)")
1005        cgen.beginBlock()
1006
1007        currFeature = None
1008
1009        for ext in self.extensionStructTypes.values():
1010            if not currFeature:
1011                cgen.leftline("#ifdef %s" % ext.feature)
1012                currFeature = ext.feature
1013
1014            if currFeature and ext.feature != currFeature:
1015                cgen.leftline("#endif")
1016                cgen.leftline("#ifdef %s" % ext.feature)
1017                currFeature = ext.feature
1018
1019            enum = ext.structEnumExpr
1020            protect = None
1021            if enum in self.typeInfo.enumElem:
1022                protect = self.typeInfo.enumElem[enum].get("protect", default=None)
1023                if protect is not None:
1024                    cgen.leftline("#ifdef %s" % protect)
1025
1026            cgen.line("case %s:" % enum)
1027            cgen.beginBlock()
1028
1029            if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING:
1030                cgen.line("switch(%s)" % rootTypeVar.paramName)
1031                cgen.beginBlock()
1032                kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum]
1033                for k in kv:
1034                    v = self.extensionStructTypes[kv[k]]
1035                    if k == "default":
1036                        cgen.line("%s:" % k)
1037                    else:
1038                        cgen.line("case %s:" % k)
1039                    cgen.beginBlock()
1040                    castedAccess = castAsStruct(
1041                        triggerVar.paramName, v.name, const=triggerVar.isConst)
1042                    forEachFunc(v, castedAccess, cgen)
1043                    cgen.line("break;")
1044                    cgen.endBlock()
1045                cgen.endBlock()
1046            else:
1047                castedAccess = castAsStruct(
1048                    triggerVar.paramName, ext.name, const=triggerVar.isConst)
1049                forEachFunc(ext, castedAccess, cgen)
1050
1051            if autoBreak:
1052                cgen.stmt("break")
1053            cgen.endBlock()
1054
1055            if protect is not None:
1056                cgen.leftline("#endif // %s" % protect)
1057
1058        if currFeature:
1059            cgen.leftline("#endif")
1060
1061        cgen.line("default:")
1062        cgen.beginBlock()
1063        if defaultEmit is None:
1064            doDefaultReturn(cgen)
1065        else:
1066            defaultEmit(cgen)
1067        cgen.endBlock()
1068
1069        cgen.endBlock()
1070
1071    def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False):
1072        currFeature = None
1073
1074        for (i, ext) in enumerate(self.extensionStructTypes.values()):
1075            if doFeatureIfdefs:
1076                if not currFeature:
1077                    cgen.leftline("#ifdef %s" % ext.feature)
1078                    currFeature = ext.feature
1079
1080                if currFeature and ext.feature != currFeature:
1081                    cgen.leftline("#endif")
1082                    cgen.leftline("#ifdef %s" % ext.feature)
1083                    currFeature = ext.feature
1084
1085            forEachFunc(i, ext, cgen)
1086
1087        if doFeatureIfdefs:
1088            if currFeature:
1089                cgen.leftline("#endif")
1090