1 #ifdef USE_XNNPACK
2
3 #include <ATen/native/xnnpack/Common.h>
4 #include <ATen/native/xnnpack/Engine.h>
5 #include <ATen/native/utils/Factory.h>
6
7 namespace at::native::xnnpack {
8
9
use_hardswish(const Tensor & input)10 bool use_hardswish(
11 const Tensor& input) {
12 return xnnpack::available() &&
13 (1 <= input.ndimension()) &&
14 (input.device().is_cpu()) &&
15 (kFloat == input.scalar_type()) &&
16 !input.requires_grad() &&
17 true;
18 }
19
hardswish_impl(Tensor & input,Tensor & output)20 static Tensor& hardswish_impl(Tensor& input, Tensor& output) {
21 using namespace internal;
22
23 xnn_operator_t hardswish_op{};
24 const xnn_status create_status = xnn_create_hardswish_nc_f32(
25 0, // flags
26 &hardswish_op);
27
28 TORCH_CHECK(
29 xnn_status_success == create_status,
30 "xnn_create_hardswish_nc_f32 failed!");
31
32 Operator hardswish_scoped_op(hardswish_op);
33
34 const xnn_status reshape_status = xnn_reshape_hardswish_nc_f32(
35 hardswish_op,
36 input.numel(), // Batch
37 1, // channels
38 1, // input stride
39 1, // output stride
40 caffe2::pthreadpool_()); // threadpool
41
42 TORCH_CHECK(
43 xnn_status_success == reshape_status,
44 "xnn_reshape_hardswish_nc_f32 failed!");
45
46 const xnn_status setup_status = xnn_setup_hardswish_nc_f32(
47 hardswish_op,
48 input.data_ptr<float>(),
49 output.data_ptr<float>());
50
51 TORCH_CHECK(
52 xnn_status_success == setup_status,
53 "xnn_setup_hardswish_nc_f32 failed!");
54
55 const xnn_status run_status = xnn_run_operator(
56 hardswish_op,
57 caffe2::pthreadpool_()); // threadpool
58
59 TORCH_INTERNAL_ASSERT(
60 xnn_status_success == run_status,
61 "xnn_run_operator failed!");
62
63 return output;
64 }
65
hardswish(const Tensor & input)66 Tensor hardswish(const Tensor& input) {
67 Tensor padded_input = mobile::allocate_padded_contiguous_if_needed(
68 input, input.suggest_memory_format());
69
70 Tensor output = mobile::empty_with_tail_padding(
71 padded_input.sizes(),
72 padded_input.options().dtype(),
73 input.suggest_memory_format(),
74 padded_input.opt_names());
75
76 hardswish_impl(padded_input, output);
77 return output.contiguous(input.suggest_memory_format());
78 }
79
hardswish_(Tensor & input)80 Tensor& hardswish_(Tensor& input) {
81 Tensor padded_input = mobile::allocate_padded_contiguous_if_needed(
82 input, input.suggest_memory_format());
83
84 // Don't need to allocate output if input is contiguous & already padded
85 if (input.data_ptr() == padded_input.data_ptr()) {
86 hardswish_impl(input, input);
87 return input;
88 } else {
89 Tensor output = mobile::empty_with_tail_padding(
90 padded_input.sizes(),
91 padded_input.options().dtype(),
92 input.suggest_memory_format(),
93 padded_input.opt_names());
94 hardswish_impl(padded_input, output);
95 return input.copy_(output);
96 }
97 }
98
99 } // namespace at::native::xnnpack
100
101 #endif /* USE_XNNPACK */
102