1 #pragma once 2 3 #ifdef USE_XNNPACK 4 5 #include <ATen/core/ivalue.h> 6 #include <ATen/native/xnnpack/Common.h> 7 #include <ATen/Tensor.h> 8 9 namespace at::native::xnnpack { 10 11 using SerializationTypeLinearPrePack = std::tuple< 12 Tensor, 13 std::optional<Tensor>, 14 std::optional<Scalar>, 15 std::optional<Scalar>>; 16 using SerializationTypeConv2dPrePack = std::tuple< 17 Tensor, 18 std::optional<Tensor>, 19 std::vector<int64_t>, 20 std::vector<int64_t>, 21 std::vector<int64_t>, 22 int64_t, 23 std::optional<Scalar>, 24 std::optional<Scalar>>; 25 using SerializationTypeTransposeConv2dPrePack = std::tuple< 26 Tensor, 27 std::optional<Tensor>, 28 std::vector<int64_t>, 29 std::vector<int64_t>, 30 std::vector<int64_t>, 31 std::vector<int64_t>, 32 int64_t, 33 std::optional<Scalar>, 34 std::optional<Scalar>>; 35 36 37 38 class LinearOpContext : public torch::jit::CustomClassHolder { 39 protected: 40 Tensor orig_weight_; 41 std::optional<Tensor> orig_bias_; 42 std::optional<Scalar> output_min_; 43 std::optional<Scalar> output_max_; 44 bool orig_weight_and_bias_freed_; 45 46 public: unpack()47 SerializationTypeLinearPrePack unpack() { 48 TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); 49 return std::make_tuple(orig_weight_, orig_bias_, output_min_, output_max_); 50 } 51 52 virtual Tensor run(const Tensor& input) = 0; 53 virtual void free_orig_weight_and_bias() = 0; 54 }; 55 56 class XNNPackLinearOpContext final : public LinearOpContext { 57 private: 58 ContextLinear op_context_; 59 60 public: XNNPackLinearOpContext(Tensor && weight,std::optional<Tensor> && bias,const std::optional<Scalar> & min,const std::optional<Scalar> & max,ContextLinear && op_context)61 XNNPackLinearOpContext( 62 Tensor&& weight, 63 std::optional<Tensor>&& bias, 64 const std::optional<Scalar>& min, 65 const std::optional<Scalar>& max, 66 ContextLinear&& op_context) 67 : op_context_(std::move(op_context)) { 68 orig_weight_ = std::move(weight); 69 orig_bias_ = std::move(bias); 70 output_min_ = min; 71 output_max_ = max; 72 orig_weight_and_bias_freed_ = false; 73 } 74 75 Tensor run(const Tensor& input) override; 76 void free_orig_weight_and_bias() override; 77 78 static c10::intrusive_ptr<LinearOpContext> create_context( 79 Tensor&& weight, 80 std::optional<Tensor>&& bias, 81 const std::optional<Scalar>& output_min, 82 const std::optional<Scalar>& output_max); 83 }; 84 85 class Conv2dOpContext : public torch::jit::CustomClassHolder { 86 protected: 87 Tensor orig_weight_; 88 std::optional<Tensor> orig_bias_; 89 std::vector<int64_t> stride_; 90 std::vector<int64_t> padding_; 91 std::vector<int64_t> dilation_; 92 int64_t groups_; 93 std::optional<Scalar> output_min_; 94 std::optional<Scalar> output_max_; 95 bool orig_weight_and_bias_freed_; 96 97 public: unpack()98 SerializationTypeConv2dPrePack unpack() { 99 TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); 100 return std::make_tuple( 101 orig_weight_, 102 orig_bias_, 103 stride_, 104 padding_, 105 dilation_, 106 groups_, 107 output_min_, 108 output_max_); 109 } 110 111 virtual Tensor run(const Tensor& input) = 0; 112 virtual void free_orig_weight_and_bias() = 0; 113 }; 114 115 class TransposeConv2dOpContext : public torch::jit::CustomClassHolder { 116 protected: 117 Tensor orig_weight_; 118 std::optional<Tensor> orig_bias_; 119 std::vector<int64_t> stride_; 120 std::vector<int64_t> padding_; 121 std::vector<int64_t> output_padding_; 122 std::vector<int64_t> dilation_; 123 int64_t groups_; 124 std::optional<Scalar> output_min_; 125 std::optional<Scalar> output_max_; 126 bool orig_weight_and_bias_freed_; 127 128 public: unpack()129 SerializationTypeTransposeConv2dPrePack unpack() { 130 TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); 131 return std::make_tuple( 132 orig_weight_, 133 orig_bias_, 134 stride_, 135 padding_, 136 output_padding_, 137 dilation_, 138 groups_, 139 output_min_, 140 output_max_); 141 } 142 143 virtual Tensor run(const Tensor& input) = 0; 144 virtual void free_orig_weight_and_bias() = 0; 145 }; 146 147 class XNNPackConv2dOpContext final : public Conv2dOpContext { 148 private: 149 ContextConv2D op_context_; 150 // xnnpack convs use indirection buffer. 151 // These buffers need setup at runtime and/or when input 152 // dims change. If we are running the same model on multiple 153 // threads, this can lead to contention where indirection buffer 154 // is being accessed and updated at the same time from two different 155 // threads. 156 std::mutex xnnp_mutex_; 157 158 public: XNNPackConv2dOpContext(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,uint64_t groups,const std::optional<Scalar> & min,const std::optional<Scalar> & max,ContextConv2D && op_context)159 XNNPackConv2dOpContext( 160 Tensor&& weight, 161 std::optional<Tensor>&& bias, 162 std::vector<int64_t>&& padding, 163 std::vector<int64_t>&& stride, 164 std::vector<int64_t>&& dilation, 165 uint64_t groups, 166 const std::optional<Scalar>& min, 167 const std::optional<Scalar>& max, 168 ContextConv2D&& op_context) 169 : op_context_(std::move(op_context)) { 170 orig_weight_ = std::move(weight); 171 orig_bias_ = std::move(bias); 172 padding_ = std::move(padding); 173 stride_ = std::move(stride); 174 dilation_ = std::move(dilation); 175 groups_ = groups; 176 output_min_ = min; 177 output_max_ = max; 178 orig_weight_and_bias_freed_ = false; 179 } 180 181 Tensor run(const Tensor& input) override; 182 void free_orig_weight_and_bias() override; 183 184 static c10::intrusive_ptr<Conv2dOpContext> create_context( 185 Tensor&& weight, 186 std::optional<Tensor>&& bias, 187 std::vector<int64_t>&& padding, 188 std::vector<int64_t>&& stride, 189 std::vector<int64_t>&& dilation, 190 int64_t groups, 191 const std::optional<Scalar>& output_min, 192 const std::optional<Scalar>& output_max); 193 }; 194 195 class XNNPackTransposeConv2dOpContext final : public TransposeConv2dOpContext { 196 private: 197 ContextConv2D op_context_; 198 // xnnpack convs use indirection buffer. 199 // These buffers need setup at runtime and/or when input 200 // dims change. If we are running the same model on multiple 201 // threads, this can lead to contention where indirection buffer 202 // is being accessed and updated at the same time from two different 203 // threads. 204 std::mutex xnnp_mutex_; 205 206 public: XNNPackTransposeConv2dOpContext(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && output_padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,uint64_t groups,const std::optional<Scalar> & min,const std::optional<Scalar> & max,ContextConv2D && op_context)207 XNNPackTransposeConv2dOpContext( 208 Tensor&& weight, 209 std::optional<Tensor>&& bias, 210 std::vector<int64_t>&& padding, 211 std::vector<int64_t>&& output_padding, 212 std::vector<int64_t>&& stride, 213 std::vector<int64_t>&& dilation, 214 uint64_t groups, 215 const std::optional<Scalar>& min, 216 const std::optional<Scalar>& max, 217 ContextConv2D&& op_context) 218 : op_context_(std::move(op_context)) { 219 orig_weight_ = std::move(weight); 220 orig_bias_ = std::move(bias); 221 padding_ = std::move(padding); 222 output_padding_ = std::move(output_padding); 223 stride_ = std::move(stride); 224 dilation_ = std::move(dilation); 225 groups_ = groups; 226 output_min_ = min; 227 output_max_ = max; 228 orig_weight_and_bias_freed_ = false; 229 } 230 231 Tensor run(const Tensor& input) override; 232 void free_orig_weight_and_bias() override; 233 234 static c10::intrusive_ptr<TransposeConv2dOpContext> create_context( 235 Tensor&& weight, 236 std::optional<Tensor>&& bias, 237 std::vector<int64_t>&& padding, 238 std::vector<int64_t>&& output_padding, 239 std::vector<int64_t>&& stride, 240 std::vector<int64_t>&& dilation, 241 int64_t groups, 242 const std::optional<Scalar>& output_min, 243 const std::optional<Scalar>& output_max); 244 }; 245 246 } // namespace at::native::xnnpack 247 248 #endif /* USE_XNNPACK */ 249