1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31from collections import OrderedDict 32 33class CWriter: 34 def __init__(self, 35 filename_without_extension, 36 message=None, 37 header_only=False, 38 create_state_struct=False, 39 enable_binary_blob=True, 40 model_struct_name="Model", 41 nnet_header="nnet.h", 42 add_typedef=False): 43 """ 44 Writer class for creating souce and header files for weight exports to C 45 46 Parameters: 47 ----------- 48 49 filename_without_extension: str 50 filename from which .c and .h files are created 51 52 message: str, optional 53 if given and not None, this message will be printed as comment in the header file 54 55 header_only: bool, optional 56 if True, only a header file is created; defaults to False 57 58 enable_binary_blob: bool, optional 59 if True, export is done in binary blob format and a model type is created; defaults to False 60 61 create_state_struct: bool, optional 62 if True, a state struct type is created in the header file; if False, state sizes are defined as macros; defaults to False 63 64 model_struct_name: str, optional 65 name used for the model struct type; only relevant when enable_binary_blob is True; defaults to "Model" 66 67 nnet_header: str, optional 68 name of header nnet header file; defaults to nnet.h 69 70 """ 71 72 73 self.header_only = header_only 74 self.enable_binary_blob = enable_binary_blob 75 self.create_state_struct = create_state_struct 76 self.model_struct_name = model_struct_name 77 self.add_typedef = add_typedef 78 79 # for binary blob format, format is key=<layer name>, value=(<layer type>, <init call>) 80 self.layer_dict = OrderedDict() 81 82 # for binary blob format, format is key=<layer name>, value=<layer type> 83 self.weight_arrays = [] 84 85 # form model struct, format is key=<layer name>, value=<number of elements> 86 self.state_dict = OrderedDict() 87 88 self.header = open(filename_without_extension + ".h", "w") 89 header_name = os.path.basename(filename_without_extension) + '.h' 90 91 if message is not None: 92 self.header.write(f"/* {message} */\n\n") 93 94 self.header_guard = os.path.basename(filename_without_extension).upper() + "_H" 95 self.header.write( 96f''' 97#ifndef {self.header_guard} 98#define {self.header_guard} 99 100#include "{nnet_header}" 101 102''' 103 ) 104 105 if not self.header_only: 106 self.source = open(filename_without_extension + ".c", "w") 107 if message is not None: 108 self.source.write(f"/* {message} */\n\n") 109 110 self.source.write( 111f""" 112#ifdef HAVE_CONFIG_H 113#include "config.h" 114#endif 115 116""") 117 self.source.write(f'#include "{header_name}"\n\n') 118 119 120 def _finalize_header(self): 121 122 # create model type 123 if self.enable_binary_blob: 124 if self.add_typedef: 125 self.header.write(f"\ntypedef struct {{") 126 else: 127 self.header.write(f"\nstruct {self.model_struct_name} {{") 128 for name, data in self.layer_dict.items(): 129 layer_type = data[0] 130 self.header.write(f"\n {layer_type} {name};") 131 if self.add_typedef: 132 self.header.write(f"\n}} {self.model_struct_name};\n") 133 else: 134 self.header.write(f"\n}};\n") 135 136 init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)" 137 self.header.write(f"\n{init_prototype};\n") 138 139 self.header.write(f"\n#endif /* {self.header_guard} */\n") 140 141 def _finalize_source(self): 142 143 if self.enable_binary_blob: 144 # create weight array 145 if len(set(self.weight_arrays)) != len(self.weight_arrays): 146 raise ValueError("error: detected duplicates in weight arrays") 147 self.source.write("\n#ifndef USE_WEIGHTS_FILE\n") 148 self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n") 149 for name in self.weight_arrays: 150 self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n") 151 self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n') 152 self.source.write(f"#endif\n") 153 self.source.write(" {NULL, 0, 0, NULL}\n") 154 self.source.write("};\n") 155 156 self.source.write("#endif /* USE_WEIGHTS_FILE */\n") 157 158 # create init function definition 159 init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)" 160 self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n") 161 self.source.write(f"{init_prototype} {{\n") 162 for name, data in self.layer_dict.items(): 163 self.source.write(f" if ({data[1]}) return 1;\n") 164 self.source.write(" return 0;\n") 165 self.source.write("}\n") 166 self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n") 167 168 169 def close(self): 170 171 if not self.header_only: 172 self._finalize_source() 173 self.source.close() 174 175 self._finalize_header() 176 self.header.close() 177 178 def __del__(self): 179 try: 180 self.close() 181 except: 182 pass