xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <oneapi/dnnl/dnnl.hpp>
5 #include <oneapi/dnnl/dnnl_types.h>
6 #include <ATen/native/mkldnn/xpu/detail/Utils.h>
7 #include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
8 
9 namespace at::native::onednn {
10 /* oneDNN quantization usage:
11    https://oneapi-src.github.io/oneDNN/dev_guide_attributes_quantization.html#
12 
13    src_fp32 = scale_src * (src_int8 - zero_point)
14    wei_fp32 = scale_wei * (wei_int8 - zero_point)
15    dst_fp32 = scale_dst * (dst_int8 - zero_point)
16    fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32
17    Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei)
18    Int8 Convolution: dst_int8 = 1 / scale_dst * dst_fp32;
19 
20    Considering zero-point (asymmetric):
21    dst_fp32 = (src_int8 - src_zp) * src_sc * wei_int8 * wei_sc
22    dst_sc * (dst_int8 - dst_zp) = (src_int8 - src_zp) * wei_int8  * src_sc *
23                                  wei_sc
24    dst_int8 = (src_int8 - src_zp) * wei_int8 * src_sc * wei_sc / dst_sc +
25               dst_zp
26 
27    considering bias:
28    fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + bias
29    Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei)
30    + bias Int8 Convolution: dst_fp32 = (src_int8 * wei_int8 + bias/(scale_src *
31    scale_wei)) * (scale_src * scale_wei) Int8 Convolution: dst_int8 = 1 /
32    scale_dst * dst_fp32;
33 */
34 
35 /*
36    oneDNN postops usage:
37    Currently, oneDNN supports 5 kinds of post ops. More details can be refered
38 to oneDNN doc.
39    https://oneapi-src.github.io/oneDNN/dev_guide_attributes_post_ops.html#doxid-dev-guide-attributes-post-ops-1dev-guide-attributes-post-ops-eltwise
40 
41 0. without post ops
42    dst = Conv(src, wei) + bias;
43    dst_int8 = 1/q_scale * dst; q_scale is the op output quantization scale
44    fp32 API: Attr attr;
45    int8 API: Attr attr(q_scale);
46 
47 1. append eltwise post op
48    dst = elt_scale * Eltwise{conv_scale * [Conv(src, wei) + bias], alpha, beta}
49    dst_int8 = 1/q_scale * dst;
50    fp32 API:
51    Attr attr;
52    attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear)
53    attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm)
54    int8 API:
55    Attr attr(q_scale);
56    attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear)
57    attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm)
58 
59 2. append sum post op
60    dst = conv_scale * Conv(src, wei) + sum_scale * (dst - zp)
61    dst_int8 = 1/q_scale * dst;
62    fp32 API:
63    Attr attr;
64    attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear)
65    attr.append_post_sum(sum_scale)
66    int8 API:
67    Attr attr(q_scale);
68    attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear)
69    attr.append_post_sum(sum_scale)
70 
71 3. append binary post op
72    dst = Binary[Conv(src, wei)]
73 
74 */
75 using kind_t = dnnl::primitive::kind;
76 struct PostOpParam {
77   // eltwise post op constructor
PostOpParamPostOpParam78   PostOpParam(float scale, float alpha, float beta, dnnl::algorithm algo, kind_t kind)
79       : scale_(scale), alpha_(alpha), beta_(beta), algo_(algo), kind_(kind) {}
80   // sum post op constructor
PostOpParamPostOpParam81   PostOpParam(float scale, kind_t kind) : scale_(scale), kind_(kind) {}
82   // binary post op constructor
PostOpParamPostOpParam83   PostOpParam(
84       at::Tensor& binary,
85       dnnl::memory::desc& binary_md,
86       dnnl::memory::desc& expected_md,
87       dnnl::algorithm algo,
88       kind_t kind)
89       : binary_(binary),
90         meta_(binary_md),
91         expected_meta_(expected_md),
92         algo_(algo),
93         kind_(kind) {}
94   // prelu post op constructor
PostOpParamPostOpParam95   PostOpParam(int mask, kind_t kind) : mask_(mask), kind_(kind) {}
96 
97   // post sum or binary with scale post op constructor
PostOpParamPostOpParam98   PostOpParam(at::Tensor& binary, float scale, dnnl::algorithm algo, kind_t kind)
99       : scale_(scale), binary_(binary), algo_(algo), kind_(kind) {}
100 
101   // for int8 sum/eltwise
102   float scale_ = 1.0;
103   // for eltwise
104   float alpha_ = 0.0;
105   float beta_ = 0.0;
106   // for binary
107   at::Tensor binary_ = at::Tensor();
108   at::Tensor expected_binary_ = at::Tensor();
109   void* binary_ptr_ = nullptr;
110   dnnl::memory::desc meta_ = dnnl::memory::desc();
111   dnnl::memory::desc expected_meta_ = dnnl::memory::desc();
112   // for prelu
113   int mask_ = 0;
114   // common
115   dnnl::algorithm algo_ = dnnl::algorithm::eltwise_relu;
116   kind_t kind_ = kind_t::eltwise;
117 };
118 
119 class Attr {
120  public:
Attr()121   Attr() : q_scale_(1.f), q_zero_point_(0) {}
q_scale_(q_scale)122   Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {}
123 
124   /***** eltwise *****/
125   dnnl::algorithm kind_with_relu = dnnl::algorithm::eltwise_relu;
126   dnnl::algorithm kind_with_sigmoid = dnnl::algorithm::eltwise_logistic;
127   dnnl::algorithm kind_with_gelu_tanh = dnnl::algorithm::eltwise_gelu_tanh;
128   dnnl::algorithm kind_with_gelu_erf = dnnl::algorithm::eltwise_gelu_erf;
129   dnnl::algorithm kind_with_mish = dnnl::algorithm::eltwise_mish;
130   dnnl::algorithm kind_with_linear = dnnl::algorithm::eltwise_linear;
131   dnnl::algorithm kind_with_swish = dnnl::algorithm::eltwise_swish;
132   dnnl::algorithm kind_with_sqrt = dnnl::algorithm::eltwise_sqrt;
133   dnnl::algorithm kind_with_tanh = dnnl::algorithm::eltwise_tanh;
134   dnnl::algorithm kind_with_square = dnnl::algorithm::eltwise_square;
135   dnnl::algorithm kind_with_abs = dnnl::algorithm::eltwise_abs;
136   dnnl::algorithm kind_with_exp = dnnl::algorithm::eltwise_exp;
137   dnnl::algorithm kind_with_log = dnnl::algorithm::eltwise_log;
138   dnnl::algorithm kind_with_round = dnnl::algorithm::eltwise_round;
139   dnnl::algorithm kind_with_hardswish = dnnl::algorithm::eltwise_hardswish;
140   dnnl::algorithm kind_with_soft_relu = dnnl::algorithm::eltwise_soft_relu;
141   dnnl::algorithm kind_with_elu = dnnl::algorithm::eltwise_elu;
142   dnnl::algorithm kind_with_pow = dnnl::algorithm::eltwise_pow;
143   dnnl::algorithm kind_with_clip = dnnl::algorithm::eltwise_clip;
144   // note: hardsigmoid seems oneDNN still not support
145   dnnl::algorithm kind_with_hardsigmoid = dnnl::algorithm::eltwise_hardsigmoid;
146 
147   /***** binary *****/
148   dnnl::algorithm kind_with_binary_mul = dnnl::algorithm::binary_mul;
149   dnnl::algorithm kind_with_binary_add = dnnl::algorithm::binary_add;
150   dnnl::algorithm kind_with_binary_sub = dnnl::algorithm::binary_sub;
151   dnnl::algorithm kind_with_binary_div = dnnl::algorithm::binary_div;
152   dnnl::algorithm kind_with_binary_eq = dnnl::algorithm::binary_eq;
153   dnnl::algorithm kind_with_binary_ne = dnnl::algorithm::binary_ne;
154   dnnl::algorithm kind_with_binary_ge = dnnl::algorithm::binary_ge;
155   dnnl::algorithm kind_with_binary_gt = dnnl::algorithm::binary_gt;
156   dnnl::algorithm kind_with_binary_le = dnnl::algorithm::binary_le;
157   dnnl::algorithm kind_with_binary_lt = dnnl::algorithm::binary_lt;
158   dnnl::algorithm kind_with_binary_max = dnnl::algorithm::binary_max;
159   dnnl::algorithm kind_with_binary_min = dnnl::algorithm::binary_min;
160 
161   // append sum post op
162   Attr& append_post_sum(
163       float sum_scale,
164       float sum_q_scale = 1.f,
165       int64_t zp = 0) {
166     ops_params_.push_back(
167         PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, kind_t::sum));
168     return *this;
169   }
170 
171   // append eltwise post op
append_post_eltwise(float scale,float alpha,float beta,dnnl::algorithm algo)172   Attr& append_post_eltwise(
173       float scale,
174       float alpha,
175       float beta,
176       dnnl::algorithm algo) {
177     ops_params_.push_back(
178         PostOpParam(scale, alpha, beta, algo, kind_t::eltwise));
179     return *this;
180   }
181 
182   // append binary post op
append_post_binary(dnnl::algorithm algo,const at::Tensor & binary)183   Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) {
184     auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary;
185     bool binary_is_channels_last = (binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
186                                       binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d);
187 
188     binary_ = binary_is_channels_last ? binary_ : binary_.contiguous();
189     dnnl::memory::desc md = get_onednn_md(binary_);
190     auto expected_md = dnnl::memory::desc(
191         md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::any);
192     ops_params_.push_back(
193         PostOpParam(binary_, md, expected_md, algo, kind_t::binary));
194     return *this;
195   }
196 
197   Attr& append_scale_binary(
198       dnnl::algorithm algo,
199       at::Tensor binary,
200       float scale,
201       float sum_q_scale = 1.f,
202       int64_t zp = 0) {
203     ops_params_.push_back(PostOpParam(
204         binary, /*scale_sum*/ scale * sum_q_scale, algo, kind_t::binary));
205     return *this;
206   }
207 
208   // append bias with binary_add method (only used for QConv now)
209   template <int N>
append_bias(const at::Tensor & binary)210   Attr& append_bias(const at::Tensor& binary) {
211     // In PyTorch, bias are in shape of [OC],
212     // we expand its shape according to Conv dimension
213     // Conv1d [OC, 1, 1], Conv2d [1, OC, 1, ,1], Conv3d [1, OC, 1, 1, 1]
214     at::Tensor binary_ = binary.contiguous();
215     dnnl::memory::desc binary_md;
216     switch (N) {
217       case 1:
218         binary_md = dnnl::memory::desc(
219             {binary.size(0), 1, 1},
220             dnnl::memory::data_type::f32,
221             dnnl::memory::format_tag::abc);
222         break;
223       case 2:
224         binary_md = dnnl::memory::desc(
225             {1, binary.size(0), 1, 1},
226             dnnl::memory::data_type::f32,
227             dnnl::memory::format_tag::abcd);
228         break;
229       case 3:
230         binary_md = dnnl::memory::desc(
231             {1, binary.size(0), 1, 1, 1},
232             dnnl::memory::data_type::f32,
233             dnnl::memory::format_tag::abcde);
234         break;
235       default:
236         TORCH_INTERNAL_ASSERT(0,
237             "XPU only supports append_bias for Conv1d, Conv2d and Conv3d.");
238     }
239     // In this case, expected_md = binary_md
240     ops_params_.push_back(PostOpParam(
241         binary_, binary_md, binary_md, kind_with_binary_add, kind_t::binary));
242     return *this;
243   }
244 
245   // append prelu post op
append_post_prelu(int mask)246   Attr& append_post_prelu(int mask) {
247     ops_params_.push_back(PostOpParam(mask, kind_t::prelu));
248     return *this;
249   }
250 
extract_post_ops(const at::Tensor & dst)251   dnnl::post_ops extract_post_ops(const at::Tensor& dst){
252     // this function is used to extract post ops params from the ops_params_
253     // and put them into onednn post ops
254     for (size_t i = 0; i < ops_params_.size(); ++i) {
255       kind_t kind = ops_params_[i].kind_;
256       switch (kind) {
257         case kind_t::eltwise: {
258           dnnl::algorithm algo = ops_params_[i].algo_;
259           float alpha = ops_params_[i].alpha_;
260           float beta = ops_params_[i].beta_;
261           dnnl_post_ops_.append_eltwise(algo, alpha, beta);
262           break;
263         }
264         case kind_t::sum: {
265           float scale = ops_params_[i].scale_;
266           // TODO [Asymmetric]:
267           // Post-sum zp for gpu is not supported currently
268           dnnl_post_ops_.append_sum(scale);
269           break;
270         }
271         case kind_t::binary: {
272           dnnl::algorithm algo = ops_params_[i].algo_;
273           auto expected_md = ops_params_[i].expected_meta_;
274           // In this case user may create src1 memory descriptor with
275           // format_tag::any or set a specific tag. However, in later case if
276           // tags mismatch with dst, it would result in suboptimal performance.
277           // So here we use format_tag::any to make sure the fast can be
278           // selected.
279           // Thus we use expected_md (with format_any) here to create pd instead
280           // of original md
281           dnnl_post_ops_.append_binary(algo, expected_md);
282           break;
283         }
284         default:
285           break;
286       }
287     }
288 
289     // if output is quantized, then append the eltwise linear to adjust the
290     // output scale/zero_point
291     if (dst.is_quantized()) {
292       // [Note: Gap of u8 qtensor scale between oneDNN and PyTorch]
293       // The /2 here is for output_scale collected by observer is different
294       // from quantization requirements in oneDNN.
295       // For Observer, the conv_scale (activation scale in other case) is
296       // computed through 2max_v/(qmax - qmin). The max_v is collected
297       // from the tensor to be observerd.
298       // (https://pytorch.org/docs/stable/generated/torch.quantization.observer.MinMaxObserver.html#torch.quantization.observer.MinMaxObserver)
299       // On the other hand, for u8 in oneDNN, the scale for quantization is
300       // defined as max_v/(qmax-qmin). Hence, we need to divide by 2 here.
301       // (https://oneapi-src.github.io/oneDNN/dev_guide_inference_int8.html)
302       dnnl_post_ops_.append_eltwise(
303           kind_with_linear, 1.f / q_scale_, q_zero_point_);
304     }
305     return dnnl_post_ops_;
306   }
307 
with_sum()308   bool with_sum() {
309     for (size_t i = 0; i < ops_params_.size(); ++i) {
310       if (ops_params_[i].kind_ == kind_t::sum) {
311         return true;
312       }
313     }
314     return false;
315   }
316 
with_binary()317   bool with_binary() {
318     for (size_t i = 0; i < ops_params_.size(); ++i) {
319       if (ops_params_[i].kind_ == kind_t::binary) {
320         return true;
321       }
322     }
323     return false;
324   }
325 
construct_post_binary(dnnl::primitive_desc & pd,std::unordered_map<int,dnnl::memory> & args)326   void construct_post_binary(
327       dnnl::primitive_desc& pd,
328       std::unordered_map<int, dnnl::memory>& args) {
329     // This function is used to construct binary memory desc in binary post ops.
330     // According to oneDNN doc, the binary tensor can be in shape of
331     // [1, 1, 1, 1], tensor broadcast
332     // [1, C, 1, 1], channel broadcast
333     // [dst.shape], no broadcast and eltwise-wise binary operations on dst
334 
335     auto engine =
336         GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
337     for (size_t i = 0; i < ops_params_.size(); ++i) {
338       kind_t kind = ops_params_[i].kind_;
339       if (kind == kind_t::binary) {
340         dnnl::memory binary_m;
341         auto binary = ops_params_[i].binary_;
342         auto md = ops_params_[i].meta_;
343         // qeury expected_md to achieve peak performance
344         auto expected_md = pd.query_md(
345             dnnl::query::exec_arg_md,
346             DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1);
347 
348         binary_m = at::native::onednn::make_onednn_memory(
349           md, engine, binary.data_ptr()
350         );
351 
352         args.insert(
353             {DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1, binary_m});
354       }
355     }
356   }
357 
358   float q_scale_ = 1.0; // the scale used to quantize the fused result from fp32
359                         // to int8, only works for int8 case
360   int64_t q_zero_point_ = 0;
361   std::vector<PostOpParam> ops_params_; // series of post ops
362   dnnl::post_ops dnnl_post_ops_;
363 };
364 
365 } // namespace at::native::onednn
366