xref: /aosp_15_r20/external/mesa3d/src/gfxstream/codegen/scripts/cereal/subdecode.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1# Copyright 2018 Google LLC
2# SPDX-License-Identifier: MIT
3from .common.codegen import CodeGen, VulkanWrapperGenerator
4from .common.vulkantypes import VulkanAPI, iterateVulkanType, VulkanType
5
6from .reservedmarshaling import VulkanReservedMarshalingCodegen
7from .transform import TransformCodegen
8
9from .wrapperdefs import API_PREFIX_RESERVEDUNMARSHAL
10from .wrapperdefs import MAX_PACKET_LENGTH
11from .wrapperdefs import ROOT_TYPE_DEFAULT_VALUE
12
13
14decoder_decl_preamble = """
15"""
16
17decoder_impl_preamble = """
18"""
19
20global_state_prefix = "this->on_"
21
22READ_STREAM = "readStream"
23WRITE_STREAM = "vkStream"
24
25# Driver workarounds for APIs that don't work well multithreaded
26driver_workarounds_global_lock_apis = [
27    "vkCreatePipelineLayout",
28    "vkDestroyPipelineLayout",
29]
30
31MAX_STACK_ITEMS = "16"
32
33
34def emit_param_decl_for_reading(param, cgen):
35    if param.staticArrExpr:
36        cgen.stmt(
37            cgen.makeRichCTypeDecl(param.getForNonConstAccess()))
38    else:
39        cgen.stmt(
40            cgen.makeRichCTypeDecl(param))
41
42    if param.pointerIndirectionLevels > 0:
43        lenAccess = cgen.generalLengthAccess(param)
44        if not lenAccess:
45            lenAccess = "1"
46        arrSize = "1" if "1" == lenAccess else "MAX_STACK_ITEMS"
47
48        typeHere = "uint8_t*" if "void" == param.typeName else param.typeName
49        cgen.stmt("%s%s stack_%s[%s]" % (
50            typeHere, "*" * (param.pointerIndirectionLevels - 1), param.paramName, arrSize))
51
52
53def emit_unmarshal(typeInfo, param, cgen, output=False, destroy=False, noUnbox=False):
54    if destroy:
55        iterateVulkanType(typeInfo, param, VulkanReservedMarshalingCodegen(
56            cgen,
57            "host",
58            READ_STREAM,
59            ROOT_TYPE_DEFAULT_VALUE,
60            param.paramName,
61            "readStreamPtrPtr",
62            API_PREFIX_RESERVEDUNMARSHAL,
63            "",
64            direction="read",
65            dynAlloc=True))
66        lenAccess = cgen.generalLengthAccess(param)
67        lenAccessGuard = cgen.generalLengthAccessGuard(param)
68        if None == lenAccess or "1" == lenAccess:
69            cgen.stmt("boxed_%s_preserve = %s" %
70                      (param.paramName, param.paramName))
71            cgen.stmt("%s = unbox_%s(%s)" %
72                      (param.paramName, param.typeName, param.paramName))
73        else:
74            if lenAccessGuard is not None:
75                self.cgen.beginIf(lenAccessGuard)
76            cgen.beginFor("uint32_t i = 0", "i < %s" % lenAccess, "++i")
77            cgen.stmt("boxed_%s_preserve[i] = %s[i]" %
78                      (param.paramName, param.paramName))
79            cgen.stmt("((%s*)(%s))[i] = unbox_%s(%s[i])" % (param.typeName,
80                                                            param.paramName, param.typeName, param.paramName))
81            cgen.endFor()
82            if lenAccessGuard is not None:
83                self.cgen.endIf()
84    else:
85        if noUnbox:
86            cgen.line("// No unbox for %s" % (param.paramName))
87
88        lenAccess = cgen.generalLengthAccess(param)
89        if not lenAccess:
90            lenAccess = "1"
91        arrSize = "1" if "1" == lenAccess else "MAX_STACK_ITEMS"
92
93        iterateVulkanType(typeInfo, param, VulkanReservedMarshalingCodegen(
94            cgen,
95            "host",
96            READ_STREAM,
97            ROOT_TYPE_DEFAULT_VALUE,
98            param.paramName,
99            "readStreamPtrPtr",
100            API_PREFIX_RESERVEDUNMARSHAL,
101            "" if (output or noUnbox) else "unbox_",
102            direction="read",
103            dynAlloc=True,
104            stackVar="stack_%s" % param.paramName,
105            stackArrSize=arrSize))
106
107
108def emit_dispatch_unmarshal(typeInfo, param, cgen, globalWrapped):
109    if globalWrapped:
110        cgen.stmt(
111            "// Begin global wrapped dispatchable handle unboxing for %s" % param.paramName)
112        iterateVulkanType(typeInfo, param, VulkanReservedMarshalingCodegen(
113            cgen,
114            "host",
115            READ_STREAM,
116            ROOT_TYPE_DEFAULT_VALUE,
117            param.paramName,
118            "readStreamPtrPtr",
119            API_PREFIX_RESERVEDUNMARSHAL,
120            "",
121            direction="read",
122            dynAlloc=True))
123    else:
124        cgen.stmt(
125            "// Begin non wrapped dispatchable handle unboxing for %s" % param.paramName)
126        # cgen.stmt("%s->unsetHandleMapping()" % READ_STREAM)
127        iterateVulkanType(typeInfo, param, VulkanReservedMarshalingCodegen(
128            cgen,
129            "host",
130            READ_STREAM,
131            ROOT_TYPE_DEFAULT_VALUE,
132            param.paramName,
133            "readStreamPtrPtr",
134            API_PREFIX_RESERVEDUNMARSHAL,
135            "",
136            direction="read",
137            dynAlloc=True))
138        cgen.stmt("auto unboxed_%s = unbox_%s(%s)" %
139                  (param.paramName, param.typeName, param.paramName))
140        cgen.stmt("auto vk = dispatch_%s(%s)" %
141                  (param.typeName, param.paramName))
142        cgen.stmt("// End manual dispatchable handle unboxing for %s" %
143                  param.paramName)
144
145
146def emit_transform(typeInfo, param, cgen, variant="tohost"):
147    res = \
148        iterateVulkanType(typeInfo, param, TransformCodegen(
149            cgen, param.paramName, "globalstate", "transform_%s_" % variant, variant))
150    if not res:
151        cgen.stmt("(void)%s" % param.paramName)
152
153# Everything here elides the initial arg
154
155
156class DecodingParameters(object):
157    def __init__(self, api: VulkanAPI):
158        self.params: list[VulkanType] = []
159        self.toRead: list[VulkanType] = []
160        self.toWrite: list[VulkanType] = []
161
162        for i, param in enumerate(api.parameters[1:]):
163            if i == 0 and param.isDispatchableHandleType():
164                param.dispatchHandle = True
165
166            if param.isNonDispatchableHandleType() and param.isCreatedBy(api):
167                param.nonDispatchableHandleCreate = True
168
169            if param.isNonDispatchableHandleType() and param.isDestroyedBy(api):
170                param.nonDispatchableHandleDestroy = True
171
172            if param.isDispatchableHandleType() and param.isCreatedBy(api):
173                param.dispatchableHandleCreate = True
174
175            if param.isDispatchableHandleType() and param.isDestroyedBy(api):
176                param.dispatchableHandleDestroy = True
177
178            self.toRead.append(param)
179
180            if param.possiblyOutput():
181                self.toWrite.append(param)
182
183            self.params.append(param)
184
185
186def emit_call_log(api, cgen):
187    decodingParams = DecodingParameters(api)
188    paramsToRead = decodingParams.toRead
189
190    # cgen.beginIf("m_logCalls")
191    paramLogFormat = "%p"
192    paramLogArgs = ["(void*)boxed_dispatchHandle"]
193
194    for p in paramsToRead:
195        paramLogFormat += "0x%llx "
196    for p in paramsToRead:
197        paramLogArgs.append("(unsigned long long)%s" % (p.paramName))
198    # cgen.stmt("fprintf(stderr, \"substream %%p: call %s %s\\n\", readStream, %s)" % (api.name, paramLogFormat, ", ".join(paramLogArgs)))
199    # cgen.endIf()
200
201
202def emit_decode_parameters(typeInfo, api, cgen, globalWrapped=False):
203
204    decodingParams = DecodingParameters(api)
205
206    paramsToRead = decodingParams.toRead
207
208    for p in paramsToRead:
209        emit_param_decl_for_reading(p, cgen)
210
211    i = 0
212    for p in paramsToRead:
213        lenAccess = cgen.generalLengthAccess(p)
214
215        if p.dispatchHandle:
216            emit_dispatch_unmarshal(typeInfo, p, cgen, globalWrapped)
217        else:
218            destroy = p.nonDispatchableHandleDestroy or p.dispatchableHandleDestroy
219            noUnbox = False
220
221            if p.nonDispatchableHandleDestroy or p.dispatchableHandleDestroy:
222                destroy = True
223                cgen.stmt(
224                    "// Begin manual non dispatchable handle destroy unboxing for %s" % p.paramName)
225                if None == lenAccess or "1" == lenAccess:
226                    cgen.stmt("%s boxed_%s_preserve" %
227                              (p.typeName, p.paramName))
228                else:
229                    cgen.stmt("%s* boxed_%s_preserve; %s->alloc((void**)&boxed_%s_preserve, %s * sizeof(%s))" %
230                              (p.typeName, p.paramName, READ_STREAM, p.paramName, lenAccess, p.typeName))
231
232            if p.possiblyOutput():
233                cgen.stmt(
234                    "// Begin manual dispatchable handle unboxing for %s" % p.paramName)
235                cgen.stmt("%s->unsetHandleMapping()" % READ_STREAM)
236
237            emit_unmarshal(typeInfo, p, cgen, output=p.possiblyOutput(
238            ), destroy=destroy, noUnbox=noUnbox)
239        i += 1
240
241    for p in paramsToRead:
242        emit_transform(typeInfo, p, cgen, variant="tohost")
243
244    emit_call_log(api, cgen)
245
246
247def emit_dispatch_call(api, cgen):
248
249    decodingParams = DecodingParameters(api)
250
251    customParams = ["(VkCommandBuffer)dispatchHandle"]
252
253    for (i, p) in enumerate(api.parameters[1:]):
254        customParam = p.paramName
255        if decodingParams.params[i].dispatchHandle:
256            customParam = "unboxed_%s" % p.paramName
257        customParams.append(customParam)
258
259    if api.name in driver_workarounds_global_lock_apis:
260        cgen.stmt("lock()")
261
262    cgen.vkApiCall(api, customPrefix="vk->", customParameters=customParams,
263                    checkForDeviceLost=True, globalStatePrefix=global_state_prefix,
264                    checkForOutOfMemory=True)
265
266    if api.name in driver_workarounds_global_lock_apis:
267        cgen.stmt("unlock()")
268
269
270def emit_global_state_wrapped_call(api, cgen, context=False):
271    customParams = ["pool", "(VkCommandBuffer)(boxed_dispatchHandle)"] + \
272        list(map(lambda p: p.paramName, api.parameters[1:]))
273    if context:
274        customParams += ["context"];
275    cgen.vkApiCall(api, customPrefix=global_state_prefix,
276                   customParameters=customParams, checkForDeviceLost=True,
277                   checkForOutOfMemory=True, globalStatePrefix=global_state_prefix)
278
279
280def emit_default_decoding(typeInfo, api, cgen):
281    emit_decode_parameters(typeInfo, api, cgen)
282    emit_dispatch_call(api, cgen)
283
284
285def emit_global_state_wrapped_decoding(typeInfo, api, cgen):
286    emit_decode_parameters(typeInfo, api, cgen, globalWrapped=True)
287    emit_global_state_wrapped_call(api, cgen)
288
289def emit_global_state_wrapped_decoding_with_context(typeInfo, api, cgen):
290    emit_decode_parameters(typeInfo, api, cgen, globalWrapped=True)
291    emit_global_state_wrapped_call(api, cgen, context=True)
292
293custom_decodes = {
294    "vkCmdCopyBufferToImage": emit_global_state_wrapped_decoding_with_context,
295    "vkCmdCopyImage": emit_global_state_wrapped_decoding,
296    "vkCmdCopyImageToBuffer": emit_global_state_wrapped_decoding,
297    "vkCmdCopyBufferToImage2": emit_global_state_wrapped_decoding_with_context,
298    "vkCmdCopyImage2": emit_global_state_wrapped_decoding,
299    "vkCmdCopyImageToBuffer2": emit_global_state_wrapped_decoding,
300    "vkCmdCopyBufferToImage2KHR": emit_global_state_wrapped_decoding_with_context,
301    "vkCmdCopyImage2KHR": emit_global_state_wrapped_decoding,
302    "vkCmdCopyImageToBuffer2KHR": emit_global_state_wrapped_decoding,
303    "vkCmdExecuteCommands": emit_global_state_wrapped_decoding,
304    "vkBeginCommandBuffer": emit_global_state_wrapped_decoding_with_context,
305    "vkEndCommandBuffer": emit_global_state_wrapped_decoding_with_context,
306    "vkResetCommandBuffer": emit_global_state_wrapped_decoding,
307    "vkCmdPipelineBarrier": emit_global_state_wrapped_decoding,
308    "vkCmdPipelineBarrier2": emit_global_state_wrapped_decoding,
309    "vkCmdBindPipeline": emit_global_state_wrapped_decoding,
310    "vkCmdBindDescriptorSets": emit_global_state_wrapped_decoding,
311    "vkCmdCopyQueryPoolResults": emit_global_state_wrapped_decoding,
312    "vkBeginCommandBufferAsyncGOOGLE": emit_global_state_wrapped_decoding_with_context,
313    "vkEndCommandBufferAsyncGOOGLE": emit_global_state_wrapped_decoding_with_context,
314    "vkResetCommandBufferAsyncGOOGLE": emit_global_state_wrapped_decoding,
315    "vkCommandBufferHostSyncGOOGLE": emit_global_state_wrapped_decoding,
316    "vkCmdBeginRenderPass" : emit_global_state_wrapped_decoding,
317    "vkCmdBeginRenderPass2" : emit_global_state_wrapped_decoding,
318    "vkCmdBeginRenderPass2KHR" : emit_global_state_wrapped_decoding,
319}
320
321
322class VulkanSubDecoder(VulkanWrapperGenerator):
323    def __init__(self, module, typeInfo):
324        VulkanWrapperGenerator.__init__(self, module, typeInfo)
325        self.typeInfo = typeInfo
326        self.cgen = CodeGen()
327
328    def onBegin(self,):
329        self.module.appendImpl(
330            "#define MAX_STACK_ITEMS %s\n" % MAX_STACK_ITEMS)
331
332        self.module.appendImpl(
333            "#define MAX_PACKET_LENGTH %s\n" % MAX_PACKET_LENGTH)
334
335        self.module.appendImpl(
336            "size_t subDecode(VulkanMemReadingStream* readStream, VulkanDispatch* vk, void* boxed_dispatchHandle, void* dispatchHandle, VkDeviceSize dataSize, const void* pData, const VkDecoderContext& context)\n")
337
338        self.cgen.beginBlock()  # function body
339
340        self.cgen.stmt("auto& metricsLogger = *context.metricsLogger")
341        self.cgen.stmt("uint32_t count = 0")
342        self.cgen.stmt("unsigned char *buf = (unsigned char *)pData")
343        self.cgen.stmt("android::base::BumpPool* pool = readStream->pool()")
344        self.cgen.stmt("unsigned char *ptr = (unsigned char *)pData")
345        self.cgen.stmt(
346            "const unsigned char* const end = (const unsigned char*)buf + dataSize")
347        self.cgen.stmt(
348            "VkDecoderGlobalState* globalstate = VkDecoderGlobalState::get()")
349
350        self.cgen.line("while (end - ptr >= 8)")
351        self.cgen.beginBlock()  # while loop
352
353        self.cgen.stmt("uint32_t opcode = *(uint32_t *)ptr")
354        self.cgen.stmt("uint32_t packetLen = *(uint32_t *)(ptr + 4)")
355        self.cgen.line("""
356        // packetLen should be at least 8 (op code and packet length) and should not be excessively large
357        if (packetLen < 8 || packetLen > MAX_PACKET_LENGTH) {
358            WARN("Bad packet length %d detected, subdecode may fail", packetLen);
359            metricsLogger.logMetricEvent(MetricEventBadPacketLength{ .len = packetLen });
360        }
361        """)
362        self.cgen.stmt("if (end - ptr < packetLen) return ptr - (unsigned char*)buf")
363
364
365        self.cgen.stmt("%s->setBuf((uint8_t*)(ptr + 8))" % READ_STREAM)
366        self.cgen.stmt(
367            "uint8_t* readStreamPtr = %s->getBuf(); uint8_t** readStreamPtrPtr = &readStreamPtr" % READ_STREAM)
368        self.cgen.line("switch (opcode)")
369        self.cgen.beginBlock()  # switch stmt
370
371        self.module.appendImpl(self.cgen.swapCode())
372
373    def onGenCmd(self, cmdinfo, name, alias):
374        typeInfo = self.typeInfo
375        cgen = self.cgen
376        api = typeInfo.apis[name]
377
378        if "commandBuffer" != api.parameters[0].paramName:
379            return
380
381        cgen.line("case OP_%s:" % name)
382        cgen.beginBlock()
383        cgen.stmt("GFXSTREAM_TRACE_EVENT(GFXSTREAM_TRACE_DECODER_CATEGORY, \"VkSubDecoder %s\")" % name)
384
385        if api.name in custom_decodes.keys():
386            custom_decodes[api.name](typeInfo, api, cgen)
387        else:
388            emit_default_decoding(typeInfo, api, cgen)
389
390        cgen.stmt("break")
391        cgen.endBlock()
392        self.module.appendImpl(self.cgen.swapCode())
393
394    def onEnd(self,):
395        self.cgen.line("default:")
396        self.cgen.beginBlock()
397        self.cgen.stmt(
398            "GFXSTREAM_ABORT(::emugl::FatalError(::emugl::ABORT_REASON_OTHER)) << \"Unrecognized opcode \" << opcode")
399        self.cgen.endBlock()
400
401        self.cgen.endBlock()  # switch stmt
402
403        self.cgen.stmt("++count; if (count % 1000 == 0) { pool->freeAll(); }")
404        self.cgen.stmt("ptr += packetLen")
405        self.cgen.endBlock()  # while loop
406
407        self.cgen.stmt("pool->freeAll()")
408        self.cgen.stmt("return ptr - (unsigned char*)buf;")
409        self.cgen.endBlock()  # function body
410        self.module.appendImpl(self.cgen.swapCode())
411