xref: /aosp_15_r20/external/libopus/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31from collections import OrderedDict
32
33class CWriter:
34    def __init__(self,
35                 filename_without_extension,
36                 message=None,
37                 header_only=False,
38                 create_state_struct=False,
39                 enable_binary_blob=True,
40                 model_struct_name="Model",
41                 nnet_header="nnet.h",
42                 add_typedef=False):
43        """
44        Writer class for creating souce and header files for weight exports to C
45
46        Parameters:
47        -----------
48
49        filename_without_extension: str
50            filename from which .c and .h files are created
51
52        message: str, optional
53            if given and not None, this message will be printed as comment in the header file
54
55        header_only: bool, optional
56            if True, only a header file is created; defaults to False
57
58        enable_binary_blob: bool, optional
59            if True, export is done in binary blob format and a model type is created; defaults to False
60
61        create_state_struct: bool, optional
62            if True, a state struct type is created in the header file; if False, state sizes are defined as macros; defaults to False
63
64        model_struct_name: str, optional
65            name used for the model struct type; only relevant when enable_binary_blob is True; defaults to "Model"
66
67        nnet_header: str, optional
68            name of header nnet header file; defaults to nnet.h
69
70        """
71
72
73        self.header_only = header_only
74        self.enable_binary_blob = enable_binary_blob
75        self.create_state_struct = create_state_struct
76        self.model_struct_name = model_struct_name
77        self.add_typedef = add_typedef
78
79        # for binary blob format, format is key=<layer name>, value=(<layer type>, <init call>)
80        self.layer_dict = OrderedDict()
81
82        # for binary blob format, format is key=<layer name>, value=<layer type>
83        self.weight_arrays = []
84
85        # form model struct, format is key=<layer name>, value=<number of elements>
86        self.state_dict = OrderedDict()
87
88        self.header = open(filename_without_extension + ".h", "w")
89        header_name = os.path.basename(filename_without_extension) + '.h'
90
91        if message is not None:
92            self.header.write(f"/* {message} */\n\n")
93
94        self.header_guard = os.path.basename(filename_without_extension).upper() + "_H"
95        self.header.write(
96f'''
97#ifndef {self.header_guard}
98#define {self.header_guard}
99
100#include "{nnet_header}"
101
102'''
103        )
104
105        if not self.header_only:
106            self.source = open(filename_without_extension + ".c", "w")
107            if message is not None:
108                self.source.write(f"/* {message} */\n\n")
109
110            self.source.write(
111f"""
112#ifdef HAVE_CONFIG_H
113#include "config.h"
114#endif
115
116""")
117            self.source.write(f'#include "{header_name}"\n\n')
118
119
120    def _finalize_header(self):
121
122        # create model type
123        if self.enable_binary_blob:
124            if self.add_typedef:
125                self.header.write(f"\ntypedef struct {{")
126            else:
127                self.header.write(f"\nstruct {self.model_struct_name} {{")
128            for name, data in self.layer_dict.items():
129                layer_type = data[0]
130                self.header.write(f"\n    {layer_type} {name};")
131            if self.add_typedef:
132                self.header.write(f"\n}} {self.model_struct_name};\n")
133            else:
134                self.header.write(f"\n}};\n")
135
136            init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
137            self.header.write(f"\n{init_prototype};\n")
138
139        self.header.write(f"\n#endif /* {self.header_guard} */\n")
140
141    def _finalize_source(self):
142
143        if self.enable_binary_blob:
144            # create weight array
145            if len(set(self.weight_arrays)) != len(self.weight_arrays):
146                raise ValueError("error: detected duplicates in weight arrays")
147            self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
148            self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
149            for name in self.weight_arrays:
150                self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
151                self.source.write(f'    {{"{name}",  WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')
152                self.source.write(f"#endif\n")
153            self.source.write("    {NULL, 0, 0, NULL}\n")
154            self.source.write("};\n")
155
156            self.source.write("#endif /* USE_WEIGHTS_FILE */\n")
157
158            # create init function definition
159            init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
160            self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
161            self.source.write(f"{init_prototype} {{\n")
162            for name, data in self.layer_dict.items():
163                self.source.write(f"    if ({data[1]}) return 1;\n")
164            self.source.write("    return 0;\n")
165            self.source.write("}\n")
166            self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")
167
168
169    def close(self):
170
171        if not self.header_only:
172            self._finalize_source()
173            self.source.close()
174
175        self._finalize_header()
176        self.header.close()
177
178    def __del__(self):
179        try:
180            self.close()
181        except:
182            pass