xref: /aosp_15_r20/external/mesa3d/src/freedreno/registers/gen_header.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1#!/usr/bin/python3
2#
3# Copyright © 2019-2024 Google, Inc.
4#
5# SPDX-License-Identifier: MIT
6
7import xml.parsers.expat
8import sys
9import os
10import collections
11import argparse
12import time
13import datetime
14
15class Error(Exception):
16	def __init__(self, message):
17		self.message = message
18
19class Enum(object):
20	def __init__(self, name):
21		self.name = name
22		self.values = []
23
24	def has_name(self, name):
25		for (n, value) in self.values:
26			if n == name:
27				return True
28		return False
29
30	def names(self):
31		return [n for (n, value) in self.values]
32
33	def dump(self, is_deprecated):
34		use_hex = False
35		for (name, value) in self.values:
36			if value > 0x1000:
37				use_hex = True
38
39		print("enum %s {" % self.name)
40		for (name, value) in self.values:
41			if use_hex:
42				print("\t%s = 0x%08x," % (name, value))
43			else:
44				print("\t%s = %d," % (name, value))
45		print("};\n")
46
47	def dump_pack_struct(self, is_deprecated):
48		pass
49
50class Field(object):
51	def __init__(self, name, low, high, shr, type, parser):
52		self.name = name
53		self.low = low
54		self.high = high
55		self.shr = shr
56		self.type = type
57
58		builtin_types = [ None, "a3xx_regid", "boolean", "uint", "hex", "int", "fixed", "ufixed", "float", "address", "waddress" ]
59
60		maxpos = parser.current_bitsize - 1
61
62		if low < 0 or low > maxpos:
63			raise parser.error("low attribute out of range: %d" % low)
64		if high < 0 or high > maxpos:
65			raise parser.error("high attribute out of range: %d" % high)
66		if high < low:
67			raise parser.error("low is greater than high: low=%d, high=%d" % (low, high))
68		if self.type == "boolean" and not low == high:
69			raise parser.error("booleans should be 1 bit fields")
70		elif self.type == "float" and not (high - low == 31 or high - low == 15):
71			raise parser.error("floats should be 16 or 32 bit fields")
72		elif not self.type in builtin_types and not self.type in parser.enums:
73			raise parser.error("unknown type '%s'" % self.type)
74
75	def ctype(self, var_name):
76		if self.type == None:
77			type = "uint32_t"
78			val = var_name
79		elif self.type == "boolean":
80			type = "bool"
81			val = var_name
82		elif self.type == "uint" or self.type == "hex" or self.type == "a3xx_regid":
83			type = "uint32_t"
84			val = var_name
85		elif self.type == "int":
86			type = "int32_t"
87			val = var_name
88		elif self.type == "fixed":
89			type = "float"
90			val = "((int32_t)(%s * %d.0))" % (var_name, 1 << self.radix)
91		elif self.type == "ufixed":
92			type = "float"
93			val = "((uint32_t)(%s * %d.0))" % (var_name, 1 << self.radix)
94		elif self.type == "float" and self.high - self.low == 31:
95			type = "float"
96			val = "fui(%s)" % var_name
97		elif self.type == "float" and self.high - self.low == 15:
98			type = "float"
99			val = "_mesa_float_to_half(%s)" % var_name
100		elif self.type in [ "address", "waddress" ]:
101			type = "uint64_t"
102			val = var_name
103		else:
104			type = "enum %s" % self.type
105			val = var_name
106
107		if self.shr > 0:
108			val = "(%s >> %d)" % (val, self.shr)
109
110		return (type, val)
111
112def tab_to(name, value):
113	tab_count = (68 - (len(name) & ~7)) // 8
114	if tab_count <= 0:
115		tab_count = 1
116	print(name + ('\t' * tab_count) + value)
117
118def mask(low, high):
119	return ((0xffffffffffffffff >> (64 - (high + 1 - low))) << low)
120
121def field_name(reg, f):
122	if f.name:
123		name = f.name.lower()
124	else:
125		# We hit this path when a reg is defined with no bitset fields, ie.
126		# 	<reg32 offset="0x88db" name="RB_BLIT_DST_ARRAY_PITCH" low="0" high="28" shr="6" type="uint"/>
127		name = reg.name.lower()
128
129	if (name in [ "double", "float", "int" ]) or not (name[0].isalpha()):
130			name = "_" + name
131
132	return name
133
134# indices - array of (ctype, stride, __offsets_NAME)
135def indices_varlist(indices):
136	return ", ".join(["i%d" % i for i in range(len(indices))])
137
138def indices_prototype(indices):
139	return ", ".join(["%s i%d" % (ctype, idx)
140			for (idx, (ctype, stride, offset)) in  enumerate(indices)])
141
142def indices_strides(indices):
143	return " + ".join(["0x%x*i%d" % (stride, idx)
144					if stride else
145					"%s(i%d)" % (offset, idx)
146			for (idx, (ctype, stride, offset)) in  enumerate(indices)])
147
148class Bitset(object):
149	def __init__(self, name, template):
150		self.name = name
151		self.inline = False
152		if template:
153			self.fields = template.fields[:]
154		else:
155			self.fields = []
156
157	# Get address field if there is one in the bitset, else return None:
158	def get_address_field(self):
159		for f in self.fields:
160			if f.type in [ "address", "waddress" ]:
161				return f
162		return None
163
164	def dump_regpair_builder(self, reg):
165		print("#ifndef NDEBUG")
166		known_mask = 0
167		for f in self.fields:
168			known_mask |= mask(f.low, f.high)
169			if f.type in [ "boolean", "address", "waddress" ]:
170				continue
171			type, val = f.ctype("fields.%s" % field_name(reg, f))
172			print("    assert((%-40s & 0x%08x) == 0);" % (val, 0xffffffff ^ mask(0 , f.high - f.low)))
173		print("    assert((%-40s & 0x%08x) == 0);" % ("fields.unknown", known_mask))
174		print("#endif\n")
175
176		print("    return (struct fd_reg_pair) {")
177		print("        .reg = (uint32_t)%s," % reg.reg_offset())
178		print("        .value =")
179		for f in self.fields:
180			if f.type in [ "address", "waddress" ]:
181				continue
182			else:
183				type, val = f.ctype("fields.%s" % field_name(reg, f))
184				print("            (%-40s << %2d) |" % (val, f.low))
185		value_name = "dword"
186		if reg.bit_size == 64:
187			value_name = "qword"
188		print("            fields.unknown | fields.%s," % (value_name,))
189
190		address = self.get_address_field()
191		if address:
192			print("        .bo = fields.bo,")
193			print("        .is_address = true,")
194			if f.type == "waddress":
195				print("        .bo_write = true,")
196			print("        .bo_offset = fields.bo_offset,")
197			print("        .bo_shift = %d," % address.shr)
198			print("        .bo_low = %d," % address.low)
199
200		print("    };")
201
202	def dump_pack_struct(self, is_deprecated, reg=None):
203		if not reg:
204			return
205
206		prefix = reg.full_name
207
208		print("struct %s {" % prefix)
209		for f in self.fields:
210			if f.type in [ "address", "waddress" ]:
211				tab_to("    __bo_type", "bo;")
212				tab_to("    uint32_t", "bo_offset;")
213				continue
214			name = field_name(reg, f)
215
216			type, val = f.ctype("var")
217
218			tab_to("    %s" % type, "%s;" % name)
219		if reg.bit_size == 64:
220			tab_to("    uint64_t", "unknown;")
221			tab_to("    uint64_t", "qword;")
222		else:
223			tab_to("    uint32_t", "unknown;")
224			tab_to("    uint32_t", "dword;")
225		print("};\n")
226
227		depcrstr = ""
228		if is_deprecated:
229			depcrstr = " __attribute__((deprecated))"
230		if reg.array:
231			print("static inline%s struct fd_reg_pair\npack_%s(uint32_t __i, struct %s fields)\n{" %
232				  (depcrstr, prefix, prefix))
233		else:
234			print("static inline%s struct fd_reg_pair\npack_%s(struct %s fields)\n{" %
235				  (depcrstr, prefix, prefix))
236
237		self.dump_regpair_builder(reg)
238
239		print("\n}\n")
240
241		if self.get_address_field():
242			skip = ", { .reg = 0 }"
243		else:
244			skip = ""
245
246		if reg.array:
247			print("#define %s(__i, ...) pack_%s(__i, __struct_cast(%s) { __VA_ARGS__ })%s\n" %
248				  (prefix, prefix, prefix, skip))
249		else:
250			print("#define %s(...) pack_%s(__struct_cast(%s) { __VA_ARGS__ })%s\n" %
251				  (prefix, prefix, prefix, skip))
252
253
254	def dump(self, is_deprecated, prefix=None):
255		if prefix == None:
256			prefix = self.name
257		for f in self.fields:
258			if f.name:
259				name = prefix + "_" + f.name
260			else:
261				name = prefix
262
263			if not f.name and f.low == 0 and f.shr == 0 and not f.type in ["float", "fixed", "ufixed"]:
264				pass
265			elif f.type == "boolean" or (f.type == None and f.low == f.high):
266				tab_to("#define %s" % name, "0x%08x" % (1 << f.low))
267			else:
268				tab_to("#define %s__MASK" % name, "0x%08x" % mask(f.low, f.high))
269				tab_to("#define %s__SHIFT" % name, "%d" % f.low)
270				type, val = f.ctype("val")
271
272				print("static inline uint32_t %s(%s val)\n{" % (name, type))
273				if f.shr > 0:
274					print("\tassert(!(val & 0x%x));" % mask(0, f.shr - 1))
275				print("\treturn ((%s) << %s__SHIFT) & %s__MASK;\n}" % (val, name, name))
276		print()
277
278class Array(object):
279	def __init__(self, attrs, domain, variant, parent, index_type):
280		if "name" in attrs:
281			self.local_name = attrs["name"]
282		else:
283			self.local_name = ""
284		self.domain = domain
285		self.variant = variant
286		self.parent = parent
287		if self.parent:
288			self.name = self.parent.name + "_" + self.local_name
289		else:
290			self.name = self.local_name
291		if "offsets" in attrs:
292			self.offsets = map(lambda i: "0x%08x" % int(i, 0), attrs["offsets"].split(","))
293			self.fixed_offsets = True
294		elif "doffsets" in attrs:
295			self.offsets = map(lambda s: "(%s)" % s , attrs["doffsets"].split(","))
296			self.fixed_offsets = True
297		else:
298			self.offset = int(attrs["offset"], 0)
299			self.stride = int(attrs["stride"], 0)
300			self.fixed_offsets = False
301		if "index" in attrs:
302			self.index_type = index_type
303		else:
304			self.index_type = None
305		self.length = int(attrs["length"], 0)
306		if "usage" in attrs:
307			self.usages = attrs["usage"].split(',')
308		else:
309			self.usages = None
310
311	def index_ctype(self):
312		if not self.index_type:
313			return "uint32_t"
314		else:
315			return "enum %s" % self.index_type.name
316
317	# Generate array of (ctype, stride, __offsets_NAME)
318	def indices(self):
319		if self.parent:
320			indices = self.parent.indices()
321		else:
322			indices = []
323		if self.length != 1:
324			if self.fixed_offsets:
325				indices.append((self.index_ctype(), None, f"__offset_{self.local_name}"))
326			else:
327				indices.append((self.index_ctype(), self.stride, None))
328		return indices
329
330	def total_offset(self):
331		offset = 0
332		if not self.fixed_offsets:
333			offset += self.offset
334		if self.parent:
335			offset += self.parent.total_offset()
336		return offset
337
338	def dump(self, is_deprecated):
339		depcrstr = ""
340		if is_deprecated:
341			depcrstr = " __attribute__((deprecated))"
342		proto = indices_varlist(self.indices())
343		strides = indices_strides(self.indices())
344		array_offset = self.total_offset()
345		if self.fixed_offsets:
346			print("static inline%s uint32_t __offset_%s(%s idx)" % (depcrstr, self.local_name, self.index_ctype()))
347			print("{\n\tswitch (idx) {")
348			if self.index_type:
349				for val, offset in zip(self.index_type.names(), self.offsets):
350					print("\t\tcase %s: return %s;" % (val, offset))
351			else:
352				for idx, offset in enumerate(self.offsets):
353					print("\t\tcase %d: return %s;" % (idx, offset))
354			print("\t\tdefault: return INVALID_IDX(idx);")
355			print("\t}\n}")
356		if proto == '':
357			tab_to("#define REG_%s_%s" % (self.domain, self.name), "0x%08x\n" % array_offset)
358		else:
359			tab_to("#define REG_%s_%s(%s)" % (self.domain, self.name, proto), "(0x%08x + %s )\n" % (array_offset, strides))
360
361	def dump_pack_struct(self, is_deprecated):
362		pass
363
364	def dump_regpair_builder(self):
365		pass
366
367class Reg(object):
368	def __init__(self, attrs, domain, array, bit_size):
369		self.name = attrs["name"]
370		self.domain = domain
371		self.array = array
372		self.offset = int(attrs["offset"], 0)
373		self.type = None
374		self.bit_size = bit_size
375		if array:
376			self.name = array.name + "_" + self.name
377		self.full_name = self.domain + "_" + self.name
378		if "stride" in attrs:
379			self.stride = int(attrs["stride"], 0)
380			self.length = int(attrs["length"], 0)
381		else:
382			self.stride = None
383			self.length = None
384
385	# Generate array of (ctype, stride, __offsets_NAME)
386	def indices(self):
387		if self.array:
388			indices = self.array.indices()
389		else:
390			indices = []
391		if self.stride:
392			indices.append(("uint32_t", self.stride, None))
393		return indices
394
395	def total_offset(self):
396		if self.array:
397			return self.array.total_offset() + self.offset
398		else:
399			return self.offset
400
401	def reg_offset(self):
402		if self.array:
403			offset = self.array.offset + self.offset
404			return "(0x%08x + 0x%x*__i)" % (offset, self.array.stride)
405		return "0x%08x" % self.offset
406
407	def dump(self, is_deprecated):
408		depcrstr = ""
409		if is_deprecated:
410			depcrstr = " __attribute__((deprecated)) "
411		proto = indices_prototype(self.indices())
412		strides = indices_strides(self.indices())
413		offset = self.total_offset()
414		if proto == '':
415			tab_to("#define REG_%s" % self.full_name, "0x%08x" % offset)
416		else:
417			print("static inline%s uint32_t REG_%s(%s) { return 0x%08x + %s; }" % (depcrstr, self.full_name, proto, offset, strides))
418
419		if self.bitset.inline:
420			self.bitset.dump(is_deprecated, self.full_name)
421		print("")
422
423	def dump_pack_struct(self, is_deprecated):
424		if self.bitset.inline:
425			self.bitset.dump_pack_struct(is_deprecated, self)
426
427	def dump_regpair_builder(self):
428		if self.bitset.inline:
429			self.bitset.dump_regpair_builder(self)
430
431	def dump_py(self):
432		print("\tREG_%s = 0x%08x" % (self.full_name, self.offset))
433
434
435class Parser(object):
436	def __init__(self):
437		self.current_array = None
438		self.current_domain = None
439		self.current_prefix = None
440		self.current_prefix_type = None
441		self.current_stripe = None
442		self.current_bitset = None
443		self.current_bitsize = 32
444		# The varset attribute on the domain specifies the enum which
445		# specifies all possible hw variants:
446		self.current_varset = None
447		# Regs that have multiple variants.. we only generated the C++
448		# template based struct-packers for these
449		self.variant_regs = {}
450		# Information in which contexts regs are used, to be used in
451		# debug options
452		self.usage_regs = collections.defaultdict(list)
453		self.bitsets = {}
454		self.enums = {}
455		self.variants = set()
456		self.file = []
457		self.xml_files = []
458		self.copyright_year = None
459		self.authors = []
460		self.license = None
461
462	def error(self, message):
463		parser, filename = self.stack[-1]
464		return Error("%s:%d:%d: %s" % (filename, parser.CurrentLineNumber, parser.CurrentColumnNumber, message))
465
466	def prefix(self, variant=None):
467		if self.current_prefix_type == "variant" and variant:
468			return variant
469		elif self.current_stripe:
470			return self.current_stripe + "_" + self.current_domain
471		elif self.current_prefix:
472			return self.current_prefix + "_" + self.current_domain
473		else:
474			return self.current_domain
475
476	def parse_field(self, name, attrs):
477		try:
478			if "pos" in attrs:
479				high = low = int(attrs["pos"], 0)
480			elif "high" in attrs and "low" in attrs:
481				high = int(attrs["high"], 0)
482				low = int(attrs["low"], 0)
483			else:
484				low = 0
485				high = self.current_bitsize - 1
486
487			if "type" in attrs:
488				type = attrs["type"]
489			else:
490				type = None
491
492			if "shr" in attrs:
493				shr = int(attrs["shr"], 0)
494			else:
495				shr = 0
496
497			b = Field(name, low, high, shr, type, self)
498
499			if type == "fixed" or type == "ufixed":
500				b.radix = int(attrs["radix"], 0)
501
502			self.current_bitset.fields.append(b)
503		except ValueError as e:
504			raise self.error(e)
505
506	def parse_varset(self, attrs):
507		# Inherit the varset from the enclosing domain if not overriden:
508		varset = self.current_varset
509		if "varset" in attrs:
510			varset = self.enums[attrs["varset"]]
511		return varset
512
513	def parse_variants(self, attrs):
514		if not "variants" in attrs:
515				return None
516		variant = attrs["variants"].split(",")[0]
517		if "-" in variant:
518			variant = variant[:variant.index("-")]
519
520		varset = self.parse_varset(attrs)
521
522		assert varset.has_name(variant)
523
524		return variant
525
526	def add_all_variants(self, reg, attrs, parent_variant):
527		# TODO this should really handle *all* variants, including dealing
528		# with open ended ranges (ie. "A2XX,A4XX-") (we have the varset
529		# enum now to make that possible)
530		variant = self.parse_variants(attrs)
531		if not variant:
532			variant = parent_variant
533
534		if reg.name not in self.variant_regs:
535			self.variant_regs[reg.name] = {}
536		else:
537			# All variants must be same size:
538			v = next(iter(self.variant_regs[reg.name]))
539			assert self.variant_regs[reg.name][v].bit_size == reg.bit_size
540
541		self.variant_regs[reg.name][variant] = reg
542
543	def add_all_usages(self, reg, usages):
544		if not usages:
545			return
546
547		for usage in usages:
548			self.usage_regs[usage].append(reg)
549
550		self.variants.add(reg.domain)
551
552	def do_validate(self, schemafile):
553		try:
554			from lxml import etree
555
556			parser, filename = self.stack[-1]
557			dirname = os.path.dirname(filename)
558
559			# we expect this to look like <namespace url> schema.xsd.. I think
560			# technically it is supposed to be just a URL, but that doesn't
561			# quite match up to what we do.. Just skip over everything up to
562			# and including the first whitespace character:
563			schemafile = schemafile[schemafile.rindex(" ")+1:]
564
565			# this is a bit cheezy, but the xml file to validate could be
566			# in a child director, ie. we don't really know where the schema
567			# file is, the way the rnn C code does.  So if it doesn't exist
568			# just look one level up
569			if not os.path.exists(dirname + "/" + schemafile):
570				schemafile = "../" + schemafile
571
572			if not os.path.exists(dirname + "/" + schemafile):
573				raise self.error("Cannot find schema for: " + filename)
574
575			xmlschema_doc = etree.parse(dirname + "/" + schemafile)
576			xmlschema = etree.XMLSchema(xmlschema_doc)
577
578			xml_doc = etree.parse(filename)
579			if not xmlschema.validate(xml_doc):
580				error_str = str(xmlschema.error_log.filter_from_errors()[0])
581				raise self.error("Schema validation failed for: " + filename + "\n" + error_str)
582		except ImportError:
583			print("lxml not found, skipping validation", file=sys.stderr)
584
585	def do_parse(self, filename):
586		filepath = os.path.abspath(filename)
587		if filepath in self.xml_files:
588			return
589		self.xml_files.append(filepath)
590		file = open(filename, "rb")
591		parser = xml.parsers.expat.ParserCreate()
592		self.stack.append((parser, filename))
593		parser.StartElementHandler = self.start_element
594		parser.EndElementHandler = self.end_element
595		parser.CharacterDataHandler = self.character_data
596		parser.buffer_text = True
597		parser.ParseFile(file)
598		self.stack.pop()
599		file.close()
600
601	def parse(self, rnn_path, filename):
602		self.path = rnn_path
603		self.stack = []
604		self.do_parse(filename)
605
606	def parse_reg(self, attrs, bit_size):
607		self.current_bitsize = bit_size
608		if "type" in attrs and attrs["type"] in self.bitsets:
609			bitset = self.bitsets[attrs["type"]]
610			if bitset.inline:
611				self.current_bitset = Bitset(attrs["name"], bitset)
612				self.current_bitset.inline = True
613			else:
614				self.current_bitset = bitset
615		else:
616			self.current_bitset = Bitset(attrs["name"], None)
617			self.current_bitset.inline = True
618			if "type" in attrs:
619				self.parse_field(None, attrs)
620
621		variant = self.parse_variants(attrs)
622		if not variant and self.current_array:
623			variant = self.current_array.variant
624
625		self.current_reg = Reg(attrs, self.prefix(variant), self.current_array, bit_size)
626		self.current_reg.bitset = self.current_bitset
627
628		if len(self.stack) == 1:
629			self.file.append(self.current_reg)
630
631		if variant is not None:
632			self.add_all_variants(self.current_reg, attrs, variant)
633
634		usages = None
635		if "usage" in attrs:
636			usages = attrs["usage"].split(',')
637		elif self.current_array:
638			usages = self.current_array.usages
639
640		self.add_all_usages(self.current_reg, usages)
641
642	def start_element(self, name, attrs):
643		self.cdata = ""
644		if name == "import":
645			filename = attrs["file"]
646			self.do_parse(os.path.join(self.path, filename))
647		elif name == "domain":
648			self.current_domain = attrs["name"]
649			if "prefix" in attrs:
650				self.current_prefix = self.parse_variants(attrs)
651				self.current_prefix_type = attrs["prefix"]
652			else:
653				self.current_prefix = None
654				self.current_prefix_type = None
655			if "varset" in attrs:
656				self.current_varset = self.enums[attrs["varset"]]
657		elif name == "stripe":
658			self.current_stripe = self.parse_variants(attrs)
659		elif name == "enum":
660			self.current_enum_value = 0
661			self.current_enum = Enum(attrs["name"])
662			self.enums[attrs["name"]] = self.current_enum
663			if len(self.stack) == 1:
664				self.file.append(self.current_enum)
665		elif name == "value":
666			if "value" in attrs:
667				value = int(attrs["value"], 0)
668			else:
669				value = self.current_enum_value
670			self.current_enum.values.append((attrs["name"], value))
671		elif name == "reg32":
672			self.parse_reg(attrs, 32)
673		elif name == "reg64":
674			self.parse_reg(attrs, 64)
675		elif name == "array":
676			self.current_bitsize = 32
677			variant = self.parse_variants(attrs)
678			index_type = self.enums[attrs["index"]] if "index" in attrs else None
679			self.current_array = Array(attrs, self.prefix(variant), variant, self.current_array, index_type)
680			if len(self.stack) == 1:
681				self.file.append(self.current_array)
682		elif name == "bitset":
683			self.current_bitset = Bitset(attrs["name"], None)
684			if "inline" in attrs and attrs["inline"] == "yes":
685				self.current_bitset.inline = True
686			self.bitsets[self.current_bitset.name] = self.current_bitset
687			if len(self.stack) == 1 and not self.current_bitset.inline:
688				self.file.append(self.current_bitset)
689		elif name == "bitfield" and self.current_bitset:
690			self.parse_field(attrs["name"], attrs)
691		elif name == "database":
692			self.do_validate(attrs["xsi:schemaLocation"])
693		elif name == "copyright":
694			self.copyright_year = attrs["year"]
695		elif name == "author":
696			self.authors.append(attrs["name"] + " <" + attrs["email"] + "> " + attrs["name"])
697
698	def end_element(self, name):
699		if name == "domain":
700			self.current_domain = None
701			self.current_prefix = None
702			self.current_prefix_type = None
703		elif name == "stripe":
704			self.current_stripe = None
705		elif name == "bitset":
706			self.current_bitset = None
707		elif name == "reg32":
708			self.current_reg = None
709		elif name == "array":
710			self.current_array = self.current_array.parent
711		elif name == "enum":
712			self.current_enum = None
713		elif name == "license":
714			self.license = self.cdata
715
716	def character_data(self, data):
717		self.cdata += data
718
719	def dump_reg_usages(self):
720		d = collections.defaultdict(list)
721		for usage, regs in self.usage_regs.items():
722			for reg in regs:
723				variants = self.variant_regs.get(reg.name)
724				if variants:
725					for variant, vreg in variants.items():
726						if reg == vreg:
727							d[(usage, variant)].append(reg)
728				else:
729					for variant in self.variants:
730						d[(usage, variant)].append(reg)
731
732		print("#ifdef __cplusplus")
733
734		for usage, regs in self.usage_regs.items():
735			print("template<chip CHIP> constexpr inline uint16_t %s_REGS[] = {};" % (usage.upper()))
736
737		for (usage, variant), regs in d.items():
738			offsets = []
739
740			for reg in regs:
741				if reg.array:
742					for i in range(reg.array.length):
743						offsets.append(reg.array.offset + reg.offset + i * reg.array.stride)
744						if reg.bit_size == 64:
745							offsets.append(offsets[-1] + 1)
746				else:
747					offsets.append(reg.offset)
748					if reg.bit_size == 64:
749						offsets.append(offsets[-1] + 1)
750
751			offsets.sort()
752
753			print("template<> constexpr inline uint16_t %s_REGS<%s>[] = {" % (usage.upper(), variant))
754			for offset in offsets:
755				print("\t%s," % hex(offset))
756			print("};")
757
758		print("#endif")
759
760	def has_variants(self, reg):
761		return reg.name in self.variant_regs and len(self.variant_regs[reg.name]) > 1
762
763	def dump(self):
764		enums = []
765		bitsets = []
766		regs = []
767		for e in self.file:
768			if isinstance(e, Enum):
769				enums.append(e)
770			elif isinstance(e, Bitset):
771				bitsets.append(e)
772			else:
773				regs.append(e)
774
775		for e in enums + bitsets + regs:
776			e.dump(self.has_variants(e))
777
778		self.dump_reg_usages()
779
780
781	def dump_regs_py(self):
782		regs = []
783		for e in self.file:
784			if isinstance(e, Reg):
785				regs.append(e)
786
787		for e in regs:
788			e.dump_py()
789
790
791	def dump_reg_variants(self, regname, variants):
792		# Don't bother for things that only have a single variant:
793		if len(variants) == 1:
794			return
795		print("#ifdef __cplusplus")
796		print("struct __%s {" % regname)
797		# TODO be more clever.. we should probably figure out which
798		# fields have the same type in all variants (in which they
799		# appear) and stuff everything else in a variant specific
800		# sub-structure.
801		seen_fields = []
802		bit_size = 32
803		array = False
804		address = None
805		for variant in variants.keys():
806			print("    /* %s fields: */" % variant)
807			reg = variants[variant]
808			bit_size = reg.bit_size
809			array = reg.array
810			for f in reg.bitset.fields:
811				fld_name = field_name(reg, f)
812				if fld_name in seen_fields:
813					continue
814				seen_fields.append(fld_name)
815				name = fld_name.lower()
816				if f.type in [ "address", "waddress" ]:
817					if address:
818						continue
819					address = f
820					tab_to("    __bo_type", "bo;")
821					tab_to("    uint32_t", "bo_offset;")
822					continue
823				type, val = f.ctype("var")
824				tab_to("    %s" %type, "%s;" %name)
825		print("    /* fallback fields: */")
826		if bit_size == 64:
827			tab_to("    uint64_t", "unknown;")
828			tab_to("    uint64_t", "qword;")
829		else:
830			tab_to("    uint32_t", "unknown;")
831			tab_to("    uint32_t", "dword;")
832		print("};")
833		# TODO don't hardcode the varset enum name
834		varenum = "chip"
835		print("template <%s %s>" % (varenum, varenum.upper()))
836		print("static inline struct fd_reg_pair")
837		xtra = ""
838		xtravar = ""
839		if array:
840			xtra = "int __i, "
841			xtravar = "__i, "
842		print("__%s(%sstruct __%s fields) {" % (regname, xtra, regname))
843		for variant in variants.keys():
844			print("  if (%s == %s) {" % (varenum.upper(), variant))
845			reg = variants[variant]
846			reg.dump_regpair_builder()
847			print("  } else")
848		print("    assert(!\"invalid variant\");")
849		print("}")
850
851		if bit_size == 64:
852			skip = ", { .reg = 0 }"
853		else:
854			skip = ""
855
856		print("#define %s(VARIANT, %s...) __%s<VARIANT>(%s{__VA_ARGS__})%s" % (regname, xtravar, regname, xtravar, skip))
857		print("#endif /* __cplusplus */")
858
859	def dump_structs(self):
860		for e in self.file:
861			e.dump_pack_struct(self.has_variants(e))
862
863		for regname in self.variant_regs:
864			self.dump_reg_variants(regname, self.variant_regs[regname])
865
866
867def dump_c(args, guard, func):
868	p = Parser()
869
870	try:
871		p.parse(args.rnn, args.xml)
872	except Error as e:
873		print(e, file=sys.stderr)
874		exit(1)
875
876	print("#ifndef %s\n#define %s\n" % (guard, guard))
877
878	print("""/* Autogenerated file, DO NOT EDIT manually!
879
880This file was generated by the rules-ng-ng gen_header.py tool in this git repository:
881http://gitlab.freedesktop.org/mesa/mesa/
882git clone https://gitlab.freedesktop.org/mesa/mesa.git
883
884The rules-ng-ng source files this header was generated from are:
885""")
886	maxlen = 0
887	for filepath in p.xml_files:
888		maxlen = max(maxlen, len(filepath))
889	for filepath in p.xml_files:
890		pad = " " * (maxlen - len(filepath))
891		filesize = str(os.path.getsize(filepath))
892		filesize = " " * (7 - len(filesize)) + filesize
893		filetime = time.ctime(os.path.getmtime(filepath))
894		print("- " + filepath + pad + " (" + filesize + " bytes, from " + filetime + ")")
895	if p.copyright_year:
896		current_year = str(datetime.date.today().year)
897		print()
898		print("Copyright © %s-%s by the following authors:" % (p.copyright_year, current_year))
899		for author in p.authors:
900			print("- " + author)
901	if p.license:
902		print(p.license)
903	print("*/")
904
905	print()
906	print("#ifdef __KERNEL__")
907	print("#include <linux/bug.h>")
908	print("#define assert(x) BUG_ON(!(x))")
909	print("#else")
910	print("#include <assert.h>")
911	print("#endif")
912	print()
913
914	print("#ifdef __cplusplus")
915	print("#define __struct_cast(X)")
916	print("#else")
917	print("#define __struct_cast(X) (struct X)")
918	print("#endif")
919	print()
920
921	func(p)
922
923	print("\n#endif /* %s */" % guard)
924
925
926def dump_c_defines(args):
927	guard = str.replace(os.path.basename(args.xml), '.', '_').upper()
928	dump_c(args, guard, lambda p: p.dump())
929
930
931def dump_c_pack_structs(args):
932	guard = str.replace(os.path.basename(args.xml), '.', '_').upper() + '_STRUCTS'
933	dump_c(args, guard, lambda p: p.dump_structs())
934
935
936def dump_py_defines(args):
937	p = Parser()
938
939	try:
940		p.parse(args.rnn, args.xml)
941	except Error as e:
942		print(e, file=sys.stderr)
943		exit(1)
944
945	file_name = os.path.splitext(os.path.basename(args.xml))[0]
946
947	print("from enum import IntEnum")
948	print("class %sRegs(IntEnum):" % file_name.upper())
949
950	os.path.basename(args.xml)
951
952	p.dump_regs_py()
953
954
955def main():
956	parser = argparse.ArgumentParser()
957	parser.add_argument('--rnn', type=str, required=True)
958	parser.add_argument('--xml', type=str, required=True)
959
960	subparsers = parser.add_subparsers(required=True)
961
962	parser_c_defines = subparsers.add_parser('c-defines')
963	parser_c_defines.set_defaults(func=dump_c_defines)
964
965	parser_c_pack_structs = subparsers.add_parser('c-pack-structs')
966	parser_c_pack_structs.set_defaults(func=dump_c_pack_structs)
967
968	parser_py_defines = subparsers.add_parser('py-defines')
969	parser_py_defines.set_defaults(func=dump_py_defines)
970
971	args = parser.parse_args()
972	args.func(args)
973
974
975if __name__ == '__main__':
976	main()
977