xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/OpContext.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/native/mkldnn/Common.h>
6 
7 #if AT_MKLDNN_ENABLED()
8 
9 namespace at {
10 namespace native {
11 namespace mkldnn {
12 
13 const static std::map<std::string, ideep::attr_t> fusion_attr_map = {
14     {"none", ideep::attr_t()},
15     {"relu", ideep::attr_t::fuse_relu()},
16 };
17 
18 using SerializationTypeConvPrePack = std::tuple<
19     Tensor,
20     std::optional<Tensor>,
21     std::vector<int64_t>,
22     std::vector<int64_t>,
23     std::vector<int64_t>,
24     int64_t,
25     std::vector<int64_t>,
26     std::string>;
27 
28 class ConvOpContext : public torch::jit::CustomClassHolder {
29  protected:
30   Tensor orig_weight_;
31   std::optional<Tensor> orig_bias_;
32   std::vector<int64_t> stride_;
33   std::vector<int64_t> padding_;
34   std::vector<int64_t> dilation_;
35   int64_t groups_;
36   std::vector<int64_t> input_size_;
37   std::string attr_;
38 
39  public:
unpack()40   SerializationTypeConvPrePack unpack() {
41     return std::make_tuple(
42         orig_weight_,
43         orig_bias_,
44         stride_,
45         padding_,
46         dilation_,
47         groups_,
48         input_size_,
49         attr_);
50   }
51 
52   virtual Tensor run(const Tensor& input) = 0;
53   virtual void run(const Tensor& input, void* output) = 0;
54 };
55 
56 class MkldnnConvOpContext final : public ConvOpContext {
57  private:
58   ContextConv op_context_;
59 
60  public:
MkldnnConvOpContext(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,uint64_t groups,std::vector<int64_t> && input_size,ContextConv && op_context)61   MkldnnConvOpContext(
62       Tensor&& weight,
63       std::optional<Tensor>&& bias,
64       std::vector<int64_t>&& padding,
65       std::vector<int64_t>&& stride,
66       std::vector<int64_t>&& dilation,
67       uint64_t groups,
68       std::vector<int64_t>&& input_size,
69       ContextConv&& op_context)
70       : op_context_(std::move(op_context)) {
71     orig_weight_ = std::move(weight);
72     orig_bias_ = std::move(bias);
73     padding_ = std::move(padding);
74     stride_ = std::move(stride);
75     dilation_ = std::move(dilation);
76     groups_ = groups;
77     input_size_ = std::move(input_size);
78   }
79 
80   Tensor run(const Tensor& input) override;
81 
82   void run(const Tensor& input, void* output) override;
83 
84   static c10::intrusive_ptr<ConvOpContext> create_context(
85       Tensor&& weight,
86       std::optional<Tensor>&& bias,
87       std::vector<int64_t>&& padding,
88       std::vector<int64_t>&& stride,
89       std::vector<int64_t>&& dilation,
90       int64_t groups,
91       std::vector<int64_t>&& input_size,
92       const ideep::attr_t& attr);
93 };
94 
95 } // namespace mkldnn
96 } // namespace native
97 } // namespace at
98 
99 #endif // AT_MKLDNN_ENABLED()
100