1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <ATen/native/mkldnn/Common.h> 5 #include <ATen/native/mkldnn/OpContext.h> 6 7 #if AT_MKLDNN_ENABLED() 8 9 namespace at { 10 namespace native { 11 namespace mkldnn { 12 namespace internal { 13 namespace convolution { 14 15 c10::intrusive_ptr<mkldnn::ConvOpContext> createConvPrePackOpContext( 16 Tensor weight, 17 std::optional<Tensor> bias, 18 std::vector<int64_t> stride, 19 std::vector<int64_t> padding, 20 std::vector<int64_t> dilation, 21 int64_t groups, 22 std::vector<int64_t> input_size, 23 std::string attr); 24 25 Tensor conv_run( 26 const Tensor& input, 27 const c10::intrusive_ptr<mkldnn::ConvOpContext>& op_context); 28 29 ContextConv create( 30 const Tensor& weight, 31 const std::optional<Tensor>& bias, 32 const IntArrayRef padding, 33 const IntArrayRef stride, 34 const IntArrayRef dilation, 35 const int64_t groups, 36 const IntArrayRef input_size, 37 const ideep::attr_t& attr); 38 39 Tensor run(ContextConv& context, const Tensor& input); 40 41 void run(ContextConv& context, const Tensor& input, void* output); 42 43 } // namespace convolution 44 } // namespace internal 45 } // namespace mkldnn 46 } // namespace native 47 } // namespace at 48 49 #endif // AT_MKLDNN_ENABLED() 50