xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/Activation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_XNNPACK
2 
3 #include <ATen/native/xnnpack/Common.h>
4 #include <ATen/native/xnnpack/Engine.h>
5 #include <ATen/native/utils/Factory.h>
6 
7 namespace at::native::xnnpack {
8 
9 
use_hardswish(const Tensor & input)10 bool use_hardswish(
11   const Tensor& input) {
12   return xnnpack::available() &&
13           (1 <= input.ndimension()) &&
14           (input.device().is_cpu()) &&
15           (kFloat == input.scalar_type()) &&
16           !input.requires_grad() &&
17            true;
18 }
19 
hardswish_impl(Tensor & input,Tensor & output)20 static Tensor& hardswish_impl(Tensor& input, Tensor& output) {
21   using namespace internal;
22 
23   xnn_operator_t hardswish_op{};
24   const xnn_status create_status = xnn_create_hardswish_nc_f32(
25     0, // flags
26     &hardswish_op);
27 
28   TORCH_CHECK(
29     xnn_status_success == create_status,
30     "xnn_create_hardswish_nc_f32 failed!");
31 
32   Operator hardswish_scoped_op(hardswish_op);
33 
34   const xnn_status reshape_status = xnn_reshape_hardswish_nc_f32(
35     hardswish_op,
36     input.numel(),  // Batch
37     1, // channels
38     1, // input stride
39     1, // output stride
40     caffe2::pthreadpool_());  // threadpool
41 
42   TORCH_CHECK(
43     xnn_status_success == reshape_status,
44     "xnn_reshape_hardswish_nc_f32 failed!");
45 
46   const xnn_status setup_status = xnn_setup_hardswish_nc_f32(
47     hardswish_op,
48     input.data_ptr<float>(),
49     output.data_ptr<float>());
50 
51   TORCH_CHECK(
52     xnn_status_success == setup_status,
53     "xnn_setup_hardswish_nc_f32 failed!");
54 
55   const xnn_status run_status = xnn_run_operator(
56     hardswish_op,
57     caffe2::pthreadpool_());  // threadpool
58 
59   TORCH_INTERNAL_ASSERT(
60     xnn_status_success == run_status,
61     "xnn_run_operator failed!");
62 
63   return output;
64 }
65 
hardswish(const Tensor & input)66 Tensor hardswish(const Tensor& input) {
67   Tensor padded_input = mobile::allocate_padded_contiguous_if_needed(
68     input, input.suggest_memory_format());
69 
70   Tensor output = mobile::empty_with_tail_padding(
71     padded_input.sizes(),
72     padded_input.options().dtype(),
73     input.suggest_memory_format(),
74     padded_input.opt_names());
75 
76   hardswish_impl(padded_input, output);
77   return output.contiguous(input.suggest_memory_format());
78 }
79 
hardswish_(Tensor & input)80 Tensor& hardswish_(Tensor& input) {
81   Tensor padded_input = mobile::allocate_padded_contiguous_if_needed(
82     input, input.suggest_memory_format());
83 
84   // Don't need to allocate output if input is contiguous & already padded
85   if (input.data_ptr() == padded_input.data_ptr()) {
86     hardswish_impl(input, input);
87     return input;
88   } else {
89     Tensor output = mobile::empty_with_tail_padding(
90       padded_input.sizes(),
91       padded_input.options().dtype(),
92       input.suggest_memory_format(),
93       padded_input.opt_names());
94     hardswish_impl(padded_input, output);
95     return input.copy_(output);
96   }
97 }
98 
99 } // namespace at::native::xnnpack
100 
101 #endif /* USE_XNNPACK */
102