xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/OpContext.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_XNNPACK
4 
5 #include <ATen/core/ivalue.h>
6 #include <ATen/native/xnnpack/Common.h>
7 #include <ATen/Tensor.h>
8 
9 namespace at::native::xnnpack {
10 
11 using SerializationTypeLinearPrePack = std::tuple<
12     Tensor,
13     std::optional<Tensor>,
14     std::optional<Scalar>,
15     std::optional<Scalar>>;
16 using SerializationTypeConv2dPrePack = std::tuple<
17     Tensor,
18     std::optional<Tensor>,
19     std::vector<int64_t>,
20     std::vector<int64_t>,
21     std::vector<int64_t>,
22     int64_t,
23     std::optional<Scalar>,
24     std::optional<Scalar>>;
25 using SerializationTypeTransposeConv2dPrePack = std::tuple<
26     Tensor,
27     std::optional<Tensor>,
28     std::vector<int64_t>,
29     std::vector<int64_t>,
30     std::vector<int64_t>,
31     std::vector<int64_t>,
32     int64_t,
33     std::optional<Scalar>,
34     std::optional<Scalar>>;
35 
36 
37 
38 class LinearOpContext : public torch::jit::CustomClassHolder {
39  protected:
40   Tensor orig_weight_;
41   std::optional<Tensor> orig_bias_;
42   std::optional<Scalar> output_min_;
43   std::optional<Scalar> output_max_;
44   bool orig_weight_and_bias_freed_;
45 
46  public:
unpack()47   SerializationTypeLinearPrePack unpack() {
48     TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed");
49     return std::make_tuple(orig_weight_, orig_bias_, output_min_, output_max_);
50   }
51 
52   virtual Tensor run(const Tensor& input) = 0;
53   virtual void free_orig_weight_and_bias() = 0;
54 };
55 
56 class XNNPackLinearOpContext final : public LinearOpContext {
57  private:
58   ContextLinear op_context_;
59 
60  public:
XNNPackLinearOpContext(Tensor && weight,std::optional<Tensor> && bias,const std::optional<Scalar> & min,const std::optional<Scalar> & max,ContextLinear && op_context)61   XNNPackLinearOpContext(
62       Tensor&& weight,
63       std::optional<Tensor>&& bias,
64       const std::optional<Scalar>& min,
65       const std::optional<Scalar>& max,
66       ContextLinear&& op_context)
67       : op_context_(std::move(op_context)) {
68     orig_weight_ = std::move(weight);
69     orig_bias_ = std::move(bias);
70     output_min_ = min;
71     output_max_ = max;
72     orig_weight_and_bias_freed_ = false;
73   }
74 
75   Tensor run(const Tensor& input) override;
76   void free_orig_weight_and_bias() override;
77 
78   static c10::intrusive_ptr<LinearOpContext> create_context(
79       Tensor&& weight,
80       std::optional<Tensor>&& bias,
81       const std::optional<Scalar>& output_min,
82       const std::optional<Scalar>& output_max);
83 };
84 
85 class Conv2dOpContext : public torch::jit::CustomClassHolder {
86  protected:
87   Tensor orig_weight_;
88   std::optional<Tensor> orig_bias_;
89   std::vector<int64_t> stride_;
90   std::vector<int64_t> padding_;
91   std::vector<int64_t> dilation_;
92   int64_t groups_;
93   std::optional<Scalar> output_min_;
94   std::optional<Scalar> output_max_;
95   bool orig_weight_and_bias_freed_;
96 
97  public:
unpack()98   SerializationTypeConv2dPrePack unpack() {
99     TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed");
100     return std::make_tuple(
101         orig_weight_,
102         orig_bias_,
103         stride_,
104         padding_,
105         dilation_,
106         groups_,
107         output_min_,
108         output_max_);
109   }
110 
111   virtual Tensor run(const Tensor& input) = 0;
112   virtual void free_orig_weight_and_bias() = 0;
113 };
114 
115 class TransposeConv2dOpContext : public torch::jit::CustomClassHolder {
116  protected:
117   Tensor orig_weight_;
118   std::optional<Tensor> orig_bias_;
119   std::vector<int64_t> stride_;
120   std::vector<int64_t> padding_;
121   std::vector<int64_t> output_padding_;
122   std::vector<int64_t> dilation_;
123   int64_t groups_;
124   std::optional<Scalar> output_min_;
125   std::optional<Scalar> output_max_;
126   bool orig_weight_and_bias_freed_;
127 
128  public:
unpack()129   SerializationTypeTransposeConv2dPrePack unpack() {
130     TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed");
131     return std::make_tuple(
132         orig_weight_,
133         orig_bias_,
134         stride_,
135         padding_,
136         output_padding_,
137         dilation_,
138         groups_,
139         output_min_,
140         output_max_);
141   }
142 
143   virtual Tensor run(const Tensor& input) = 0;
144   virtual void free_orig_weight_and_bias() = 0;
145 };
146 
147 class XNNPackConv2dOpContext final : public Conv2dOpContext {
148  private:
149   ContextConv2D op_context_;
150   // xnnpack convs use indirection buffer.
151   // These buffers need setup at runtime and/or when input
152   // dims change. If we are running the same model on multiple
153   // threads, this can lead to contention where indirection buffer
154   // is being accessed and updated at the same time from two different
155   // threads.
156   std::mutex xnnp_mutex_;
157 
158  public:
XNNPackConv2dOpContext(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,uint64_t groups,const std::optional<Scalar> & min,const std::optional<Scalar> & max,ContextConv2D && op_context)159   XNNPackConv2dOpContext(
160       Tensor&& weight,
161       std::optional<Tensor>&& bias,
162       std::vector<int64_t>&& padding,
163       std::vector<int64_t>&& stride,
164       std::vector<int64_t>&& dilation,
165       uint64_t groups,
166       const std::optional<Scalar>& min,
167       const std::optional<Scalar>& max,
168       ContextConv2D&& op_context)
169       : op_context_(std::move(op_context)) {
170     orig_weight_ = std::move(weight);
171     orig_bias_ = std::move(bias);
172     padding_ = std::move(padding);
173     stride_ = std::move(stride);
174     dilation_ = std::move(dilation);
175     groups_ = groups;
176     output_min_ = min;
177     output_max_ = max;
178     orig_weight_and_bias_freed_ = false;
179   }
180 
181   Tensor run(const Tensor& input) override;
182   void free_orig_weight_and_bias() override;
183 
184   static c10::intrusive_ptr<Conv2dOpContext> create_context(
185       Tensor&& weight,
186       std::optional<Tensor>&& bias,
187       std::vector<int64_t>&& padding,
188       std::vector<int64_t>&& stride,
189       std::vector<int64_t>&& dilation,
190       int64_t groups,
191       const std::optional<Scalar>& output_min,
192       const std::optional<Scalar>& output_max);
193 };
194 
195 class XNNPackTransposeConv2dOpContext final : public TransposeConv2dOpContext {
196  private:
197   ContextConv2D op_context_;
198   // xnnpack convs use indirection buffer.
199   // These buffers need setup at runtime and/or when input
200   // dims change. If we are running the same model on multiple
201   // threads, this can lead to contention where indirection buffer
202   // is being accessed and updated at the same time from two different
203   // threads.
204   std::mutex xnnp_mutex_;
205 
206  public:
XNNPackTransposeConv2dOpContext(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && output_padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,uint64_t groups,const std::optional<Scalar> & min,const std::optional<Scalar> & max,ContextConv2D && op_context)207   XNNPackTransposeConv2dOpContext(
208       Tensor&& weight,
209       std::optional<Tensor>&& bias,
210       std::vector<int64_t>&& padding,
211       std::vector<int64_t>&& output_padding,
212       std::vector<int64_t>&& stride,
213       std::vector<int64_t>&& dilation,
214       uint64_t groups,
215       const std::optional<Scalar>& min,
216       const std::optional<Scalar>& max,
217       ContextConv2D&& op_context)
218       : op_context_(std::move(op_context)) {
219     orig_weight_ = std::move(weight);
220     orig_bias_ = std::move(bias);
221     padding_ = std::move(padding);
222     output_padding_ = std::move(output_padding);
223     stride_ = std::move(stride);
224     dilation_ = std::move(dilation);
225     groups_ = groups;
226     output_min_ = min;
227     output_max_ = max;
228     orig_weight_and_bias_freed_ = false;
229   }
230 
231   Tensor run(const Tensor& input) override;
232   void free_orig_weight_and_bias() override;
233 
234   static c10::intrusive_ptr<TransposeConv2dOpContext> create_context(
235       Tensor&& weight,
236       std::optional<Tensor>&& bias,
237       std::vector<int64_t>&& padding,
238       std::vector<int64_t>&& output_padding,
239       std::vector<int64_t>&& stride,
240       std::vector<int64_t>&& dilation,
241       int64_t groups,
242       const std::optional<Scalar>& output_min,
243       const std::optional<Scalar>& output_max);
244 };
245 
246 } // namespace at::native::xnnpack
247 
248 #endif /* USE_XNNPACK */
249