xref: /aosp_15_r20/external/XNNPACK/tools/xngen.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import io
10import os
11import re
12import sys
13from itertools import chain
14
15
16def key_value_pair(line):
17  key, value = line.split("=", 1)
18  # represent value as integer, if possible, otherwise as str
19  try:
20    value = int(value)
21  except ValueError:
22    pass
23  return key, value
24
25
26parser = argparse.ArgumentParser(description='XNNPACK generator')
27parser.add_argument("input", metavar="FILE", nargs=1,
28          help="Input file")
29parser.add_argument("-D", dest="defines", metavar="KEY=VALUE", nargs="*",
30          type=key_value_pair, action="append",
31          help="Predefined variables")
32parser.add_argument("-o", "--output",
33          help='Output file')
34parser.set_defaults(defines=list())
35
36
37LEADING_WHITESPACE_REGEX = re.compile(r"^\s*", flags=0)
38
39
40def extract_leading_whitespace(line):
41  match = re.match(r"\s*", line)
42  return match.group(0) if match else ""
43
44
45def escape(line):
46  output_parts = []
47  while "${" in line:
48    start_pos = line.index("${")
49    end_pos = line.index("}", start_pos + 2)
50    if start_pos != 0:
51      output_parts.append("\"" + line[:start_pos].replace("\"", "\\\"") + "\"")
52    output_parts.append("str(" + line[start_pos+2:end_pos] + ")")
53    line = line[end_pos+1:]
54  if line:
55    output_parts.append("\"" + line.replace("\"", "\\\"") + "\"")
56  return " + ".join(output_parts)
57
58
59def preprocess(input_text, input_globals, input_path="codegen"):
60  input_lines = input_text.splitlines()
61  python_lines = ["from __future__ import print_function"]
62
63  blank_lines = 0
64
65  last_line = ""
66  last_indent = ""
67
68  # List of tuples (total_index, python_indent)
69  indent_stack = [("", "")]
70
71  # Indicates whether this is the first line inside Python
72  # code block (i.e. for, while, if, elif, else)
73  python_block_start = True
74  for i, input_line in enumerate(input_lines):
75    if input_line == "":
76      blank_lines += 1
77      continue
78    # Skip lint markers.
79    if 'LINT' in input_line:
80      continue
81
82    input_indent = extract_leading_whitespace(input_line)
83    if python_block_start:
84      assert input_indent.startswith(last_indent)
85      extra_python_indent = input_indent[len(last_indent):]
86      python_indent = indent_stack[-1][1] + extra_python_indent
87      indent_stack.append((input_indent, python_indent))
88      assert input_indent.startswith(indent_stack[-1][0])
89    else:
90      while not input_indent.startswith(indent_stack[-1][0]):
91        del indent_stack[-1]
92    python_block_start = False
93
94    python_indent = indent_stack[-1][1]
95    stripped_input_line = input_line.strip()
96    if stripped_input_line.startswith("$") and not stripped_input_line.startswith("${"):
97      if stripped_input_line.endswith(":"):
98        python_block_start = True
99      while blank_lines != 0:
100        python_lines.append(python_indent + "print(file=OUT_STREAM)")
101        blank_lines -= 1
102      python_lines.append(python_indent + stripped_input_line.replace("$", ""))
103    else:
104      assert input_line.startswith(python_indent)
105      while blank_lines != 0:
106        python_lines.append(python_indent + "print(file=OUT_STREAM)")
107        blank_lines -= 1
108      python_lines.append(python_indent + "print(%s, file=OUT_STREAM)" % escape(input_line[len(python_indent):]))
109    last_line = input_line
110    last_indent = input_indent
111
112  while blank_lines != 0:
113    python_lines.append(python_indent + "print(file=OUT_STREAM)")
114    blank_lines -= 1
115
116  exec_globals = dict(input_globals)
117  if sys.version_info > (3, 0):
118    output_stream = io.StringIO()
119  else:
120    output_stream = io.BytesIO()
121  exec_globals["OUT_STREAM"] = output_stream
122  python_bytecode = compile("\n".join(python_lines), input_path, 'exec')
123  exec(python_bytecode, exec_globals)
124
125  return output_stream.getvalue()
126
127
128PREAMBLE = """\
129// Auto-generated file. Do not edit!
130//   Template: {template}
131//   Generator: {generator}
132//
133"""
134
135
136def main(args):
137  options = parser.parse_args(args)
138
139  input_text = codecs.open(options.input[0], "r", encoding="utf-8").read()
140  python_globals = dict(chain(*options.defines))
141  output_text = PREAMBLE.format(template=options.input[0], generator=sys.argv[0]) + preprocess(input_text, python_globals, options.input[0])
142
143  txt_changed = True
144  if os.path.exists(options.output):
145    with codecs.open(options.output, "r", encoding="utf-8") as output_file:
146      txt_changed = output_file.read() != output_text
147
148  if txt_changed:
149    with codecs.open(options.output, "w", encoding="utf-8") as output_file:
150      output_file.write(output_text)
151
152if __name__ == "__main__":
153  main(sys.argv[1:])
154