1 #pragma once 2 3 #ifdef USE_XNNPACK 4 5 #include <ATen/Tensor.h> 6 #include <ATen/native/xnnpack/Common.h> 7 #include <ATen/native/xnnpack/OpContext.h> 8 9 namespace at::native::xnnpack { 10 namespace internal::linear { 11 12 c10::intrusive_ptr<xnnpack::LinearOpContext> createLinearClampPrePackOpContext( 13 Tensor weight, 14 std::optional<Tensor> bias, 15 const std::optional<Scalar>& output_min, 16 const std::optional<Scalar>& output_max); 17 18 Tensor linear_clamp_run(const Tensor& input, const c10::intrusive_ptr<xnnpack::LinearOpContext>& op_context); 19 20 IValue 21 unpack_prepacked_sizes_linear(const IValue& ivalue); 22 23 ContextLinear create( 24 const Tensor& weight, 25 const std::optional<Tensor>& bias, 26 const float output_min, 27 const float output_max); 28 29 Tensor run(const ContextLinear& context, const Tensor& input); 30 } // namespace internal::linear 31 32 bool use_linear( 33 const Tensor& input, 34 const Tensor& weight, 35 const Tensor& bias); 36 37 Tensor linear( 38 const Tensor& input, 39 const Tensor& weight, 40 const Tensor& bias); 41 42 } // namespace at::native::xnnpack 43 44 #endif /* USE_XNNPACK */ 45