1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <oneapi/dnnl/dnnl.hpp> 5 #include <oneapi/dnnl/dnnl_types.h> 6 #include <ATen/native/mkldnn/xpu/detail/Utils.h> 7 #include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h> 8 9 namespace at::native::onednn { 10 /* oneDNN quantization usage: 11 https://oneapi-src.github.io/oneDNN/dev_guide_attributes_quantization.html# 12 13 src_fp32 = scale_src * (src_int8 - zero_point) 14 wei_fp32 = scale_wei * (wei_int8 - zero_point) 15 dst_fp32 = scale_dst * (dst_int8 - zero_point) 16 fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 17 Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) 18 Int8 Convolution: dst_int8 = 1 / scale_dst * dst_fp32; 19 20 Considering zero-point (asymmetric): 21 dst_fp32 = (src_int8 - src_zp) * src_sc * wei_int8 * wei_sc 22 dst_sc * (dst_int8 - dst_zp) = (src_int8 - src_zp) * wei_int8 * src_sc * 23 wei_sc 24 dst_int8 = (src_int8 - src_zp) * wei_int8 * src_sc * wei_sc / dst_sc + 25 dst_zp 26 27 considering bias: 28 fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + bias 29 Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) 30 + bias Int8 Convolution: dst_fp32 = (src_int8 * wei_int8 + bias/(scale_src * 31 scale_wei)) * (scale_src * scale_wei) Int8 Convolution: dst_int8 = 1 / 32 scale_dst * dst_fp32; 33 */ 34 35 /* 36 oneDNN postops usage: 37 Currently, oneDNN supports 5 kinds of post ops. More details can be refered 38 to oneDNN doc. 39 https://oneapi-src.github.io/oneDNN/dev_guide_attributes_post_ops.html#doxid-dev-guide-attributes-post-ops-1dev-guide-attributes-post-ops-eltwise 40 41 0. without post ops 42 dst = Conv(src, wei) + bias; 43 dst_int8 = 1/q_scale * dst; q_scale is the op output quantization scale 44 fp32 API: Attr attr; 45 int8 API: Attr attr(q_scale); 46 47 1. append eltwise post op 48 dst = elt_scale * Eltwise{conv_scale * [Conv(src, wei) + bias], alpha, beta} 49 dst_int8 = 1/q_scale * dst; 50 fp32 API: 51 Attr attr; 52 attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) 53 attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) 54 int8 API: 55 Attr attr(q_scale); 56 attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) 57 attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) 58 59 2. append sum post op 60 dst = conv_scale * Conv(src, wei) + sum_scale * (dst - zp) 61 dst_int8 = 1/q_scale * dst; 62 fp32 API: 63 Attr attr; 64 attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) 65 attr.append_post_sum(sum_scale) 66 int8 API: 67 Attr attr(q_scale); 68 attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) 69 attr.append_post_sum(sum_scale) 70 71 3. append binary post op 72 dst = Binary[Conv(src, wei)] 73 74 */ 75 using kind_t = dnnl::primitive::kind; 76 struct PostOpParam { 77 // eltwise post op constructor PostOpParamPostOpParam78 PostOpParam(float scale, float alpha, float beta, dnnl::algorithm algo, kind_t kind) 79 : scale_(scale), alpha_(alpha), beta_(beta), algo_(algo), kind_(kind) {} 80 // sum post op constructor PostOpParamPostOpParam81 PostOpParam(float scale, kind_t kind) : scale_(scale), kind_(kind) {} 82 // binary post op constructor PostOpParamPostOpParam83 PostOpParam( 84 at::Tensor& binary, 85 dnnl::memory::desc& binary_md, 86 dnnl::memory::desc& expected_md, 87 dnnl::algorithm algo, 88 kind_t kind) 89 : binary_(binary), 90 meta_(binary_md), 91 expected_meta_(expected_md), 92 algo_(algo), 93 kind_(kind) {} 94 // prelu post op constructor PostOpParamPostOpParam95 PostOpParam(int mask, kind_t kind) : mask_(mask), kind_(kind) {} 96 97 // post sum or binary with scale post op constructor PostOpParamPostOpParam98 PostOpParam(at::Tensor& binary, float scale, dnnl::algorithm algo, kind_t kind) 99 : scale_(scale), binary_(binary), algo_(algo), kind_(kind) {} 100 101 // for int8 sum/eltwise 102 float scale_ = 1.0; 103 // for eltwise 104 float alpha_ = 0.0; 105 float beta_ = 0.0; 106 // for binary 107 at::Tensor binary_ = at::Tensor(); 108 at::Tensor expected_binary_ = at::Tensor(); 109 void* binary_ptr_ = nullptr; 110 dnnl::memory::desc meta_ = dnnl::memory::desc(); 111 dnnl::memory::desc expected_meta_ = dnnl::memory::desc(); 112 // for prelu 113 int mask_ = 0; 114 // common 115 dnnl::algorithm algo_ = dnnl::algorithm::eltwise_relu; 116 kind_t kind_ = kind_t::eltwise; 117 }; 118 119 class Attr { 120 public: Attr()121 Attr() : q_scale_(1.f), q_zero_point_(0) {} q_scale_(q_scale)122 Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {} 123 124 /***** eltwise *****/ 125 dnnl::algorithm kind_with_relu = dnnl::algorithm::eltwise_relu; 126 dnnl::algorithm kind_with_sigmoid = dnnl::algorithm::eltwise_logistic; 127 dnnl::algorithm kind_with_gelu_tanh = dnnl::algorithm::eltwise_gelu_tanh; 128 dnnl::algorithm kind_with_gelu_erf = dnnl::algorithm::eltwise_gelu_erf; 129 dnnl::algorithm kind_with_mish = dnnl::algorithm::eltwise_mish; 130 dnnl::algorithm kind_with_linear = dnnl::algorithm::eltwise_linear; 131 dnnl::algorithm kind_with_swish = dnnl::algorithm::eltwise_swish; 132 dnnl::algorithm kind_with_sqrt = dnnl::algorithm::eltwise_sqrt; 133 dnnl::algorithm kind_with_tanh = dnnl::algorithm::eltwise_tanh; 134 dnnl::algorithm kind_with_square = dnnl::algorithm::eltwise_square; 135 dnnl::algorithm kind_with_abs = dnnl::algorithm::eltwise_abs; 136 dnnl::algorithm kind_with_exp = dnnl::algorithm::eltwise_exp; 137 dnnl::algorithm kind_with_log = dnnl::algorithm::eltwise_log; 138 dnnl::algorithm kind_with_round = dnnl::algorithm::eltwise_round; 139 dnnl::algorithm kind_with_hardswish = dnnl::algorithm::eltwise_hardswish; 140 dnnl::algorithm kind_with_soft_relu = dnnl::algorithm::eltwise_soft_relu; 141 dnnl::algorithm kind_with_elu = dnnl::algorithm::eltwise_elu; 142 dnnl::algorithm kind_with_pow = dnnl::algorithm::eltwise_pow; 143 dnnl::algorithm kind_with_clip = dnnl::algorithm::eltwise_clip; 144 // note: hardsigmoid seems oneDNN still not support 145 dnnl::algorithm kind_with_hardsigmoid = dnnl::algorithm::eltwise_hardsigmoid; 146 147 /***** binary *****/ 148 dnnl::algorithm kind_with_binary_mul = dnnl::algorithm::binary_mul; 149 dnnl::algorithm kind_with_binary_add = dnnl::algorithm::binary_add; 150 dnnl::algorithm kind_with_binary_sub = dnnl::algorithm::binary_sub; 151 dnnl::algorithm kind_with_binary_div = dnnl::algorithm::binary_div; 152 dnnl::algorithm kind_with_binary_eq = dnnl::algorithm::binary_eq; 153 dnnl::algorithm kind_with_binary_ne = dnnl::algorithm::binary_ne; 154 dnnl::algorithm kind_with_binary_ge = dnnl::algorithm::binary_ge; 155 dnnl::algorithm kind_with_binary_gt = dnnl::algorithm::binary_gt; 156 dnnl::algorithm kind_with_binary_le = dnnl::algorithm::binary_le; 157 dnnl::algorithm kind_with_binary_lt = dnnl::algorithm::binary_lt; 158 dnnl::algorithm kind_with_binary_max = dnnl::algorithm::binary_max; 159 dnnl::algorithm kind_with_binary_min = dnnl::algorithm::binary_min; 160 161 // append sum post op 162 Attr& append_post_sum( 163 float sum_scale, 164 float sum_q_scale = 1.f, 165 int64_t zp = 0) { 166 ops_params_.push_back( 167 PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, kind_t::sum)); 168 return *this; 169 } 170 171 // append eltwise post op append_post_eltwise(float scale,float alpha,float beta,dnnl::algorithm algo)172 Attr& append_post_eltwise( 173 float scale, 174 float alpha, 175 float beta, 176 dnnl::algorithm algo) { 177 ops_params_.push_back( 178 PostOpParam(scale, alpha, beta, algo, kind_t::eltwise)); 179 return *this; 180 } 181 182 // append binary post op append_post_binary(dnnl::algorithm algo,const at::Tensor & binary)183 Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) { 184 auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary; 185 bool binary_is_channels_last = (binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast || 186 binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d); 187 188 binary_ = binary_is_channels_last ? binary_ : binary_.contiguous(); 189 dnnl::memory::desc md = get_onednn_md(binary_); 190 auto expected_md = dnnl::memory::desc( 191 md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::any); 192 ops_params_.push_back( 193 PostOpParam(binary_, md, expected_md, algo, kind_t::binary)); 194 return *this; 195 } 196 197 Attr& append_scale_binary( 198 dnnl::algorithm algo, 199 at::Tensor binary, 200 float scale, 201 float sum_q_scale = 1.f, 202 int64_t zp = 0) { 203 ops_params_.push_back(PostOpParam( 204 binary, /*scale_sum*/ scale * sum_q_scale, algo, kind_t::binary)); 205 return *this; 206 } 207 208 // append bias with binary_add method (only used for QConv now) 209 template <int N> append_bias(const at::Tensor & binary)210 Attr& append_bias(const at::Tensor& binary) { 211 // In PyTorch, bias are in shape of [OC], 212 // we expand its shape according to Conv dimension 213 // Conv1d [OC, 1, 1], Conv2d [1, OC, 1, ,1], Conv3d [1, OC, 1, 1, 1] 214 at::Tensor binary_ = binary.contiguous(); 215 dnnl::memory::desc binary_md; 216 switch (N) { 217 case 1: 218 binary_md = dnnl::memory::desc( 219 {binary.size(0), 1, 1}, 220 dnnl::memory::data_type::f32, 221 dnnl::memory::format_tag::abc); 222 break; 223 case 2: 224 binary_md = dnnl::memory::desc( 225 {1, binary.size(0), 1, 1}, 226 dnnl::memory::data_type::f32, 227 dnnl::memory::format_tag::abcd); 228 break; 229 case 3: 230 binary_md = dnnl::memory::desc( 231 {1, binary.size(0), 1, 1, 1}, 232 dnnl::memory::data_type::f32, 233 dnnl::memory::format_tag::abcde); 234 break; 235 default: 236 TORCH_INTERNAL_ASSERT(0, 237 "XPU only supports append_bias for Conv1d, Conv2d and Conv3d."); 238 } 239 // In this case, expected_md = binary_md 240 ops_params_.push_back(PostOpParam( 241 binary_, binary_md, binary_md, kind_with_binary_add, kind_t::binary)); 242 return *this; 243 } 244 245 // append prelu post op append_post_prelu(int mask)246 Attr& append_post_prelu(int mask) { 247 ops_params_.push_back(PostOpParam(mask, kind_t::prelu)); 248 return *this; 249 } 250 extract_post_ops(const at::Tensor & dst)251 dnnl::post_ops extract_post_ops(const at::Tensor& dst){ 252 // this function is used to extract post ops params from the ops_params_ 253 // and put them into onednn post ops 254 for (size_t i = 0; i < ops_params_.size(); ++i) { 255 kind_t kind = ops_params_[i].kind_; 256 switch (kind) { 257 case kind_t::eltwise: { 258 dnnl::algorithm algo = ops_params_[i].algo_; 259 float alpha = ops_params_[i].alpha_; 260 float beta = ops_params_[i].beta_; 261 dnnl_post_ops_.append_eltwise(algo, alpha, beta); 262 break; 263 } 264 case kind_t::sum: { 265 float scale = ops_params_[i].scale_; 266 // TODO [Asymmetric]: 267 // Post-sum zp for gpu is not supported currently 268 dnnl_post_ops_.append_sum(scale); 269 break; 270 } 271 case kind_t::binary: { 272 dnnl::algorithm algo = ops_params_[i].algo_; 273 auto expected_md = ops_params_[i].expected_meta_; 274 // In this case user may create src1 memory descriptor with 275 // format_tag::any or set a specific tag. However, in later case if 276 // tags mismatch with dst, it would result in suboptimal performance. 277 // So here we use format_tag::any to make sure the fast can be 278 // selected. 279 // Thus we use expected_md (with format_any) here to create pd instead 280 // of original md 281 dnnl_post_ops_.append_binary(algo, expected_md); 282 break; 283 } 284 default: 285 break; 286 } 287 } 288 289 // if output is quantized, then append the eltwise linear to adjust the 290 // output scale/zero_point 291 if (dst.is_quantized()) { 292 // [Note: Gap of u8 qtensor scale between oneDNN and PyTorch] 293 // The /2 here is for output_scale collected by observer is different 294 // from quantization requirements in oneDNN. 295 // For Observer, the conv_scale (activation scale in other case) is 296 // computed through 2max_v/(qmax - qmin). The max_v is collected 297 // from the tensor to be observerd. 298 // (https://pytorch.org/docs/stable/generated/torch.quantization.observer.MinMaxObserver.html#torch.quantization.observer.MinMaxObserver) 299 // On the other hand, for u8 in oneDNN, the scale for quantization is 300 // defined as max_v/(qmax-qmin). Hence, we need to divide by 2 here. 301 // (https://oneapi-src.github.io/oneDNN/dev_guide_inference_int8.html) 302 dnnl_post_ops_.append_eltwise( 303 kind_with_linear, 1.f / q_scale_, q_zero_point_); 304 } 305 return dnnl_post_ops_; 306 } 307 with_sum()308 bool with_sum() { 309 for (size_t i = 0; i < ops_params_.size(); ++i) { 310 if (ops_params_[i].kind_ == kind_t::sum) { 311 return true; 312 } 313 } 314 return false; 315 } 316 with_binary()317 bool with_binary() { 318 for (size_t i = 0; i < ops_params_.size(); ++i) { 319 if (ops_params_[i].kind_ == kind_t::binary) { 320 return true; 321 } 322 } 323 return false; 324 } 325 construct_post_binary(dnnl::primitive_desc & pd,std::unordered_map<int,dnnl::memory> & args)326 void construct_post_binary( 327 dnnl::primitive_desc& pd, 328 std::unordered_map<int, dnnl::memory>& args) { 329 // This function is used to construct binary memory desc in binary post ops. 330 // According to oneDNN doc, the binary tensor can be in shape of 331 // [1, 1, 1, 1], tensor broadcast 332 // [1, C, 1, 1], channel broadcast 333 // [dst.shape], no broadcast and eltwise-wise binary operations on dst 334 335 auto engine = 336 GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()}); 337 for (size_t i = 0; i < ops_params_.size(); ++i) { 338 kind_t kind = ops_params_[i].kind_; 339 if (kind == kind_t::binary) { 340 dnnl::memory binary_m; 341 auto binary = ops_params_[i].binary_; 342 auto md = ops_params_[i].meta_; 343 // qeury expected_md to achieve peak performance 344 auto expected_md = pd.query_md( 345 dnnl::query::exec_arg_md, 346 DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1); 347 348 binary_m = at::native::onednn::make_onednn_memory( 349 md, engine, binary.data_ptr() 350 ); 351 352 args.insert( 353 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1, binary_m}); 354 } 355 } 356 } 357 358 float q_scale_ = 1.0; // the scale used to quantize the fused result from fp32 359 // to int8, only works for int8 case 360 int64_t q_zero_point_ = 0; 361 std::vector<PostOpParam> ops_params_; // series of post ops 362 dnnl::post_ops dnnl_post_ops_; 363 }; 364 365 } // namespace at::native::onednn 366