xref: /aosp_15_r20/external/XNNPACK/src/subgraph/static-resize-bilinear-2d.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/subgraph.h>
16 #include <xnnpack/subgraph-validation.h>
17 
18 
create_resize_bilinear_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)19 static enum xnn_status create_resize_bilinear_operator(
20   const struct xnn_node* node,
21   const struct xnn_value* values,
22   size_t num_values,
23   struct xnn_operator_data* opdata,
24   const struct xnn_caches* caches)
25 {
26   assert(node->num_inputs == 1);
27   const uint32_t input_id = node->inputs[0];
28   assert(input_id != XNN_INVALID_VALUE_ID);
29   assert(input_id < num_values);
30 
31   assert(node->num_outputs == 1);
32   const uint32_t output_id = node->outputs[0];
33   assert(output_id != XNN_INVALID_VALUE_ID);
34   assert(output_id < num_values);
35 
36   const size_t channel_dim = values[input_id].shape.dim[3];
37   assert(channel_dim == values[output_id].shape.dim[3]);
38 
39   enum xnn_status status;
40   if (values[input_id].layout == xnn_layout_type_nchw) {
41     assert(values[output_id].layout == xnn_layout_type_nchw);
42     assert(node->compute_type == xnn_compute_type_fp32);
43     status = xnn_create_resize_bilinear2d_nchw_f32(
44       channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
45       node->flags,
46       &opdata->operator_objects[0]);
47   } else {
48     assert(values[input_id].layout == xnn_layout_type_nhwc);
49     assert(values[output_id].layout == xnn_layout_type_nhwc);
50     switch (node->compute_type) {
51 #ifndef XNN_NO_F16_OPERATORS
52       case xnn_compute_type_fp16:
53         status = xnn_create_resize_bilinear2d_nhwc_f16(
54           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
55           node->flags,
56           &opdata->operator_objects[0]);
57         break;
58 #endif  // !defined(XNN_NO_F16_OPERATORS)
59       case xnn_compute_type_fp32:
60         status = xnn_create_resize_bilinear2d_nhwc_f32(
61           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
62           node->flags,
63           &opdata->operator_objects[0]);
64         break;
65 #ifndef XNN_NO_S8_OPERATORS
66       case xnn_compute_type_qs8:
67         status = xnn_create_resize_bilinear2d_nhwc_s8(
68           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
69           node->flags,
70           &opdata->operator_objects[0]);
71         break;
72 #endif  // !defined(XNN_NO_S8_OPERATORS)
73 #ifndef XNN_NO_U8_OPERATORS
74       case xnn_compute_type_qu8:
75         status = xnn_create_resize_bilinear2d_nhwc_u8(
76           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
77           node->flags,
78           &opdata->operator_objects[0]);
79         break;
80 #endif  // !defined(XNN_NO_U8_OPERATORS)
81       default:
82         XNN_UNREACHABLE;
83     }
84   }
85   if (status == xnn_status_success) {
86     opdata->batch_size = values[input_id].shape.dim[0];
87     opdata->input_height = values[input_id].shape.dim[1];
88     opdata->input_width = values[input_id].shape.dim[2];
89     opdata->output_height = values[output_id].shape.dim[1];
90     opdata->output_width = values[output_id].shape.dim[2];
91     opdata->inputs[0] = input_id;
92     opdata->outputs[0] = output_id;
93   }
94   return status;
95 }
96 
setup_resize_bilinear_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)97 static enum xnn_status setup_resize_bilinear_operator(
98   const struct xnn_operator_data* opdata,
99   const struct xnn_blob* blobs,
100   size_t num_blobs,
101   pthreadpool_t threadpool)
102 {
103   const uint32_t input_id = opdata->inputs[0];
104   assert(input_id != XNN_INVALID_VALUE_ID);
105   assert(input_id < num_blobs);
106 
107   const uint32_t output_id = opdata->outputs[0];
108   assert(output_id != XNN_INVALID_VALUE_ID);
109   assert(output_id < num_blobs);
110 
111   const struct xnn_blob* input_blob = blobs + input_id;
112   const void* input_data = input_blob->data;
113   assert(input_data != NULL);
114 
115   const struct xnn_blob* output_blob = blobs + output_id;
116   void* output_data = output_blob->data;
117   assert(output_data != NULL);
118 
119   switch (opdata->operator_objects[0]->type) {
120     case xnn_operator_type_resize_bilinear_nchw_f32:
121       return xnn_setup_resize_bilinear2d_nchw_f32(
122         opdata->operator_objects[0],
123         opdata->batch_size,
124         opdata->input_height,
125         opdata->input_width,
126         opdata->output_height,
127         opdata->output_width,
128         input_data,
129         output_data,
130         threadpool);
131       break;
132 #ifndef XNN_NO_F16_OPERATORS
133     case xnn_operator_type_resize_bilinear_nhwc_f16:
134       return xnn_setup_resize_bilinear2d_nhwc_f16(
135         opdata->operator_objects[0],
136         opdata->batch_size,
137         opdata->input_height,
138         opdata->input_width,
139         opdata->output_height,
140         opdata->output_width,
141         input_data,
142         output_data,
143         threadpool);
144       break;
145 #endif  // !defined(XNN_NO_F16_OPERATORS)
146     case xnn_operator_type_resize_bilinear_nhwc_f32:
147       return xnn_setup_resize_bilinear2d_nhwc_f32(
148         opdata->operator_objects[0],
149         opdata->batch_size,
150         opdata->input_height,
151         opdata->input_width,
152         opdata->output_height,
153         opdata->output_width,
154         input_data,
155         output_data,
156         threadpool);
157       break;
158 #ifndef XNN_NO_S8_OPERATORS
159     case xnn_operator_type_resize_bilinear_nhwc_s8:
160       return xnn_setup_resize_bilinear2d_nhwc_s8(
161         opdata->operator_objects[0],
162         opdata->batch_size,
163         opdata->input_height,
164         opdata->input_width,
165         opdata->output_height,
166         opdata->output_width,
167         input_data,
168         output_data,
169         threadpool);
170       break;
171 #endif  // !defined(XNN_NO_S8_OPERATORS)
172 #ifndef XNN_NO_U8_OPERATORS
173     case xnn_operator_type_resize_bilinear_nhwc_u8:
174       return xnn_setup_resize_bilinear2d_nhwc_u8(
175         opdata->operator_objects[0],
176         opdata->batch_size,
177         opdata->input_height,
178         opdata->input_width,
179         opdata->output_height,
180         opdata->output_width,
181         input_data,
182         output_data,
183         threadpool);
184       break;
185 #endif  // !defined(XNN_NO_U8_OPERATORS)
186     default:
187       XNN_UNREACHABLE;
188   }
189 }
190 
xnn_define_static_resize_bilinear_2d(xnn_subgraph_t subgraph,size_t new_height,size_t new_width,uint32_t input_id,uint32_t output_id,uint32_t flags)191 enum xnn_status xnn_define_static_resize_bilinear_2d(
192   xnn_subgraph_t subgraph,
193   size_t new_height,
194   size_t new_width,
195   uint32_t input_id,
196   uint32_t output_id,
197   uint32_t flags)
198 {
199   enum xnn_status status;
200   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_static_resize_bilinear_2d)) != xnn_status_success) {
201     return status;
202   }
203 
204   if (new_width == 0 || new_height == 0) {
205     xnn_log_error(
206       "failed to define %s operator with %zux%zu output: output dimensions must be non-zero",
207       xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), new_width, new_height);
208     return xnn_status_invalid_parameter;
209   }
210 
211   if (max(new_width, new_height) >= 16777216) {
212     xnn_log_error(
213       "failed to define %s operator with %zux%zu output: output dimensions must be below 2**24",
214       xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), new_width, new_height);
215     return xnn_status_unsupported_parameter;
216   }
217 
218   const uint32_t supported_flags = XNN_FLAG_TENSORFLOW_LEGACY_MODE | XNN_FLAG_ALIGN_CORNERS;
219   const uint32_t invalid_flags = flags & ~supported_flags;
220   if (invalid_flags != 0) {
221     xnn_log_error(
222       "failed to define %s operator with 0x%08" PRIx32 " flags: invalid flags 0x%08" PRIx32,
223       xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), flags, invalid_flags);
224     return xnn_status_invalid_parameter;
225   }
226 
227   const uint32_t exclusive_flags = XNN_FLAG_TENSORFLOW_LEGACY_MODE | XNN_FLAG_ALIGN_CORNERS;
228   if ((flags & exclusive_flags) == exclusive_flags) {
229     xnn_log_error(
230       "failed to define %s operator with both XNN_FLAG_TENSORFLOW_LEGACY_MODE and XNN_FLAG_ALIGN_CORNERS flags: "
231       "the two flags are mutually exclusive",
232       xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d));
233     return xnn_status_invalid_parameter;
234   }
235 
236   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_static_resize_bilinear_2d, input_id, subgraph->num_values)) !=
237       xnn_status_success) {
238     return status;
239   }
240 
241   const struct xnn_value* input_value = &subgraph->values[input_id];
242   status = xnn_subgraph_check_input_type_dense(xnn_node_type_static_resize_bilinear_2d, input_id, input_value);
243   if (status != xnn_status_success) {
244     return status;
245   }
246 
247   switch (input_value->datatype) {
248     case xnn_datatype_fp32:
249 #ifndef XNN_NO_S8_OPERATORS
250     case xnn_datatype_qint8:
251 #endif  // !defined(XNN_NO_S8_OPERATORS)
252 #ifndef XNN_NO_U8_OPERATORS
253     case xnn_datatype_quint8:
254 #endif  // !defined(XNN_NO_U8_OPERATORS)
255       break;
256     default:
257       xnn_log_error(
258         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
259         xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), input_id,
260         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
261       return xnn_status_invalid_parameter;
262   }
263 
264   status = xnn_subgraph_check_output_node_id(xnn_node_type_static_resize_bilinear_2d, output_id, subgraph->num_values);
265   if (status != xnn_status_success) {
266     return status;
267   }
268 
269   const struct xnn_value* output_value = &subgraph->values[output_id];
270   status = xnn_subgraph_check_output_type_dense(xnn_node_type_static_resize_bilinear_2d, output_id, output_value);
271   if (status != xnn_status_success) {
272     return status;
273   }
274 
275   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
276   switch (output_value->datatype) {
277     case xnn_datatype_fp32:
278       compute_type = xnn_compute_type_fp32;
279       break;
280 #ifndef XNN_NO_S8_OPERATORS
281     case xnn_datatype_qint8:
282       compute_type = xnn_compute_type_qs8;
283       break;
284 #endif  // !defined(XNN_NO_S8_OPERATORS)
285 #ifndef XNN_NO_U8_OPERATORS
286     case xnn_datatype_quint8:
287       compute_type = xnn_compute_type_qu8;
288       break;
289 #endif  // !defined(XNN_NO_U8_OPERATORS)
290       break;
291     default:
292       xnn_log_error(
293         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
294         xnn_node_type_to_string(xnn_node_type_static_resize_bilinear_2d), output_id,
295         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
296       return xnn_status_invalid_parameter;
297   }
298 
299 #if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
300   if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
301     if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
302       xnn_log_error(
303         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
304         ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
305         xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
306         input_value->quantization.zero_point, output_value->quantization.zero_point);
307       return xnn_status_invalid_parameter;
308     }
309     if (input_value->quantization.scale != output_value->quantization.scale) {
310       xnn_log_error(
311         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
312         ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
313         xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
314         input_value->quantization.scale, output_value->quantization.scale);
315       return xnn_status_invalid_parameter;
316     }
317   }
318 #endif  // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
319 
320   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
321   if (node == NULL) {
322     return xnn_status_out_of_memory;
323   }
324 
325   node->params.static_resize.new_height = new_height;
326   node->params.static_resize.new_width = new_width;
327 
328   node->type = xnn_node_type_static_resize_bilinear_2d;
329   node->compute_type = compute_type;
330   node->num_inputs = 1;
331   node->inputs[0] = input_id;
332   node->num_outputs = 1;
333   node->outputs[0] = output_id;
334   node->flags = flags;
335 
336   node->create = create_resize_bilinear_operator;
337   node->setup = setup_resize_bilinear_operator;
338 
339   return xnn_status_success;
340 }
341