xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalPrepackOpContext.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <torch/custom_class.h>
5 
6 namespace at::native::metal {
7 
8 using SerializationTypeConv2dPrePack = std::tuple<
9     Tensor,
10     std::optional<Tensor>,
11     std::vector<int64_t>,
12     std::vector<int64_t>,
13     std::vector<int64_t>,
14     int64_t,
15     std::optional<Scalar>,
16     std::optional<Scalar>>;
17 
18 class Conv2dOpContext : public torch::jit::CustomClassHolder {
19  public:
pack()20   SerializationTypeConv2dPrePack pack() {
21     return std::make_tuple(
22         weight_,
23         bias_,
24         stride_,
25         padding_,
26         dilation_,
27         groups_,
28         output_min_,
29         output_max_);
30   }
31   Conv2dOpContext() = delete;
Conv2dOpContext(at::Tensor && weight,std::optional<at::Tensor> && bias,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> dilation,int64_t groups,std::optional<Scalar> output_min,std::optional<Scalar> output_max)32   Conv2dOpContext(
33       at::Tensor&& weight,
34       std::optional<at::Tensor>&& bias,
35       std::vector<int64_t> stride,
36       std::vector<int64_t> padding,
37       std::vector<int64_t> dilation,
38       int64_t groups,
39       std::optional<Scalar> output_min,
40       std::optional<Scalar> output_max)
41       : weight_(std::move(weight)),
42         bias_(std::move(bias)),
43         stride_(std::move(stride)),
44         padding_(std::move(padding)),
45         dilation_(std::move(dilation)),
46         groups_(groups),
47         output_min_(std::move(output_min)),
48         output_max_(std::move(output_max)) {}
49 
~Conv2dOpContext()50   ~Conv2dOpContext() override {
51     if (releaseCallback_) {
52       releaseCallback_(conv2dOp_);
53     }
54   }
55 
release_resources()56   void release_resources() override {
57     if (releaseCallback_) {
58       releaseCallback_(conv2dOp_);
59     }
60   }
61 
get_weight()62   const Tensor& get_weight() const {
63     return weight_;
64   }
65 
get_bias()66   const std::optional<Tensor>& get_bias() const {
67     return bias_;
68   }
69 
get_stride()70   const std::vector<int64_t>& get_stride() const {
71     return stride_;
72   }
73 
get_padding()74   const std::vector<int64_t>& get_padding() const {
75     return padding_;
76   }
77 
get_dilation()78   const std::vector<int64_t>& get_dilation() const {
79     return dilation_;
80   }
81 
get_groups()82   int64_t get_groups() const {
83     return groups_;
84   }
85 
get_output_min()86   const std::optional<Scalar>& get_output_min() const {
87     return output_min_;
88   }
89 
get_output_max()90   const std::optional<Scalar>& get_output_max() const {
91     return output_max_;
92   }
93 
set_conv2dOpPtr(void * ptr)94   void set_conv2dOpPtr(void* ptr) {
95       conv2dOp_ = ptr;
96   }
97 
get_conv2dOpPtr()98   void* get_conv2dOpPtr() const {
99     return conv2dOp_;
100   }
101 
set_releaseCallback(const std::function<void (void *)> & func)102   void set_releaseCallback(const std::function<void(void*)>& func) {
103     releaseCallback_ = func;
104   }
105 
get_releaseCallback()106   std::function<void(void*)>& get_releaseCallback() {
107      return releaseCallback_;
108   }
109 
110   private:
111     Tensor weight_;
112     std::optional<Tensor> bias_;
113     std::vector<int64_t> stride_;
114     std::vector<int64_t> padding_;
115     std::vector<int64_t> dilation_;
116     int64_t groups_;
117     std::optional<Scalar> output_min_;
118     std::optional<Scalar> output_max_;
119     std::function<void(void*)> releaseCallback_ = nullptr;
120     void* conv2dOp_ = nullptr; // reserved to hold MPSCNNConv2dOp objects
121 };
122 
123 using SerializationTypeLinearPrePack = std::tuple<
124     Tensor,
125     std::optional<Tensor>,
126     std::optional<Scalar>,
127     std::optional<Scalar>>;
128 
129 class LinearOpContext : public torch::jit::CustomClassHolder {
130  public:
pack()131   SerializationTypeLinearPrePack pack() {
132     return std::make_tuple(weight_, bias_, output_min_, output_max_);
133   }
134   LinearOpContext() = delete;
LinearOpContext(at::Tensor && weight,std::optional<at::Tensor> && bias,std::optional<Scalar> output_min,std::optional<Scalar> output_max)135   LinearOpContext(
136       at::Tensor&& weight,
137       std::optional<at::Tensor>&& bias,
138       std::optional<Scalar> output_min,
139       std::optional<Scalar> output_max)
140       : weight_(std::move(weight)),
141         bias_(std::move(bias)),
142         output_min_(std::move(output_min)),
143         output_max_(std::move(output_max)) {}
144 
~LinearOpContext()145   ~LinearOpContext() override {
146     if (releaseCallback_) {
147       releaseCallback_(opaqueOpPtr_);
148     }
149   }
150 
release_resources()151   void release_resources() override {
152     if (releaseCallback_) {
153       releaseCallback_(opaqueOpPtr_);
154     }
155   }
156 
get_weight()157   const Tensor& get_weight() const {
158     return weight_;
159   }
160 
get_bias()161   const std::optional<Tensor>& get_bias() const {
162     return bias_;
163   }
164 
get_output_min()165   const std::optional<Scalar>& get_output_min() const {
166     return output_min_;
167   }
168 
get_output_max()169   const std::optional<Scalar>& get_output_max() const {
170     return output_max_;
171   }
172 
set_opaqueOpPtr(void * ptr)173   void set_opaqueOpPtr(void* ptr) {
174     opaqueOpPtr_ = ptr;
175   }
176 
get_opaqueOpPtr()177   void* get_opaqueOpPtr() const {
178     return opaqueOpPtr_;
179   }
180 
set_releaseCallback(const std::function<void (void *)> & func)181   void set_releaseCallback(const std::function<void(void*)>& func) {
182     releaseCallback_ = func;
183   }
184 
get_releaseCallback()185   std::function<void(void*)>& get_releaseCallback() {
186     return releaseCallback_;
187   }
188 
189  private:
190   Tensor weight_;
191   std::optional<Tensor> bias_;
192   std::optional<Scalar> output_min_;
193   std::optional<Scalar> output_max_;
194   void* opaqueOpPtr_ = nullptr; // reserved to hold MPSCNNFullyConnected objects
195   std::function<void(void*)> releaseCallback_ = nullptr;
196 };
197 
198 } // namespace at::native::metal
199