1 #ifdef USE_XNNPACK
2
3 #include <ATen/native/xnnpack/Common.h>
4 #include <ATen/native/utils/Factory.h>
5 #include <ATen/native/xnnpack/Linear.h>
6
7 namespace at::native::xnnpack {
8 namespace internal::linear {
9
10 namespace {
11
12 // Supports NHWC and NCHW FP32 linear operators.
13
14 // TODO: Decouple and improve error handling and messages.
available(const Tensor & weight,const std::optional<Tensor> & bias,const float output_min,const float output_max)15 bool available(
16 const Tensor& weight,
17 const std::optional<Tensor>& bias,
18 const float output_min,
19 const float output_max) {
20 // XNNPACK
21 return xnnpack::available() &&
22 // Weight
23 (2 == weight.ndimension()) &&
24 (weight.device().is_cpu()) &&
25 (kFloat == weight.scalar_type()) &&
26 !weight.requires_grad() &&
27 // Bias
28 ((bias && bias->defined()) ? ((1 == bias->ndimension()) &&
29 (bias->device().is_cpu()) &&
30 (kFloat == bias->scalar_type()) &&
31 (weight.size(Layout::Filter::output)) == bias->size(0) &&
32 !bias->requires_grad())
33 : true) &&
34 // Output Min / Max
35 (output_max > output_min) &&
36 true;
37 }
38
39 // TODO: Decouple and improve error handling and messages.
usable(const Tensor & input)40 bool usable(const Tensor& input) {
41 // Input
42 return (1 <= input.ndimension()) &&
43 (input.device().is_cpu()) &&
44 (kFloat == input.scalar_type()) &&
45 !input.requires_grad() &&
46 true;
47 }
48
create_and_run(const Tensor & input,const Tensor & weight,const Tensor & bias,const float output_min,const float output_max)49 Tensor create_and_run(
50 const Tensor& input,
51 const Tensor& weight,
52 const Tensor& bias,
53 const float output_min,
54 const float output_max) {
55 return run(
56 create(
57 weight,
58 bias,
59 output_min,
60 output_max),
61 input);
62 }
63
64 } // anonymous namespace
65
create(const Tensor & weight,const std::optional<Tensor> & bias,const float output_min,const float output_max)66 ContextLinear create(
67 const Tensor& weight,
68 const std::optional<Tensor>& bias,
69 const float output_min,
70 const float output_max) {
71 const Tensor weight_contig = weight.contiguous();
72
73 TORCH_CHECK(
74 available(
75 weight_contig,
76 bias,
77 output_min,
78 output_max),
79 "XNNPACK Linear not available! "
80 "Reason: The provided (weight, bias, output_min, output_max) parameters are "
81 "either invalid individually or their combination is not supported by XNNPACK.");
82
83 xnn_operator_t linear_op{};
84
85 const xnn_status create_status = xnn_create_fully_connected_nc_f32(
86 weight_contig.size(Layout::Filter::input), // input_channels
87 weight_contig.size(Layout::Filter::output), // output_channels
88 weight_contig.size(Layout::Filter::input), // input_pixel_stride
89 weight_contig.size(Layout::Filter::output), // output_pixel_stride
90 weight_contig.data_ptr<float>(), // kernel
91 (bias && bias->defined()) ?
92 bias->contiguous().data_ptr<float>() :
93 nullptr, // bias
94 output_min, // output_min
95 output_max, // output_max
96 0u, // flags
97 nullptr, // xnn_caches_t
98 nullptr, // xnn_weights_cache_t
99 &linear_op); // operator
100
101 TORCH_CHECK(
102 xnn_status_success == create_status,
103 "xnn_create_fully_connected_nc_f32 failed!");
104
105 return ContextLinear(
106 Operator(linear_op),
107 weight_contig.size(Layout::Filter::output)
108 );
109 }
110
run(const ContextLinear & context,const Tensor & input)111 Tensor run(
112 const ContextLinear& context,
113 const Tensor& input) {
114 using namespace internal;
115
116 // For compatibility with aten::linear
117 auto ip = input;
118 if (input.ndimension() == 1) {
119 ip = input.unsqueeze(0);
120 }
121
122 const Tensor padded_input = mobile::allocate_padded_contiguous_if_needed(
123 ip, ip.suggest_memory_format());
124
125 TORCH_CHECK(
126 usable(padded_input),
127 "XNNPACK Linear not usable! "
128 "Reason: The provided input tensor is either invalid or unsupported by XNNPACK.");
129
130 const IntArrayRef input_size = padded_input.sizes();
131 std::vector<int64_t> output_size(input_size.cbegin(), input_size.cend());
132 output_size.back() = context.output_channels;
133
134 Tensor output = mobile::empty_with_tail_padding(
135 output_size,
136 padded_input.options().dtype(),
137 padded_input.suggest_memory_format(),
138 padded_input.opt_names());
139
140 const xnn_status reshape_status = xnn_reshape_fully_connected_nc_f32(
141 context.op.get(), // operator
142 Layout::ActivationND::batch(padded_input.sizes()), // Batch,
143 caffe2::pthreadpool_()); // threadpool
144
145 TORCH_CHECK(
146 xnn_status_success == reshape_status,
147 "xnn_reshape_fully_connected_nc_f32 failed!");
148
149 const xnn_status setup_status = xnn_setup_fully_connected_nc_f32(
150 context.op.get(), // operator
151 padded_input.data_ptr<float>(), // input
152 output.data_ptr<float>()); // output
153
154 TORCH_CHECK(
155 xnn_status_success == setup_status,
156 "xnn_setup_fully_connected_nc_f32 failed!");
157
158 const xnn_status run_status = xnn_run_operator(
159 context.op.get(), // operator
160 caffe2::pthreadpool_()); // threadpool
161
162 TORCH_INTERNAL_ASSERT(
163 xnn_status_success == run_status,
164 "xnn_run_operator failed!");
165
166 // For compatibility with aten::linear
167 if (input.ndimension() == 1) {
168 output.squeeze_(0);
169 }
170
171 return output;
172 }
173
createLinearClampPrePackOpContext(Tensor weight,std::optional<Tensor> bias,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)174 c10::intrusive_ptr<xnnpack::LinearOpContext> createLinearClampPrePackOpContext(
175 Tensor weight,
176 std::optional<Tensor> bias,
177 const std::optional<Scalar>& output_min,
178 const std::optional<Scalar>& output_max) {
179 return xnnpack::XNNPackLinearOpContext::create_context(
180 std::move(weight), std::move(bias), output_min, output_max);
181 }
182
linear_clamp_run(const Tensor & input,const c10::intrusive_ptr<xnnpack::LinearOpContext> & op_context)183 Tensor linear_clamp_run(
184 const Tensor& input,
185 const c10::intrusive_ptr<xnnpack::LinearOpContext>& op_context) {
186 return op_context->run(input);
187 }
188
189 IValue
unpack_prepacked_sizes_linear(const IValue & ivalue)190 unpack_prepacked_sizes_linear(const IValue& ivalue) {
191 auto op_context = ivalue.toCustomClass<xnnpack::LinearOpContext>();
192 const auto tuple = op_context->unpack();
193 const auto& bias = std::get<1>(tuple);
194 return IValue(std::make_tuple(
195 std::get<0>(tuple).sizes(),
196 (bias && bias->defined()) ? at::OptionalIntArrayRef(bias->sizes()) : std::nullopt));
197 }
198
199 } // namespace internal::linear
200
use_linear(const Tensor & input,const Tensor & weight,const Tensor & bias)201 bool use_linear(
202 const Tensor& input,
203 const Tensor& weight,
204 const Tensor& bias) {
205 return internal::linear::available(
206 weight,
207 bias,
208 ContextLinear::kMin,
209 ContextLinear::kMax) &&
210 internal::linear::usable(input);
211 internal::linear::usable(input);
212 }
213
linear(const Tensor & input,const Tensor & weight,const Tensor & bias)214 Tensor linear(
215 const Tensor& input,
216 const Tensor& weight,
217 const Tensor& bias) {
218 return internal::linear::create_and_run(
219 input,
220 weight,
221 bias,
222 ContextLinear::kMin,
223 ContextLinear::kMax);
224 }
225
226 } // namespace at::native::xnnpack
227
228 #endif /* USE_XNNPACK */
229