xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/xpu/detail/Deconv.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/xpu/XPUFunctions.h>
2 #include <ATen/ATen.h>
3 
4 #include <oneapi/dnnl/dnnl.hpp>
5 #include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
6 #include <ATen/native/mkldnn/xpu/detail/Utils.h>
7 #include <ATen/native/mkldnn/xpu/detail/Attr.h>
8 
9 namespace at::native::onednn {
10 
deconv_compatible_dilation(IntArrayRef & dilation)11 static inline dnnl::memory::dims deconv_compatible_dilation(IntArrayRef& dilation) {
12   dnnl::memory::dims ret = dilation.vec();
13   for (auto it = ret.begin(); it != ret.end(); it++) {
14     *it -= 1;
15   }
16   return ret;
17 }
18 
compatible_groups_deconv_strides(const at::Tensor & weight,dnnl::memory::dims group_size)19 static inline std::vector<int64_t> compatible_groups_deconv_strides(
20     const at::Tensor& weight,
21     dnnl::memory::dims group_size) {
22   std::vector<int64_t> strides = weight.strides().vec();
23   strides[0] = weight.strides()[1];
24   strides[1] = weight.strides()[0];
25   strides.insert(strides.begin(), group_size[2] * weight.strides()[0]);
26   return strides;
27 }
28 
deconv_dst_size(IntArrayRef src_size,IntArrayRef weight_size,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,IntArrayRef dst_padding,int64_t groups)29 dnnl::memory::dims deconv_dst_size(
30     IntArrayRef src_size,
31     IntArrayRef weight_size,
32     IntArrayRef padding,
33     IntArrayRef stride,
34     IntArrayRef dilation,
35     IntArrayRef dst_padding,
36     int64_t groups) {
37   auto dim = src_size.size();
38   dnnl::memory::dims dst_size(dim);
39   auto kernel_size = weight_size.slice(2);
40 
41   dst_size[0] = src_size[0];
42   dst_size[1] = weight_size[1] * groups;
43   for (size_t d = 2; d < dim; ++d) {
44     dst_size[d] = (src_size[d] - 1) * stride[d - 2] - 2 * padding[d - 2] +
45         (dilation[d - 2] * (kernel_size[d - 2] - 1) + 1) + dst_padding[d - 2];
46   }
47   return dst_size;
48 }
49 
deconv_src_fmt(const int64_t ndim,const bool is_channels_last=false)50 static inline dnnl::memory::format_tag deconv_src_fmt(
51     const int64_t ndim,
52     const bool is_channels_last = false) {
53   // 3D: n/c/w (n/w/c)         [a/b/c (a/c/b)]
54   // 4D: n/c/h/w (n/h/w/c)     [a/b/c/d (a/c/d/b)]
55   // 5D: n/c/d/h/w (n/d/h/w/c) [a/b/c/d/e (a/c/d/e/b)]
56   if (!is_channels_last) {
57     return (ndim == 3)
58         ? dnnl::memory::format_tag::ncw
59         : ((ndim == 4) ? dnnl::memory::format_tag::nchw
60                        : ((ndim == 5) ? dnnl::memory::format_tag::ncdhw
61                                       : dnnl::memory::format_tag::undef));
62   } else {
63     return (ndim == 3)
64         ? dnnl::memory::format_tag::nwc
65         : ((ndim == 4) ? dnnl::memory::format_tag::nhwc
66                        : ((ndim == 5) ? dnnl::memory::format_tag::ndhwc
67                                       : dnnl::memory::format_tag::undef));
68   }
69 }
70 
deconv_weight_fmt(const at::Tensor & weight,const int64_t ndim,dnnl::memory::dims weight_size,const bool grouped=false,const bool is_channels_last=false)71 static inline std::vector<int64_t> deconv_weight_fmt(
72     const at::Tensor& weight,
73     const int64_t ndim,
74     dnnl::memory::dims weight_size,
75     const bool grouped = false,
76     const bool is_channels_last = false) {
77   // 3D fmt: (g)i/o/w ((g)i/w/o)  [b/a/c  (b/c/a)]
78   // 4D fmt: (g)i/o/h/w ((g)i/h/w/o) [b/a/c/d (b/c/d/a)]
79   // 5D fmt: (g)i/o/d/h/w ((g)i/d/h/w/o) [b/a/c/d/e (b/c/d/e/a)]
80   auto strides_ = weight.strides().vec();
81   std::vector<int64_t> strides;
82   if (grouped) {
83     strides = compatible_groups_deconv_strides(weight, weight_size);
84   } else {
85     strides = strides_;
86     std::swap(strides[0], strides[1]);
87   }
88   return strides;
89 }
90 
deconv_compatible_weight_dims(int64_t ndim,int64_t groups,int64_t oc,int64_t ic,IntArrayRef weight_size)91 static inline dnnl::memory::dims deconv_compatible_weight_dims(
92     int64_t ndim,
93     int64_t groups,
94     int64_t oc,
95     int64_t ic,
96     IntArrayRef weight_size) {
97   if (ndim == 3) {
98     auto kw = weight_size[2];
99     return (groups != 1) ? dnnl::memory::dims({groups, oc / groups, ic / groups, kw})
100                          : dnnl::memory::dims({oc, ic, kw});
101   } else if (ndim == 4) {
102     auto kh = weight_size[2];
103     auto kw = weight_size[3];
104     return (groups != 1)
105         ? dnnl::memory::dims({groups, oc / groups, ic / groups, kh, kw})
106         : dnnl::memory::dims({oc, ic, kh, kw});
107   } else if (ndim == 5) {
108     auto kd = weight_size[2];
109     auto kh = weight_size[3];
110     auto kw = weight_size[4];
111     return (groups != 1)
112         ? dnnl::memory::dims({groups, oc / groups, ic / groups, kd, kh, kw})
113         : dnnl::memory::dims({oc, ic, kd, kh, kw});
114   } else {
115     TORCH_CHECK(0, "unsupported dimension in xpu oneDNN deconvolution...");
116   }
117 }
118 
119 static std::tuple<
120     dnnl::memory::desc,
121     dnnl::memory::desc,
122     dnnl::memory::desc>
deconv_get_plain_md(const at::Tensor & src,const at::Tensor & weight,const at::Tensor & dst,int64_t groups,bool is_channels_last_suggested)123 deconv_get_plain_md(
124     const at::Tensor& src,
125     const at::Tensor& weight,
126     const at::Tensor& dst,
127     int64_t groups,
128     bool is_channels_last_suggested) {
129   auto ndim = src.ndimension();
130   auto src_data_t = get_onednn_dtype_include_double(src);
131   auto fmt_src = deconv_src_fmt(ndim, is_channels_last_suggested);
132   auto src_usr_md = dnnl::memory::desc(src.sizes().vec(), src_data_t, fmt_src);
133 
134   auto dst_data_t = get_onednn_dtype_include_double(dst);
135   auto dst_usr_md = dnnl::memory::desc(dst.sizes().vec(), dst_data_t, fmt_src);
136 
137   auto ic = src.size(1);
138   auto oc = dst.size(1);
139   dnnl::memory::dims weight_size =
140       deconv_compatible_weight_dims(ndim, groups, oc, ic, weight.sizes());
141   auto weight_dt = get_onednn_dtype_include_double(weight);
142   auto fmt_weight = deconv_weight_fmt(
143       weight, ndim, weight_size, groups != 1, is_channels_last_suggested);
144   dnnl::memory::desc weight_usr_md = dnnl::memory::desc(weight_size, weight_dt, fmt_weight);
145 
146   return {src_usr_md, weight_usr_md, dst_usr_md};
147 }
148 
deconvolution(at::Tensor & dst,const at::Tensor & src,const at::Tensor & weight,const at::Tensor & bia,IntArrayRef stride,IntArrayRef padding,IntArrayRef dst_padding,IntArrayRef dilation,int64_t groups,Attr & attr,const std::vector<sycl::event> & deps)149 sycl::event deconvolution(
150     at::Tensor& dst,
151     const at::Tensor& src,
152     const at::Tensor& weight,
153     const at::Tensor& bia,
154     IntArrayRef stride,
155     IntArrayRef padding,
156     IntArrayRef dst_padding,
157     IntArrayRef dilation,
158     int64_t groups,
159     Attr& attr,
160     const std::vector<sycl::event>& deps) {
161   auto engine =
162       GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
163   auto stream = GpuStreamManager::Instance().get_stream();
164 
165   bool is_channels_last_suggested = use_channels_last_for_conv(src, weight, /*is_transposed=*/true);
166 
167   // create usr_md for tensors, and md for conv primitive
168   auto [src_md, weight_md, dst_md] =
169       deconv_get_plain_md(src, weight, dst, groups, is_channels_last_suggested);
170 
171   dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
172   auto bia_md = bia.defined()
173       ? dnnl::memory::desc(
174             {dst.size(1)}, get_onednn_dtype_include_double(bia), bia_fmt)
175       : dnnl::memory::desc();
176 
177   // create primitive desc
178   dnnl::memory::dims _stride = stride.vec();
179   dnnl::memory::dims _padding = padding.vec();
180   dnnl::memory::dims _dilation = deconv_compatible_dilation(dilation);
181 
182   // construct primitive attr
183   dnnl::primitive_attr pattr;
184   dnnl::post_ops po = attr.extract_post_ops(dst);
185   pattr.set_post_ops(po);
186   #if ONEDNN_SUPPORT_DETERMINISTIC
187     if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
188         pattr.set_deterministic(true);
189   #endif
190 
191   pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
192 
193   auto deconv_fwd_pd = dnnl::deconvolution_forward::primitive_desc(
194       engine,
195       dnnl::prop_kind::forward,
196       dnnl::algorithm::deconvolution_direct,
197       src_md,
198       weight_md,
199       bia_md,
200       dst_md,
201       _stride,
202       _dilation,
203       _padding,
204       _padding,
205       pattr);
206 
207   dnnl::memory src_m, weight_m, dst_m, bia_m;
208   at::Tensor src_blocked, weight_blocked, dst_blocked = dst;
209 
210   src_m = make_onednn_memory(src_md, engine, src.data_ptr());
211   weight_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
212   dst_m = make_onednn_memory(dst_md, engine, dst.data_ptr());
213 
214   std::unordered_map<int, dnnl::memory> args;
215   args.insert({DNNL_ARG_SRC, src_m});
216   args.insert({DNNL_ARG_WEIGHTS, weight_m});
217   args.insert({DNNL_ARG_DST, dst_m});
218 
219   if (bia.defined()) {
220     auto bia_m = make_onednn_memory(bia_md, engine, bia.data_ptr());
221     args.insert({DNNL_ARG_BIAS, bia_m});
222   }
223   if (attr.with_binary())
224     attr.construct_post_binary(deconv_fwd_pd, args);
225 
226   size_t scratchpad_size = deconv_fwd_pd.scratchpad_desc().get_size();
227   at::Tensor scratchpad_tensor = at::empty(
228       {static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
229   auto scratchpad_m = make_onednn_memory(
230       deconv_fwd_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
231   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
232 
233   auto deconv_fwd = dnnl::deconvolution_forward(deconv_fwd_pd);
234   sycl::event deconv_event = dnnl::sycl_interop::execute(deconv_fwd, stream, args, deps);
235   return deconv_event;
236 
237 }
238 
deconvolution_backward_data(at::Tensor & diff_src,const at::Tensor & diff_dst,const at::Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,bool bias_defined,const std::vector<sycl::event> & deps)239 sycl::event deconvolution_backward_data(
240     at::Tensor& diff_src,
241     const at::Tensor& diff_dst,
242     const at::Tensor& weight,
243     IntArrayRef stride,
244     IntArrayRef padding,
245     IntArrayRef dilation,
246     int64_t groups,
247     bool bias_defined,
248     const std::vector<sycl::event>& deps) {
249   auto engine =
250       GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
251   auto stream = GpuStreamManager::Instance().get_stream();
252 
253   bool is_channels_last_suggested =
254       use_channels_last_for_conv(diff_dst, weight, /*is_transposed=*/true);
255   // create memory desc
256   auto [src_md, weight_md, dst_md] =
257       deconv_get_plain_md(
258           diff_src, weight, diff_dst, groups, is_channels_last_suggested);
259 
260   dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
261   auto bias_md = bias_defined
262       ? dnnl::memory::desc({diff_dst.size(1)}, weight_md.get_data_type(), bia_fmt)
263       : dnnl::memory::desc();
264 
265   // create fwd primitive desc hint
266   dnnl::primitive_attr pattr;
267   pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
268   #if ONEDNN_SUPPORT_DETERMINISTIC
269     if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
270         pattr.set_deterministic(true);
271   #endif
272 
273   dnnl::memory::dims _stride = stride.vec();
274   dnnl::memory::dims _padding = padding.vec();
275   dnnl::memory::dims _dilation = deconv_compatible_dilation(dilation);
276   auto deconv_fwd_pd = dnnl::deconvolution_forward::primitive_desc(
277       engine,
278       dnnl::prop_kind::forward,
279       dnnl::algorithm::deconvolution_direct,
280       src_md,
281       weight_md,
282       bias_md,
283       dst_md,
284       _stride,
285       _dilation,
286       _padding,
287       _padding,
288       pattr);
289 
290   // create bwd primitive desc
291   auto deconv_backward_data_pd = dnnl::deconvolution_backward_data::primitive_desc(
292       engine,
293       dnnl::algorithm::deconvolution_direct,
294       src_md,
295       weight_md,
296       dst_md,
297       _stride,
298       _dilation,
299       _padding,
300       _padding,
301       deconv_fwd_pd);
302 
303   // create memory
304   dnnl::memory diff_dst_m, wei_m, diff_src_m;
305 
306   diff_src_m = make_onednn_memory(src_md, engine, diff_src.data_ptr());
307   wei_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
308   diff_dst_m = make_onednn_memory(dst_md, engine, diff_dst.data_ptr());
309 
310   // insert args
311   std::unordered_map<int, dnnl::memory> args;
312   size_t scratchpad_size = deconv_backward_data_pd.scratchpad_desc().get_size();
313   at::Tensor scratchpad_tensor = at::empty(
314       {static_cast<int64_t>(scratchpad_size)}, diff_dst.options().dtype(at::kByte), std::nullopt);
315   auto scratchpad_memory = make_onednn_memory(
316       deconv_backward_data_pd.scratchpad_desc(),
317       engine,
318       scratchpad_tensor.data_ptr());
319   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
320   args.insert({DNNL_ARG_DIFF_DST, diff_dst_m});
321   args.insert({DNNL_ARG_WEIGHTS, wei_m});
322   args.insert({DNNL_ARG_DIFF_SRC, diff_src_m});
323 
324   // execute primitive
325   auto deconv_backward_data =
326       dnnl::deconvolution_backward_data(deconv_backward_data_pd);
327   sycl::event deconv_bwd_data_event = dnnl::sycl_interop::execute(deconv_backward_data, stream, args, deps);
328   return deconv_bwd_data_event;
329 
330 }
331 
deconvolution_backward_weights(at::Tensor & diff_weight,at::Tensor & diff_bia,const at::Tensor & diff_dst,const at::Tensor & src,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,const std::vector<sycl::event> & deps)332 sycl::event deconvolution_backward_weights(
333     at::Tensor& diff_weight,
334     at::Tensor& diff_bia,
335     const at::Tensor& diff_dst,
336     const at::Tensor& src,
337     IntArrayRef stride,
338     IntArrayRef padding,
339     IntArrayRef dilation,
340     int64_t groups,
341     const std::vector<sycl::event>& deps) {
342   auto engine =
343       GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
344   auto stream = GpuStreamManager::Instance().get_stream();
345 
346   bool is_channels_last_suggested =
347       use_channels_last_for_conv(src, diff_dst, /*is_transposed=*/true);
348 
349   // create memory desc
350   auto [src_md, weight_md, dst_md] = deconv_get_plain_md(
351           src, diff_weight, diff_dst, groups, is_channels_last_suggested);
352 
353   dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
354   auto bia_md = diff_bia.defined()
355       ? dnnl::memory::desc({diff_dst.size(1)}, src_md.get_data_type(), bia_fmt)
356       : dnnl::memory::desc();
357 
358   // create fwd primitive desc hint
359   dnnl::memory::dims _stride = stride.vec();
360   dnnl::memory::dims _padding = padding.vec();
361   dnnl::memory::dims _dilation = deconv_compatible_dilation(dilation);
362   dnnl::primitive_attr pattr;
363 
364   #if ONEDNN_SUPPORT_DETERMINISTIC
365     if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
366         pattr.set_deterministic(true);
367   #endif
368   pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
369   auto deconv_fwd_pd = dnnl::deconvolution_forward::primitive_desc(
370       engine,
371       dnnl::prop_kind::forward,
372       dnnl::algorithm::deconvolution_direct,
373       src_md,
374       weight_md,
375       bia_md,
376       dst_md,
377       _stride,
378       _dilation,
379       _padding,
380       _padding,
381       pattr);
382 
383   auto deconv_bwd_w_pd = dnnl::deconvolution_backward_weights::primitive_desc(
384       engine,
385       dnnl::algorithm::deconvolution_direct,
386       src_md,
387       weight_md,
388       bia_md,
389       dst_md,
390       _stride,
391       _dilation,
392       _padding,
393       _padding,
394       deconv_fwd_pd,
395       pattr);
396 
397   // create bwd dnnl::memory
398   dnnl::memory src_m, diff_dst_m, diff_weight_m;
399 
400   src_m = make_onednn_memory(src_md, engine, src.data_ptr());
401   diff_dst_m = make_onednn_memory(dst_md, engine, diff_dst.data_ptr());
402   diff_weight_m = make_onednn_memory(weight_md, engine, diff_weight.data_ptr());
403 
404   // insert args
405   std::unordered_map<int, dnnl::memory> args;
406   args.insert({DNNL_ARG_DIFF_DST, diff_dst_m});
407   args.insert({DNNL_ARG_SRC, src_m});
408   args.insert({DNNL_ARG_DIFF_WEIGHTS, diff_weight_m});
409 
410   if (diff_bia.defined()) {
411     dnnl::memory diff_bia_m =
412         make_onednn_memory(bia_md, engine, diff_bia.data_ptr());
413     args.insert({DNNL_ARG_DIFF_BIAS, diff_bia_m});
414   }
415 
416   size_t scratchpad_size = deconv_bwd_w_pd.scratchpad_desc().get_size();
417   at::Tensor scratchpad_tensor = at::empty(
418       {static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
419   auto scratchpad_m = make_onednn_memory(
420       deconv_bwd_w_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
421   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
422 
423   // execute primitive
424   auto deconv_bwd_w = dnnl::deconvolution_backward_weights(deconv_bwd_w_pd);
425 
426   sycl::event deconv_bwd_w_event = dnnl::sycl_interop::execute(deconv_bwd_w, stream, args, deps);
427   return deconv_bwd_w_event;
428 
429 }
430 
431 } // namespace at::native::onednn
432