xref: /aosp_15_r20/external/mesa3d/src/nouveau/headers/class_parser.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1#! /usr/bin/env python3
2
3# script to parse nvidia CL headers and generate inlines to be used in pushbuffer encoding.
4# probably needs python3.9
5
6import argparse
7import os.path
8import sys
9import re
10import subprocess
11
12from mako.template import Template
13
14METHOD_ARRAY_SIZES = {
15    'BIND_GROUP_CONSTANT_BUFFER'                        : 16,
16    'CALL_MME_DATA'                                     : 256,
17    'CALL_MME_MACRO'                                    : 256,
18    'LOAD_CONSTANT_BUFFER'                              : 16,
19    'LOAD_INLINE_QMD_DATA'                              : 64,
20    'SET_ANTI_ALIAS_SAMPLE_POSITIONS'                   : 4,
21    'SET_BLEND'                                         : 8,
22    'SET_BLEND_PER_TARGET_*'                            : 8,
23    'SET_COLOR_TARGET_*'                                : 8,
24    'SET_COLOR_COMPRESSION'                             : 8,
25    'SET_COLOR_CLEAR_VALUE'                             : 4,
26    'SET_CT_WRITE'                                      : 8,
27    'SET_MME_SHADOW_SCRATCH'                            : 256,
28    'SET_PIPELINE_*'                                    : 6,
29    'SET_SCG_COMPUTE_SCHEDULING_PARAMETERS'             : 16,
30    'SET_SCISSOR_*'                                     : 16,
31    'SET_SHADER_PERFORMANCE_SNAPSHOT_COUNTER_VALUE*'    : 8,
32    'SET_SHADER_PERFORMANCE_COUNTER_VALUE*'             : 8,
33    'SET_SHADER_PERFORMANCE_COUNTER_EVENT'              : 8,
34    'SET_SHADER_PERFORMANCE_COUNTER_CONTROL_A'          : 8,
35    'SET_SHADER_PERFORMANCE_COUNTER_CONTROL_B'          : 8,
36    'SET_STREAM_OUT_BUFFER_*'                           : 4,
37    'SET_STREAM_OUT_CONTROL_*'                          : 4,
38    'SET_VIEWPORT_*'                                    : 16,
39    'SET_VERTEX_ATTRIBUTE_*'                            : 16,
40    'SET_VERTEX_STREAM_*'                               : 16,
41}
42
43METHOD_IS_FLOAT = [
44    'SET_BLEND_CONST_*',
45    'SET_DEPTH_BIAS',
46    'SET_SLOPE_SCALE_DEPTH_BIAS',
47    'SET_DEPTH_BIAS_CLAMP',
48    'SET_DEPTH_BOUNDS_M*',
49    'SET_LINE_WIDTH_FLOAT',
50    'SET_ALIASED_LINE_WIDTH_FLOAT',
51    'SET_VIEWPORT_SCALE_*',
52    'SET_VIEWPORT_OFFSET_*',
53    'SET_VIEWPORT_CLIP_MIN_Z',
54    'SET_VIEWPORT_CLIP_MAX_Z',
55    'SET_Z_CLEAR_VALUE',
56]
57
58TEMPLATE_H = Template("""\
59/* parsed class ${nvcl} */
60
61#include "nvtypes.h"
62#include "${clheader}"
63
64#include <assert.h>
65#include <stdio.h>
66#include "util/u_math.h"
67
68%for mthd in mthddict:
69struct nv_${nvcl.lower()}_${mthd} {
70  %for field_name in mthddict[mthd].field_name_start:
71    uint32_t ${field_name.lower()};
72  %endfor
73};
74
75static inline void
76__${nvcl}_${mthd}(uint32_t *val_out, struct nv_${nvcl.lower()}_${mthd} st)
77{
78    uint32_t val = 0;
79  %for field_name in mthddict[mthd].field_name_start:
80    <%
81        field_start = int(mthddict[mthd].field_name_start[field_name])
82        field_end = int(mthddict[mthd].field_name_end[field_name])
83        field_width = field_end - field_start + 1
84    %>
85    %if field_width == 32:
86    val |= st.${field_name.lower()};
87    %else:
88    assert(st.${field_name.lower()} < (1ULL << ${field_width}));
89    val |= st.${field_name.lower()} << ${field_start};
90    %endif
91  %endfor
92    *val_out = val;
93}
94
95#define V_${nvcl}_${mthd}(val, args...) { ${bs}
96  %for field_name in mthddict[mthd].field_name_start:
97    %for d in mthddict[mthd].field_defs[field_name]:
98    UNUSED uint32_t ${field_name}_${d} = ${nvcl}_${mthd}_${field_name}_${d}; ${bs}
99    %endfor
100  %endfor
101  %if len(mthddict[mthd].field_name_start) > 1:
102    struct nv_${nvcl.lower()}_${mthd} __data = args; ${bs}
103  %else:
104<% field_name = next(iter(mthddict[mthd].field_name_start)).lower() %>\
105    struct nv_${nvcl.lower()}_${mthd} __data = { .${field_name} = (args) }; ${bs}
106  %endif
107    __${nvcl}_${mthd}(&val, __data); ${bs}
108}
109
110%if mthddict[mthd].is_array:
111#define VA_${nvcl}_${mthd}(i) V_${nvcl}_${mthd}
112%else:
113#define VA_${nvcl}_${mthd} V_${nvcl}_${mthd}
114%endif
115
116%if mthddict[mthd].is_array:
117#define P_${nvcl}_${mthd}(push, idx, args...) do { ${bs}
118%else:
119#define P_${nvcl}_${mthd}(push, args...) do { ${bs}
120%endif
121  %for field_name in mthddict[mthd].field_name_start:
122    %for d in mthddict[mthd].field_defs[field_name]:
123    UNUSED uint32_t ${field_name}_${d} = ${nvcl}_${mthd}_${field_name}_${d}; ${bs}
124    %endfor
125  %endfor
126    uint32_t nvk_p_ret; ${bs}
127    V_${nvcl}_${mthd}(nvk_p_ret, args); ${bs}
128    %if mthddict[mthd].is_array:
129    nv_push_val(push, ${nvcl}_${mthd}(idx), nvk_p_ret); ${bs}
130    %else:
131    nv_push_val(push, ${nvcl}_${mthd}, nvk_p_ret); ${bs}
132    %endif
133} while(0)
134
135%endfor
136
137const char *P_PARSE_${nvcl}_MTHD(uint16_t idx);
138void P_DUMP_${nvcl}_MTHD_DATA(FILE *fp, uint16_t idx, uint32_t data,
139                              const char *prefix);
140""")
141
142TEMPLATE_C = Template("""\
143#include "${header}"
144
145#include <stdio.h>
146
147const char*
148P_PARSE_${nvcl}_MTHD(uint16_t idx)
149{
150    switch (idx) {
151%for mthd in mthddict:
152  %if mthddict[mthd].is_array and mthddict[mthd].array_size == 0:
153    <% continue %>
154  %endif
155  %if mthddict[mthd].is_array:
156    %for i in range(mthddict[mthd].array_size):
157    case ${nvcl}_${mthd}(${i}):
158        return "${nvcl}_${mthd}(${i})";
159    %endfor
160  % else:
161    case ${nvcl}_${mthd}:
162        return "${nvcl}_${mthd}";
163  %endif
164%endfor
165    default:
166        return "unknown method";
167    }
168}
169
170void
171P_DUMP_${nvcl}_MTHD_DATA(FILE *fp, uint16_t idx, uint32_t data,
172                         const char *prefix)
173{
174    uint32_t parsed;
175    switch (idx) {
176%for mthd in mthddict:
177  %if mthddict[mthd].is_array and mthddict[mthd].array_size == 0:
178    <% continue %>
179  %endif
180  %if mthddict[mthd].is_array:
181    %for i in range(mthddict[mthd].array_size):
182    case ${nvcl}_${mthd}(${i}):
183    %endfor
184  % else:
185    case ${nvcl}_${mthd}:
186  %endif
187  %for field_name in mthddict[mthd].field_name_start:
188    <%
189        field_start = int(mthddict[mthd].field_name_start[field_name])
190        field_end = int(mthddict[mthd].field_name_end[field_name])
191        field_width = field_end - field_start + 1
192    %>
193    %if field_width == 32:
194        parsed = data;
195    %else:
196        parsed = (data >> ${field_start}) & ((1u << ${field_width}) - 1);
197    %endif
198        fprintf(fp, "%s.${field_name} = ", prefix);
199    %if len(mthddict[mthd].field_defs[field_name]):
200        switch (parsed) {
201      %for d in mthddict[mthd].field_defs[field_name]:
202        case ${nvcl}_${mthd}_${field_name}_${d}:
203            fprintf(fp, "${d}${bs}n");
204            break;
205      %endfor
206        default:
207            fprintf(fp, "0x%x${bs}n", parsed);
208            break;
209        }
210    %else:
211      %if mthddict[mthd].is_float:
212        fprintf(fp, "%ff (0x%x)${bs}n", uif(parsed), parsed);
213      %else:
214        fprintf(fp, "(0x%x)${bs}n", parsed);
215      %endif
216    %endif
217  %endfor
218        break;
219%endfor
220    default:
221        fprintf(fp, "%s.VALUE = 0x%x${bs}n", prefix, data);
222        break;
223    }
224}
225""")
226
227TEMPLATE_RS = Template("""\
228// parsed class ${nvcl}
229
230% if version is not None:
231pub const ${version[0]}: u16 = ${version[1]};
232% endif
233""")
234
235TEMPLATE_RS_MTHD = Template("""\
236
237// parsed class ${nvcl}
238
239## Write out the methods in Rust
240%for mthd_name, mthd in mthddict.items():
241## Identify the field type.
242<%
243for field_name, field_value in mthd.field_defs.items():
244    if field_name == 'V' and len(field_value) > 0:
245        mthd.field_rs_types[field_name] = to_camel(mthd_name) + 'V'
246        mthd.field_is_rs_enum[field_name] = True
247    elif len(field_value) > 0:
248        assert(field_name != "")
249        mthd.field_rs_types[field_name] = to_camel(mthd_name) + to_camel(field_name)
250        mthd.field_is_rs_enum[field_name] = True
251    elif mthd.is_float:
252        mthd.field_rs_types[field_name] = "f32"
253        mthd.field_is_rs_enum[field_name] = False
254    else:
255        mthd.field_rs_types[field_name] = "u32"
256        mthd.field_is_rs_enum[field_name] = False
257
258    # TRUE and FALSE are special cases.
259    if len(field_value) == 2:
260        for enumerant in field_value:
261            if enumerant.lower() == 'true' or enumerant.lower() == 'false':
262                mthd.field_rs_types[field_name] = "bool"
263                mthd.field_is_rs_enum[field_name] = False
264                break
265%>
266
267## If there are a range of values for a field, we define an enum.
268%for field_name in mthd.field_defs:
269    %if mthd.field_is_rs_enum[field_name]:
270#[repr(u16)]
271#[derive(Copy, Clone, Debug, PartialEq)]
272pub enum ${mthd.field_rs_types[field_name]} {
273    %for field_name, field_value in mthd.field_defs[field_name].items():
274    ${to_camel(rs_field_name(field_name))} = ${field_value.lower()},
275    %endfor
276}
277    %endif
278%endfor
279
280## We also define a struct with the fields for the mthd.
281#[derive(Copy, Clone, Debug, PartialEq)]
282pub struct ${to_camel(mthd_name)} {
283  %for field_name in mthddict[mthd_name].field_name_start:
284    pub ${rs_field_name(field_name.lower())}: ${mthd.field_rs_types[field_name]},
285  %endfor
286}
287
288## Notice that the "to_bits" implementation is identical, so the first brace is
289## not closed.
290% if not mthd.is_array:
291## This trait lays out how the conversion to u32 happens
292impl Mthd for ${to_camel(mthd_name)} {
293    const ADDR: u16 = ${mthd.addr.replace('(', '').replace(')', '')};
294    const CLASS: u16 = ${version[1].lower() if version is not None else nvcl.lower().replace("nv", "0x")};
295
296%else:
297impl ArrayMthd for ${to_camel(mthd_name)} {
298    const CLASS: u16 = ${version[1].lower() if version is not None else nvcl.lower().replace("nv", "0x")};
299
300    fn addr(i: usize) -> u16 {
301        <% assert not ('i' in mthd.addr and 'j' in mthd.addr) %>
302        (${mthd.addr.replace('j', 'i').replace('(', '').replace(')', '')}).try_into().unwrap()
303    }
304%endif
305
306    #[inline]
307    fn to_bits(self) -> u32 {
308        let mut val = 0;
309        %for field_name in mthddict[mthd_name].field_name_start:
310            <%
311                field_start = int(mthd.field_name_start[field_name])
312                field_end = int(mthd.field_name_end[field_name])
313                field_width = field_end - field_start + 1
314                field = rs_field_name(field_name.lower()) if mthd.field_rs_types[field_name] == "u32" else f"{rs_field_name(field_name)} as u32"
315            %>
316            %if field_width == 32:
317        val |= self.${field};
318            %else:
319                %if "as u32" in field:
320        assert!((self.${field}) < (1 << ${field_width}));
321        val |= (self.${field}) << ${field_start};
322                %else:
323        assert!(self.${field} < (1 << ${field_width}));
324        val |= self.${field} << ${field_start};
325                %endif
326            %endif
327        %endfor
328
329        val
330    }
331## Close the first brace.
332}
333%endfor
334""")
335
336## A mere convenience to convert snake_case to CamelCase. Numbers are prefixed
337## with "_".
338def to_camel(snake_str):
339    result = ''.join(word.title() for word in snake_str.split('_'))
340    return result if not result[0].isdigit() else '_' + result
341
342def rs_field_name(name):
343    name = name.lower()
344
345    # Fix up some Rust keywords
346    if name == 'type':
347        return 'type_'
348    elif name == 'override':
349        return 'override_'
350    elif name[0].isdigit():
351        return '_' + name
352    else:
353        return name
354
355def glob_match(glob, name):
356    if glob.endswith('*'):
357        return name.startswith(glob[:-1])
358    else:
359        assert '*' not in glob
360        return name == glob
361
362class method(object):
363    @property
364    def array_size(self):
365        for (glob, value) in METHOD_ARRAY_SIZES.items():
366            if glob_match(glob, self.name):
367                return value
368        return 0
369
370    @property
371    def is_float(self):
372        for glob in METHOD_IS_FLOAT:
373            if glob_match(glob, self.name):
374                assert len(self.field_defs) == 1
375                return True
376        return False
377
378def parse_header(nvcl, f):
379    # Simple state machine
380    # state 0 looking for a new method define
381    # state 1 looking for new fields in a method
382    # state 2 looking for enums for a fields in a method
383    # blank lines reset the state machine to 0
384
385    version = None
386    state = 0
387    mthddict = {}
388    curmthd = {}
389    for line in f:
390
391        if line.strip() == "":
392            state = 0
393            if (curmthd):
394                if not len(curmthd.field_name_start):
395                    del mthddict[curmthd.name]
396            curmthd = {}
397            continue
398
399        if line.startswith("#define"):
400            list = line.split();
401            if "_cl_" in list[1]:
402                continue
403
404            if not list[1].startswith(nvcl):
405                if len(list) > 2 and list[2].startswith("0x"):
406                    assert version is None
407                    version = (list[1], list[2])
408                continue
409
410            if list[1].endswith("TYPEDEF"):
411                continue
412
413            if state == 2:
414                teststr = nvcl + "_" + curmthd.name + "_" + curfield + "_"
415                if ":" in list[2]:
416                    state = 1
417                elif teststr in list[1]:
418                    curmthd.field_defs[curfield][list[1].removeprefix(teststr)] = list[2]
419                else:
420                    state = 1
421
422            if state == 1:
423                teststr = nvcl + "_" + curmthd.name + "_"
424                if teststr in list[1]:
425                    if ("0x" in list[2]):
426                        state = 1
427                    else:
428                        field = list[1].removeprefix(teststr)
429                        bitfield = list[2].split(":")
430                        curmthd.field_name_start[field] = bitfield[1]
431                        curmthd.field_name_end[field] = bitfield[0]
432                        curmthd.field_defs[field] = {}
433                        curfield = field
434                        state = 2
435                else:
436                    if not len(curmthd.field_name_start):
437                        del mthddict[curmthd.name]
438                        curmthd = {}
439                    state = 0
440
441            if state == 0:
442                if (curmthd):
443                    if not len(curmthd.field_name_start):
444                        del mthddict[curmthd.name]
445                teststr = nvcl + "_"
446                is_array = 0
447                if (':' in list[2]):
448                    continue
449                name = list[1].removeprefix(teststr)
450                if name.endswith("(i)"):
451                    is_array = 1
452                    name = name.removesuffix("(i)")
453                if name.endswith("(j)"):
454                    is_array = 1
455                    name = name.removesuffix("(j)")
456                x = method()
457                x.name = name
458                x.addr = list[2]
459                x.is_array = is_array
460                x.field_name_start = {}
461                x.field_name_end = {}
462                x.field_defs = {}
463                x.field_rs_types = {}
464                x.field_is_rs_enum = {}
465                mthddict[x.name] = x
466
467                curmthd = x
468                state = 1
469
470    return (version, mthddict)
471
472def convert_to_rust_constants(filename):
473    with open(filename, 'r') as file:
474        lines = file.readlines()
475
476    rust_items = []
477    processed_constants = {}
478    file_prefix = "NV" + os.path.splitext(os.path.basename(filename))[0].upper() + "_"
479    file_prefix = file_prefix.replace('CL', '')
480    for line in lines:
481        match = re.match(r'#define\s+(\w+)\((\w+)\)\s+(.+)', line.strip())
482        if match:
483            name, arg, expr = match.groups()
484            if name in processed_constants:
485                processed_constants[name] += 1
486                name += f"_{processed_constants[name]}"
487            else:
488                processed_constants[name] = 0
489            name = name.replace(file_prefix, '')
490            # convert to snake case
491            name =  re.sub(r'(?<=[a-z])(?=[A-Z])', '_', name).lower()
492            rust_items.append(f"#[inline]\npub fn {name}  ({arg}: u32) -> u32 {{ {expr.replace('(', '').replace(')', '')} }} ")
493        else:
494            match = re.match(r'#define\s+(\w+)\s+(?:MW\()?(\d+):(\d+)\)?', line.strip())
495            if match:
496                name, high, low = match.groups()
497                high = int(high) + 1  # Convert to exclusive range
498                if name in processed_constants:
499                    processed_constants[name] += 1
500                    name += f"_{processed_constants[name]}"
501                else:
502                    processed_constants[name] = 0
503                # name = name.replace('__', '_').replace(file_prefix, '')
504                name = name.replace(file_prefix, '')
505                rust_items.append(f"pub const {name}: Range<u32> = {low}..{high};")
506            else:
507                match = re.match(r'#define\s+(\w+)\s+\(?0x(\w+)\)?', line.strip())
508                if match:
509                    name, value = match.groups()
510                    if name in processed_constants:
511                        processed_constants[name] += 1
512                        name += f"_{processed_constants[name]}"
513                    else:
514                        processed_constants[name] = 0
515                    name = name.replace(file_prefix, '')
516                    rust_items.append(f"pub const {name}: u32 = 0x{value};")
517                else:
518                    match = re.match(r'#define\s+(\w+)\s+\(?(\d+)\)?', line.strip())
519                    if match:
520                        name, value = match.groups()
521                        if name in processed_constants:
522                            processed_constants[name] += 1
523                            name += f"_{processed_constants[name]}"
524                        else:
525                            processed_constants[name] = 0
526                        name = name.replace(file_prefix, '')
527                        rust_items.append(f"pub const {name}: u32 = {value};")
528
529    return '\n'.join(rust_items)
530
531def main():
532    parser = argparse.ArgumentParser()
533    parser.add_argument('--out-h', required=False, help='Output C header.')
534    parser.add_argument('--out-c', required=False, help='Output C file.')
535    parser.add_argument('--out-rs', required=False, help='Output Rust file.')
536    parser.add_argument('--out-rs-mthd', required=False,
537                        help='Output Rust file for methods.')
538    parser.add_argument('--in-h',
539                        help='Input class header file.',
540                        required=True)
541    args = parser.parse_args()
542
543    clheader = os.path.basename(args.in_h)
544    nvcl = clheader
545    nvcl = nvcl.removeprefix("cl")
546    nvcl = nvcl.removesuffix(".h")
547    nvcl = nvcl.upper()
548    nvcl = "NV" + nvcl
549
550    with open(args.in_h, 'r', encoding='utf-8') as f:
551        (version, mthddict) = parse_header(nvcl, f)
552
553    environment = {
554        'clheader': clheader,
555        'nvcl': nvcl,
556        'version': version,
557        'mthddict': mthddict,
558        'rs_field_name': rs_field_name,
559        'to_camel': to_camel,
560        'bs': '\\'
561    }
562
563    try:
564        if args.out_h is not None:
565            environment['header'] = os.path.basename(args.out_h)
566            with open(args.out_h, 'w', encoding='utf-8') as f:
567                f.write(TEMPLATE_H.render(**environment))
568        if args.out_c is not None:
569            with open(args.out_c, 'w', encoding='utf-8') as f:
570                f.write(TEMPLATE_C.render(**environment))
571        if args.out_rs is not None:
572            with open(args.out_rs, 'w', encoding='utf-8') as f:
573                f.write(TEMPLATE_RS.render(**environment))
574        if args.out_rs_mthd is not None:
575            with open(args.out_rs_mthd, 'w', encoding='utf-8') as f:
576                f.write("#![allow(non_camel_case_types)]\n")
577                f.write("#![allow(non_snake_case)]\n")
578                f.write("#![allow(non_upper_case_globals)]\n\n")
579                f.write("use std::ops::Range;\n")
580                f.write("use crate::Mthd;\n")
581                f.write("use crate::ArrayMthd;\n")
582                f.write("\n")
583                f.write(convert_to_rust_constants(args.in_h))
584                f.write("\n")
585                f.write(TEMPLATE_RS_MTHD.render(**environment))
586
587    except Exception:
588        # In the event there's an error, this imports some helpers from mako
589        # to print a useful stack trace and prints it, then exits with
590        # status 1, if python is run with debug; otherwise it just raises
591        # the exception
592        import sys
593        from mako import exceptions
594        print(exceptions.text_error_template().render(), file=sys.stderr)
595        sys.exit(1)
596
597if __name__ == '__main__':
598    main()
599