xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/Linear.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_XNNPACK
4 
5 #include <ATen/Tensor.h>
6 #include <ATen/native/xnnpack/Common.h>
7 #include <ATen/native/xnnpack/OpContext.h>
8 
9 namespace at::native::xnnpack {
10 namespace internal::linear {
11 
12 c10::intrusive_ptr<xnnpack::LinearOpContext> createLinearClampPrePackOpContext(
13     Tensor weight,
14     std::optional<Tensor> bias,
15     const std::optional<Scalar>& output_min,
16     const std::optional<Scalar>& output_max);
17 
18 Tensor linear_clamp_run(const Tensor& input, const c10::intrusive_ptr<xnnpack::LinearOpContext>& op_context);
19 
20 IValue
21 unpack_prepacked_sizes_linear(const IValue& ivalue);
22 
23 ContextLinear create(
24     const Tensor& weight,
25     const std::optional<Tensor>& bias,
26     const float output_min,
27     const float output_max);
28 
29 Tensor run(const ContextLinear& context, const Tensor& input);
30 } // namespace internal::linear
31 
32 bool use_linear(
33     const Tensor& input,
34     const Tensor& weight,
35     const Tensor& bias);
36 
37 Tensor linear(
38     const Tensor& input,
39     const Tensor& weight,
40     const Tensor& bias);
41 
42 } // namespace at::native::xnnpack
43 
44 #endif /* USE_XNNPACK */
45