1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <torch/custom_class.h> 5 6 namespace at::native::metal { 7 8 using SerializationTypeConv2dPrePack = std::tuple< 9 Tensor, 10 std::optional<Tensor>, 11 std::vector<int64_t>, 12 std::vector<int64_t>, 13 std::vector<int64_t>, 14 int64_t, 15 std::optional<Scalar>, 16 std::optional<Scalar>>; 17 18 class Conv2dOpContext : public torch::jit::CustomClassHolder { 19 public: pack()20 SerializationTypeConv2dPrePack pack() { 21 return std::make_tuple( 22 weight_, 23 bias_, 24 stride_, 25 padding_, 26 dilation_, 27 groups_, 28 output_min_, 29 output_max_); 30 } 31 Conv2dOpContext() = delete; Conv2dOpContext(at::Tensor && weight,std::optional<at::Tensor> && bias,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> dilation,int64_t groups,std::optional<Scalar> output_min,std::optional<Scalar> output_max)32 Conv2dOpContext( 33 at::Tensor&& weight, 34 std::optional<at::Tensor>&& bias, 35 std::vector<int64_t> stride, 36 std::vector<int64_t> padding, 37 std::vector<int64_t> dilation, 38 int64_t groups, 39 std::optional<Scalar> output_min, 40 std::optional<Scalar> output_max) 41 : weight_(std::move(weight)), 42 bias_(std::move(bias)), 43 stride_(std::move(stride)), 44 padding_(std::move(padding)), 45 dilation_(std::move(dilation)), 46 groups_(groups), 47 output_min_(std::move(output_min)), 48 output_max_(std::move(output_max)) {} 49 ~Conv2dOpContext()50 ~Conv2dOpContext() override { 51 if (releaseCallback_) { 52 releaseCallback_(conv2dOp_); 53 } 54 } 55 release_resources()56 void release_resources() override { 57 if (releaseCallback_) { 58 releaseCallback_(conv2dOp_); 59 } 60 } 61 get_weight()62 const Tensor& get_weight() const { 63 return weight_; 64 } 65 get_bias()66 const std::optional<Tensor>& get_bias() const { 67 return bias_; 68 } 69 get_stride()70 const std::vector<int64_t>& get_stride() const { 71 return stride_; 72 } 73 get_padding()74 const std::vector<int64_t>& get_padding() const { 75 return padding_; 76 } 77 get_dilation()78 const std::vector<int64_t>& get_dilation() const { 79 return dilation_; 80 } 81 get_groups()82 int64_t get_groups() const { 83 return groups_; 84 } 85 get_output_min()86 const std::optional<Scalar>& get_output_min() const { 87 return output_min_; 88 } 89 get_output_max()90 const std::optional<Scalar>& get_output_max() const { 91 return output_max_; 92 } 93 set_conv2dOpPtr(void * ptr)94 void set_conv2dOpPtr(void* ptr) { 95 conv2dOp_ = ptr; 96 } 97 get_conv2dOpPtr()98 void* get_conv2dOpPtr() const { 99 return conv2dOp_; 100 } 101 set_releaseCallback(const std::function<void (void *)> & func)102 void set_releaseCallback(const std::function<void(void*)>& func) { 103 releaseCallback_ = func; 104 } 105 get_releaseCallback()106 std::function<void(void*)>& get_releaseCallback() { 107 return releaseCallback_; 108 } 109 110 private: 111 Tensor weight_; 112 std::optional<Tensor> bias_; 113 std::vector<int64_t> stride_; 114 std::vector<int64_t> padding_; 115 std::vector<int64_t> dilation_; 116 int64_t groups_; 117 std::optional<Scalar> output_min_; 118 std::optional<Scalar> output_max_; 119 std::function<void(void*)> releaseCallback_ = nullptr; 120 void* conv2dOp_ = nullptr; // reserved to hold MPSCNNConv2dOp objects 121 }; 122 123 using SerializationTypeLinearPrePack = std::tuple< 124 Tensor, 125 std::optional<Tensor>, 126 std::optional<Scalar>, 127 std::optional<Scalar>>; 128 129 class LinearOpContext : public torch::jit::CustomClassHolder { 130 public: pack()131 SerializationTypeLinearPrePack pack() { 132 return std::make_tuple(weight_, bias_, output_min_, output_max_); 133 } 134 LinearOpContext() = delete; LinearOpContext(at::Tensor && weight,std::optional<at::Tensor> && bias,std::optional<Scalar> output_min,std::optional<Scalar> output_max)135 LinearOpContext( 136 at::Tensor&& weight, 137 std::optional<at::Tensor>&& bias, 138 std::optional<Scalar> output_min, 139 std::optional<Scalar> output_max) 140 : weight_(std::move(weight)), 141 bias_(std::move(bias)), 142 output_min_(std::move(output_min)), 143 output_max_(std::move(output_max)) {} 144 ~LinearOpContext()145 ~LinearOpContext() override { 146 if (releaseCallback_) { 147 releaseCallback_(opaqueOpPtr_); 148 } 149 } 150 release_resources()151 void release_resources() override { 152 if (releaseCallback_) { 153 releaseCallback_(opaqueOpPtr_); 154 } 155 } 156 get_weight()157 const Tensor& get_weight() const { 158 return weight_; 159 } 160 get_bias()161 const std::optional<Tensor>& get_bias() const { 162 return bias_; 163 } 164 get_output_min()165 const std::optional<Scalar>& get_output_min() const { 166 return output_min_; 167 } 168 get_output_max()169 const std::optional<Scalar>& get_output_max() const { 170 return output_max_; 171 } 172 set_opaqueOpPtr(void * ptr)173 void set_opaqueOpPtr(void* ptr) { 174 opaqueOpPtr_ = ptr; 175 } 176 get_opaqueOpPtr()177 void* get_opaqueOpPtr() const { 178 return opaqueOpPtr_; 179 } 180 set_releaseCallback(const std::function<void (void *)> & func)181 void set_releaseCallback(const std::function<void(void*)>& func) { 182 releaseCallback_ = func; 183 } 184 get_releaseCallback()185 std::function<void(void*)>& get_releaseCallback() { 186 return releaseCallback_; 187 } 188 189 private: 190 Tensor weight_; 191 std::optional<Tensor> bias_; 192 std::optional<Scalar> output_min_; 193 std::optional<Scalar> output_max_; 194 void* opaqueOpPtr_ = nullptr; // reserved to hold MPSCNNFullyConnected objects 195 std::function<void(void*)> releaseCallback_ = nullptr; 196 }; 197 198 } // namespace at::native::metal 199