1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <ATen/Config.h> 5 6 #if AT_MKLDNN_ENABLED() 7 8 #include <ideep/tensor.hpp> 9 10 namespace at { 11 namespace native { 12 namespace mkldnn { 13 14 struct ContextConv final { 15 ideep::tensor weight_packed_; 16 std::optional<at::Tensor> at_bias_; 17 std::vector<int64_t> padding_; 18 std::vector<int64_t> stride_; 19 std::vector<int64_t> dilation_; 20 int64_t groups_; 21 ideep::attr_t attr_; 22 23 ContextConv() = delete; 24 ContextConvfinal25 ContextConv( 26 ideep::tensor&& weight_packed, 27 std::optional<at::Tensor> at_bias, 28 std::vector<int64_t> padding, 29 std::vector<int64_t> stride, 30 std::vector<int64_t> dilation, 31 int64_t groups, 32 ideep::attr_t attr) 33 : weight_packed_(std::move(weight_packed)), 34 at_bias_(std::move(at_bias)), 35 padding_(padding), 36 stride_(stride), 37 dilation_(dilation), 38 groups_(groups), 39 attr_(attr) {} 40 }; 41 42 } // namespace mkldnn 43 } // namespace native 44 } // namespace at 45 46 #endif // AT_MKLDNN_ENABLED() 47