xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/Linear.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_XNNPACK
2 
3 #include <ATen/native/xnnpack/Common.h>
4 #include <ATen/native/utils/Factory.h>
5 #include <ATen/native/xnnpack/Linear.h>
6 
7 namespace at::native::xnnpack {
8 namespace internal::linear {
9 
10 namespace {
11 
12 // Supports NHWC and NCHW FP32 linear operators.
13 
14 // TODO: Decouple and improve error handling and messages.
available(const Tensor & weight,const std::optional<Tensor> & bias,const float output_min,const float output_max)15 bool available(
16     const Tensor& weight,
17     const std::optional<Tensor>& bias,
18     const float output_min,
19     const float output_max) {
20          // XNNPACK
21   return xnnpack::available() &&
22           // Weight
23           (2 == weight.ndimension()) &&
24           (weight.device().is_cpu()) &&
25           (kFloat == weight.scalar_type()) &&
26           !weight.requires_grad() &&
27           // Bias
28           ((bias && bias->defined()) ? ((1 == bias->ndimension()) &&
29                                        (bias->device().is_cpu()) &&
30                                        (kFloat == bias->scalar_type()) &&
31                                        (weight.size(Layout::Filter::output)) == bias->size(0) &&
32                                        !bias->requires_grad())
33                                      : true) &&
34           // Output Min / Max
35           (output_max > output_min) &&
36           true;
37 }
38 
39 // TODO: Decouple and improve error handling and messages.
usable(const Tensor & input)40 bool usable(const Tensor& input) {
41          // Input
42   return (1 <= input.ndimension()) &&
43          (input.device().is_cpu()) &&
44          (kFloat == input.scalar_type()) &&
45          !input.requires_grad() &&
46          true;
47 }
48 
create_and_run(const Tensor & input,const Tensor & weight,const Tensor & bias,const float output_min,const float output_max)49 Tensor create_and_run(
50     const Tensor& input,
51     const Tensor& weight,
52     const Tensor& bias,
53     const float output_min,
54     const float output_max) {
55   return run(
56       create(
57           weight,
58           bias,
59           output_min,
60           output_max),
61       input);
62 }
63 
64 } // anonymous namespace
65 
create(const Tensor & weight,const std::optional<Tensor> & bias,const float output_min,const float output_max)66 ContextLinear create(
67     const Tensor& weight,
68     const std::optional<Tensor>& bias,
69     const float output_min,
70     const float output_max) {
71   const Tensor weight_contig = weight.contiguous();
72 
73   TORCH_CHECK(
74         available(
75           weight_contig,
76           bias,
77           output_min,
78           output_max),
79       "XNNPACK Linear not available! "
80       "Reason: The provided (weight, bias, output_min, output_max) parameters are "
81       "either invalid individually or their combination is not supported by XNNPACK.");
82 
83   xnn_operator_t linear_op{};
84 
85   const xnn_status create_status = xnn_create_fully_connected_nc_f32(
86       weight_contig.size(Layout::Filter::input),                        // input_channels
87       weight_contig.size(Layout::Filter::output),                       // output_channels
88       weight_contig.size(Layout::Filter::input),                        // input_pixel_stride
89       weight_contig.size(Layout::Filter::output),                       // output_pixel_stride
90       weight_contig.data_ptr<float>(),                                  // kernel
91       (bias && bias->defined()) ?
92           bias->contiguous().data_ptr<float>() :
93           nullptr,                                                      // bias
94       output_min,                                                     // output_min
95       output_max,                                                     // output_max
96       0u,                                                             // flags
97       nullptr,                                                        // xnn_caches_t
98       nullptr,                                                        // xnn_weights_cache_t
99       &linear_op);                                                    // operator
100 
101   TORCH_CHECK(
102       xnn_status_success == create_status,
103       "xnn_create_fully_connected_nc_f32 failed!");
104 
105   return ContextLinear(
106     Operator(linear_op),
107     weight_contig.size(Layout::Filter::output)
108   );
109 }
110 
run(const ContextLinear & context,const Tensor & input)111 Tensor run(
112     const ContextLinear& context,
113     const Tensor& input) {
114   using namespace internal;
115 
116   // For compatibility with aten::linear
117   auto ip = input;
118   if (input.ndimension() == 1) {
119     ip = input.unsqueeze(0);
120   }
121 
122   const Tensor padded_input = mobile::allocate_padded_contiguous_if_needed(
123       ip, ip.suggest_memory_format());
124 
125   TORCH_CHECK(
126       usable(padded_input),
127       "XNNPACK Linear not usable! "
128       "Reason: The provided input tensor is either invalid or unsupported by XNNPACK.");
129 
130   const IntArrayRef input_size = padded_input.sizes();
131   std::vector<int64_t> output_size(input_size.cbegin(), input_size.cend());
132   output_size.back() = context.output_channels;
133 
134   Tensor output = mobile::empty_with_tail_padding(
135       output_size,
136       padded_input.options().dtype(),
137       padded_input.suggest_memory_format(),
138       padded_input.opt_names());
139 
140   const xnn_status reshape_status = xnn_reshape_fully_connected_nc_f32(
141       context.op.get(),                                   // operator
142       Layout::ActivationND::batch(padded_input.sizes()),  // Batch,
143       caffe2::pthreadpool_());                            // threadpool
144 
145   TORCH_CHECK(
146       xnn_status_success == reshape_status,
147       "xnn_reshape_fully_connected_nc_f32 failed!");
148 
149   const xnn_status setup_status = xnn_setup_fully_connected_nc_f32(
150       context.op.get(),                                   // operator
151       padded_input.data_ptr<float>(),                     // input
152       output.data_ptr<float>());                          // output
153 
154   TORCH_CHECK(
155       xnn_status_success == setup_status,
156       "xnn_setup_fully_connected_nc_f32 failed!");
157 
158   const xnn_status run_status = xnn_run_operator(
159       context.op.get(),         // operator
160       caffe2::pthreadpool_());  // threadpool
161 
162   TORCH_INTERNAL_ASSERT(
163       xnn_status_success == run_status,
164       "xnn_run_operator failed!");
165 
166   // For compatibility with aten::linear
167   if (input.ndimension() == 1) {
168       output.squeeze_(0);
169   }
170 
171   return output;
172 }
173 
createLinearClampPrePackOpContext(Tensor weight,std::optional<Tensor> bias,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)174 c10::intrusive_ptr<xnnpack::LinearOpContext> createLinearClampPrePackOpContext(
175     Tensor weight,
176     std::optional<Tensor> bias,
177     const std::optional<Scalar>& output_min,
178     const std::optional<Scalar>& output_max) {
179   return xnnpack::XNNPackLinearOpContext::create_context(
180       std::move(weight), std::move(bias), output_min, output_max);
181 }
182 
linear_clamp_run(const Tensor & input,const c10::intrusive_ptr<xnnpack::LinearOpContext> & op_context)183 Tensor linear_clamp_run(
184     const Tensor& input,
185     const c10::intrusive_ptr<xnnpack::LinearOpContext>& op_context) {
186   return op_context->run(input);
187 }
188 
189 IValue
unpack_prepacked_sizes_linear(const IValue & ivalue)190 unpack_prepacked_sizes_linear(const IValue& ivalue) {
191   auto op_context = ivalue.toCustomClass<xnnpack::LinearOpContext>();
192   const auto tuple = op_context->unpack();
193   const auto& bias = std::get<1>(tuple);
194   return IValue(std::make_tuple(
195       std::get<0>(tuple).sizes(),
196       (bias && bias->defined()) ? at::OptionalIntArrayRef(bias->sizes()) : std::nullopt));
197 }
198 
199 } // namespace internal::linear
200 
use_linear(const Tensor & input,const Tensor & weight,const Tensor & bias)201 bool use_linear(
202     const Tensor& input,
203     const Tensor& weight,
204     const Tensor& bias) {
205   return internal::linear::available(
206             weight,
207             bias,
208             ContextLinear::kMin,
209             ContextLinear::kMax) &&
210          internal::linear::usable(input);
211       internal::linear::usable(input);
212 }
213 
linear(const Tensor & input,const Tensor & weight,const Tensor & bias)214 Tensor linear(
215     const Tensor& input,
216     const Tensor& weight,
217     const Tensor& bias) {
218   return internal::linear::create_and_run(
219       input,
220       weight,
221       bias,
222       ContextLinear::kMin,
223       ContextLinear::kMax);
224 }
225 
226 } // namespace at::native::xnnpack
227 
228 #endif /* USE_XNNPACK */
229