1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/subgraph.h>
16 #include <xnnpack/subgraph-validation.h>
17
18
create_resize_bilinear_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)19 static enum xnn_status create_resize_bilinear_operator(
20 const struct xnn_node* node,
21 const struct xnn_value* values,
22 size_t num_values,
23 struct xnn_operator_data* opdata,
24 const struct xnn_caches* caches)
25 {
26 assert(node->num_inputs == 1);
27 const uint32_t input_id = node->inputs[0];
28 assert(input_id != XNN_INVALID_VALUE_ID);
29 assert(input_id < num_values);
30
31 assert(node->num_outputs == 1);
32 const uint32_t output_id = node->outputs[0];
33 assert(output_id != XNN_INVALID_VALUE_ID);
34 assert(output_id < num_values);
35
36 const size_t channel_dim = values[input_id].shape.dim[3];
37 assert(channel_dim == values[output_id].shape.dim[3]);
38
39 enum xnn_status status;
40 if (values[input_id].layout == xnn_layout_type_nchw) {
41 assert(values[output_id].layout == xnn_layout_type_nchw);
42 assert(node->compute_type == xnn_compute_type_fp32);
43 status = xnn_create_resize_bilinear2d_nchw_f32(
44 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
45 node->flags,
46 &opdata->operator_objects[0]);
47 } else {
48 assert(values[input_id].layout == xnn_layout_type_nhwc);
49 assert(values[output_id].layout == xnn_layout_type_nhwc);
50 switch (node->compute_type) {
51 #ifndef XNN_NO_F16_OPERATORS
52 case xnn_compute_type_fp16:
53 status = xnn_create_resize_bilinear2d_nhwc_f16(
54 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
55 node->flags,
56 &opdata->operator_objects[0]);
57 break;
58 #endif // !defined(XNN_NO_F16_OPERATORS)
59 case xnn_compute_type_fp32:
60 status = xnn_create_resize_bilinear2d_nhwc_f32(
61 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
62 node->flags,
63 &opdata->operator_objects[0]);
64 break;
65 #ifndef XNN_NO_S8_OPERATORS
66 case xnn_compute_type_qs8:
67 status = xnn_create_resize_bilinear2d_nhwc_s8(
68 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
69 node->flags,
70 &opdata->operator_objects[0]);
71 break;
72 #endif // !defined(XNN_NO_S8_OPERATORS)
73 #ifndef XNN_NO_U8_OPERATORS
74 case xnn_compute_type_qu8:
75 status = xnn_create_resize_bilinear2d_nhwc_u8(
76 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
77 node->flags,
78 &opdata->operator_objects[0]);
79 break;
80 #endif // !defined(XNN_NO_U8_OPERATORS)
81 default:
82 XNN_UNREACHABLE;
83 }
84 }
85 if (status == xnn_status_success) {
86 opdata->batch_size = values[input_id].shape.dim[0];
87 opdata->input_height = values[input_id].shape.dim[1];
88 opdata->input_width = values[input_id].shape.dim[2];
89 opdata->output_height = values[output_id].shape.dim[1];
90 opdata->output_width = values[output_id].shape.dim[2];
91 opdata->inputs[0] = input_id;
92 opdata->outputs[0] = output_id;
93 }
94 return status;
95 }
96
setup_resize_bilinear_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)97 static enum xnn_status setup_resize_bilinear_operator(
98 const struct xnn_operator_data* opdata,
99 const struct xnn_blob* blobs,
100 size_t num_blobs,
101 pthreadpool_t threadpool)
102 {
103 const uint32_t input_id = opdata->inputs[0];
104 assert(input_id != XNN_INVALID_VALUE_ID);
105 assert(input_id < num_blobs);
106
107 const uint32_t output_id = opdata->outputs[0];
108 assert(output_id != XNN_INVALID_VALUE_ID);
109 assert(output_id < num_blobs);
110
111 const struct xnn_blob* input_blob = blobs + input_id;
112 const void* input_data = input_blob->data;
113 assert(input_data != NULL);
114
115 const struct xnn_blob* output_blob = blobs + output_id;
116 void* output_data = output_blob->data;
117 assert(output_data != NULL);
118
119 switch (opdata->operator_objects[0]->type) {
120 case xnn_operator_type_resize_bilinear_nchw_f32:
121 return xnn_setup_resize_bilinear2d_nchw_f32(
122 opdata->operator_objects[0],
123 opdata->batch_size,
124 opdata->input_height,
125 opdata->input_width,
126 opdata->output_height,
127 opdata->output_width,
128 input_data,
129 output_data,
130 threadpool);
131 break;
132 #ifndef XNN_NO_F16_OPERATORS
133 case xnn_operator_type_resize_bilinear_nhwc_f16:
134 return xnn_setup_resize_bilinear2d_nhwc_f16(
135 opdata->operator_objects[0],
136 opdata->batch_size,
137 opdata->input_height,
138 opdata->input_width,
139 opdata->output_height,
140 opdata->output_width,
141 input_data,
142 output_data,
143 threadpool);
144 break;
145 #endif // !defined(XNN_NO_F16_OPERATORS)
146 case xnn_operator_type_resize_bilinear_nhwc_f32:
147 return xnn_setup_resize_bilinear2d_nhwc_f32(
148 opdata->operator_objects[0],
149 opdata->batch_size,
150 opdata->input_height,
151 opdata->input_width,
152 opdata->output_height,
153 opdata->output_width,
154 input_data,
155 output_data,
156 threadpool);
157 break;
158 #ifndef XNN_NO_S8_OPERATORS
159 case xnn_operator_type_resize_bilinear_nhwc_s8:
160 return xnn_setup_resize_bilinear2d_nhwc_s8(
161 opdata->operator_objects[0],
162 opdata->batch_size,
163 opdata->input_height,
164 opdata->input_width,
165 opdata->output_height,
166 opdata->output_width,
167 input_data,
168 output_data,
169 threadpool);
170 break;
171 #endif // !defined(XNN_NO_S8_OPERATORS)
172 #ifndef XNN_NO_U8_OPERATORS
173 case xnn_operator_type_resize_bilinear_nhwc_u8:
174 return xnn_setup_resize_bilinear2d_nhwc_u8(
175 opdata->operator_objects[0],
176 opdata->batch_size,
177 opdata->input_height,
178 opdata->input_width,
179 opdata->output_height,
180 opdata->output_width,
181 input_data,
182 output_data,
183 threadpool);
184 break;
185 #endif // !defined(XNN_NO_U8_OPERATORS)
186 default:
187 XNN_UNREACHABLE;
188 }
189 }
190
xnn_define_static_resize_bilinear_2d(xnn_subgraph_t subgraph,size_t new_height,size_t new_width,uint32_t input_id,uint32_t output_id,uint32_t flags)191 enum xnn_status xnn_define_static_resize_bilinear_2d(
192 xnn_subgraph_t subgraph,
193 size_t new_height,
194 size_t new_width,
195 uint32_t input_id,
196 uint32_t output_id,
197 uint32_t flags)
198 {
199 enum xnn_status status;
200 if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_static_resize_bilinear_2d)) != xnn_status_success) {
201 return status;
202 }
203
204 if (new_width == 0 || new_height == 0) {
205 xnn_log_error(
206 "failed to define %s operator with %zux%zu output: output dimensions must be non-zero",
207 xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), new_width, new_height);
208 return xnn_status_invalid_parameter;
209 }
210
211 if (max(new_width, new_height) >= 16777216) {
212 xnn_log_error(
213 "failed to define %s operator with %zux%zu output: output dimensions must be below 2**24",
214 xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), new_width, new_height);
215 return xnn_status_unsupported_parameter;
216 }
217
218 const uint32_t supported_flags = XNN_FLAG_TENSORFLOW_LEGACY_MODE | XNN_FLAG_ALIGN_CORNERS;
219 const uint32_t invalid_flags = flags & ~supported_flags;
220 if (invalid_flags != 0) {
221 xnn_log_error(
222 "failed to define %s operator with 0x%08" PRIx32 " flags: invalid flags 0x%08" PRIx32,
223 xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), flags, invalid_flags);
224 return xnn_status_invalid_parameter;
225 }
226
227 const uint32_t exclusive_flags = XNN_FLAG_TENSORFLOW_LEGACY_MODE | XNN_FLAG_ALIGN_CORNERS;
228 if ((flags & exclusive_flags) == exclusive_flags) {
229 xnn_log_error(
230 "failed to define %s operator with both XNN_FLAG_TENSORFLOW_LEGACY_MODE and XNN_FLAG_ALIGN_CORNERS flags: "
231 "the two flags are mutually exclusive",
232 xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d));
233 return xnn_status_invalid_parameter;
234 }
235
236 if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_static_resize_bilinear_2d, input_id, subgraph->num_values)) !=
237 xnn_status_success) {
238 return status;
239 }
240
241 const struct xnn_value* input_value = &subgraph->values[input_id];
242 status = xnn_subgraph_check_input_type_dense(xnn_node_type_static_resize_bilinear_2d, input_id, input_value);
243 if (status != xnn_status_success) {
244 return status;
245 }
246
247 switch (input_value->datatype) {
248 case xnn_datatype_fp32:
249 #ifndef XNN_NO_S8_OPERATORS
250 case xnn_datatype_qint8:
251 #endif // !defined(XNN_NO_S8_OPERATORS)
252 #ifndef XNN_NO_U8_OPERATORS
253 case xnn_datatype_quint8:
254 #endif // !defined(XNN_NO_U8_OPERATORS)
255 break;
256 default:
257 xnn_log_error(
258 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
259 xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), input_id,
260 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
261 return xnn_status_invalid_parameter;
262 }
263
264 status = xnn_subgraph_check_output_node_id(xnn_node_type_static_resize_bilinear_2d, output_id, subgraph->num_values);
265 if (status != xnn_status_success) {
266 return status;
267 }
268
269 const struct xnn_value* output_value = &subgraph->values[output_id];
270 status = xnn_subgraph_check_output_type_dense(xnn_node_type_static_resize_bilinear_2d, output_id, output_value);
271 if (status != xnn_status_success) {
272 return status;
273 }
274
275 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
276 switch (output_value->datatype) {
277 case xnn_datatype_fp32:
278 compute_type = xnn_compute_type_fp32;
279 break;
280 #ifndef XNN_NO_S8_OPERATORS
281 case xnn_datatype_qint8:
282 compute_type = xnn_compute_type_qs8;
283 break;
284 #endif // !defined(XNN_NO_S8_OPERATORS)
285 #ifndef XNN_NO_U8_OPERATORS
286 case xnn_datatype_quint8:
287 compute_type = xnn_compute_type_qu8;
288 break;
289 #endif // !defined(XNN_NO_U8_OPERATORS)
290 break;
291 default:
292 xnn_log_error(
293 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
294 xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), output_id,
295 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
296 return xnn_status_invalid_parameter;
297 }
298
299 #if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
300 if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
301 if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
302 xnn_log_error(
303 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
304 ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
305 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
306 input_value->quantization.zero_point, output_value->quantization.zero_point);
307 return xnn_status_invalid_parameter;
308 }
309 if (input_value->quantization.scale != output_value->quantization.scale) {
310 xnn_log_error(
311 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
312 ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
313 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
314 input_value->quantization.scale, output_value->quantization.scale);
315 return xnn_status_invalid_parameter;
316 }
317 }
318 #endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
319
320 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
321 if (node == NULL) {
322 return xnn_status_out_of_memory;
323 }
324
325 node->params.static_resize.new_height = new_height;
326 node->params.static_resize.new_width = new_width;
327
328 node->type = xnn_node_type_static_resize_bilinear_2d;
329 node->compute_type = compute_type;
330 node->num_inputs = 1;
331 node->inputs[0] = input_id;
332 node->num_outputs = 1;
333 node->outputs[0] = output_id;
334 node->flags = flags;
335
336 node->create = create_resize_bilinear_operator;
337 node->setup = setup_resize_bilinear_operator;
338
339 return xnn_status_success;
340 }
341