xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_convolution.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/dtype_util.h>
12 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
13 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 using ScalarType = exec_aten::ScalarType;
22 using IntArrayRef = exec_aten::ArrayRef<int64_t>;
23 using SizesArrayRef = exec_aten::ArrayRef<exec_aten::SizesType>;
24 using DimOrderArrayRef = exec_aten::ArrayRef<exec_aten::DimOrderType>;
25 using StridesArrayRef = exec_aten::ArrayRef<exec_aten::StridesType>;
26 
27 namespace {
28 
29 /**
30  * Computes 2D convolution out results for a given group and channel. The
31  * computation can be thought of as a stencil computation: we iterate over an
32  * in of size in_C_per_group x in_H x in_W, with a stencil of size
33  * in_C_per_group x in_H x in_W, to compute an out channel of size 1 x out_H x
34  * out_W.
35  */
36 template <typename CTYPE, typename LoadFn = CTYPE (*)(const void*)>
conv2d_impl(const CTYPE * const in_ptr,SizesArrayRef in_sizes,StridesArrayRef in_strides,const CTYPE * const w_ptr,SizesArrayRef w_sizes,StridesArrayRef w_strides,const exec_aten::optional<Tensor> & bias,const char * const bias_ptr,LoadFn load_bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,const int64_t groups,CTYPE * const out_ptr,SizesArrayRef out_sizes,StridesArrayRef out_strides,const size_t batch,const size_t group,const size_t out_c,bool transposed)37 void conv2d_impl(
38     const CTYPE* const in_ptr,
39     SizesArrayRef in_sizes,
40     StridesArrayRef in_strides,
41     const CTYPE* const w_ptr,
42     SizesArrayRef w_sizes,
43     StridesArrayRef w_strides,
44     const exec_aten::optional<Tensor>& bias,
45     const char* const bias_ptr,
46     LoadFn load_bias,
47     IntArrayRef stride,
48     IntArrayRef padding,
49     IntArrayRef dilation,
50     const int64_t groups,
51     CTYPE* const out_ptr,
52     SizesArrayRef out_sizes,
53     StridesArrayRef out_strides,
54     const size_t batch,
55     const size_t group,
56     const size_t out_c,
57     bool transposed) {
58   size_t in_C = in_sizes[1];
59   size_t out_C = out_sizes[1];
60 
61   size_t out_H = out_sizes[2];
62   size_t in_H = in_sizes[2];
63   size_t w_H = w_sizes[2];
64 
65   size_t out_W = out_sizes[3];
66   size_t in_W = in_sizes[3];
67   size_t w_W = w_sizes[3];
68 
69   size_t in_C_per_group = in_C / groups;
70   size_t in_c_start = group * in_C_per_group;
71 
72   size_t out_C_per_group = out_C / groups;
73   size_t out_c_start = group * out_C_per_group;
74 
75   exec_aten::SizesType in_coord[kTensorDimensionLimit];
76   in_coord[0] = batch;
77   exec_aten::SizesType out_coord[kTensorDimensionLimit];
78   out_coord[0] = batch;
79   out_coord[1] = out_c;
80   exec_aten::SizesType w_coord[kTensorDimensionLimit];
81 
82   const int64_t stride_y = val_at(stride, 0);
83   const int64_t padding_y = val_at(padding, 0, /*default_value=*/0);
84   const int64_t dilation_y = val_at(dilation, 0);
85   const int64_t stride_x = val_at(stride, 1);
86   const int64_t padding_x = val_at(padding, 1, /*default_value=*/0);
87   const int64_t dilation_x = val_at(dilation, 1);
88 
89   if (!transposed) {
90     w_coord[0] = out_c;
91     // Compute 2D output region
92     for (size_t out_y = 0; out_y < out_H; ++out_y) {
93       out_coord[2] = out_y;
94       for (size_t out_x = 0; out_x < out_W; ++out_x) {
95         out_coord[3] = out_x;
96 
97         CTYPE accum = 0.0f;
98         for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
99              ++in_c) {
100           in_coord[1] = in_c;
101           w_coord[1] = in_c - in_c_start;
102 
103           for (size_t w_y = 0; w_y < w_H; ++w_y) {
104             w_coord[2] = w_y;
105 
106             size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
107             in_coord[2] = in_y;
108             // Only proceed if input y coordinate is within bounds
109             if (in_y >= 0 && in_y < in_H) {
110               for (size_t w_x = 0; w_x < w_W; ++w_x) {
111                 w_coord[3] = w_x;
112 
113                 size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
114                 in_coord[3] = in_x;
115 
116                 // Only proceed if input x coordinate is within bounds
117                 if (in_x >= 0 && in_x < in_W) {
118                   size_t in_idx =
119                       calculate_linear_index(in_coord, in_strides.data(), 4);
120                   CTYPE in_val = in_ptr[in_idx];
121 
122                   size_t w_idx =
123                       calculate_linear_index(w_coord, w_strides.data(), 4);
124                   CTYPE w_val = w_ptr[w_idx];
125 
126                   accum += in_val * w_val;
127                 }
128               }
129             }
130           }
131         }
132 
133         if (bias_ptr != nullptr) {
134           accum += load_bias(&bias_ptr[out_c * bias.value().element_size()]);
135         }
136         size_t out_idx =
137             calculate_linear_index(out_coord, out_strides.data(), 4);
138         out_ptr[out_idx] = accum;
139       }
140     }
141   } else { // transposed convolution
142     w_coord[1] = out_c - out_c_start;
143 
144     for (size_t in_y = 0; in_y < in_H; ++in_y) {
145       in_coord[2] = in_y;
146 
147       for (size_t in_x = 0; in_x < in_W; ++in_x) {
148         in_coord[3] = in_x;
149 
150         for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
151              ++in_c) {
152           in_coord[1] = in_c;
153 
154           size_t in_idx =
155               calculate_linear_index(in_coord, in_strides.data(), 4);
156           CTYPE in_val = in_ptr[in_idx];
157 
158           w_coord[0] = in_c;
159           for (size_t w_y = 0; w_y < w_H; ++w_y) {
160             w_coord[2] = w_y;
161             size_t out_y = stride_y * in_y + dilation_y * w_y - padding_y;
162             out_coord[2] = out_y;
163 
164             // Only proceed if output y coordinate is within bounds
165             if (out_y >= 0 && out_y < out_H) {
166               for (size_t w_x = 0; w_x < w_W; ++w_x) {
167                 w_coord[3] = w_x;
168                 size_t out_x = stride_x * in_x + dilation_x * w_x - padding_x;
169                 out_coord[3] = out_x;
170 
171                 // Only proceed if output x coordinate is within bounds
172                 if (out_x >= 0 && out_x < out_W) {
173                   size_t w_idx =
174                       calculate_linear_index(w_coord, w_strides.data(), 4);
175                   CTYPE w_val = w_ptr[w_idx];
176 
177                   size_t out_idx =
178                       calculate_linear_index(out_coord, out_strides.data(), 4);
179 
180                   out_ptr[out_idx] += in_val * w_val;
181                 }
182               }
183             }
184           }
185         }
186       }
187     }
188   }
189 }
190 
191 template <typename CTYPE, typename LoadFn = CTYPE (*)(const void*)>
convolution_wrapper(const Tensor & in,const Tensor & weight,const exec_aten::optional<Tensor> & bias,LoadFn load_bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,int64_t groups,Tensor & out)192 void convolution_wrapper(
193     const Tensor& in,
194     const Tensor& weight,
195     const exec_aten::optional<Tensor>& bias,
196     LoadFn load_bias,
197     IntArrayRef stride,
198     IntArrayRef padding,
199     IntArrayRef dilation,
200     bool transposed,
201     int64_t groups,
202     Tensor& out) {
203   SizesArrayRef in_sizes = in.sizes();
204   SizesArrayRef weight_sizes = weight.sizes();
205   SizesArrayRef out_sizes = out.sizes();
206 
207   DimOrderArrayRef in_dim_order = in.dim_order();
208   DimOrderArrayRef weight_dim_order = weight.dim_order();
209   DimOrderArrayRef out_dim_order = out.dim_order();
210 
211   IntArrayRef stride_ = stride;
212   IntArrayRef padding_ = padding;
213   IntArrayRef dilation_ = dilation;
214 
215   // Define arrays for modified sizes, etc. which will potentially be used
216   exec_aten::SizesType in_sizes_arr[kTensorDimensionLimit];
217   exec_aten::DimOrderType in_dim_order_arr[kTensorDimensionLimit];
218   size_t in_ndim;
219   exec_aten::SizesType weight_sizes_arr[kTensorDimensionLimit];
220   exec_aten::DimOrderType weight_dim_order_arr[kTensorDimensionLimit];
221   size_t weight_ndim;
222   exec_aten::SizesType out_sizes_arr[kTensorDimensionLimit];
223   exec_aten::DimOrderType out_dim_order_arr[kTensorDimensionLimit];
224   size_t out_ndim;
225 
226   int64_t stride_arr[2];
227   int64_t padding_arr[2];
228   int64_t dilation_arr[2];
229 
230   // If in has a dim of 3, then a 1D convolution will be performed. A 1D
231   // convolution is equivalent to a 2D convolution where the height dim of
232   // all tensors is 1, and stride = 1, padding = 0, and dilation = 1 for
233   // the height dimension. Therefore the tensor sizes are unsqueezed and
234   // the stride, padding, and dilation are adjusted so that a 2D
235   // convolution implementation can be used.
236   if (in.dim() == 3) {
237     get_unsqueezed_sizes(in, 2, in_sizes_arr, in_ndim);
238     in_sizes = {in_sizes_arr, in_ndim};
239     get_unsqueezed_dim_order(in, 2, in_dim_order_arr);
240     in_dim_order = {in_dim_order_arr, in_ndim};
241 
242     get_unsqueezed_sizes(weight, 2, weight_sizes_arr, weight_ndim);
243     weight_sizes = {weight_sizes_arr, weight_ndim};
244     get_unsqueezed_dim_order(weight, 2, weight_dim_order_arr);
245     weight_dim_order = {weight_dim_order_arr, weight_ndim};
246 
247     get_unsqueezed_sizes(out, 2, out_sizes_arr, out_ndim);
248     out_sizes = {out_sizes_arr, out_ndim};
249     get_unsqueezed_dim_order(out, 2, out_dim_order_arr);
250     out_dim_order = {out_dim_order_arr, out_ndim};
251 
252     stride_arr[0] = 1;
253     stride_arr[1] = stride[0];
254     stride_ = {stride_arr, 2};
255 
256     padding_arr[0] = 0;
257     padding_arr[1] = padding[0];
258     padding_ = {padding_arr, 2};
259 
260     dilation_arr[0] = 1;
261     if (dilation.size() > 0) {
262       dilation_arr[1] = dilation[0];
263     } else {
264       dilation_arr[1] = 1;
265     }
266     dilation_ = {dilation_arr, 2};
267   }
268 
269   exec_aten::StridesType in_strides[kTensorDimensionLimit];
270   dim_order_to_stride_nocheck(
271       in_sizes.data(), in_dim_order.data(), in_sizes.size(), in_strides);
272 
273   exec_aten::StridesType weight_strides[kTensorDimensionLimit];
274   dim_order_to_stride_nocheck(
275       weight_sizes.data(),
276       weight_dim_order.data(),
277       weight_sizes.size(),
278       weight_strides);
279 
280   exec_aten::StridesType out_strides[kTensorDimensionLimit];
281   dim_order_to_stride_nocheck(
282       out_sizes.data(), out_dim_order.data(), out_sizes.size(), out_strides);
283 
284   CTYPE* const out_ptr = out.mutable_data_ptr<CTYPE>();
285   const CTYPE* const in_ptr = in.const_data_ptr<CTYPE>();
286   const CTYPE* const w_ptr = weight.const_data_ptr<CTYPE>();
287   const char* const bias_ptr = bias.has_value()
288       ? reinterpret_cast<const char*>(bias.value().const_data_ptr())
289       : nullptr;
290 
291   size_t out_N = out.size(0);
292   size_t out_C = out.size(1);
293   size_t out_C_per_group = out_C / groups;
294 
295   if (transposed) {
296     // For transposed convolution, we need to initialized the output before we
297     // can accumulate into it.
298     if (bias_ptr == nullptr) {
299       // If bias is not present, we need to initialize the output to 0
300       memset(out_ptr, 0, out.nbytes());
301     } else {
302       // If bias is present, we initialize the output to the bias value
303       for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
304         out_ptr[out_ix] = load_bias(&bias_ptr
305                                         [((out_ix / out_strides[1]) % out_C) *
306                                          bias.value().element_size()]);
307       }
308     }
309   }
310 
311   for (size_t batch = 0; batch < out_N; ++batch) {
312     for (size_t group = 0; group < groups; ++group) {
313       // Align channel offset based on the group
314       size_t out_c_start = group * out_C_per_group;
315       // Populate all the out channels in the group
316       for (size_t out_c = out_c_start; out_c < out_c_start + out_C_per_group;
317            ++out_c) {
318         conv2d_impl(
319             in_ptr,
320             in_sizes,
321             {in_strides, 4},
322             w_ptr,
323             weight_sizes,
324             {weight_strides, 4},
325             bias,
326             bias_ptr,
327             load_bias,
328             stride_,
329             padding_,
330             dilation_,
331             groups,
332             out_ptr,
333             out_sizes,
334             {out_strides, 4},
335             batch,
336             group,
337             out_c,
338             transposed);
339       }
340     }
341   }
342 }
343 
344 } // namespace
345 
convolution_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & weight,const exec_aten::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,Tensor & out)346 Tensor& convolution_out(
347     KernelRuntimeContext& ctx,
348     const Tensor& in,
349     const Tensor& weight,
350     const exec_aten::optional<Tensor>& bias,
351     IntArrayRef stride,
352     IntArrayRef padding,
353     IntArrayRef dilation,
354     bool transposed,
355     IntArrayRef output_padding,
356     int64_t groups,
357     Tensor& out) {
358   (void)ctx;
359 
360   ET_KERNEL_CHECK(
361       ctx,
362       check_convolution_args(
363           in,
364           weight,
365           bias,
366           stride,
367           padding,
368           dilation,
369           transposed,
370           output_padding,
371           groups,
372           out),
373       InvalidArgument,
374       out);
375 
376   ET_KERNEL_CHECK(
377       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
378 
379   size_t output_ndim = 0;
380   exec_aten::SizesType output_sizes[kTensorDimensionLimit];
381   get_convolution_out_target_size(
382       in,
383       weight,
384       stride,
385       padding,
386       dilation,
387       transposed,
388       output_padding,
389       groups,
390       output_sizes,
391       &output_ndim);
392 
393   ET_KERNEL_CHECK(
394       ctx,
395       output_size_is_valid({output_sizes, output_ndim}, in.dim() - 2),
396       InvalidArgument,
397       out);
398 
399   ET_KERNEL_CHECK(
400       ctx,
401       resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
402       InvalidArgument,
403       out);
404 
405   if (out.numel() == 0) {
406     return out;
407   }
408 
409   // @lint-ignore CLANGTIDY facebook-hte-CArray
410   static constexpr const char name[] = "convolution.out";
411 
412   ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
413     const auto load_bias = bias.has_value()
414         ? utils::internal::get_load_to_common_fn<CTYPE, name>(
415               bias.value(), utils::SupportedTensorDtypes::REALHBF16)
416         : nullptr;
417     convolution_wrapper<CTYPE>(
418         in,
419         weight,
420         bias,
421         load_bias,
422         stride,
423         padding,
424         dilation,
425         transposed,
426         groups,
427         out);
428   });
429 
430   return out;
431 }
432 
433 } // namespace native
434 } // namespace executor
435 } // namespace torch
436