1 /* Copyright (c) 2023 Amazon */
2 /*
3 Redistribution and use in source and binary forms, with or without
4 modification, are permitted provided that the following conditions
5 are met:
6
7 - Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9
10 - Redistributions in binary form must reproduce the above copyright
11 notice, this list of conditions and the following disclaimer in the
12 documentation and/or other materials provided with the distribution.
13
14 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
15 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
16 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
17 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
18 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
21 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
22 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
23 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26
27 #ifdef HAVE_CONFIG_H
28 #include "config.h"
29 #endif
30
31 #include <string.h>
32 #include <stdlib.h>
33 #include "nnet.h"
34 #include "os_support.h"
35
36 #define SPARSE_BLOCK_SIZE 32
37
parse_record(const void ** data,int * len,WeightArray * array)38 int parse_record(const void **data, int *len, WeightArray *array) {
39 WeightHead *h = (WeightHead *)*data;
40 if (*len < WEIGHT_BLOCK_SIZE) return -1;
41 if (h->block_size < h->size) return -1;
42 if (h->block_size > *len-WEIGHT_BLOCK_SIZE) return -1;
43 if (h->name[sizeof(h->name)-1] != 0) return -1;
44 if (h->size < 0) return -1;
45 array->name = h->name;
46 array->type = h->type;
47 array->size = h->size;
48 array->data = (void*)((unsigned char*)(*data)+WEIGHT_BLOCK_SIZE);
49
50 *data = (void*)((unsigned char*)*data + h->block_size+WEIGHT_BLOCK_SIZE);
51 *len -= h->block_size+WEIGHT_BLOCK_SIZE;
52 return array->size;
53 }
54
parse_weights(WeightArray ** list,const void * data,int len)55 int parse_weights(WeightArray **list, const void *data, int len)
56 {
57 int nb_arrays=0;
58 int capacity=20;
59 *list = opus_alloc(capacity*sizeof(WeightArray));
60 while (len > 0) {
61 int ret;
62 WeightArray array = {NULL, 0, 0, 0};
63 ret = parse_record(&data, &len, &array);
64 if (ret > 0) {
65 if (nb_arrays+1 >= capacity) {
66 /* Make sure there's room for the ending NULL element too. */
67 capacity = capacity*3/2;
68 *list = opus_realloc(*list, capacity*sizeof(WeightArray));
69 }
70 (*list)[nb_arrays++] = array;
71 } else {
72 opus_free(*list);
73 *list = NULL;
74 return -1;
75 }
76 }
77 (*list)[nb_arrays].name=NULL;
78 return nb_arrays;
79 }
80
find_array_entry(const WeightArray * arrays,const char * name)81 static const void *find_array_entry(const WeightArray *arrays, const char *name) {
82 while (arrays->name && strcmp(arrays->name, name) != 0) arrays++;
83 return arrays;
84 }
85
find_array_check(const WeightArray * arrays,const char * name,int size)86 static const void *find_array_check(const WeightArray *arrays, const char *name, int size) {
87 const WeightArray *a = find_array_entry(arrays, name);
88 if (a->name && a->size == size) return a->data;
89 else return NULL;
90 }
91
opt_array_check(const WeightArray * arrays,const char * name,int size,int * error)92 static const void *opt_array_check(const WeightArray *arrays, const char *name, int size, int *error) {
93 const WeightArray *a = find_array_entry(arrays, name);
94 *error = (a->name != NULL && a->size != size);
95 if (a->name && a->size == size) return a->data;
96 else return NULL;
97 }
98
find_idx_check(const WeightArray * arrays,const char * name,int nb_in,int nb_out,int * total_blocks)99 static const void *find_idx_check(const WeightArray *arrays, const char *name, int nb_in, int nb_out, int *total_blocks) {
100 int remain;
101 const int *idx;
102 const WeightArray *a = find_array_entry(arrays, name);
103 *total_blocks = 0;
104 if (a == NULL) return NULL;
105 idx = a->data;
106 remain = a->size/sizeof(int);
107 while (remain > 0) {
108 int nb_blocks;
109 int i;
110 nb_blocks = *idx++;
111 if (remain < nb_blocks+1) return NULL;
112 for (i=0;i<nb_blocks;i++) {
113 int pos = *idx++;
114 if (pos+3 >= nb_in || (pos&0x3)) return NULL;
115 }
116 nb_out -= 8;
117 remain -= nb_blocks+1;
118 *total_blocks += nb_blocks;
119 }
120 if (nb_out != 0) return NULL;
121 return a->data;
122 }
123
linear_init(LinearLayer * layer,const WeightArray * arrays,const char * bias,const char * subias,const char * weights,const char * float_weights,const char * weights_idx,const char * diag,const char * scale,int nb_inputs,int nb_outputs)124 int linear_init(LinearLayer *layer, const WeightArray *arrays,
125 const char *bias,
126 const char *subias,
127 const char *weights,
128 const char *float_weights,
129 const char *weights_idx,
130 const char *diag,
131 const char *scale,
132 int nb_inputs,
133 int nb_outputs)
134 {
135 int err;
136 layer->bias = NULL;
137 layer->subias = NULL;
138 layer->weights = NULL;
139 layer->float_weights = NULL;
140 layer->weights_idx = NULL;
141 layer->diag = NULL;
142 layer->scale = NULL;
143 if (bias != NULL) {
144 if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1;
145 }
146 if (subias != NULL) {
147 if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1;
148 }
149 if (weights_idx != NULL) {
150 int total_blocks;
151 if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_inputs, nb_outputs, &total_blocks)) == NULL) return 1;
152 if (weights != NULL) {
153 if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1;
154 }
155 if (float_weights != NULL) {
156 layer->float_weights = opt_array_check(arrays, float_weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->float_weights[0]), &err);
157 if (err) return 1;
158 }
159 } else {
160 if (weights != NULL) {
161 if ((layer->weights = find_array_check(arrays, weights, nb_inputs*nb_outputs*sizeof(layer->weights[0]))) == NULL) return 1;
162 }
163 if (float_weights != NULL) {
164 layer->float_weights = opt_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]), &err);
165 if (err) return 1;
166 }
167 }
168 if (diag != NULL) {
169 if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1;
170 }
171 if (weights != NULL) {
172 if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1;
173 }
174 layer->nb_inputs = nb_inputs;
175 layer->nb_outputs = nb_outputs;
176 return 0;
177 }
178
conv2d_init(Conv2dLayer * layer,const WeightArray * arrays,const char * bias,const char * float_weights,int in_channels,int out_channels,int ktime,int kheight)179 int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays,
180 const char *bias,
181 const char *float_weights,
182 int in_channels,
183 int out_channels,
184 int ktime,
185 int kheight)
186 {
187 int err;
188 layer->bias = NULL;
189 layer->float_weights = NULL;
190 if (bias != NULL) {
191 if ((layer->bias = find_array_check(arrays, bias, out_channels*sizeof(layer->bias[0]))) == NULL) return 1;
192 }
193 if (float_weights != NULL) {
194 layer->float_weights = opt_array_check(arrays, float_weights, in_channels*out_channels*ktime*kheight*sizeof(layer->float_weights[0]), &err);
195 if (err) return 1;
196 }
197 layer->in_channels = in_channels;
198 layer->out_channels = out_channels;
199 layer->ktime = ktime;
200 layer->kheight = kheight;
201 return 0;
202 }
203
204
205
206 #if 0
207 #include <fcntl.h>
208 #include <sys/mman.h>
209 #include <unistd.h>
210 #include <sys/stat.h>
211 #include <stdio.h>
212
213 int main()
214 {
215 int fd;
216 void *data;
217 int len;
218 int nb_arrays;
219 int i;
220 WeightArray *list;
221 struct stat st;
222 const char *filename = "weights_blob.bin";
223 stat(filename, &st);
224 len = st.st_size;
225 fd = open(filename, O_RDONLY);
226 data = mmap(NULL, len, PROT_READ, MAP_SHARED, fd, 0);
227 printf("size is %d\n", len);
228 nb_arrays = parse_weights(&list, data, len);
229 for (i=0;i<nb_arrays;i++) {
230 printf("found %s: size %d\n", list[i].name, list[i].size);
231 }
232 printf("%p\n", list[i].name);
233 opus_free(list);
234 munmap(data, len);
235 close(fd);
236 return 0;
237 }
238 #endif
239