xref: /aosp_15_r20/external/mesa3d/src/gfxstream/codegen/scripts/cereal/transform.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1# Copyright 2018 Google LLC
2# SPDX-License-Identifier: MIT
3
4from .common.codegen import CodeGen
5from .common.vulkantypes import \
6        VulkanCompoundType, VulkanAPI, makeVulkanTypeSimple, vulkanTypeNeedsTransform, vulkanTypeGetNeededTransformTypes, VulkanTypeIterator, iterateVulkanType, vulkanTypeforEachSubType, TRIVIAL_TRANSFORMED_TYPES, NON_TRIVIAL_TRANSFORMED_TYPES, TRANSFORMED_TYPES
7
8from .wrapperdefs import VulkanWrapperGenerator
9from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE
10
11def deviceMemoryTransform(resourceTrackerVarName, structOrApiInfo, getExpr, getLen, cgen, variant="tohost"):
12    paramIndices = \
13        structOrApiInfo.deviceMemoryInfoParameterIndices
14
15    for _, info in paramIndices.items():
16        orderedKeys = [
17            "handle",
18            "offset",
19            "size",
20            "typeIndex",
21            "typeBits",]
22
23        casts = {
24            "handle" : "VkDeviceMemory*",
25            "offset" : "VkDeviceSize*",
26            "size" : "VkDeviceSize*",
27            "typeIndex" : "uint32_t*",
28            "typeBits" : "uint32_t*",
29        }
30
31        accesses = {
32            "handle" : "nullptr",
33            "offset" : "nullptr",
34            "size" : "nullptr",
35            "typeIndex" : "nullptr",
36            "typeBits" : "nullptr",
37        }
38
39        lenAccesses = {
40            "handle" : "0",
41            "offset" : "0",
42            "size" : "0",
43            "typeIndex" : "0",
44            "typeBits" : "0",
45        }
46
47        def doParam(i, vulkanType):
48            access = getExpr(vulkanType)
49            lenAccess = getLen(vulkanType)
50
51            for k in orderedKeys:
52                if i == info.__dict__[k]:
53                    accesses[k] = access
54                    if lenAccess is not None:
55                        lenAccesses[k] = lenAccess
56                    else:
57                        lenAccesses[k] = "1"
58
59        vulkanTypeforEachSubType(structOrApiInfo, doParam)
60
61        callParams = ", ".join( \
62            ["(%s)%s, %s" % (casts[k], accesses[k], lenAccesses[k]) \
63                for k in orderedKeys])
64
65        if variant == "tohost":
66            cgen.stmt("%s->deviceMemoryTransform_tohost(%s)" % \
67                (resourceTrackerVarName, callParams))
68        else:
69            cgen.stmt("%s->deviceMemoryTransform_fromhost(%s)" % \
70                (resourceTrackerVarName, callParams))
71
72def directTransform(resourceTrackerVarName, vulkanType, getExpr, getLen, cgen, variant="tohost"):
73    access = getExpr(vulkanType)
74    lenAccess = getLen(vulkanType)
75
76    if lenAccess:
77        finalLenAccess = lenAccess
78    else:
79        finalLenAccess = "1"
80
81    cgen.stmt("%s->transformImpl_%s_%s(%s, %s)" % (resourceTrackerVarName,
82                                                   vulkanType.typeName, variant, access, finalLenAccess))
83
84def genTransformsForVulkanType(resourceTrackerVarName, structOrApiInfo, getExpr, getLen, cgen, variant="tohost"):
85    for transform in vulkanTypeGetNeededTransformTypes(structOrApiInfo):
86        if transform == "devicememory":
87            deviceMemoryTransform( \
88                resourceTrackerVarName,
89                structOrApiInfo,
90                getExpr, getLen, cgen, variant=variant)
91
92class TransformCodegen(VulkanTypeIterator):
93    def __init__(self, cgen, inputVar, resourceTrackerVarName, prefix, variant):
94        self.cgen = cgen
95        self.inputVar = inputVar
96        self.prefix = prefix
97        self.resourceTrackerVarName = resourceTrackerVarName
98
99        def makeAccess(varName, asPtr = True):
100            return lambda t: self.cgen.generalAccess(t, parentVarName = varName, asPtr = asPtr)
101
102        def makeLengthAccess(varName):
103            return lambda t: self.cgen.generalLengthAccess(t, parentVarName = varName)
104
105        def makeLengthAccessGuard(varName):
106            return lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName=varName)
107
108        self.exprAccessor = makeAccess(self.inputVar)
109        self.exprAccessorValue = makeAccess(self.inputVar, asPtr = False)
110        self.lenAccessor = makeLengthAccess(self.inputVar)
111        self.lenAccessorGuard = makeLengthAccessGuard(self.inputVar)
112
113        self.checked = False
114
115        self.variant = variant
116
117    def makeCastExpr(self, vulkanType):
118        return "(%s)" % (
119            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
120
121    def asNonConstCast(self, access, vulkanType):
122        if vulkanType.staticArrExpr:
123            casted = "%s(%s)" % (self.makeCastExpr(vulkanType.getForAddressAccess().getForNonConstAccess()), access)
124        elif vulkanType.accessibleAsPointer():
125            casted = "%s(%s)" % (self.makeCastExpr(vulkanType.getForNonConstAccess()), access)
126        else:
127            casted = "%s(%s)" % (self.makeCastExpr(vulkanType.getForAddressAccess().getForNonConstAccess()), access)
128        return casted
129
130    def onCheck(self, vulkanType):
131        pass
132
133    def endCheck(self, vulkanType):
134        pass
135
136    def onCompoundType(self, vulkanType):
137
138        access = self.exprAccessor(vulkanType)
139        lenAccess = self.lenAccessor(vulkanType)
140        lenAccessGuard = self.lenAccessorGuard(vulkanType)
141
142        isPtr = vulkanType.pointerIndirectionLevels > 0
143
144        if lenAccessGuard is not None:
145            self.cgen.beginIf(lenAccessGuard)
146
147        if isPtr:
148            self.cgen.beginIf(access)
149
150        if lenAccess is not None:
151
152            loopVar = "i"
153            access = "%s + %s" % (access, loopVar)
154            forInit = "uint32_t %s = 0" % loopVar
155            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess)
156            forIncr = "++%s" % loopVar
157
158            self.cgen.beginFor(forInit, forCond, forIncr)
159
160        accessCasted = self.asNonConstCast(access, vulkanType)
161
162        if vulkanType.isTransformed:
163            directTransform(self.resourceTrackerVarName, vulkanType, self.exprAccessor, self.lenAccessor, self.cgen, variant=self.variant)
164
165        self.cgen.funcCall(None, self.prefix + vulkanType.typeName,
166                           [self.resourceTrackerVarName, accessCasted])
167
168        if lenAccess is not None:
169            self.cgen.endFor()
170
171        if isPtr:
172            self.cgen.endIf()
173
174        if lenAccessGuard is not None:
175            self.cgen.endIf()
176
177    def onString(self, vulkanType):
178        pass
179
180    def onStringArray(self, vulkanType):
181        pass
182
183    def onStaticArr(self, vulkanType):
184        pass
185
186    def onStructExtension(self, vulkanType):
187        access = self.exprAccessor(vulkanType)
188
189        castedAccessExpr = "(%s)(%s)" % ("void*", access)
190        self.cgen.beginIf(access)
191        self.cgen.funcCall(None, self.prefix + "extension_struct",
192                           [self.resourceTrackerVarName, castedAccessExpr])
193        self.cgen.endIf()
194
195    def onPointer(self, vulkanType):
196        pass
197
198    def onValue(self, vulkanType):
199        pass
200
201
202class VulkanTransform(VulkanWrapperGenerator):
203    def __init__(self, module, typeInfo, resourceTrackerTypeName="ResourceTracker", resourceTrackerVarName="resourceTracker"):
204        VulkanWrapperGenerator.__init__(self, module, typeInfo)
205
206        self.codegen = CodeGen()
207
208        self.transformPrefix = "transform_"
209
210        self.tohostpart = "tohost"
211        self.fromhostpart = "fromhost"
212        self.variants = [self.tohostpart, self.fromhostpart]
213
214        self.toTransformVar = "toTransform"
215        self.resourceTrackerTypeName = resourceTrackerTypeName
216        self.resourceTrackerVarName = resourceTrackerVarName
217        self.transformParam = \
218            makeVulkanTypeSimple(False, self.resourceTrackerTypeName, 1,
219                                 self.resourceTrackerVarName)
220        self.voidType = makeVulkanTypeSimple(False, "void", 0)
221
222        self.extensionTransformPrototypes = []
223
224        for variant in self.variants:
225            self.extensionTransformPrototypes.append( \
226                VulkanAPI(self.transformPrefix + variant + "_extension_struct",
227                          self.voidType,
228                          [self.transformParam, STRUCT_EXTENSION_PARAM_FOR_WRITE]))
229
230        self.knownStructs = {}
231        self.needsTransform = set([])
232
233    def onBegin(self,):
234        VulkanWrapperGenerator.onBegin(self)
235        # Set up a convenience macro fro the transformed structs
236        # and forward-declare the resource tracker class
237        self.codegen.stmt("class %s" % self.resourceTrackerTypeName)
238        self.codegen.line("#define LIST_TRIVIAL_TRANSFORMED_TYPES(f) \\")
239        for name in TRIVIAL_TRANSFORMED_TYPES:
240            self.codegen.line("f(%s) \\" % name)
241        self.codegen.line("")
242
243        self.codegen.line("#define LIST_NON_TRIVIAL_TRANSFORMED_TYPES(f) \\")
244        for name in NON_TRIVIAL_TRANSFORMED_TYPES:
245            self.codegen.line("f(%s) \\" % name)
246        self.codegen.line("")
247
248        self.codegen.line("#define LIST_TRANSFORMED_TYPES(f) \\")
249        self.codegen.line("LIST_TRIVIAL_TRANSFORMED_TYPES(f) \\")
250        self.codegen.line("LIST_NON_TRIVIAL_TRANSFORMED_TYPES(f) \\")
251        self.codegen.line("")
252
253        self.module.appendHeader(self.codegen.swapCode())
254
255        for prototype in self.extensionTransformPrototypes:
256            self.module.appendImpl(self.codegen.makeFuncDecl(
257                prototype))
258
259    def onGenType(self, typeXml, name, alias):
260        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
261
262        if name in self.knownStructs:
263            return
264
265        category = self.typeInfo.categoryOf(name)
266
267        if category in ["struct", "union"] and alias:
268            for variant in self.variants:
269                self.module.appendHeader(
270                    self.codegen.makeFuncAlias(self.transformPrefix + variant + "_" + name,
271                                               self.transformPrefix + variant + "_" + alias))
272
273        if category in ["struct", "union"] and not alias:
274            structInfo = self.typeInfo.structs[name]
275            self.knownStructs[name] = structInfo
276
277            for variant in self.variants:
278                api = VulkanAPI( \
279                    self.transformPrefix + variant + "_" + name,
280                    self.voidType,
281                    [self.transformParam] + \
282                    [makeVulkanTypeSimple( \
283                        False, name, 1, self.toTransformVar)])
284
285                transformer = TransformCodegen(
286                    None,
287                    self.toTransformVar,
288                    self.resourceTrackerVarName,
289                    self.transformPrefix + variant + "_",
290                    variant)
291
292                def funcDefGenerator(cgen):
293                    transformer.cgen = cgen
294                    for p in api.parameters:
295                        cgen.stmt("(void)%s" % p.paramName)
296
297                    genTransformsForVulkanType(
298                        self.resourceTrackerVarName,
299                        structInfo,
300                        transformer.exprAccessor,
301                        transformer.lenAccessor,
302                        cgen,
303                        variant=variant)
304
305                    for member in structInfo.members:
306                        iterateVulkanType(
307                            self.typeInfo, member,
308                            transformer)
309
310                self.module.appendHeader(
311                    self.codegen.makeFuncDecl(api))
312                self.module.appendImpl(
313                    self.codegen.makeFuncImpl(api, funcDefGenerator))
314
315
316    def onGenCmd(self, cmdinfo, name, alias):
317        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
318
319    def onEnd(self,):
320        VulkanWrapperGenerator.onEnd(self)
321
322        for (variant, prototype) in zip(self.variants, self.extensionTransformPrototypes):
323            def forEachExtensionTransform(ext, castedAccess, cgen):
324                if ext.isTransformed:
325                    directTransform(self.resourceTrackerVarName, ext, lambda _ : castedAccess, lambda _ : "1", cgen, variant);
326                cgen.funcCall(None, self.transformPrefix + variant + "_" + ext.name,
327                              [self.resourceTrackerVarName, castedAccess])
328
329            self.module.appendImpl(
330                self.codegen.makeFuncImpl(
331                    prototype,
332                    lambda cgen: self.emitForEachStructExtension(
333                        cgen,
334                        self.voidType,
335                        STRUCT_EXTENSION_PARAM_FOR_WRITE,
336                        forEachExtensionTransform)))
337