1*a58d3d2aSXin Li /* Copyright (c) 2023 Amazon */
2*a58d3d2aSXin Li /*
3*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without
4*a58d3d2aSXin Li modification, are permitted provided that the following conditions
5*a58d3d2aSXin Li are met:
6*a58d3d2aSXin Li
7*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright
8*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer.
9*a58d3d2aSXin Li
10*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright
11*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the
12*a58d3d2aSXin Li documentation and/or other materials provided with the distribution.
13*a58d3d2aSXin Li
14*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
15*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
16*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
17*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
18*a58d3d2aSXin Li CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
21*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
22*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
23*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25*a58d3d2aSXin Li */
26*a58d3d2aSXin Li
27*a58d3d2aSXin Li #ifdef HAVE_CONFIG_H
28*a58d3d2aSXin Li #include "config.h"
29*a58d3d2aSXin Li #endif
30*a58d3d2aSXin Li
31*a58d3d2aSXin Li #include <string.h>
32*a58d3d2aSXin Li #include <stdlib.h>
33*a58d3d2aSXin Li #include "nnet.h"
34*a58d3d2aSXin Li #include "os_support.h"
35*a58d3d2aSXin Li
36*a58d3d2aSXin Li #define SPARSE_BLOCK_SIZE 32
37*a58d3d2aSXin Li
parse_record(const void ** data,int * len,WeightArray * array)38*a58d3d2aSXin Li int parse_record(const void **data, int *len, WeightArray *array) {
39*a58d3d2aSXin Li WeightHead *h = (WeightHead *)*data;
40*a58d3d2aSXin Li if (*len < WEIGHT_BLOCK_SIZE) return -1;
41*a58d3d2aSXin Li if (h->block_size < h->size) return -1;
42*a58d3d2aSXin Li if (h->block_size > *len-WEIGHT_BLOCK_SIZE) return -1;
43*a58d3d2aSXin Li if (h->name[sizeof(h->name)-1] != 0) return -1;
44*a58d3d2aSXin Li if (h->size < 0) return -1;
45*a58d3d2aSXin Li array->name = h->name;
46*a58d3d2aSXin Li array->type = h->type;
47*a58d3d2aSXin Li array->size = h->size;
48*a58d3d2aSXin Li array->data = (void*)((unsigned char*)(*data)+WEIGHT_BLOCK_SIZE);
49*a58d3d2aSXin Li
50*a58d3d2aSXin Li *data = (void*)((unsigned char*)*data + h->block_size+WEIGHT_BLOCK_SIZE);
51*a58d3d2aSXin Li *len -= h->block_size+WEIGHT_BLOCK_SIZE;
52*a58d3d2aSXin Li return array->size;
53*a58d3d2aSXin Li }
54*a58d3d2aSXin Li
parse_weights(WeightArray ** list,const void * data,int len)55*a58d3d2aSXin Li int parse_weights(WeightArray **list, const void *data, int len)
56*a58d3d2aSXin Li {
57*a58d3d2aSXin Li int nb_arrays=0;
58*a58d3d2aSXin Li int capacity=20;
59*a58d3d2aSXin Li *list = opus_alloc(capacity*sizeof(WeightArray));
60*a58d3d2aSXin Li while (len > 0) {
61*a58d3d2aSXin Li int ret;
62*a58d3d2aSXin Li WeightArray array = {NULL, 0, 0, 0};
63*a58d3d2aSXin Li ret = parse_record(&data, &len, &array);
64*a58d3d2aSXin Li if (ret > 0) {
65*a58d3d2aSXin Li if (nb_arrays+1 >= capacity) {
66*a58d3d2aSXin Li /* Make sure there's room for the ending NULL element too. */
67*a58d3d2aSXin Li capacity = capacity*3/2;
68*a58d3d2aSXin Li *list = opus_realloc(*list, capacity*sizeof(WeightArray));
69*a58d3d2aSXin Li }
70*a58d3d2aSXin Li (*list)[nb_arrays++] = array;
71*a58d3d2aSXin Li } else {
72*a58d3d2aSXin Li opus_free(*list);
73*a58d3d2aSXin Li *list = NULL;
74*a58d3d2aSXin Li return -1;
75*a58d3d2aSXin Li }
76*a58d3d2aSXin Li }
77*a58d3d2aSXin Li (*list)[nb_arrays].name=NULL;
78*a58d3d2aSXin Li return nb_arrays;
79*a58d3d2aSXin Li }
80*a58d3d2aSXin Li
find_array_entry(const WeightArray * arrays,const char * name)81*a58d3d2aSXin Li static const void *find_array_entry(const WeightArray *arrays, const char *name) {
82*a58d3d2aSXin Li while (arrays->name && strcmp(arrays->name, name) != 0) arrays++;
83*a58d3d2aSXin Li return arrays;
84*a58d3d2aSXin Li }
85*a58d3d2aSXin Li
find_array_check(const WeightArray * arrays,const char * name,int size)86*a58d3d2aSXin Li static const void *find_array_check(const WeightArray *arrays, const char *name, int size) {
87*a58d3d2aSXin Li const WeightArray *a = find_array_entry(arrays, name);
88*a58d3d2aSXin Li if (a->name && a->size == size) return a->data;
89*a58d3d2aSXin Li else return NULL;
90*a58d3d2aSXin Li }
91*a58d3d2aSXin Li
opt_array_check(const WeightArray * arrays,const char * name,int size,int * error)92*a58d3d2aSXin Li static const void *opt_array_check(const WeightArray *arrays, const char *name, int size, int *error) {
93*a58d3d2aSXin Li const WeightArray *a = find_array_entry(arrays, name);
94*a58d3d2aSXin Li *error = (a->name != NULL && a->size != size);
95*a58d3d2aSXin Li if (a->name && a->size == size) return a->data;
96*a58d3d2aSXin Li else return NULL;
97*a58d3d2aSXin Li }
98*a58d3d2aSXin Li
find_idx_check(const WeightArray * arrays,const char * name,int nb_in,int nb_out,int * total_blocks)99*a58d3d2aSXin Li static const void *find_idx_check(const WeightArray *arrays, const char *name, int nb_in, int nb_out, int *total_blocks) {
100*a58d3d2aSXin Li int remain;
101*a58d3d2aSXin Li const int *idx;
102*a58d3d2aSXin Li const WeightArray *a = find_array_entry(arrays, name);
103*a58d3d2aSXin Li *total_blocks = 0;
104*a58d3d2aSXin Li if (a == NULL) return NULL;
105*a58d3d2aSXin Li idx = a->data;
106*a58d3d2aSXin Li remain = a->size/sizeof(int);
107*a58d3d2aSXin Li while (remain > 0) {
108*a58d3d2aSXin Li int nb_blocks;
109*a58d3d2aSXin Li int i;
110*a58d3d2aSXin Li nb_blocks = *idx++;
111*a58d3d2aSXin Li if (remain < nb_blocks+1) return NULL;
112*a58d3d2aSXin Li for (i=0;i<nb_blocks;i++) {
113*a58d3d2aSXin Li int pos = *idx++;
114*a58d3d2aSXin Li if (pos+3 >= nb_in || (pos&0x3)) return NULL;
115*a58d3d2aSXin Li }
116*a58d3d2aSXin Li nb_out -= 8;
117*a58d3d2aSXin Li remain -= nb_blocks+1;
118*a58d3d2aSXin Li *total_blocks += nb_blocks;
119*a58d3d2aSXin Li }
120*a58d3d2aSXin Li if (nb_out != 0) return NULL;
121*a58d3d2aSXin Li return a->data;
122*a58d3d2aSXin Li }
123*a58d3d2aSXin Li
linear_init(LinearLayer * layer,const WeightArray * arrays,const char * bias,const char * subias,const char * weights,const char * float_weights,const char * weights_idx,const char * diag,const char * scale,int nb_inputs,int nb_outputs)124*a58d3d2aSXin Li int linear_init(LinearLayer *layer, const WeightArray *arrays,
125*a58d3d2aSXin Li const char *bias,
126*a58d3d2aSXin Li const char *subias,
127*a58d3d2aSXin Li const char *weights,
128*a58d3d2aSXin Li const char *float_weights,
129*a58d3d2aSXin Li const char *weights_idx,
130*a58d3d2aSXin Li const char *diag,
131*a58d3d2aSXin Li const char *scale,
132*a58d3d2aSXin Li int nb_inputs,
133*a58d3d2aSXin Li int nb_outputs)
134*a58d3d2aSXin Li {
135*a58d3d2aSXin Li int err;
136*a58d3d2aSXin Li layer->bias = NULL;
137*a58d3d2aSXin Li layer->subias = NULL;
138*a58d3d2aSXin Li layer->weights = NULL;
139*a58d3d2aSXin Li layer->float_weights = NULL;
140*a58d3d2aSXin Li layer->weights_idx = NULL;
141*a58d3d2aSXin Li layer->diag = NULL;
142*a58d3d2aSXin Li layer->scale = NULL;
143*a58d3d2aSXin Li if (bias != NULL) {
144*a58d3d2aSXin Li if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1;
145*a58d3d2aSXin Li }
146*a58d3d2aSXin Li if (subias != NULL) {
147*a58d3d2aSXin Li if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1;
148*a58d3d2aSXin Li }
149*a58d3d2aSXin Li if (weights_idx != NULL) {
150*a58d3d2aSXin Li int total_blocks;
151*a58d3d2aSXin Li if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_inputs, nb_outputs, &total_blocks)) == NULL) return 1;
152*a58d3d2aSXin Li if (weights != NULL) {
153*a58d3d2aSXin Li if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1;
154*a58d3d2aSXin Li }
155*a58d3d2aSXin Li if (float_weights != NULL) {
156*a58d3d2aSXin Li layer->float_weights = opt_array_check(arrays, float_weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->float_weights[0]), &err);
157*a58d3d2aSXin Li if (err) return 1;
158*a58d3d2aSXin Li }
159*a58d3d2aSXin Li } else {
160*a58d3d2aSXin Li if (weights != NULL) {
161*a58d3d2aSXin Li if ((layer->weights = find_array_check(arrays, weights, nb_inputs*nb_outputs*sizeof(layer->weights[0]))) == NULL) return 1;
162*a58d3d2aSXin Li }
163*a58d3d2aSXin Li if (float_weights != NULL) {
164*a58d3d2aSXin Li layer->float_weights = opt_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]), &err);
165*a58d3d2aSXin Li if (err) return 1;
166*a58d3d2aSXin Li }
167*a58d3d2aSXin Li }
168*a58d3d2aSXin Li if (diag != NULL) {
169*a58d3d2aSXin Li if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1;
170*a58d3d2aSXin Li }
171*a58d3d2aSXin Li if (weights != NULL) {
172*a58d3d2aSXin Li if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1;
173*a58d3d2aSXin Li }
174*a58d3d2aSXin Li layer->nb_inputs = nb_inputs;
175*a58d3d2aSXin Li layer->nb_outputs = nb_outputs;
176*a58d3d2aSXin Li return 0;
177*a58d3d2aSXin Li }
178*a58d3d2aSXin Li
conv2d_init(Conv2dLayer * layer,const WeightArray * arrays,const char * bias,const char * float_weights,int in_channels,int out_channels,int ktime,int kheight)179*a58d3d2aSXin Li int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays,
180*a58d3d2aSXin Li const char *bias,
181*a58d3d2aSXin Li const char *float_weights,
182*a58d3d2aSXin Li int in_channels,
183*a58d3d2aSXin Li int out_channels,
184*a58d3d2aSXin Li int ktime,
185*a58d3d2aSXin Li int kheight)
186*a58d3d2aSXin Li {
187*a58d3d2aSXin Li int err;
188*a58d3d2aSXin Li layer->bias = NULL;
189*a58d3d2aSXin Li layer->float_weights = NULL;
190*a58d3d2aSXin Li if (bias != NULL) {
191*a58d3d2aSXin Li if ((layer->bias = find_array_check(arrays, bias, out_channels*sizeof(layer->bias[0]))) == NULL) return 1;
192*a58d3d2aSXin Li }
193*a58d3d2aSXin Li if (float_weights != NULL) {
194*a58d3d2aSXin Li layer->float_weights = opt_array_check(arrays, float_weights, in_channels*out_channels*ktime*kheight*sizeof(layer->float_weights[0]), &err);
195*a58d3d2aSXin Li if (err) return 1;
196*a58d3d2aSXin Li }
197*a58d3d2aSXin Li layer->in_channels = in_channels;
198*a58d3d2aSXin Li layer->out_channels = out_channels;
199*a58d3d2aSXin Li layer->ktime = ktime;
200*a58d3d2aSXin Li layer->kheight = kheight;
201*a58d3d2aSXin Li return 0;
202*a58d3d2aSXin Li }
203*a58d3d2aSXin Li
204*a58d3d2aSXin Li
205*a58d3d2aSXin Li
206*a58d3d2aSXin Li #if 0
207*a58d3d2aSXin Li #include <fcntl.h>
208*a58d3d2aSXin Li #include <sys/mman.h>
209*a58d3d2aSXin Li #include <unistd.h>
210*a58d3d2aSXin Li #include <sys/stat.h>
211*a58d3d2aSXin Li #include <stdio.h>
212*a58d3d2aSXin Li
213*a58d3d2aSXin Li int main()
214*a58d3d2aSXin Li {
215*a58d3d2aSXin Li int fd;
216*a58d3d2aSXin Li void *data;
217*a58d3d2aSXin Li int len;
218*a58d3d2aSXin Li int nb_arrays;
219*a58d3d2aSXin Li int i;
220*a58d3d2aSXin Li WeightArray *list;
221*a58d3d2aSXin Li struct stat st;
222*a58d3d2aSXin Li const char *filename = "weights_blob.bin";
223*a58d3d2aSXin Li stat(filename, &st);
224*a58d3d2aSXin Li len = st.st_size;
225*a58d3d2aSXin Li fd = open(filename, O_RDONLY);
226*a58d3d2aSXin Li data = mmap(NULL, len, PROT_READ, MAP_SHARED, fd, 0);
227*a58d3d2aSXin Li printf("size is %d\n", len);
228*a58d3d2aSXin Li nb_arrays = parse_weights(&list, data, len);
229*a58d3d2aSXin Li for (i=0;i<nb_arrays;i++) {
230*a58d3d2aSXin Li printf("found %s: size %d\n", list[i].name, list[i].size);
231*a58d3d2aSXin Li }
232*a58d3d2aSXin Li printf("%p\n", list[i].name);
233*a58d3d2aSXin Li opus_free(list);
234*a58d3d2aSXin Li munmap(data, len);
235*a58d3d2aSXin Li close(fd);
236*a58d3d2aSXin Li return 0;
237*a58d3d2aSXin Li }
238*a58d3d2aSXin Li #endif
239