xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/Convolution.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_XNNPACK
2 
3 #include <vector>
4 
5 #include <ATen/native/ConvUtils.h>
6 #include <ATen/native/utils/Factory.h>
7 #include <ATen/native/utils/ParamUtils.h>
8 #include <ATen/native/xnnpack/Common.h>
9 #include <ATen/native/xnnpack/Convolution.h>
10 #include <ATen/native/xnnpack/Engine.h>
11 #include <c10/util/irange.h>
12 
13 namespace at::native::xnnpack {
14 namespace internal {
15 namespace convolution2d {
16 
17 namespace {
18 
19 // Supports NHWC and NCHW FP32 convolutions with any valid
20 //  - kernel size
21 //  - padding
22 //  - stride
23 //  - dilation
24 //  - grouping
25 
26 // TODO: Decouple and improve error handling and messages.
available(const Tensor & weight,const at::OptionalIntArrayRef bias_sizes_opt,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool transposed,const float output_min,const float output_max)27 bool available(
28     const Tensor& weight,
29     const at::OptionalIntArrayRef bias_sizes_opt,
30     const IntArrayRef padding,
31     const IntArrayRef stride,
32     const IntArrayRef dilation,
33     const int64_t groups,
34     const bool transposed,
35     const float output_min,
36     const float output_max) {
37          // XNNPACK
38   return xnnpack::available() &&
39          // Weight
40          (4 == weight.ndimension()) &&
41          (weight.size(Layout::Filter::height) > 0) &&
42          (weight.size(Layout::Filter::width) > 0) &&
43          (weight.device().is_cpu()) &&
44          (kFloat == weight.scalar_type()) &&
45          // Bias
46          (bias_sizes_opt.has_value() ? ((1 == bias_sizes_opt->size()) &&
47                 ((transposed ? (weight.size(Layout::Filter::input) ==
48                                 ((*bias_sizes_opt)[0] / groups))
49                   : (weight.size(Layout::Filter::output) == ((*bias_sizes_opt)[0])))))
50             : true) &&
51          // Padding
52          (padding[Layout::Parameter::height] >= 0) &&
53          (padding[Layout::Parameter::width] >= 0) &&
54          // Stride
55          (stride[Layout::Parameter::height] > 0) &&
56          (stride[Layout::Parameter::width] > 0) &&
57          // Dilation
58          (dilation[Layout::Parameter::height] > 0) &&
59          (dilation[Layout::Parameter::width] > 0) &&
60          // Groups
61          (groups > 0) &&
62          // Input
63          (weight.size(Layout::Filter::input) > 0) &&
64          // Output
65          (weight.size(Layout::Filter::output) > 0) &&
66          // Output - Groups
67          ((weight.size(Layout::Filter::output) % groups) == 0) &&
68          // Output Min / Max
69          (output_max > output_min) &&
70          true;
71 }
72 
73 // TODO: Decouple and improve error handling and messages.
usable(const Tensor & input)74 bool usable(const Tensor& input) {
75        // Input
76   return (4 == input.ndimension()) &&
77          (input.device().is_cpu()) &&
78          (kFloat == input.scalar_type()) &&
79          (input.size(Layout::Activation4D::batch) >= 0) &&
80          (input.size(Layout::Activation4D::channels) > 0) &&
81          (input.size(Layout::Activation4D::height) > 0) &&
82          (input.size(Layout::Activation4D::width) > 0) &&
83          !input.requires_grad() &&
84          true;
85 }
86 
create_and_run(const Tensor & input,const Tensor & weight,const Tensor & bias,const IntArrayRef padding,const IntArrayRef output_padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool transposed,const float output_min,const float output_max)87 Tensor create_and_run(
88     const Tensor& input,
89     const Tensor& weight,
90     const Tensor& bias,
91     const IntArrayRef padding,
92     const IntArrayRef output_padding,
93     const IntArrayRef stride,
94     const IntArrayRef dilation,
95     const int64_t groups,
96     const bool transposed,
97     const float output_min,
98     const float output_max) {
99   auto op_context = create(
100       weight,
101       bias,
102       padding,
103       output_padding,
104       stride,
105       dilation,
106       groups,
107       transposed,
108       output_min,
109       output_max);
110   return run(op_context, input);
111 }
112 
113 // XNNPack's deconvolution operator expects weights to be indexed in the following order:
114 //   * Groups
115 //   * Group Output Channels
116 //   * Kernel Height
117 //   * Kernel Width
118 //   * Group Input Channels
119 //
120 // (ref: https://github.com/google/XNNPACK/blob/ecd8311c8fd3d9ab47edbc3df5f2b5de7dabe75f/test/deconvolution-operator-tester.h#L678)
121 //
122 // This function takes in a contiguous NHWC pytorch tensor (e.g. MemoryFormat == ChannelsLast) and rearranges the weights in preparation for use with xnnpack.
123 // By default, for pytorch, transpose conv2d weights are {input_channels, output_Channels_per_group, kernel_height, kernel_width}.
124 // In addition, it condenses the tensor from 5 to 4 dimensions as expected by the rest of the pytorch framework by combining the groups and input_channels dimension.
reorder_weights_for_transpose_conv(const Tensor & weight_nhwc,int num_groups)125 const Tensor reorder_weights_for_transpose_conv(const Tensor& weight_nhwc,
126     int num_groups) {
127 
128   TORCH_CHECK(weight_nhwc.size(0) % num_groups == 0, "The number of groups cannot be satisfied by the provided weight tensor.");
129 
130   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
131   int input_channels_per_group = weight_nhwc.size(0) / num_groups;
132   int output_channels_per_group = weight_nhwc.size(1);
133   int kernel_width = weight_nhwc.size(3);
134   int kernel_height = weight_nhwc.size(2);
135 
136   int o_offset = 1;
137   int h_offset = (output_channels_per_group);
138   int w_offset = (output_channels_per_group)*(kernel_height);
139   int i_offset = (output_channels_per_group)*(kernel_height)*(kernel_width);
140   int g_offset = (output_channels_per_group)*(kernel_height)*(kernel_width)*(input_channels_per_group);
141 
142   Tensor reordered = mobile::empty_with_tail_padding(
143      weight_nhwc.sizes(),
144      weight_nhwc.options().dtype(),
145      MemoryFormat::ChannelsLast,
146      weight_nhwc.opt_names());
147 
148   float* out_ptr = reordered.data_ptr<float>();
149   float* in_ptr = weight_nhwc.data_ptr<float>();
150 
151   int out_index = 0;
152   for (const auto g : c10::irange(num_groups)) {
153     for (const auto o : c10::irange(output_channels_per_group)) {
154       for (const auto w : c10::irange(kernel_width)) {
155         for (const auto h : c10::irange(kernel_height)) {
156           for (const auto i : c10::irange(input_channels_per_group)) {
157             int in_index = (g*g_offset) + (i*i_offset) + (h*h_offset) + (w*w_offset) + (o*o_offset);
158             out_ptr[out_index] = in_ptr[in_index];
159             out_index++;
160           }
161         }
162       }
163     }
164   }
165 
166   return reordered;
167 }
168 
169 } // namespace
170 
create(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef padding,const IntArrayRef output_padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool transposed,const float output_min,const float output_max)171 ContextConv2D create(
172     const Tensor& weight,
173     const std::optional<Tensor>& bias,
174     const IntArrayRef padding,
175     const IntArrayRef output_padding,
176     const IntArrayRef stride,
177     const IntArrayRef dilation,
178     const int64_t groups,
179     const bool transposed,
180     const float output_min,
181     const float output_max) {
182   const auto padding_expanded = expand_param_if_needed(padding, "padding", 2);
183   const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", 2);
184   const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
185   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 2);
186   const Tensor weight_nhwc = weight.contiguous(MemoryFormat::ChannelsLast);
187 
188   TORCH_CHECK(
189       available(
190           weight_nhwc,
191           (bias.has_value() && bias->defined()) ? at::OptionalIntArrayRef(bias->sizes()) : std::nullopt,
192           padding_expanded,
193           stride_expanded,
194           dilation_expanded,
195           groups,
196           transposed,
197           output_min,
198           output_max),
199       "xnnpack::convolution not available! "
200       "Reason: The provided (weight, bias, padding, stride, dilation, groups, transposed, output_min, output_max) "
201       "parameters are either invalid individually or their combination is not supported by XNNPACK.");
202 
203 
204   xnn_operator_t convolution_op{};
205   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
206   xnn_status create_status;
207   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
208   std::array<int64_t, 4> weight_sizes;
209 
210   if (transposed) {
211     const Tensor weight_reordered = reorder_weights_for_transpose_conv(weight_nhwc, groups);
212     for (const auto i : c10::irange(4)) {
213       weight_sizes[i] = weight_reordered.size(i);
214     }
215     create_status = xnn_create_deconvolution2d_nhwc_f32(
216       padding_expanded[Layout::Parameter::height],                    // output_padding_top
217       padding_expanded[Layout::Parameter::width],                     // output_padding_right
218       padding_expanded[Layout::Parameter::height],                    // output_padding_bottom
219       padding_expanded[Layout::Parameter::width],                     // output_padding_left
220       weight_reordered.size(Layout::Filter::height),                  // kernel_height
221       weight_reordered.size(Layout::Filter::width),                   // kernel_width
222       stride_expanded[Layout::Parameter::height],                     // subsampling_height
223       stride_expanded[Layout::Parameter::width],                      // subsampling_width
224       dilation_expanded[Layout::Parameter::height],                   // dilation_height
225       dilation_expanded[Layout::Parameter::width],                    // dilation_width
226       groups,                                                         // groups
227       weight_reordered.size(Layout::Filter::output) / groups,         // group_input_channels
228       weight_reordered.size(Layout::Filter::input),                   // group_output_channels
229       weight_reordered.size(Layout::Filter::output),                  // input_pixel_stride
230       weight_reordered.size(Layout::Filter::input) * groups,          // output_pixel_stride
231       weight_reordered.data_ptr<float>(),                             // kernel
232       (bias && bias->defined())
233           ? bias->contiguous().data_ptr<float>()
234           : nullptr,                                                  // bias
235       output_min,                                                     // output_min
236       output_max,                                                     // output_max
237       0u,                                                             // flags
238       nullptr,                                                        // xnn_caches_t
239       nullptr,                                                        // xnn_weights_cache_t
240       &convolution_op);                                               // operator
241   } else {
242     for (const auto i : c10::irange(4)) {
243       weight_sizes[i] = weight_nhwc.size(i);
244     }
245     create_status = xnn_create_convolution2d_nhwc_f32(
246       padding_expanded[Layout::Parameter::height],                    // input_padding_top
247       padding_expanded[Layout::Parameter::width],                     // input_padding_right
248       padding_expanded[Layout::Parameter::height],                    // input_padding_bottom
249       padding_expanded[Layout::Parameter::width],                     // input_padding_left
250       weight_nhwc.size(Layout::Filter::height),                       // kernel_height
251       weight_nhwc.size(Layout::Filter::width),                        // kernel_width
252       stride_expanded[Layout::Parameter::height],                     // subsampling_height
253       stride_expanded[Layout::Parameter::width],                      // subsampling_width
254       dilation_expanded[Layout::Parameter::height],                   // dilation_height
255       dilation_expanded[Layout::Parameter::width],                    // dilation_width
256       groups,                                                         // groups
257       weight_nhwc.size(Layout::Filter::input),                        // group_input_channels
258       weight_nhwc.size(Layout::Filter::output) / groups,              // group_output_channels
259       weight_nhwc.size(Layout::Filter::input) * groups,               // input_pixel_stride
260       weight_nhwc.size(Layout::Filter::output),                       // output_pixel_stride
261       weight_nhwc.data_ptr<float>(),                                  // kernel
262       (bias && bias->defined())
263           ? bias->contiguous().data_ptr<float>()
264           : nullptr,                                                  // bias
265       output_min,                                                     // output_min
266       output_max,                                                     // output_max
267       0u,                                                             // flags
268       nullptr,                                                        // xnn_caches_t
269       nullptr,                                                        // xnn_weights_cache_t
270       &convolution_op);                                               // operator
271   }
272 
273   TORCH_CHECK(
274       xnn_status_success == create_status,
275       (transposed ? "xnn_create_deconvolution2d_nhwc_f32 failed!"
276                   : "xnn_create_convolution2d_nhwc_f32 failed!"));
277 
278   return ContextConv2D{
279       Operator(convolution_op),
280       weight_sizes,
281       {padding_expanded[0], padding_expanded[1]},
282       {output_padding_expanded[0], output_padding_expanded[1]},
283       {stride_expanded[0], stride_expanded[1]},
284       {dilation_expanded[0], dilation_expanded[1]},
285       transposed, groups
286   };
287 }
288 
run(ContextConv2D & context,const Tensor & input)289 Tensor run(
290     ContextConv2D& context,
291     const Tensor& input) {
292   using namespace internal;
293 
294   const Tensor padded_input_nhwc = mobile::allocate_padded_contiguous_if_needed(
295       input, MemoryFormat::ChannelsLast);
296 
297   TORCH_CHECK(
298       usable(padded_input_nhwc),
299       "XNNPACK Convolution not usable! "
300       "Reason: The provided input tensor is either invalid or unsupported by XNNPACK.");
301 
302   Tensor output;
303   if (context.transposed_) {
304     output = mobile::empty_with_tail_padding(
305       conv_input_size(padded_input_nhwc.sizes(),
306         context.weight_size_,
307         context.padding_,
308         context.output_padding_,
309         context.stride_,
310         context.dilation_,
311         context.groups_),
312       padded_input_nhwc.options().dtype(),
313       MemoryFormat::ChannelsLast,
314       padded_input_nhwc.opt_names());
315   } else {
316     output = mobile::empty_with_tail_padding(
317       conv_output_size(
318           padded_input_nhwc.sizes(),
319           context.weight_size_,
320           context.padding_,
321           context.stride_,
322           context.dilation_),
323       padded_input_nhwc.options().dtype(),
324       MemoryFormat::ChannelsLast,
325       padded_input_nhwc.opt_names());
326   }
327 
328   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
329   xnn_status setup_status;
330 
331   /*
332    * Input Pointer Caching:
333    * Previously, we cached the input/output pointers and dimension parameters
334    * so that if the same pointers and parameters are used, this setup could be
335    * skipped.
336    * However, XNNPack has integrated offsets with its indirection buffer, so the
337    * buffer does not need to be recalculated even if activation tensor pointer
338    * changes as long as tensor dimensions are the same. Thus, the aforementioned
339    * manual caching is not needed here.
340    */
341 
342   if (context.transposed_) {
343     setup_status = xnn_reshape_deconvolution2d_nhwc_f32(
344       context.op.get(),
345       padded_input_nhwc.size(Layout::Activation4D::batch),   // batch_size
346       padded_input_nhwc.size(Layout::Activation4D::height),  // input_height
347       padded_input_nhwc.size(Layout::Activation4D::width),   // input_width
348       context.output_padding_[0],                            // adjustment_height
349       context.output_padding_[1],                            // adjustment_width
350       nullptr,                                               // output_height_out
351       nullptr,                                               // output_width_out
352       caffe2::pthreadpool_());                               // threadpool
353 
354     setup_status = xnn_setup_deconvolution2d_nhwc_f32(
355       context.op.get(),                                      // operator
356       padded_input_nhwc.data_ptr<float>(),                   // input
357       output.data_ptr<float>());                             // output
358   } else {
359     size_t workspace_size = SIZE_MAX;
360     size_t workspace_alignment = SIZE_MAX;
361 
362     setup_status = xnn_reshape_convolution2d_nhwc_f32(
363       context.op.get(),
364       padded_input_nhwc.size(Layout::Activation4D::batch),   // batch_size
365       padded_input_nhwc.size(Layout::Activation4D::height),  // input_height
366       padded_input_nhwc.size(Layout::Activation4D::width),   // input_width
367       &workspace_size,                                       // workspace_size
368       &workspace_alignment,                                  // workspace_alignment
369       nullptr,                                               // output_height_out
370       nullptr,                                               // output_width_out
371       caffe2::pthreadpool_());
372 
373     setup_status = xnn_setup_convolution2d_nhwc_f32(
374       context.op.get(),                                      // operator
375       nullptr,                                               // workspace
376       padded_input_nhwc.data_ptr<float>(),                   // input
377       output.data_ptr<float>());                             // output
378   }
379 
380   TORCH_CHECK(
381       xnn_status_success == setup_status,
382       (context.transposed_ ? "xnn_setup_deconvolution2d_nhwc_f32 failed!"
383                             : "xnn_setup_convolution2d_nhwc_f32 failed!"));
384 
385   const xnn_status run_status = xnn_run_operator(
386       context.op.get(),         // operator
387       caffe2::pthreadpool_());  // threadpool
388 
389   TORCH_INTERNAL_ASSERT(
390       xnn_status_success == run_status,
391       "xnn_run_operator failed!");
392 
393   return output.contiguous(input.suggest_memory_format());
394 }
395 
396 c10::intrusive_ptr<xnnpack::Conv2dOpContext>
createConv2dClampPrePackOpContext(Tensor weight,std::optional<Tensor> bias,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)397     createConv2dClampPrePackOpContext(
398         Tensor weight,
399         std::optional<Tensor> bias,
400         std::vector<int64_t> stride,
401         std::vector<int64_t> padding,
402         std::vector<int64_t> dilation,
403         int64_t groups,
404         const std::optional<Scalar>& output_min,
405         const std::optional<Scalar>& output_max) {
406       return xnnpack::XNNPackConv2dOpContext::create_context(
407           std::move(weight),
408           std::move(bias),
409           std::move(padding),
410           std::move(stride),
411           std::move(dilation),
412           groups,
413           output_min,
414           output_max);
415 }
416 
417 c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>
createConv2dTransposeClampPrePackOpContext(Tensor weight,std::optional<Tensor> bias,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> output_padding,std::vector<int64_t> dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)418     createConv2dTransposeClampPrePackOpContext(
419         Tensor weight,
420         std::optional<Tensor> bias,
421         std::vector<int64_t> stride,
422         std::vector<int64_t> padding,
423         std::vector<int64_t> output_padding,
424         std::vector<int64_t> dilation,
425         int64_t groups,
426         const std::optional<Scalar>& output_min,
427         const std::optional<Scalar>& output_max) {
428       return xnnpack::XNNPackTransposeConv2dOpContext::create_context(
429           std::move(weight),
430           std::move(bias),
431           std::move(padding),
432           std::move(output_padding),
433           std::move(stride),
434           std::move(dilation),
435           groups,
436           output_min,
437           output_max);
438 }
439 
conv2d_clamp_run(const Tensor & input,const c10::intrusive_ptr<xnnpack::Conv2dOpContext> & op_context)440 Tensor conv2d_clamp_run(
441     const Tensor& input,
442     const c10::intrusive_ptr<xnnpack::Conv2dOpContext>& op_context) {
443   return op_context->run(input);
444 }
445 
446 // Op is registered to have Any argument as we plan to reuse it for prepacked conv2d of other backends
447 IValue
unpack_prepacked_sizes_conv2d(const IValue & ivalue)448 unpack_prepacked_sizes_conv2d(const IValue& ivalue) {
449   auto op_context = ivalue.toCustomClass<xnnpack::Conv2dOpContext>();
450   const auto tuple = op_context->unpack();
451   const auto& bias = std::get<1>(tuple);
452   return IValue(std::make_tuple(
453       std::get<0>(tuple).sizes(),
454       (bias && bias->defined()) ? at::OptionalIntArrayRef(bias->sizes()) : std::nullopt,
455       std::get<2>(tuple),
456       std::get<3>(tuple),
457       std::get<4>(tuple),
458       std::get<5>(tuple)));
459 }
460 
conv2d_transpose_clamp_run(const Tensor & input,const c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext> & op_context)461 Tensor conv2d_transpose_clamp_run(
462     const Tensor& input,
463     const c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>& op_context) {
464   return op_context->run(input);
465 }
466 
467 } // namespace convolution2d
468 } // namespace internal
469 
use_convolution2d(const Tensor & input,const Tensor & weight,const at::OptionalIntArrayRef bias_sizes_opt,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool transposed)470 bool use_convolution2d(
471     const Tensor& input,
472     const Tensor& weight,
473     const at::OptionalIntArrayRef bias_sizes_opt,
474     const IntArrayRef padding,
475     const IntArrayRef stride,
476     const IntArrayRef dilation,
477     const int64_t groups,
478     const bool transposed) {
479   return internal::convolution2d::available(
480             weight,
481             bias_sizes_opt,
482             padding,
483             stride,
484             dilation,
485             groups,
486             transposed,
487             ContextConv2D::kMin,
488             ContextConv2D::kMax) &&
489          internal::convolution2d::usable(input);
490 }
491 
convolution2d(const Tensor & input,const Tensor & weight,const Tensor & bias,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups)492 Tensor convolution2d(
493     const Tensor& input,
494     const Tensor& weight,
495     const Tensor& bias,
496     const IntArrayRef padding,
497     const IntArrayRef stride,
498     const IntArrayRef dilation,
499     const int64_t groups) {
500   return internal::convolution2d::create_and_run(
501       input,
502       weight,
503       bias,
504       padding,
505       {0, 0}, // output_padding
506       stride,
507       dilation,
508       groups,
509       false,  // transposed
510       ContextConv2D::kMin,
511       ContextConv2D::kMax);
512 }
513 
514 } // namespace at::native::xnnpack
515 
516 #endif /* USE_XNNPACK */
517