1# Copyright 2018 Google LLC 2# SPDX-License-Identifier: MIT 3 4from .common.codegen import CodeGen 5from .common.vulkantypes import \ 6 VulkanCompoundType, VulkanAPI, makeVulkanTypeSimple, vulkanTypeNeedsTransform, vulkanTypeGetNeededTransformTypes, VulkanTypeIterator, iterateVulkanType, vulkanTypeforEachSubType, TRIVIAL_TRANSFORMED_TYPES, NON_TRIVIAL_TRANSFORMED_TYPES, TRANSFORMED_TYPES 7 8from .wrapperdefs import VulkanWrapperGenerator 9from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE 10 11def deviceMemoryTransform(resourceTrackerVarName, structOrApiInfo, getExpr, getLen, cgen, variant="tohost"): 12 paramIndices = \ 13 structOrApiInfo.deviceMemoryInfoParameterIndices 14 15 for _, info in paramIndices.items(): 16 orderedKeys = [ 17 "handle", 18 "offset", 19 "size", 20 "typeIndex", 21 "typeBits",] 22 23 casts = { 24 "handle" : "VkDeviceMemory*", 25 "offset" : "VkDeviceSize*", 26 "size" : "VkDeviceSize*", 27 "typeIndex" : "uint32_t*", 28 "typeBits" : "uint32_t*", 29 } 30 31 accesses = { 32 "handle" : "nullptr", 33 "offset" : "nullptr", 34 "size" : "nullptr", 35 "typeIndex" : "nullptr", 36 "typeBits" : "nullptr", 37 } 38 39 lenAccesses = { 40 "handle" : "0", 41 "offset" : "0", 42 "size" : "0", 43 "typeIndex" : "0", 44 "typeBits" : "0", 45 } 46 47 def doParam(i, vulkanType): 48 access = getExpr(vulkanType) 49 lenAccess = getLen(vulkanType) 50 51 for k in orderedKeys: 52 if i == info.__dict__[k]: 53 accesses[k] = access 54 if lenAccess is not None: 55 lenAccesses[k] = lenAccess 56 else: 57 lenAccesses[k] = "1" 58 59 vulkanTypeforEachSubType(structOrApiInfo, doParam) 60 61 callParams = ", ".join( \ 62 ["(%s)%s, %s" % (casts[k], accesses[k], lenAccesses[k]) \ 63 for k in orderedKeys]) 64 65 if variant == "tohost": 66 cgen.stmt("%s->deviceMemoryTransform_tohost(%s)" % \ 67 (resourceTrackerVarName, callParams)) 68 else: 69 cgen.stmt("%s->deviceMemoryTransform_fromhost(%s)" % \ 70 (resourceTrackerVarName, callParams)) 71 72def directTransform(resourceTrackerVarName, vulkanType, getExpr, getLen, cgen, variant="tohost"): 73 access = getExpr(vulkanType) 74 lenAccess = getLen(vulkanType) 75 76 if lenAccess: 77 finalLenAccess = lenAccess 78 else: 79 finalLenAccess = "1" 80 81 cgen.stmt("%s->transformImpl_%s_%s(%s, %s)" % (resourceTrackerVarName, 82 vulkanType.typeName, variant, access, finalLenAccess)) 83 84def genTransformsForVulkanType(resourceTrackerVarName, structOrApiInfo, getExpr, getLen, cgen, variant="tohost"): 85 for transform in vulkanTypeGetNeededTransformTypes(structOrApiInfo): 86 if transform == "devicememory": 87 deviceMemoryTransform( \ 88 resourceTrackerVarName, 89 structOrApiInfo, 90 getExpr, getLen, cgen, variant=variant) 91 92class TransformCodegen(VulkanTypeIterator): 93 def __init__(self, cgen, inputVar, resourceTrackerVarName, prefix, variant): 94 self.cgen = cgen 95 self.inputVar = inputVar 96 self.prefix = prefix 97 self.resourceTrackerVarName = resourceTrackerVarName 98 99 def makeAccess(varName, asPtr = True): 100 return lambda t: self.cgen.generalAccess(t, parentVarName = varName, asPtr = asPtr) 101 102 def makeLengthAccess(varName): 103 return lambda t: self.cgen.generalLengthAccess(t, parentVarName = varName) 104 105 def makeLengthAccessGuard(varName): 106 return lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName=varName) 107 108 self.exprAccessor = makeAccess(self.inputVar) 109 self.exprAccessorValue = makeAccess(self.inputVar, asPtr = False) 110 self.lenAccessor = makeLengthAccess(self.inputVar) 111 self.lenAccessorGuard = makeLengthAccessGuard(self.inputVar) 112 113 self.checked = False 114 115 self.variant = variant 116 117 def makeCastExpr(self, vulkanType): 118 return "(%s)" % ( 119 self.cgen.makeCTypeDecl(vulkanType, useParamName=False)) 120 121 def asNonConstCast(self, access, vulkanType): 122 if vulkanType.staticArrExpr: 123 casted = "%s(%s)" % (self.makeCastExpr(vulkanType.getForAddressAccess().getForNonConstAccess()), access) 124 elif vulkanType.accessibleAsPointer(): 125 casted = "%s(%s)" % (self.makeCastExpr(vulkanType.getForNonConstAccess()), access) 126 else: 127 casted = "%s(%s)" % (self.makeCastExpr(vulkanType.getForAddressAccess().getForNonConstAccess()), access) 128 return casted 129 130 def onCheck(self, vulkanType): 131 pass 132 133 def endCheck(self, vulkanType): 134 pass 135 136 def onCompoundType(self, vulkanType): 137 138 access = self.exprAccessor(vulkanType) 139 lenAccess = self.lenAccessor(vulkanType) 140 lenAccessGuard = self.lenAccessorGuard(vulkanType) 141 142 isPtr = vulkanType.pointerIndirectionLevels > 0 143 144 if lenAccessGuard is not None: 145 self.cgen.beginIf(lenAccessGuard) 146 147 if isPtr: 148 self.cgen.beginIf(access) 149 150 if lenAccess is not None: 151 152 loopVar = "i" 153 access = "%s + %s" % (access, loopVar) 154 forInit = "uint32_t %s = 0" % loopVar 155 forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess) 156 forIncr = "++%s" % loopVar 157 158 self.cgen.beginFor(forInit, forCond, forIncr) 159 160 accessCasted = self.asNonConstCast(access, vulkanType) 161 162 if vulkanType.isTransformed: 163 directTransform(self.resourceTrackerVarName, vulkanType, self.exprAccessor, self.lenAccessor, self.cgen, variant=self.variant) 164 165 self.cgen.funcCall(None, self.prefix + vulkanType.typeName, 166 [self.resourceTrackerVarName, accessCasted]) 167 168 if lenAccess is not None: 169 self.cgen.endFor() 170 171 if isPtr: 172 self.cgen.endIf() 173 174 if lenAccessGuard is not None: 175 self.cgen.endIf() 176 177 def onString(self, vulkanType): 178 pass 179 180 def onStringArray(self, vulkanType): 181 pass 182 183 def onStaticArr(self, vulkanType): 184 pass 185 186 def onStructExtension(self, vulkanType): 187 access = self.exprAccessor(vulkanType) 188 189 castedAccessExpr = "(%s)(%s)" % ("void*", access) 190 self.cgen.beginIf(access) 191 self.cgen.funcCall(None, self.prefix + "extension_struct", 192 [self.resourceTrackerVarName, castedAccessExpr]) 193 self.cgen.endIf() 194 195 def onPointer(self, vulkanType): 196 pass 197 198 def onValue(self, vulkanType): 199 pass 200 201 202class VulkanTransform(VulkanWrapperGenerator): 203 def __init__(self, module, typeInfo, resourceTrackerTypeName="ResourceTracker", resourceTrackerVarName="resourceTracker"): 204 VulkanWrapperGenerator.__init__(self, module, typeInfo) 205 206 self.codegen = CodeGen() 207 208 self.transformPrefix = "transform_" 209 210 self.tohostpart = "tohost" 211 self.fromhostpart = "fromhost" 212 self.variants = [self.tohostpart, self.fromhostpart] 213 214 self.toTransformVar = "toTransform" 215 self.resourceTrackerTypeName = resourceTrackerTypeName 216 self.resourceTrackerVarName = resourceTrackerVarName 217 self.transformParam = \ 218 makeVulkanTypeSimple(False, self.resourceTrackerTypeName, 1, 219 self.resourceTrackerVarName) 220 self.voidType = makeVulkanTypeSimple(False, "void", 0) 221 222 self.extensionTransformPrototypes = [] 223 224 for variant in self.variants: 225 self.extensionTransformPrototypes.append( \ 226 VulkanAPI(self.transformPrefix + variant + "_extension_struct", 227 self.voidType, 228 [self.transformParam, STRUCT_EXTENSION_PARAM_FOR_WRITE])) 229 230 self.knownStructs = {} 231 self.needsTransform = set([]) 232 233 def onBegin(self,): 234 VulkanWrapperGenerator.onBegin(self) 235 # Set up a convenience macro fro the transformed structs 236 # and forward-declare the resource tracker class 237 self.codegen.stmt("class %s" % self.resourceTrackerTypeName) 238 self.codegen.line("#define LIST_TRIVIAL_TRANSFORMED_TYPES(f) \\") 239 for name in TRIVIAL_TRANSFORMED_TYPES: 240 self.codegen.line("f(%s) \\" % name) 241 self.codegen.line("") 242 243 self.codegen.line("#define LIST_NON_TRIVIAL_TRANSFORMED_TYPES(f) \\") 244 for name in NON_TRIVIAL_TRANSFORMED_TYPES: 245 self.codegen.line("f(%s) \\" % name) 246 self.codegen.line("") 247 248 self.codegen.line("#define LIST_TRANSFORMED_TYPES(f) \\") 249 self.codegen.line("LIST_TRIVIAL_TRANSFORMED_TYPES(f) \\") 250 self.codegen.line("LIST_NON_TRIVIAL_TRANSFORMED_TYPES(f) \\") 251 self.codegen.line("") 252 253 self.module.appendHeader(self.codegen.swapCode()) 254 255 for prototype in self.extensionTransformPrototypes: 256 self.module.appendImpl(self.codegen.makeFuncDecl( 257 prototype)) 258 259 def onGenType(self, typeXml, name, alias): 260 VulkanWrapperGenerator.onGenType(self, typeXml, name, alias) 261 262 if name in self.knownStructs: 263 return 264 265 category = self.typeInfo.categoryOf(name) 266 267 if category in ["struct", "union"] and alias: 268 for variant in self.variants: 269 self.module.appendHeader( 270 self.codegen.makeFuncAlias(self.transformPrefix + variant + "_" + name, 271 self.transformPrefix + variant + "_" + alias)) 272 273 if category in ["struct", "union"] and not alias: 274 structInfo = self.typeInfo.structs[name] 275 self.knownStructs[name] = structInfo 276 277 for variant in self.variants: 278 api = VulkanAPI( \ 279 self.transformPrefix + variant + "_" + name, 280 self.voidType, 281 [self.transformParam] + \ 282 [makeVulkanTypeSimple( \ 283 False, name, 1, self.toTransformVar)]) 284 285 transformer = TransformCodegen( 286 None, 287 self.toTransformVar, 288 self.resourceTrackerVarName, 289 self.transformPrefix + variant + "_", 290 variant) 291 292 def funcDefGenerator(cgen): 293 transformer.cgen = cgen 294 for p in api.parameters: 295 cgen.stmt("(void)%s" % p.paramName) 296 297 genTransformsForVulkanType( 298 self.resourceTrackerVarName, 299 structInfo, 300 transformer.exprAccessor, 301 transformer.lenAccessor, 302 cgen, 303 variant=variant) 304 305 for member in structInfo.members: 306 iterateVulkanType( 307 self.typeInfo, member, 308 transformer) 309 310 self.module.appendHeader( 311 self.codegen.makeFuncDecl(api)) 312 self.module.appendImpl( 313 self.codegen.makeFuncImpl(api, funcDefGenerator)) 314 315 316 def onGenCmd(self, cmdinfo, name, alias): 317 VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias) 318 319 def onEnd(self,): 320 VulkanWrapperGenerator.onEnd(self) 321 322 for (variant, prototype) in zip(self.variants, self.extensionTransformPrototypes): 323 def forEachExtensionTransform(ext, castedAccess, cgen): 324 if ext.isTransformed: 325 directTransform(self.resourceTrackerVarName, ext, lambda _ : castedAccess, lambda _ : "1", cgen, variant); 326 cgen.funcCall(None, self.transformPrefix + variant + "_" + ext.name, 327 [self.resourceTrackerVarName, castedAccess]) 328 329 self.module.appendImpl( 330 self.codegen.makeFuncImpl( 331 prototype, 332 lambda cgen: self.emitForEachStructExtension( 333 cgen, 334 self.voidType, 335 STRUCT_EXTENSION_PARAM_FOR_WRITE, 336 forEachExtensionTransform))) 337