1# Copyright 2023 Google LLC 2# SPDX-License-Identifier: MIT 3from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI 4from collections import OrderedDict 5from copy import copy 6from pathlib import Path, PurePosixPath 7 8import os 9import sys 10import shutil 11import subprocess 12import re 13 14# Class capturing a single file 15 16 17class SingleFileModule(object): 18 def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False): 19 self.directory = directory 20 self.basename = basename 21 self.customAbsDir = customAbsDir 22 self.suffix = suffix 23 self.file = None 24 25 self.preamble = "" 26 self.postamble = "" 27 28 self.suppress = suppress 29 30 def begin(self, globalDir): 31 if self.suppress: 32 return 33 34 # Create subdirectory, if needed 35 if self.customAbsDir: 36 absDir = self.customAbsDir 37 else: 38 absDir = os.path.join(globalDir, self.directory) 39 40 filename = os.path.join(absDir, self.basename) 41 42 self.file = open(filename + self.suffix, "w", encoding="utf-8") 43 self.file.write(self.preamble) 44 45 def append(self, toAppend): 46 if self.suppress: 47 return 48 49 self.file.write(toAppend) 50 51 def end(self): 52 if self.suppress: 53 return 54 55 self.file.write(self.postamble) 56 self.file.close() 57 58 def getMakefileSrcEntry(self): 59 return "" 60 61 def getCMakeSrcEntry(self): 62 return "" 63 64# Class capturing a .cpp file and a .h file (a "C++ module") 65 66 67class Module(object): 68 69 def __init__( 70 self, directory, basename, customAbsDir=None, suppress=False, implOnly=False, 71 headerOnly=False, suppressFeatureGuards=False): 72 self._headerFileModule = SingleFileModule( 73 ".h", directory, basename, customAbsDir, suppress or implOnly) 74 self._implFileModule = SingleFileModule( 75 ".cpp", directory, basename, customAbsDir, suppress or headerOnly) 76 77 self._headerOnly = headerOnly 78 self._implOnly = implOnly 79 80 self.directory = directory 81 self.basename = basename 82 self._customAbsDir = customAbsDir 83 84 self.suppressFeatureGuards = suppressFeatureGuards 85 86 @property 87 def suppress(self): 88 raise AttributeError("suppress is write only") 89 90 @suppress.setter 91 def suppress(self, value: bool): 92 self._headerFileModule.suppress = self._implOnly or value 93 self._implFileModule.suppress = self._headerOnly or value 94 95 @property 96 def headerPreamble(self) -> str: 97 return self._headerFileModule.preamble 98 99 @headerPreamble.setter 100 def headerPreamble(self, value: str): 101 self._headerFileModule.preamble = value 102 103 @property 104 def headerPostamble(self) -> str: 105 return self._headerFileModule.postamble 106 107 @headerPostamble.setter 108 def headerPostamble(self, value: str): 109 self._headerFileModule.postamble = value 110 111 @property 112 def implPreamble(self) -> str: 113 return self._implFileModule.preamble 114 115 @implPreamble.setter 116 def implPreamble(self, value: str): 117 self._implFileModule.preamble = value 118 119 @property 120 def implPostamble(self) -> str: 121 return self._implFileModule.postamble 122 123 @implPostamble.setter 124 def implPostamble(self, value: str): 125 self._implFileModule.postamble = value 126 127 def getMakefileSrcEntry(self): 128 if self._customAbsDir: 129 return self.basename + ".cpp \\\n" 130 dirName = self.directory 131 baseName = self.basename 132 joined = os.path.join(dirName, baseName) 133 return " " + joined + ".cpp \\\n" 134 135 def getCMakeSrcEntry(self): 136 if self._customAbsDir: 137 return "\n" + self.basename + ".cpp " 138 dirName = Path(self.directory) 139 baseName = Path(self.basename) 140 joined = PurePosixPath(dirName / baseName) 141 return "\n " + str(joined) + ".cpp " 142 143 def begin(self, globalDir): 144 self._headerFileModule.begin(globalDir) 145 self._implFileModule.begin(globalDir) 146 147 def appendHeader(self, toAppend): 148 self._headerFileModule.append(toAppend) 149 150 def appendImpl(self, toAppend): 151 self._implFileModule.append(toAppend) 152 153 def end(self): 154 self._headerFileModule.end() 155 self._implFileModule.end() 156 157 # Removes empty ifdef blocks with a regex query over the file 158 # which are mainly introduced by extensions with no functions or variables 159 def remove_empty_ifdefs(filename: Path): 160 """Removes empty #ifdef blocks from a C++ file.""" 161 162 # Load file contents 163 with open(filename, 'r') as file: 164 content = file.read() 165 166 # Regular Expression Pattern 167 pattern = r"#ifdef\s+(\w+)\s*(?://.*)?\s*\n\s*#endif\s*(?://.*)?\s*" 168 169 # Replace Empty Blocks 170 modified_content = re.sub(pattern, "", content) 171 172 # Save file back 173 with open(filename, 'w') as file: 174 file.write(modified_content) 175 176 clang_format_command = shutil.which('clang-format') 177 178 def formatFile(filename: Path): 179 if "GFXSTREAM_NO_CLANG_FMT" in os.environ: 180 return 181 assert (clang_format_command is not None) 182 assert (subprocess.call([clang_format_command, "-i", 183 "--style=file", str(filename.resolve())]) == 0) 184 185 if not self._headerFileModule.suppress: 186 filename = Path(self._headerFileModule.file.name) 187 remove_empty_ifdefs(filename) 188 formatFile(filename) 189 190 if not self._implFileModule.suppress: 191 filename = Path(self._implFileModule.file.name) 192 remove_empty_ifdefs(filename) 193 formatFile(filename) 194 195 196class PyScript(SingleFileModule): 197 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 198 super().__init__(".py", directory, basename, customAbsDir, suppress) 199 200 201# Class capturing a .proto protobuf definition file 202class Proto(SingleFileModule): 203 204 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 205 super().__init__(".proto", directory, basename, customAbsDir, suppress) 206 207 def getMakefileSrcEntry(self): 208 super().getMakefileSrcEntry() 209 if self.customAbsDir: 210 return self.basename + ".proto \\\n" 211 dirName = self.directory 212 baseName = self.basename 213 joined = os.path.join(dirName, baseName) 214 return " " + joined + ".proto \\\n" 215 216 def getCMakeSrcEntry(self): 217 super().getCMakeSrcEntry() 218 if self.customAbsDir: 219 return "\n" + self.basename + ".proto " 220 221 dirName = self.directory 222 baseName = self.basename 223 joined = os.path.join(dirName, baseName) 224 return "\n " + joined + ".proto " 225 226class CodeGen(object): 227 228 def __init__(self,): 229 self.code = "" 230 self.indentLevel = 0 231 self.gensymCounter = [-1] 232 233 def var(self, prefix="cgen_var"): 234 self.gensymCounter[-1] += 1 235 res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0)) 236 return res 237 238 def swapCode(self,): 239 res = "%s" % self.code 240 self.code = "" 241 return res 242 243 def indent(self,extra=0): 244 return "".join(" " * (self.indentLevel + extra)) 245 246 def incrIndent(self,): 247 self.indentLevel += 1 248 249 def decrIndent(self,): 250 if self.indentLevel > 0: 251 self.indentLevel -= 1 252 253 def beginBlock(self, bracketPrint=True): 254 if bracketPrint: 255 self.code += self.indent() + "{\n" 256 self.indentLevel += 1 257 self.gensymCounter.append(-1) 258 259 def endBlock(self,bracketPrint=True): 260 self.indentLevel -= 1 261 if bracketPrint: 262 self.code += self.indent() + "}\n" 263 del self.gensymCounter[-1] 264 265 def beginIf(self, cond): 266 self.code += self.indent() + "if (" + cond + ")\n" 267 self.beginBlock() 268 269 def beginElse(self, cond = None): 270 if cond is not None: 271 self.code += \ 272 self.indent() + \ 273 "else if (" + cond + ")\n" 274 else: 275 self.code += self.indent() + "else\n" 276 self.beginBlock() 277 278 def endElse(self): 279 self.endBlock() 280 281 def endIf(self): 282 self.endBlock() 283 284 def beginSwitch(self, switchvar): 285 self.code += self.indent() + "switch (" + switchvar + ")\n" 286 self.beginBlock() 287 288 def switchCase(self, switchval, blocked = False): 289 self.code += self.indent() + "case %s:" % switchval 290 self.beginBlock(bracketPrint = blocked) 291 292 def switchCaseBreak(self, switchval, blocked = False): 293 self.code += self.indent() + "case %s:" % switchval 294 self.endBlock(bracketPrint = blocked) 295 296 def switchCaseDefault(self, blocked = False): 297 self.code += self.indent() + "default:" % switchval 298 self.beginBlock(bracketPrint = blocked) 299 300 def endSwitch(self): 301 self.endBlock() 302 303 def beginWhile(self, cond): 304 self.code += self.indent() + "while (" + cond + ")\n" 305 self.beginBlock() 306 307 def endWhile(self): 308 self.endBlock() 309 310 def beginFor(self, initial, condition, increment): 311 self.code += \ 312 self.indent() + "for (" + \ 313 "; ".join([initial, condition, increment]) + \ 314 ")\n" 315 self.beginBlock() 316 317 def endFor(self): 318 self.endBlock() 319 320 def beginLoop(self, loopVarType, loopVar, loopInit, loopBound): 321 self.beginFor( 322 "%s %s = %s" % (loopVarType, loopVar, loopInit), 323 "%s < %s" % (loopVar, loopBound), 324 "++%s" % (loopVar)) 325 326 def endLoop(self): 327 self.endBlock() 328 329 def stmt(self, code): 330 self.code += self.indent() + code + ";\n" 331 332 def line(self, code): 333 self.code += self.indent() + code + "\n" 334 335 def leftline(self, code): 336 self.code += code + "\n" 337 338 def makeCallExpr(self, funcName, parameters): 339 return funcName + "(%s)" % (", ".join(parameters)) 340 341 def funcCall(self, lhs, funcName, parameters): 342 res = self.indent() 343 344 if lhs is not None: 345 res += lhs + " = " 346 347 res += self.makeCallExpr(funcName, parameters) + ";\n" 348 self.code += res 349 350 def funcCallRet(self, _lhs, funcName, parameters): 351 res = self.indent() 352 res += "return " + self.makeCallExpr(funcName, parameters) + ";\n" 353 self.code += res 354 355 # Given a VulkanType object, generate a C type declaration 356 # with optional parameter name: 357 # [const] [typename][*][const*] [paramName] 358 def makeCTypeDecl(self, vulkanType, useParamName=True): 359 constness = "const " if vulkanType.isConst else "" 360 typeName = vulkanType.typeName 361 362 if vulkanType.pointerIndirectionLevels == 0: 363 ptrSpec = "" 364 elif vulkanType.isPointerToConstPointer: 365 ptrSpec = "* const*" if vulkanType.isConst else "**" 366 if vulkanType.pointerIndirectionLevels > 2: 367 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 368 else: 369 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 370 371 if useParamName and (vulkanType.paramName is not None): 372 paramStr = (" " + vulkanType.paramName) 373 else: 374 paramStr = "" 375 376 return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr) 377 378 def makeRichCTypeDecl(self, vulkanType, useParamName=True): 379 constness = "const " if vulkanType.isConst else "" 380 typeName = vulkanType.typeName 381 382 if vulkanType.pointerIndirectionLevels == 0: 383 ptrSpec = "" 384 elif vulkanType.isPointerToConstPointer: 385 ptrSpec = "* const*" if vulkanType.isConst else "**" 386 if vulkanType.pointerIndirectionLevels > 2: 387 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 388 else: 389 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 390 391 if useParamName and (vulkanType.paramName is not None): 392 paramStr = (" " + vulkanType.paramName) 393 else: 394 paramStr = "" 395 396 if vulkanType.staticArrExpr: 397 staticArrInfo = "[%s]" % vulkanType.staticArrExpr 398 else: 399 staticArrInfo = "" 400 401 return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo) 402 403 # Given a VulkanAPI object, generate the C function protype: 404 # <returntype> <funcname>(<parameters>) 405 def makeFuncProto(self, vulkanApi, useParamName=True): 406 407 protoBegin = "%s %s" % (self.makeCTypeDecl( 408 vulkanApi.retType, useParamName=False), vulkanApi.name) 409 410 def getFuncArgDecl(param): 411 if param.staticArrExpr: 412 return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr) 413 else: 414 return self.makeCTypeDecl(param, useParamName=useParamName) 415 416 protoParams = "(\n %s)" % ((",\n%s" % self.indent(1)).join( 417 list(map( 418 getFuncArgDecl, 419 vulkanApi.parameters)))) 420 421 return protoBegin + protoParams 422 423 def makeFuncAlias(self, nameDst, nameSrc): 424 return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst) 425 426 def makeFuncDecl(self, vulkanApi): 427 return self.makeFuncProto(vulkanApi) + ";\n\n" 428 429 def makeFuncImpl(self, vulkanApi, codegenFunc): 430 self.swapCode() 431 432 self.line(self.makeFuncProto(vulkanApi)) 433 self.beginBlock() 434 codegenFunc(self) 435 self.endBlock() 436 437 return self.swapCode() + "\n" 438 439 def emitFuncImpl(self, vulkanApi, codegenFunc): 440 self.line(self.makeFuncProto(vulkanApi)) 441 self.beginBlock() 442 codegenFunc(self) 443 self.endBlock() 444 445 def makeStructAccess(self, 446 vulkanType, 447 structVarName, 448 asPtr=True, 449 structAsPtr=True, 450 accessIndex=None): 451 452 deref = "->" if structAsPtr else "." 453 454 indexExpr = (" + %s" % accessIndex) if accessIndex else "" 455 456 addrOfExpr = "" if vulkanType.accessibleAsPointer() or ( 457 not asPtr) else "&" 458 459 return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref, 460 vulkanType.paramName, indexExpr) 461 462 def makeRawLengthAccess(self, vulkanType): 463 lenExpr = vulkanType.getLengthExpression() 464 465 if not lenExpr: 466 return None, None 467 468 if lenExpr == "null-terminated": 469 return "strlen(%s)" % vulkanType.paramName, None 470 471 return lenExpr, None 472 473 def makeLengthAccessFromStruct(self, 474 structInfo, 475 vulkanType, 476 structVarName, 477 asPtr=True): 478 # Handle special cases first 479 # Mostly when latexmath is involved 480 def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr): 481 cases = [ 482 { 483 "structName": "VkShaderModuleCreateInfo", 484 "field": "pCode", 485 "lenExprMember": "codeSize", 486 "postprocess": lambda expr: "(%s / 4)" % expr 487 }, 488 { 489 "structName": "VkPipelineMultisampleStateCreateInfo", 490 "field": "pSampleMask", 491 "lenExprMember": "rasterizationSamples", 492 "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr 493 }, 494 { 495 "structName": "VkAccelerationStructureVersionInfoKHR", 496 "field": "pVersionData", 497 "lenExprMember": "", 498 "postprocess": lambda _: "2*VK_UUID_SIZE" 499 }, 500 ] 501 502 for c in cases: 503 if (structInfo.name, vulkanType.paramName) == (c["structName"], 504 c["field"]): 505 deref = "->" if asPtr else "." 506 expr = "%s%s%s" % (structVarName, deref, 507 c["lenExprMember"]) 508 lenAccessGuardExpr = "%s" % structVarName 509 return c["postprocess"](expr), lenAccessGuardExpr 510 511 return None, None 512 513 specialCaseAccess = \ 514 handleSpecialCases( 515 structInfo, vulkanType, structVarName, asPtr) 516 517 if specialCaseAccess != (None, None): 518 return specialCaseAccess 519 520 lenExpr = vulkanType.getLengthExpression() 521 522 if not lenExpr: 523 return None, None 524 525 deref = "->" if asPtr else "." 526 lenAccessGuardExpr = "%s" % ( 527 528 structVarName) if deref else None 529 if lenExpr == "null-terminated": 530 return "strlen(%s%s%s)" % (structVarName, deref, 531 vulkanType.paramName), lenAccessGuardExpr 532 533 if not structInfo.getMember(lenExpr): 534 return self.makeRawLengthAccess(vulkanType) 535 536 return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr 537 538 def makeLengthAccessFromApi(self, api, vulkanType): 539 # Handle special cases first 540 # Mostly when :: is involved 541 def handleSpecialCases(vulkanType): 542 lenExpr = vulkanType.getLengthExpression() 543 544 if lenExpr is None: 545 return None, None 546 547 if "::" in lenExpr: 548 structVarName, memberVarName = lenExpr.split("::") 549 lenAccessGuardExpr = "%s" % (structVarName) 550 return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr 551 return None, None 552 553 specialCaseAccess = handleSpecialCases(vulkanType) 554 555 if specialCaseAccess != (None, None): 556 return specialCaseAccess 557 558 lenExpr = vulkanType.getLengthExpression() 559 560 if not lenExpr: 561 return None, None 562 563 lenExprInfo = api.getParameter(lenExpr) 564 565 if not lenExprInfo: 566 return self.makeRawLengthAccess(vulkanType) 567 568 if lenExpr == "null-terminated": 569 return "strlen(%s)" % vulkanType.paramName(), None 570 else: 571 deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else "" 572 lenAccessGuardExpr = "%s" % lenExpr if deref else None 573 return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr 574 575 def accessParameter(self, param, asPtr=True): 576 if asPtr: 577 if param.pointerIndirectionLevels > 0: 578 return param.paramName 579 else: 580 return "&%s" % param.paramName 581 else: 582 return param.paramName 583 584 def sizeofExpr(self, vulkanType): 585 return "sizeof(%s)" % ( 586 self.makeCTypeDecl(vulkanType, useParamName=False)) 587 588 def generalAccess(self, 589 vulkanType, 590 parentVarName=None, 591 asPtr=True, 592 structAsPtr=True): 593 if vulkanType.parent is None: 594 if parentVarName is None: 595 return self.accessParameter(vulkanType, asPtr=asPtr) 596 else: 597 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 598 599 if isinstance(vulkanType.parent, VulkanCompoundType): 600 return self.makeStructAccess( 601 vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr) 602 603 if isinstance(vulkanType.parent, VulkanAPI): 604 if parentVarName is None: 605 return self.accessParameter(vulkanType, asPtr=asPtr) 606 else: 607 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 608 609 os.abort("Could not find a way to access Vulkan type %s" % 610 vulkanType.name) 611 612 def makeLengthAccess(self, vulkanType, parentVarName="parent"): 613 if vulkanType.parent is None: 614 return self.makeRawLengthAccess(vulkanType) 615 616 if isinstance(vulkanType.parent, VulkanCompoundType): 617 return self.makeLengthAccessFromStruct( 618 vulkanType.parent, vulkanType, parentVarName, asPtr=True) 619 620 if isinstance(vulkanType.parent, VulkanAPI): 621 return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType) 622 623 os.abort("Could not find a way to access length of Vulkan type %s" % 624 vulkanType.name) 625 626 def generalLengthAccess(self, vulkanType, parentVarName="parent"): 627 return self.makeLengthAccess(vulkanType, parentVarName)[0] 628 629 def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"): 630 return self.makeLengthAccess(vulkanType, parentVarName)[1] 631 632 def vkApiCall(self, api, customPrefix="", globalStatePrefix="", customParameters=None, checkForDeviceLost=False, checkForOutOfMemory=False): 633 callLhs = None 634 635 retTypeName = api.getRetTypeExpr() 636 retVar = None 637 638 if retTypeName != "void": 639 retVar = api.getRetVarExpr() 640 self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName)) 641 callLhs = retVar 642 643 if customParameters is None: 644 self.funcCall( 645 callLhs, customPrefix + api.name, [p.paramName for p in api.parameters]) 646 else: 647 self.funcCall( 648 callLhs, customPrefix + api.name, customParameters) 649 650 if retTypeName == "VkResult" and checkForDeviceLost: 651 self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix)) 652 653 if retTypeName == "VkResult" and checkForOutOfMemory: 654 if api.name == "vkAllocateMemory": 655 self.stmt( 656 "%sCheckOutOfMemory(%s, opcode, context, std::make_optional<uint64_t>(pAllocateInfo->allocationSize))" 657 % (globalStatePrefix, callLhs)) 658 else: 659 self.stmt( 660 "%sCheckOutOfMemory(%s, opcode, context)" 661 % (globalStatePrefix, callLhs)) 662 663 return (retTypeName, retVar) 664 665 def makeCheckVkSuccess(self, expr): 666 return "((%s) == VK_SUCCESS)" % expr 667 668 def makeReinterpretCast(self, varName, typeName, const=True): 669 return "reinterpret_cast<%s%s*>(%s)" % \ 670 ("const " if const else "", typeName, varName) 671 672 def validPrimitive(self, typeInfo, typeName): 673 size = typeInfo.getPrimitiveEncodingSize(typeName) 674 return size != None 675 676 def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"): 677 if not self.validPrimitive(typeInfo, typeName): 678 return None 679 680 size = typeInfo.getPrimitiveEncodingSize(typeName) 681 prefix = "put" if direction == "write" else "get" 682 suffix = None 683 if size == 1: 684 suffix = "Byte" 685 elif size == 2: 686 suffix = "Be16" 687 elif size == 4: 688 suffix = "Be32" 689 elif size == 8: 690 suffix = "Be64" 691 692 if suffix: 693 return prefix + suffix 694 695 return None 696 697 def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"): 698 if not self.validPrimitive(typeInfo, typeName): 699 return None 700 701 size = typeInfo.getPrimitiveEncodingSize(typeName) 702 prefix = "to" if direction == "write" else "from" 703 suffix = None 704 if size == 1: 705 suffix = "Byte" 706 elif size == 2: 707 suffix = "Be16" 708 elif size == 4: 709 suffix = "Be32" 710 elif size == 8: 711 suffix = "Be64" 712 713 if suffix: 714 return prefix + suffix 715 716 return None 717 718 def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"): 719 accessTypeName = accessType.typeName 720 721 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 722 print("Tried to stream a non-primitive type: %s" % accessTypeName) 723 os.abort() 724 725 needPtrCast = False 726 727 if accessType.pointerIndirectionLevels > 0: 728 streamSize = 8 729 streamStorageVarType = "uint64_t" 730 needPtrCast = True 731 streamMethod = "putBe64" if direction == "write" else "getBe64" 732 else: 733 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 734 if streamSize == 1: 735 streamStorageVarType = "uint8_t" 736 elif streamSize == 2: 737 streamStorageVarType = "uint16_t" 738 elif streamSize == 4: 739 streamStorageVarType = "uint32_t" 740 elif streamSize == 8: 741 streamStorageVarType = "uint64_t" 742 streamMethod = self.makePrimitiveStreamMethod( 743 typeInfo, accessTypeName, direction=direction) 744 745 streamStorageVar = self.var() 746 747 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 748 749 ptrCast = "(uintptr_t)" if needPtrCast else "" 750 751 if direction == "read": 752 self.stmt("%s = (%s)%s%s->%s()" % 753 (accessExpr, 754 accessCast, 755 ptrCast, 756 streamVar, 757 streamMethod)) 758 else: 759 self.stmt("%s %s = (%s)%s%s" % 760 (streamStorageVarType, streamStorageVar, 761 streamStorageVarType, ptrCast, accessExpr)) 762 self.stmt("%s->%s(%s)" % 763 (streamVar, streamMethod, streamStorageVar)) 764 765 def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, variant, direction="write"): 766 accessTypeName = accessType.typeName 767 768 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 769 print("Tried to stream a non-primitive type: %s" % accessTypeName) 770 os.abort() 771 772 needPtrCast = False 773 774 streamSize = 8 775 776 if accessType.pointerIndirectionLevels > 0: 777 streamSize = 8 778 streamStorageVarType = "uint64_t" 779 needPtrCast = True 780 streamMethod = "toBe64" if direction == "write" else "fromBe64" 781 else: 782 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 783 if streamSize == 1: 784 streamStorageVarType = "uint8_t" 785 elif streamSize == 2: 786 streamStorageVarType = "uint16_t" 787 elif streamSize == 4: 788 streamStorageVarType = "uint32_t" 789 elif streamSize == 8: 790 streamStorageVarType = "uint64_t" 791 streamMethod = self.makePrimitiveStreamMethodInPlace( 792 typeInfo, accessTypeName, direction=direction) 793 794 streamStorageVar = self.var() 795 796 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 797 798 if direction == "read": 799 accessCast = self.makeRichCTypeDecl( 800 accessType.getForNonConstAccess(), useParamName=False) 801 802 ptrCast = "(uintptr_t)" if needPtrCast else "" 803 if variant == "guest": 804 streamNamespace = "gfxstream::aemu" 805 else: 806 streamNamespace = "android::base" 807 808 if direction == "read": 809 self.stmt("memcpy((%s*)&%s, %s, %s)" % 810 (accessCast, 811 accessExpr, 812 streamVar, 813 str(streamSize))) 814 self.stmt("%s::Stream::%s((uint8_t*)&%s)" % ( 815 streamNamespace, 816 streamMethod, 817 accessExpr)) 818 else: 819 self.stmt("%s %s = (%s)%s%s" % 820 (streamStorageVarType, streamStorageVar, 821 streamStorageVarType, ptrCast, accessExpr)) 822 self.stmt("memcpy(%s, &%s, %s)" % 823 (streamVar, streamStorageVar, str(streamSize))) 824 self.stmt("%s::Stream::%s((uint8_t*)%s)" % ( 825 streamNamespace, 826 streamMethod, 827 streamVar)) 828 829 def countPrimitive(self, typeInfo, accessType): 830 accessTypeName = accessType.typeName 831 832 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 833 print("Tried to count a non-primitive type: %s" % accessTypeName) 834 os.abort() 835 836 needPtrCast = False 837 838 if accessType.pointerIndirectionLevels > 0: 839 streamSize = 8 840 else: 841 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 842 843 return streamSize 844 845# Class to wrap a Vulkan API call. 846# 847# The user gives a generic callback, |codegenDef|, 848# that takes a CodeGen object and a VulkanAPI object as arguments. 849# codegenDef uses CodeGen along with the VulkanAPI object 850# to generate the function body. 851class VulkanAPIWrapper(object): 852 853 def __init__(self, 854 customApiPrefix, 855 extraParameters=None, 856 returnTypeOverride=None, 857 codegenDef=None): 858 self.customApiPrefix = customApiPrefix 859 self.extraParameters = extraParameters 860 self.returnTypeOverride = returnTypeOverride 861 862 self.codegen = CodeGen() 863 864 self.definitionFunc = codegenDef 865 866 # Private function 867 868 def makeApiFunc(self, typeInfo, apiName): 869 customApi = copy(typeInfo.apis[apiName]) 870 customApi.name = self.customApiPrefix + customApi.name 871 if self.extraParameters is not None: 872 if isinstance(self.extraParameters, list): 873 customApi.parameters = \ 874 self.extraParameters + customApi.parameters 875 else: 876 os.abort( 877 "Type of extra parameters to custom API not valid. Expected list, got %s" % type( 878 self.extraParameters)) 879 880 if self.returnTypeOverride is not None: 881 customApi.retType = self.returnTypeOverride 882 return customApi 883 884 self.makeApi = makeApiFunc 885 886 def setCodegenDef(self, codegenDefFunc): 887 self.definitionFunc = codegenDefFunc 888 889 def makeDecl(self, typeInfo, apiName): 890 return self.codegen.makeFuncProto( 891 self.makeApi(self, typeInfo, apiName)) + ";\n\n" 892 893 def makeDefinition(self, typeInfo, apiName, isStatic=False): 894 vulkanApi = self.makeApi(self, typeInfo, apiName) 895 896 self.codegen.swapCode() 897 self.codegen.beginBlock() 898 899 if self.definitionFunc is None: 900 print("ERROR: No definition found for (%s, %s)" % 901 (vulkanApi.name, self.customApiPrefix)) 902 sys.exit(1) 903 904 self.definitionFunc(self.codegen, vulkanApi) 905 906 self.codegen.endBlock() 907 908 return ("static " if isStatic else "") + self.codegen.makeFuncProto( 909 vulkanApi) + "\n" + self.codegen.swapCode() + "\n" 910 911# Base class for wrapping all Vulkan API objects. These work with Vulkan 912# Registry generators and have gen* triggers. They tend to contain 913# VulkanAPIWrapper objects to make it easier to generate the code. 914class VulkanWrapperGenerator(object): 915 916 def __init__(self, module: Module, typeInfo: VulkanTypeInfo): 917 self.module: Module = module 918 self.typeInfo: VulkanTypeInfo = typeInfo 919 self.extensionStructTypes = OrderedDict() 920 921 def onBegin(self): 922 pass 923 924 def onEnd(self): 925 pass 926 927 def onBeginFeature(self, featureName, featureType): 928 pass 929 930 def onFeatureNewCmd(self, cmdName): 931 pass 932 933 def onEndFeature(self): 934 pass 935 936 def onGenType(self, typeInfo, name, alias): 937 category = self.typeInfo.categoryOf(name) 938 if category in ["struct", "union"] and not alias: 939 structInfo = self.typeInfo.structs[name] 940 if structInfo.structExtendsExpr: 941 self.extensionStructTypes[name] = structInfo 942 pass 943 944 def onGenStruct(self, typeInfo, name, alias): 945 pass 946 947 def onGenGroup(self, groupinfo, groupName, alias=None): 948 pass 949 950 def onGenEnum(self, enuminfo, name, alias): 951 pass 952 953 def onGenCmd(self, cmdinfo, name, alias): 954 pass 955 956 # Below Vulkan structure types may correspond to multiple Vulkan structs 957 # due to a conflict between different Vulkan registries. In order to get 958 # the correct Vulkan struct type, we need to check the type of its "root" 959 # struct as well. 960 ROOT_TYPE_MAPPING = { 961 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": { 962 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 963 "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 964 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE", 965 "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 966 }, 967 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": { 968 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 969 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkCreateBlobGOOGLE", 970 "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 971 }, 972 "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": { 973 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT", 974 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT", 975 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE", 976 "default": "VkRenderPassFragmentDensityMapCreateInfoEXT", 977 }, 978 } 979 980 def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None): 981 def readStructType(structTypeName, structVarName, cgen): 982 cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \ 983 (structTypeName, "goldfish_vk_struct_type", structVarName)) 984 985 def castAsStruct(varName, typeName, const=True): 986 return "reinterpret_cast<%s%s*>(%s)" % \ 987 ("const " if const else "", typeName, varName) 988 989 def doDefaultReturn(cgen): 990 if retType.typeName == "void": 991 cgen.stmt("return") 992 else: 993 cgen.stmt("return (%s)0" % retType.typeName) 994 995 cgen.beginIf("!%s" % triggerVar.paramName) 996 if nullEmit is None: 997 doDefaultReturn(cgen) 998 else: 999 nullEmit(cgen) 1000 cgen.endIf() 1001 1002 readStructType("structType", triggerVar.paramName, cgen) 1003 1004 cgen.line("switch(structType)") 1005 cgen.beginBlock() 1006 1007 currFeature = None 1008 1009 for ext in self.extensionStructTypes.values(): 1010 if not currFeature: 1011 cgen.leftline("#ifdef %s" % ext.feature) 1012 currFeature = ext.feature 1013 1014 if currFeature and ext.feature != currFeature: 1015 cgen.leftline("#endif") 1016 cgen.leftline("#ifdef %s" % ext.feature) 1017 currFeature = ext.feature 1018 1019 enum = ext.structEnumExpr 1020 protect = None 1021 if enum in self.typeInfo.enumElem: 1022 protect = self.typeInfo.enumElem[enum].get("protect", default=None) 1023 if protect is not None: 1024 cgen.leftline("#ifdef %s" % protect) 1025 1026 cgen.line("case %s:" % enum) 1027 cgen.beginBlock() 1028 1029 if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING: 1030 cgen.line("switch(%s)" % rootTypeVar.paramName) 1031 cgen.beginBlock() 1032 kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum] 1033 for k in kv: 1034 v = self.extensionStructTypes[kv[k]] 1035 if k == "default": 1036 cgen.line("%s:" % k) 1037 else: 1038 cgen.line("case %s:" % k) 1039 cgen.beginBlock() 1040 castedAccess = castAsStruct( 1041 triggerVar.paramName, v.name, const=triggerVar.isConst) 1042 forEachFunc(v, castedAccess, cgen) 1043 cgen.line("break;") 1044 cgen.endBlock() 1045 cgen.endBlock() 1046 else: 1047 castedAccess = castAsStruct( 1048 triggerVar.paramName, ext.name, const=triggerVar.isConst) 1049 forEachFunc(ext, castedAccess, cgen) 1050 1051 if autoBreak: 1052 cgen.stmt("break") 1053 cgen.endBlock() 1054 1055 if protect is not None: 1056 cgen.leftline("#endif // %s" % protect) 1057 1058 if currFeature: 1059 cgen.leftline("#endif") 1060 1061 cgen.line("default:") 1062 cgen.beginBlock() 1063 if defaultEmit is None: 1064 doDefaultReturn(cgen) 1065 else: 1066 defaultEmit(cgen) 1067 cgen.endBlock() 1068 1069 cgen.endBlock() 1070 1071 def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False): 1072 currFeature = None 1073 1074 for (i, ext) in enumerate(self.extensionStructTypes.values()): 1075 if doFeatureIfdefs: 1076 if not currFeature: 1077 cgen.leftline("#ifdef %s" % ext.feature) 1078 currFeature = ext.feature 1079 1080 if currFeature and ext.feature != currFeature: 1081 cgen.leftline("#endif") 1082 cgen.leftline("#ifdef %s" % ext.feature) 1083 currFeature = ext.feature 1084 1085 forEachFunc(i, ext, cgen) 1086 1087 if doFeatureIfdefs: 1088 if currFeature: 1089 cgen.leftline("#endif") 1090