1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <ATen/core/ivalue.h> 5 #include <ATen/native/mkldnn/Common.h> 6 7 #if AT_MKLDNN_ENABLED() 8 9 namespace at { 10 namespace native { 11 namespace mkldnn { 12 13 const static std::map<std::string, ideep::attr_t> fusion_attr_map = { 14 {"none", ideep::attr_t()}, 15 {"relu", ideep::attr_t::fuse_relu()}, 16 }; 17 18 using SerializationTypeConvPrePack = std::tuple< 19 Tensor, 20 std::optional<Tensor>, 21 std::vector<int64_t>, 22 std::vector<int64_t>, 23 std::vector<int64_t>, 24 int64_t, 25 std::vector<int64_t>, 26 std::string>; 27 28 class ConvOpContext : public torch::jit::CustomClassHolder { 29 protected: 30 Tensor orig_weight_; 31 std::optional<Tensor> orig_bias_; 32 std::vector<int64_t> stride_; 33 std::vector<int64_t> padding_; 34 std::vector<int64_t> dilation_; 35 int64_t groups_; 36 std::vector<int64_t> input_size_; 37 std::string attr_; 38 39 public: unpack()40 SerializationTypeConvPrePack unpack() { 41 return std::make_tuple( 42 orig_weight_, 43 orig_bias_, 44 stride_, 45 padding_, 46 dilation_, 47 groups_, 48 input_size_, 49 attr_); 50 } 51 52 virtual Tensor run(const Tensor& input) = 0; 53 virtual void run(const Tensor& input, void* output) = 0; 54 }; 55 56 class MkldnnConvOpContext final : public ConvOpContext { 57 private: 58 ContextConv op_context_; 59 60 public: MkldnnConvOpContext(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,uint64_t groups,std::vector<int64_t> && input_size,ContextConv && op_context)61 MkldnnConvOpContext( 62 Tensor&& weight, 63 std::optional<Tensor>&& bias, 64 std::vector<int64_t>&& padding, 65 std::vector<int64_t>&& stride, 66 std::vector<int64_t>&& dilation, 67 uint64_t groups, 68 std::vector<int64_t>&& input_size, 69 ContextConv&& op_context) 70 : op_context_(std::move(op_context)) { 71 orig_weight_ = std::move(weight); 72 orig_bias_ = std::move(bias); 73 padding_ = std::move(padding); 74 stride_ = std::move(stride); 75 dilation_ = std::move(dilation); 76 groups_ = groups; 77 input_size_ = std::move(input_size); 78 } 79 80 Tensor run(const Tensor& input) override; 81 82 void run(const Tensor& input, void* output) override; 83 84 static c10::intrusive_ptr<ConvOpContext> create_context( 85 Tensor&& weight, 86 std::optional<Tensor>&& bias, 87 std::vector<int64_t>&& padding, 88 std::vector<int64_t>&& stride, 89 std::vector<int64_t>&& dilation, 90 int64_t groups, 91 std::vector<int64_t>&& input_size, 92 const ideep::attr_t& attr); 93 }; 94 95 } // namespace mkldnn 96 } // namespace native 97 } // namespace at 98 99 #endif // AT_MKLDNN_ENABLED() 100