xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/xpu/detail/Conv.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/xpu/XPUFunctions.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/grad_mode.h>
5 #include <ATen/record_function.h>
6 #include <c10/core/MemoryFormat.h>
7 
8 #include <ATen/native/mkldnn/xpu/detail/Attr.h>
9 #include <ATen/native/mkldnn/xpu/detail/Utils.h>
10 
11 #include <oneapi/dnnl/dnnl.hpp>
12 
13 namespace at::native::onednn {
14 
15 constexpr int src_batch_size_dim = 0;
16 constexpr int weight_dst_channels_dim = 0;
17 
conv_dst_size(int64_t ndim,IntArrayRef src_size,IntArrayRef weight_size,IntArrayRef padding_front_top_left,IntArrayRef padding_back_bottom_right,IntArrayRef stride,IntArrayRef dilation)18 dnnl::memory::dims conv_dst_size(
19     int64_t ndim,
20     IntArrayRef src_size,
21     IntArrayRef weight_size,
22     IntArrayRef padding_front_top_left,
23     IntArrayRef padding_back_bottom_right,
24     IntArrayRef stride,
25     IntArrayRef dilation) {
26   bool has_dilation = dilation.size() > 0;
27  dnnl::memory::dims dst_size(ndim);
28   dst_size[0] = src_size[src_batch_size_dim];
29   dst_size[1] = weight_size[weight_dst_channels_dim];
30   for (int d = 2; d < ndim; ++d) {
31     auto dilate = has_dilation ? dilation[d - 2] : 1;
32     auto kernel = dilate * (weight_size[d] - 1) + 1;
33     dst_size[d] =
34         (src_size[d] +
35          (padding_front_top_left[d - 2] + padding_back_bottom_right[d - 2]) -
36          kernel) /
37             stride[d - 2] +
38         1;
39   }
40   return dst_size;
41 }
42 
compatible_dilation(IntArrayRef & dilation)43 static inline dnnl::memory::dims compatible_dilation(IntArrayRef& dilation) {
44  dnnl::memory::dims ret = dilation.vec();
45   for (auto it = ret.begin(); it != ret.end(); it++) {
46     *it -= 1;
47   }
48   return ret;
49 }
50 
conv_src_fmt(const int64_t ndim,const bool is_channels_last=false)51 static inline dnnl::memory::format_tag conv_src_fmt(
52     const int64_t ndim,
53     const bool is_channels_last = false) {
54   if (!is_channels_last) {
55     return (ndim == 3)
56         ? dnnl::memory::format_tag::ncw
57         : ((ndim == 4) ? dnnl::memory::format_tag::nchw
58                        : ((ndim == 5) ? dnnl::memory::format_tag::ncdhw
59                                       : dnnl::memory::format_tag::undef));
60   } else {
61     return (ndim == 3)
62         ? dnnl::memory::format_tag::nwc
63         : ((ndim == 4) ? dnnl::memory::format_tag::nhwc
64                        : ((ndim == 5) ? dnnl::memory::format_tag::ndhwc
65                                       : dnnl::memory::format_tag::undef));
66   }
67 }
68 
conv_weight_fmt(const int64_t ndim,const bool grouped=false,const bool is_channels_last=false)69 static inline dnnl::memory::format_tag conv_weight_fmt(
70     const int64_t ndim,
71     const bool grouped = false,
72     const bool is_channels_last = false) {
73   if (!is_channels_last) {
74     return (ndim == 3)
75         ? (grouped ? dnnl::memory::format_tag::goiw : dnnl::memory::format_tag::oiw)
76         : (ndim == 4)
77         ? (grouped ? dnnl::memory::format_tag::goihw : dnnl::memory::format_tag::oihw)
78         : ((ndim == 5) ? (grouped ? dnnl::memory::format_tag::goidhw
79                                   : dnnl::memory::format_tag::oidhw)
80                        : dnnl::memory::format_tag::undef);
81   } else {
82     return (ndim == 3)
83         ? (grouped ? dnnl::memory::format_tag::gowi : dnnl::memory::format_tag::owi)
84         : (ndim == 4)
85         ? (grouped ? dnnl::memory::format_tag::gohwi : dnnl::memory::format_tag::ohwi)
86         : ((ndim == 5) ? (grouped ? dnnl::memory::format_tag::godhwi
87                                   : dnnl::memory::format_tag::odhwi)
88                        : dnnl::memory::format_tag::undef);
89   }
90 }
91 
compatible_weight_dims(const int64_t ndim,const int64_t groups,const int64_t oc,const int64_t ic,const IntArrayRef wsizes)92 static inline dnnl::memory::dims compatible_weight_dims(
93     const int64_t ndim,
94     const int64_t groups,
95     const int64_t oc,
96     const int64_t ic,
97     const IntArrayRef wsizes) {
98   if (ndim == 3) {
99     auto kw = wsizes[2];
100     return (groups != 1) ? dnnl::memory::dims({groups, oc / groups, ic / groups, kw})
101                          : dnnl::memory::dims({oc, ic, kw});
102   } else if (ndim == 4) {
103     auto kh = wsizes[2];
104     auto kw = wsizes[3];
105     return (groups != 1)
106         ? dnnl::memory::dims({groups, oc / groups, ic / groups, kh, kw})
107         : dnnl::memory::dims({oc, ic, kh, kw});
108   } else if (ndim == 5) {
109     auto kd = wsizes[2];
110     auto kh = wsizes[3];
111     auto kw = wsizes[4];
112     return (groups != 1)
113         ? dnnl::memory::dims({groups, oc / groups, ic / groups, kd, kh, kw})
114         : dnnl::memory::dims({oc, ic, kd, kh, kw});
115   }
116 
117   return {};
118 }
119 
120 static std::tuple<
121     dnnl::memory::desc,
122     dnnl::memory::desc,
123     dnnl::memory::desc>
conv_get_md(const at::Tensor & src,const at::Tensor & weight,const at::Tensor & dst,int64_t groups,bool is_channels_last)124  conv_get_md(
125     const at::Tensor& src,
126     const at::Tensor& weight,
127     const at::Tensor& dst,
128     int64_t groups,
129     bool is_channels_last) {
130   // create memory desc from the src/weight/dst tensors
131   dnnl::memory::desc src_usr_md, weight_usr_md, dst_usr_md;
132   auto ndim = src.ndimension();
133   auto fmt_src =
134       conv_src_fmt(ndim, is_channels_last);
135 
136   auto src_size = src.sizes().vec();
137   auto src_data_t = get_onednn_dtype_include_double(src);
138   src_usr_md = dnnl::memory::desc(src_size, src_data_t, fmt_src);
139 
140   auto dst_size = dst.sizes().vec();
141   auto dst_data_t = get_onednn_dtype_include_double(dst);
142   dst_usr_md = dnnl::memory::desc(dst_size, dst_data_t, fmt_src);
143 
144   auto ic = src.size(1);
145   auto oc = dst.size(1);
146   auto wei_data_t = get_onednn_dtype_include_double(weight);
147   dnnl::memory::dims weight_size =
148       compatible_weight_dims(ndim, groups, oc, ic, weight.sizes());
149   auto fmt_weight = conv_weight_fmt(
150       ndim,
151       groups != 1,
152       is_channels_last);
153   weight_usr_md = dnnl::memory::desc(weight_size, wei_data_t, fmt_weight);
154 
155   return {src_usr_md, weight_usr_md, dst_usr_md};
156 }
157 
convolution(at::Tensor & dst,const at::Tensor & src,const at::Tensor & weight,const at::Tensor & bia,IntArrayRef padding_front_top_left,IntArrayRef padding_back_bottom_right,IntArrayRef stride,IntArrayRef dilation,int64_t groups,Attr & attr,const std::vector<sycl::event> & deps)158 sycl::event convolution(
159     at::Tensor& dst,
160     const at::Tensor& src,
161     const at::Tensor& weight,
162     const at::Tensor& bia,
163     IntArrayRef padding_front_top_left,
164     IntArrayRef padding_back_bottom_right,
165     IntArrayRef stride,
166     IntArrayRef dilation,
167     int64_t groups,
168     Attr& attr,
169     const std::vector<sycl::event>& deps) {
170   auto engine =
171       GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
172   auto stream = GpuStreamManager::Instance().get_stream();
173 
174   bool is_channels_last = use_channels_last_for_conv(src, weight, false);
175 
176   // create usr_md for tensors, and md for conv primitive
177   auto [src_md, weight_md, dst_md] = conv_get_md(src, weight, dst, groups, is_channels_last);
178 
179   auto bia_fmt = dnnl::memory::format_tag::x;
180   auto bia_md = bia.defined()
181       ? dnnl::memory::desc(
182             {dst.size(1)}, get_onednn_dtype_include_double(bia), bia_fmt)
183       : dnnl::memory::desc();
184 
185   // create conv primitive descriptor
186   dnnl::memory::dims _stride = stride.vec();
187   dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
188   dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
189   dnnl::memory::dims _dilation = compatible_dilation(dilation);
190 
191   // extract post ops
192   dnnl::primitive_attr pattr;
193   dnnl::post_ops po = attr.extract_post_ops(dst);
194   pattr.set_post_ops(po);
195 
196   pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
197 
198   #if ONEDNN_SUPPORT_DETERMINISTIC
199     if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn()){
200         pattr.set_deterministic(true);
201     }
202   #endif
203 
204   auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
205       engine,
206       dnnl::prop_kind::forward,
207       dnnl::algorithm::convolution_direct,
208       src_md,
209       weight_md,
210       bia_md,
211       dst_md,
212       _stride,
213       _dilation,
214       _padding_front_top_left,
215       _padding_back_bottom_right,
216       pattr);
217 
218   dnnl::memory src_m, weight_m, dst_m, bia_m;
219   at::Tensor src_blocked, weight_blocked, dst_blocked = dst;
220 
221   src_m = make_onednn_memory(src_md, engine, src.data_ptr());
222   weight_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
223   dst_m = make_onednn_memory(dst_md, engine, dst.data_ptr());
224 
225 
226   std::unordered_map<int, dnnl::memory> args;
227   if (bia.defined()) {
228     bia_m = make_onednn_memory(bia_md, engine, bia.data_ptr());
229     args.insert({DNNL_ARG_BIAS, bia_m});
230   }
231   auto expected_dst_md = conv_fwd_pd.dst_desc();
232   if (attr.with_binary())
233     attr.construct_post_binary(conv_fwd_pd, args);
234 
235   args.insert({DNNL_ARG_SRC, src_m});
236   args.insert({DNNL_ARG_WEIGHTS, weight_m});
237   args.insert({DNNL_ARG_DST, dst_m});
238 
239   size_t scratchpad_size = conv_fwd_pd.scratchpad_desc().get_size();
240   at::Tensor scratchpad_tensor = at::empty(
241       {static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
242   auto scratchpad_m = make_onednn_memory(
243       conv_fwd_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
244   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
245 
246   auto conv_forward = dnnl::convolution_forward(conv_fwd_pd);
247   auto conv_fwd_event = dnnl::sycl_interop::execute(conv_forward, stream, args, deps);
248 
249   return conv_fwd_event;
250 }
251 
convolution_backward_weights(at::Tensor & diff_weight,at::Tensor & diff_bia,const at::Tensor & diff_dst,const at::Tensor & src,IntArrayRef diff_weight_aten_size,IntArrayRef padding_front_top_left,IntArrayRef padding_back_bottom_right,IntArrayRef stride,IntArrayRef dilation,int64_t groups,const std::vector<sycl::event> & deps)252 sycl::event convolution_backward_weights(
253     at::Tensor& diff_weight,
254     at::Tensor& diff_bia,
255     const at::Tensor& diff_dst,
256     const at::Tensor& src,
257     IntArrayRef diff_weight_aten_size,
258     IntArrayRef padding_front_top_left,
259     IntArrayRef padding_back_bottom_right,
260     IntArrayRef stride,
261     IntArrayRef dilation,
262     int64_t groups,
263     const std::vector<sycl::event>& deps) {
264   auto engine =
265       GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
266   auto stream = GpuStreamManager::Instance().get_stream();
267 
268   bool is_channels_last = use_channels_last_for_conv(src, diff_dst, /*is_transposed=*/false);
269 
270   // create dnnl::memory desc
271   auto [src_md, weight_md, dst_md] =
272       conv_get_md(src, diff_weight, diff_dst, groups, is_channels_last);
273   dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
274   auto bia_md = diff_bia.defined()
275       ? dnnl::memory::desc({diff_dst.size(1)}, src_md.get_data_type(), bia_fmt)
276       : dnnl::memory::desc();
277 
278   // create fwd primitive hint
279   dnnl::memory::dims _stride = stride.vec();
280   dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
281   dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
282   dnnl::memory::dims _dilation = compatible_dilation(dilation);
283   dnnl::primitive_attr pattr;
284 
285   #if ONEDNN_SUPPORT_DETERMINISTIC
286     if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn()){
287         pattr.set_deterministic(true);
288     }
289   #endif
290 
291   pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
292   auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
293       engine,
294       dnnl::prop_kind::forward,
295       dnnl::algorithm::convolution_direct,
296       src_md,
297       weight_md,
298       bia_md,
299       dst_md,
300       _stride,
301       _dilation,
302       _padding_front_top_left,
303       _padding_back_bottom_right,
304       pattr);
305 
306   // create bwd weight primitive
307   auto conv_bwd_w_pd = dnnl::convolution_backward_weights::primitive_desc(
308       engine,
309       dnnl::algorithm::convolution_direct,
310       src_md,
311       weight_md,
312       bia_md,
313       dst_md,
314       _stride,
315       _dilation,
316       _padding_front_top_left,
317       _padding_back_bottom_right,
318       conv_fwd_pd,
319       pattr);
320 
321   // create bwd memory
322   at::Tensor expected_src, expected_diff_dst, expected_diff_weight;
323   dnnl::memory src_m, diff_dst_m, diff_weight_m;
324 
325   src_m = make_onednn_memory(src_md, engine, src.data_ptr());
326   diff_dst_m = make_onednn_memory(dst_md, engine, diff_dst.data_ptr());
327   diff_weight_m = make_onednn_memory(weight_md, engine, diff_weight.data_ptr());
328 
329   // insert args
330   std::unordered_map<int, dnnl::memory> args;
331   args.insert({DNNL_ARG_DIFF_DST, diff_dst_m});
332   args.insert({DNNL_ARG_SRC, src_m});
333   args.insert({DNNL_ARG_DIFF_WEIGHTS, diff_weight_m});
334   if (diff_bia.defined()) {
335     dnnl::memory diff_bia_m =
336         make_onednn_memory(bia_md, engine, diff_bia.data_ptr());
337     args.insert({DNNL_ARG_DIFF_BIAS, diff_bia_m});
338   }
339 
340   size_t scratchpad_size = conv_bwd_w_pd.scratchpad_desc().get_size();
341   at::Tensor scratchpad_tensor = at::empty(
342       {static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
343   auto scratchpad_m = make_onednn_memory(
344       conv_bwd_w_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
345   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
346 
347   // execute primitive
348   auto conv_bwd_w = dnnl::convolution_backward_weights(conv_bwd_w_pd);
349   sycl::event conv_bwd_w_event = dnnl::sycl_interop::execute(conv_bwd_w, stream, args, deps);
350 
351   return conv_bwd_w_event;
352 }
353 
convolution_backward_data(at::Tensor & diff_src,const at::Tensor & diff_dst,const at::Tensor & weight,IntArrayRef padding_front_top_left,IntArrayRef padding_back_bottom_right,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool bias_defined,const std::vector<sycl::event> & deps)354 sycl::event convolution_backward_data(
355     at::Tensor& diff_src,
356     const at::Tensor& diff_dst,
357     const at::Tensor& weight,
358     IntArrayRef padding_front_top_left,
359     IntArrayRef padding_back_bottom_right,
360     IntArrayRef stride,
361     IntArrayRef dilation,
362     int64_t groups,
363     bool bias_defined,
364     const std::vector<sycl::event>& deps) {
365   auto engine =
366       GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
367   auto stream = GpuStreamManager::Instance().get_stream();
368 
369   bool is_channels_last = use_channels_last_for_conv(diff_dst, weight, /*is_transposed=*/false);
370 
371   // create memory desc
372   auto [src_md, weight_md, dst_md] =
373       conv_get_md(diff_src, weight, diff_dst, groups, is_channels_last);
374   dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
375   auto bia_md = bias_defined
376       ? dnnl::memory::desc({diff_dst.size(1)}, weight_md.get_data_type(), bia_fmt)
377       : dnnl::memory::desc();
378 
379   // create fwd primitive desc hint
380   dnnl::primitive_attr pattr;
381 
382   #if ONEDNN_SUPPORT_DETERMINISTIC
383     if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn()){
384         pattr.set_deterministic(true);
385     }
386   #endif
387 
388   pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
389   dnnl::memory::dims _stride = stride.vec();
390   dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
391   dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
392   dnnl::memory::dims _dilation = compatible_dilation(dilation);
393   auto conv_forward_pd = dnnl::convolution_forward::primitive_desc(
394       engine,
395       dnnl::prop_kind::forward,
396       dnnl::algorithm::convolution_direct,
397       src_md,
398       weight_md,
399       bia_md,
400       dst_md,
401       _stride,
402       _dilation,
403       _padding_front_top_left,
404       _padding_back_bottom_right,
405       pattr);
406 
407   auto conv_backward_data_pd = dnnl::convolution_backward_data::primitive_desc(
408       engine,
409       dnnl::algorithm::convolution_direct,
410       src_md,
411       weight_md,
412       dst_md,
413       _stride,
414       _dilation,
415       _padding_front_top_left,
416       _padding_back_bottom_right,
417       conv_forward_pd,
418       pattr);
419 
420   // create memory
421   at::Tensor expected_src, expected_wei, expected_dst;
422   dnnl::memory diff_dst_m, wei_m, diff_src_m;
423 
424   diff_src_m = make_onednn_memory(src_md, engine, diff_src.data_ptr());
425   wei_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
426   diff_dst_m = make_onednn_memory(dst_md, engine, diff_dst.data_ptr());
427 
428 
429   // insert args
430   std::unordered_map<int, dnnl::memory> args;
431   size_t scratchpad_size = conv_backward_data_pd.scratchpad_desc().get_size();
432   at::Tensor scratchpad_tensor = at::empty(
433       {static_cast<int64_t>(scratchpad_size)}, diff_dst.options().dtype(at::kByte), std::nullopt);
434   auto scratchpad_memory = make_onednn_memory(
435       conv_backward_data_pd.scratchpad_desc(),
436       engine,
437       scratchpad_tensor.data_ptr());
438   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
439   args.insert({DNNL_ARG_DIFF_DST, diff_dst_m});
440   args.insert({DNNL_ARG_WEIGHTS, wei_m});
441   args.insert({DNNL_ARG_DIFF_SRC, diff_src_m});
442 
443   // execute primitive
444   auto conv_backward_data =
445       dnnl::convolution_backward_data(conv_backward_data_pd);
446   auto conv_backward_data_event = dnnl::sycl_interop::execute(conv_backward_data, stream, args, deps);
447   return conv_backward_data_event;
448 
449 }
450 
451 } // namespace at::native::onednn
452