xref: /aosp_15_r20/external/pytorch/aten/src/ATen/nnapi/nnapi_model_loader.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // NOLINTNEXTLINE(modernize-deprecated-headers)
2 #include <stdint.h>
3 
4 #include <ATen/nnapi/NeuralNetworks.h>
5 #include <ATen/nnapi/nnapi_wrapper.h>
6 #include <ATen/nnapi/nnapi_model_loader.h>
7 #include <c10/util/irange.h>
8 
9 
10 #ifndef NNAPI_LOADER_STANDALONE
11 
12 # include <c10/util/Logging.h>
13 
14 #else
15 
16 #define CAFFE_ENFORCE(cond, ...) do { if (!cond) { return -1; } } while (0)
17 
18 #endif
19 
20 
21 #define NNAPI_CHECK(res) CAFFE_ENFORCE(res == ANEURALNETWORKS_NO_ERROR, "NNAPI returned error: ", res)
22 
23 
24 namespace caffe2 {
25 namespace nnapi {
26 
27 namespace {
28 
29 /*
30 Serialized format for NNAPI models.  It is basically just a list arguments
31 for calls to be made to NNAPI.
32 */
33 
34 typedef enum _SourceType {
35   SOURCE_IMMEDIATE = 0,
36   SOURCE_NUMBERED_BUFFER = 2,
37   SOURCE_NUMBERED_MEMORY = 3,
38 } SourceType;
39 
40 typedef struct _SerializedOperand {
41   int32_t type;
42   uint32_t dimension_count;
43   float scale;
44   int32_t zero_point;
45 } SerializedOperand;
46 
47 typedef struct _SerializedValue {
48   int32_t index;
49   int32_t source_type;
50   uint32_t source_length;
51 } SerializedValue;
52 
53 typedef struct _SerializedOperation {
54   int32_t operation_type;
55   uint32_t input_count;
56   uint32_t output_count;
57 } SerializedOperation;
58 
59 typedef struct _SerializedModel {
60   int32_t version;
61   int32_t operand_count;
62   int32_t value_count;
63   int32_t operation_count;
64   int32_t input_count;
65   int32_t output_count;
66   // SerializedOperand operands[operand_count];
67   // SerializedValue values[value_count];
68   // SerializedOperation operations[operation_count];
69   // uint32_t operand_dimensions[sum(dimension_count)]
70   // uint32_t value_data[sum(source_length+pad)/4]
71   // uint32_t operation_args[sum(input_count + output_count)]
72   // uint32_t model_inputs[input_count]
73   // uint32_t model_outputs[output_count]
74 } SerializedModel;
75 
76 
77 /**
78  * Get the physically stored size of a value.  All values are padded out
79  * to a multiple of 4 bytes to ensure the next value is 4-byte aligned.
80  */
value_physical_size(uint32_t len)81 static uint32_t value_physical_size(uint32_t len) {
82   uint32_t phys = len;
83   if (len % 4 == 0) {
84     return len;
85   }
86   return len + 4 - (phys % 4);
87 }
88 
89 } // namespace
90 
91 
load_nnapi_model(struct nnapi_wrapper * nnapi,ANeuralNetworksModel * model,const void * serialized_model,int64_t model_length,size_t num_buffers,const void ** buffer_ptrs,int32_t * buffer_sizes,size_t,ANeuralNetworksMemory **,int32_t *,int32_t * out_input_count,int32_t * out_output_count,size_t * out_bytes_consumed)92 int load_nnapi_model(
93     struct nnapi_wrapper* nnapi,
94     ANeuralNetworksModel* model,
95     const void* serialized_model,
96     int64_t model_length,
97     size_t num_buffers,
98     const void** buffer_ptrs,
99     int32_t* buffer_sizes,
100     size_t /*num_memories*/,
101     ANeuralNetworksMemory** /*memories*/,
102     int32_t* /*memory_sizes*/,
103     int32_t* out_input_count,
104     int32_t* out_output_count,
105     size_t* out_bytes_consumed) {
106   int64_t required_size = 0;
107   const uint8_t* next_pointer = (const uint8_t*)serialized_model;
108   const uint8_t* end_of_buf = (const uint8_t*)serialized_model + model_length;
109 
110   required_size += sizeof(SerializedModel);
111   CAFFE_ENFORCE(model_length >= required_size, "Model is too small.  Size = ", model_length);
112   const SerializedModel* ser_model = (SerializedModel*)next_pointer;
113   next_pointer = (uint8_t*)serialized_model + required_size;
114   CAFFE_ENFORCE(next_pointer <= end_of_buf);
115 
116   CAFFE_ENFORCE(ser_model->version == 1);
117   // Keep these small to avoid integer overflow.
118   CAFFE_ENFORCE(ser_model->operand_count    < (1 << 24));
119   CAFFE_ENFORCE(ser_model->value_count      < (1 << 24));
120   CAFFE_ENFORCE(ser_model->operation_count  < (1 << 24));
121   CAFFE_ENFORCE(ser_model->input_count      < (1 << 24));
122   CAFFE_ENFORCE(ser_model->output_count     < (1 << 24));
123 
124   required_size += sizeof(SerializedOperand) * ser_model->operand_count;
125   CAFFE_ENFORCE(model_length >= required_size, "Model is too small.  Size = ", model_length);
126   const SerializedOperand* operands = (const SerializedOperand*)next_pointer;
127   next_pointer = (uint8_t*)serialized_model + required_size;
128   CAFFE_ENFORCE(next_pointer <= end_of_buf);
129 
130   required_size += sizeof(SerializedValue) * ser_model->value_count;
131   CAFFE_ENFORCE(model_length >= required_size, "Model is too small.  Size = ", model_length);
132   const SerializedValue* values = (const SerializedValue*)next_pointer;
133   next_pointer = (uint8_t*)serialized_model + required_size;
134   CAFFE_ENFORCE(next_pointer <= end_of_buf);
135 
136   required_size += sizeof(SerializedOperation) * ser_model->operation_count;
137   CAFFE_ENFORCE(model_length >= required_size, "Model is too small.  Size = ", model_length);
138   const SerializedOperation* operations = (const SerializedOperation*)next_pointer;
139   next_pointer = (uint8_t*)serialized_model + required_size;
140   CAFFE_ENFORCE(next_pointer <= end_of_buf);
141 
142   for (const auto i : c10::irange(ser_model->operand_count)) {
143     required_size += 4 * operands[i].dimension_count;
144   }
145 
146   for (const auto i : c10::irange(ser_model->value_count)) {
147     required_size += value_physical_size(values[i].source_length);
148   }
149 
150   for (const auto i : c10::irange(ser_model->operation_count)) {
151     required_size += 4 * (operations[i].input_count + operations[i].output_count);
152   }
153 
154   required_size += 4 * (ser_model->input_count + ser_model->output_count);
155 
156   CAFFE_ENFORCE(model_length >= required_size, "Model is too small.  Size = ", model_length);
157   CAFFE_ENFORCE(next_pointer <= end_of_buf);
158 
159   for (const auto i : c10::irange(ser_model->operand_count)) {
160     ANeuralNetworksOperandType operand;
161     operand.type = operands[i].type;
162     operand.scale = operands[i].scale;
163     operand.zeroPoint = operands[i].zero_point;
164     operand.dimensionCount = operands[i].dimension_count;
165     operand.dimensions = operands[i].dimension_count ? (const uint32_t*)next_pointer : nullptr;
166 
167     next_pointer += 4 * operands[i].dimension_count;
168     CAFFE_ENFORCE(next_pointer <= end_of_buf);
169 
170     int result = nnapi->Model_addOperand(model, &operand);
171     NNAPI_CHECK(result);
172   }
173 
174   for (const auto i : c10::irange(ser_model->value_count)) {
175     uint32_t len = values[i].source_length;
176     const uint8_t* stored_pointer = next_pointer;
177     const void* value_pointer = nullptr;
178     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
179     size_t value_length;
180 
181     switch ((SourceType)values[i].source_type) {
182       case SOURCE_IMMEDIATE:
183         {
184           value_pointer = stored_pointer;
185           value_length = len;
186         }
187         break;
188       case SOURCE_NUMBERED_BUFFER:
189         {
190           CAFFE_ENFORCE(len == 12);
191           uint32_t buffer_number = *(uint32_t*)stored_pointer;
192           uint32_t buffer_offset = *(uint32_t*)(stored_pointer + 4);
193           uint32_t operand_length = *(uint32_t*)(stored_pointer + 8);
194           CAFFE_ENFORCE(buffer_number < num_buffers);
195           CAFFE_ENFORCE(buffer_offset + operand_length >= buffer_offset);  // No integer overflow
196           CAFFE_ENFORCE(buffer_offset + operand_length <= (uint32_t)buffer_sizes[buffer_number]);  // No buffer overflow
197           value_pointer = (uint8_t*)buffer_ptrs[buffer_number] + buffer_offset;
198           value_length = operand_length;
199         }
200         break;
201       case SOURCE_NUMBERED_MEMORY:
202         CAFFE_ENFORCE(false, "Memory inputs not implemented yet.");
203         break;
204       default:
205         CAFFE_ENFORCE(false, "Unknown source type: ", values[i].source_type);
206     }
207 
208     CAFFE_ENFORCE(value_pointer != nullptr);
209 
210     next_pointer += value_physical_size(len);
211     CAFFE_ENFORCE(next_pointer <= end_of_buf);
212 
213     int result = nnapi->Model_setOperandValue(
214         model,
215         values[i].index,
216         value_pointer,
217         value_length);
218     NNAPI_CHECK(result);
219   }
220 
221   for (const auto i : c10::irange(ser_model->operation_count)) {
222     const uint32_t* inputs = (const uint32_t*)next_pointer;
223     next_pointer += 4 * operations[i].input_count;
224     CAFFE_ENFORCE(next_pointer <= end_of_buf);
225     const uint32_t* outputs = (const uint32_t*)next_pointer;
226     next_pointer += 4 * operations[i].output_count;
227     CAFFE_ENFORCE(next_pointer <= end_of_buf);
228 
229     int result = nnapi->Model_addOperation(
230         model,
231         operations[i].operation_type,
232         operations[i].input_count,
233         inputs,
234         operations[i].output_count,
235         outputs);
236     NNAPI_CHECK(result);
237   }
238 
239   const uint32_t* model_inputs = (const uint32_t*)next_pointer;
240   next_pointer += 4 * ser_model->input_count;
241   CAFFE_ENFORCE(next_pointer <= end_of_buf);
242   const uint32_t* model_outputs = (const uint32_t*)next_pointer;
243   next_pointer += 4 * ser_model->output_count;
244   CAFFE_ENFORCE(next_pointer <= end_of_buf);
245 
246   int result = nnapi->Model_identifyInputsAndOutputs(
247       model,
248       ser_model->input_count,
249       model_inputs,
250       ser_model->output_count,
251       model_outputs);
252   NNAPI_CHECK(result);
253 
254   *out_input_count = ser_model->input_count;
255   *out_output_count = ser_model->output_count;
256 
257   // TODO: Maybe eliminate required_size and just rely on next_pointer for bounds checking.
258   CAFFE_ENFORCE(next_pointer <= end_of_buf);
259   CAFFE_ENFORCE(next_pointer == (const uint8_t*)serialized_model + required_size);
260   if (out_bytes_consumed != nullptr) {
261     *out_bytes_consumed = next_pointer - (const uint8_t*)serialized_model;
262   }
263 
264   return 0;
265 }
266 
267 }} // namespace caffe2::nnapi
268