1 #include <c10/xpu/XPUFunctions.h>
2 #include <ATen/ATen.h>
3
4 #include <oneapi/dnnl/dnnl.hpp>
5 #include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
6 #include <ATen/native/mkldnn/xpu/detail/Utils.h>
7 #include <ATen/native/mkldnn/xpu/detail/Attr.h>
8
9 namespace at::native::onednn {
10
deconv_compatible_dilation(IntArrayRef & dilation)11 static inline dnnl::memory::dims deconv_compatible_dilation(IntArrayRef& dilation) {
12 dnnl::memory::dims ret = dilation.vec();
13 for (auto it = ret.begin(); it != ret.end(); it++) {
14 *it -= 1;
15 }
16 return ret;
17 }
18
compatible_groups_deconv_strides(const at::Tensor & weight,dnnl::memory::dims group_size)19 static inline std::vector<int64_t> compatible_groups_deconv_strides(
20 const at::Tensor& weight,
21 dnnl::memory::dims group_size) {
22 std::vector<int64_t> strides = weight.strides().vec();
23 strides[0] = weight.strides()[1];
24 strides[1] = weight.strides()[0];
25 strides.insert(strides.begin(), group_size[2] * weight.strides()[0]);
26 return strides;
27 }
28
deconv_dst_size(IntArrayRef src_size,IntArrayRef weight_size,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,IntArrayRef dst_padding,int64_t groups)29 dnnl::memory::dims deconv_dst_size(
30 IntArrayRef src_size,
31 IntArrayRef weight_size,
32 IntArrayRef padding,
33 IntArrayRef stride,
34 IntArrayRef dilation,
35 IntArrayRef dst_padding,
36 int64_t groups) {
37 auto dim = src_size.size();
38 dnnl::memory::dims dst_size(dim);
39 auto kernel_size = weight_size.slice(2);
40
41 dst_size[0] = src_size[0];
42 dst_size[1] = weight_size[1] * groups;
43 for (size_t d = 2; d < dim; ++d) {
44 dst_size[d] = (src_size[d] - 1) * stride[d - 2] - 2 * padding[d - 2] +
45 (dilation[d - 2] * (kernel_size[d - 2] - 1) + 1) + dst_padding[d - 2];
46 }
47 return dst_size;
48 }
49
deconv_src_fmt(const int64_t ndim,const bool is_channels_last=false)50 static inline dnnl::memory::format_tag deconv_src_fmt(
51 const int64_t ndim,
52 const bool is_channels_last = false) {
53 // 3D: n/c/w (n/w/c) [a/b/c (a/c/b)]
54 // 4D: n/c/h/w (n/h/w/c) [a/b/c/d (a/c/d/b)]
55 // 5D: n/c/d/h/w (n/d/h/w/c) [a/b/c/d/e (a/c/d/e/b)]
56 if (!is_channels_last) {
57 return (ndim == 3)
58 ? dnnl::memory::format_tag::ncw
59 : ((ndim == 4) ? dnnl::memory::format_tag::nchw
60 : ((ndim == 5) ? dnnl::memory::format_tag::ncdhw
61 : dnnl::memory::format_tag::undef));
62 } else {
63 return (ndim == 3)
64 ? dnnl::memory::format_tag::nwc
65 : ((ndim == 4) ? dnnl::memory::format_tag::nhwc
66 : ((ndim == 5) ? dnnl::memory::format_tag::ndhwc
67 : dnnl::memory::format_tag::undef));
68 }
69 }
70
deconv_weight_fmt(const at::Tensor & weight,const int64_t ndim,dnnl::memory::dims weight_size,const bool grouped=false,const bool is_channels_last=false)71 static inline std::vector<int64_t> deconv_weight_fmt(
72 const at::Tensor& weight,
73 const int64_t ndim,
74 dnnl::memory::dims weight_size,
75 const bool grouped = false,
76 const bool is_channels_last = false) {
77 // 3D fmt: (g)i/o/w ((g)i/w/o) [b/a/c (b/c/a)]
78 // 4D fmt: (g)i/o/h/w ((g)i/h/w/o) [b/a/c/d (b/c/d/a)]
79 // 5D fmt: (g)i/o/d/h/w ((g)i/d/h/w/o) [b/a/c/d/e (b/c/d/e/a)]
80 auto strides_ = weight.strides().vec();
81 std::vector<int64_t> strides;
82 if (grouped) {
83 strides = compatible_groups_deconv_strides(weight, weight_size);
84 } else {
85 strides = strides_;
86 std::swap(strides[0], strides[1]);
87 }
88 return strides;
89 }
90
deconv_compatible_weight_dims(int64_t ndim,int64_t groups,int64_t oc,int64_t ic,IntArrayRef weight_size)91 static inline dnnl::memory::dims deconv_compatible_weight_dims(
92 int64_t ndim,
93 int64_t groups,
94 int64_t oc,
95 int64_t ic,
96 IntArrayRef weight_size) {
97 if (ndim == 3) {
98 auto kw = weight_size[2];
99 return (groups != 1) ? dnnl::memory::dims({groups, oc / groups, ic / groups, kw})
100 : dnnl::memory::dims({oc, ic, kw});
101 } else if (ndim == 4) {
102 auto kh = weight_size[2];
103 auto kw = weight_size[3];
104 return (groups != 1)
105 ? dnnl::memory::dims({groups, oc / groups, ic / groups, kh, kw})
106 : dnnl::memory::dims({oc, ic, kh, kw});
107 } else if (ndim == 5) {
108 auto kd = weight_size[2];
109 auto kh = weight_size[3];
110 auto kw = weight_size[4];
111 return (groups != 1)
112 ? dnnl::memory::dims({groups, oc / groups, ic / groups, kd, kh, kw})
113 : dnnl::memory::dims({oc, ic, kd, kh, kw});
114 } else {
115 TORCH_CHECK(0, "unsupported dimension in xpu oneDNN deconvolution...");
116 }
117 }
118
119 static std::tuple<
120 dnnl::memory::desc,
121 dnnl::memory::desc,
122 dnnl::memory::desc>
deconv_get_plain_md(const at::Tensor & src,const at::Tensor & weight,const at::Tensor & dst,int64_t groups,bool is_channels_last_suggested)123 deconv_get_plain_md(
124 const at::Tensor& src,
125 const at::Tensor& weight,
126 const at::Tensor& dst,
127 int64_t groups,
128 bool is_channels_last_suggested) {
129 auto ndim = src.ndimension();
130 auto src_data_t = get_onednn_dtype_include_double(src);
131 auto fmt_src = deconv_src_fmt(ndim, is_channels_last_suggested);
132 auto src_usr_md = dnnl::memory::desc(src.sizes().vec(), src_data_t, fmt_src);
133
134 auto dst_data_t = get_onednn_dtype_include_double(dst);
135 auto dst_usr_md = dnnl::memory::desc(dst.sizes().vec(), dst_data_t, fmt_src);
136
137 auto ic = src.size(1);
138 auto oc = dst.size(1);
139 dnnl::memory::dims weight_size =
140 deconv_compatible_weight_dims(ndim, groups, oc, ic, weight.sizes());
141 auto weight_dt = get_onednn_dtype_include_double(weight);
142 auto fmt_weight = deconv_weight_fmt(
143 weight, ndim, weight_size, groups != 1, is_channels_last_suggested);
144 dnnl::memory::desc weight_usr_md = dnnl::memory::desc(weight_size, weight_dt, fmt_weight);
145
146 return {src_usr_md, weight_usr_md, dst_usr_md};
147 }
148
deconvolution(at::Tensor & dst,const at::Tensor & src,const at::Tensor & weight,const at::Tensor & bia,IntArrayRef stride,IntArrayRef padding,IntArrayRef dst_padding,IntArrayRef dilation,int64_t groups,Attr & attr,const std::vector<sycl::event> & deps)149 sycl::event deconvolution(
150 at::Tensor& dst,
151 const at::Tensor& src,
152 const at::Tensor& weight,
153 const at::Tensor& bia,
154 IntArrayRef stride,
155 IntArrayRef padding,
156 IntArrayRef dst_padding,
157 IntArrayRef dilation,
158 int64_t groups,
159 Attr& attr,
160 const std::vector<sycl::event>& deps) {
161 auto engine =
162 GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
163 auto stream = GpuStreamManager::Instance().get_stream();
164
165 bool is_channels_last_suggested = use_channels_last_for_conv(src, weight, /*is_transposed=*/true);
166
167 // create usr_md for tensors, and md for conv primitive
168 auto [src_md, weight_md, dst_md] =
169 deconv_get_plain_md(src, weight, dst, groups, is_channels_last_suggested);
170
171 dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
172 auto bia_md = bia.defined()
173 ? dnnl::memory::desc(
174 {dst.size(1)}, get_onednn_dtype_include_double(bia), bia_fmt)
175 : dnnl::memory::desc();
176
177 // create primitive desc
178 dnnl::memory::dims _stride = stride.vec();
179 dnnl::memory::dims _padding = padding.vec();
180 dnnl::memory::dims _dilation = deconv_compatible_dilation(dilation);
181
182 // construct primitive attr
183 dnnl::primitive_attr pattr;
184 dnnl::post_ops po = attr.extract_post_ops(dst);
185 pattr.set_post_ops(po);
186 #if ONEDNN_SUPPORT_DETERMINISTIC
187 if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
188 pattr.set_deterministic(true);
189 #endif
190
191 pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
192
193 auto deconv_fwd_pd = dnnl::deconvolution_forward::primitive_desc(
194 engine,
195 dnnl::prop_kind::forward,
196 dnnl::algorithm::deconvolution_direct,
197 src_md,
198 weight_md,
199 bia_md,
200 dst_md,
201 _stride,
202 _dilation,
203 _padding,
204 _padding,
205 pattr);
206
207 dnnl::memory src_m, weight_m, dst_m, bia_m;
208 at::Tensor src_blocked, weight_blocked, dst_blocked = dst;
209
210 src_m = make_onednn_memory(src_md, engine, src.data_ptr());
211 weight_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
212 dst_m = make_onednn_memory(dst_md, engine, dst.data_ptr());
213
214 std::unordered_map<int, dnnl::memory> args;
215 args.insert({DNNL_ARG_SRC, src_m});
216 args.insert({DNNL_ARG_WEIGHTS, weight_m});
217 args.insert({DNNL_ARG_DST, dst_m});
218
219 if (bia.defined()) {
220 auto bia_m = make_onednn_memory(bia_md, engine, bia.data_ptr());
221 args.insert({DNNL_ARG_BIAS, bia_m});
222 }
223 if (attr.with_binary())
224 attr.construct_post_binary(deconv_fwd_pd, args);
225
226 size_t scratchpad_size = deconv_fwd_pd.scratchpad_desc().get_size();
227 at::Tensor scratchpad_tensor = at::empty(
228 {static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
229 auto scratchpad_m = make_onednn_memory(
230 deconv_fwd_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
231 args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
232
233 auto deconv_fwd = dnnl::deconvolution_forward(deconv_fwd_pd);
234 sycl::event deconv_event = dnnl::sycl_interop::execute(deconv_fwd, stream, args, deps);
235 return deconv_event;
236
237 }
238
deconvolution_backward_data(at::Tensor & diff_src,const at::Tensor & diff_dst,const at::Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,bool bias_defined,const std::vector<sycl::event> & deps)239 sycl::event deconvolution_backward_data(
240 at::Tensor& diff_src,
241 const at::Tensor& diff_dst,
242 const at::Tensor& weight,
243 IntArrayRef stride,
244 IntArrayRef padding,
245 IntArrayRef dilation,
246 int64_t groups,
247 bool bias_defined,
248 const std::vector<sycl::event>& deps) {
249 auto engine =
250 GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
251 auto stream = GpuStreamManager::Instance().get_stream();
252
253 bool is_channels_last_suggested =
254 use_channels_last_for_conv(diff_dst, weight, /*is_transposed=*/true);
255 // create memory desc
256 auto [src_md, weight_md, dst_md] =
257 deconv_get_plain_md(
258 diff_src, weight, diff_dst, groups, is_channels_last_suggested);
259
260 dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
261 auto bias_md = bias_defined
262 ? dnnl::memory::desc({diff_dst.size(1)}, weight_md.get_data_type(), bia_fmt)
263 : dnnl::memory::desc();
264
265 // create fwd primitive desc hint
266 dnnl::primitive_attr pattr;
267 pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
268 #if ONEDNN_SUPPORT_DETERMINISTIC
269 if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
270 pattr.set_deterministic(true);
271 #endif
272
273 dnnl::memory::dims _stride = stride.vec();
274 dnnl::memory::dims _padding = padding.vec();
275 dnnl::memory::dims _dilation = deconv_compatible_dilation(dilation);
276 auto deconv_fwd_pd = dnnl::deconvolution_forward::primitive_desc(
277 engine,
278 dnnl::prop_kind::forward,
279 dnnl::algorithm::deconvolution_direct,
280 src_md,
281 weight_md,
282 bias_md,
283 dst_md,
284 _stride,
285 _dilation,
286 _padding,
287 _padding,
288 pattr);
289
290 // create bwd primitive desc
291 auto deconv_backward_data_pd = dnnl::deconvolution_backward_data::primitive_desc(
292 engine,
293 dnnl::algorithm::deconvolution_direct,
294 src_md,
295 weight_md,
296 dst_md,
297 _stride,
298 _dilation,
299 _padding,
300 _padding,
301 deconv_fwd_pd);
302
303 // create memory
304 dnnl::memory diff_dst_m, wei_m, diff_src_m;
305
306 diff_src_m = make_onednn_memory(src_md, engine, diff_src.data_ptr());
307 wei_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
308 diff_dst_m = make_onednn_memory(dst_md, engine, diff_dst.data_ptr());
309
310 // insert args
311 std::unordered_map<int, dnnl::memory> args;
312 size_t scratchpad_size = deconv_backward_data_pd.scratchpad_desc().get_size();
313 at::Tensor scratchpad_tensor = at::empty(
314 {static_cast<int64_t>(scratchpad_size)}, diff_dst.options().dtype(at::kByte), std::nullopt);
315 auto scratchpad_memory = make_onednn_memory(
316 deconv_backward_data_pd.scratchpad_desc(),
317 engine,
318 scratchpad_tensor.data_ptr());
319 args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
320 args.insert({DNNL_ARG_DIFF_DST, diff_dst_m});
321 args.insert({DNNL_ARG_WEIGHTS, wei_m});
322 args.insert({DNNL_ARG_DIFF_SRC, diff_src_m});
323
324 // execute primitive
325 auto deconv_backward_data =
326 dnnl::deconvolution_backward_data(deconv_backward_data_pd);
327 sycl::event deconv_bwd_data_event = dnnl::sycl_interop::execute(deconv_backward_data, stream, args, deps);
328 return deconv_bwd_data_event;
329
330 }
331
deconvolution_backward_weights(at::Tensor & diff_weight,at::Tensor & diff_bia,const at::Tensor & diff_dst,const at::Tensor & src,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,const std::vector<sycl::event> & deps)332 sycl::event deconvolution_backward_weights(
333 at::Tensor& diff_weight,
334 at::Tensor& diff_bia,
335 const at::Tensor& diff_dst,
336 const at::Tensor& src,
337 IntArrayRef stride,
338 IntArrayRef padding,
339 IntArrayRef dilation,
340 int64_t groups,
341 const std::vector<sycl::event>& deps) {
342 auto engine =
343 GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
344 auto stream = GpuStreamManager::Instance().get_stream();
345
346 bool is_channels_last_suggested =
347 use_channels_last_for_conv(src, diff_dst, /*is_transposed=*/true);
348
349 // create memory desc
350 auto [src_md, weight_md, dst_md] = deconv_get_plain_md(
351 src, diff_weight, diff_dst, groups, is_channels_last_suggested);
352
353 dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
354 auto bia_md = diff_bia.defined()
355 ? dnnl::memory::desc({diff_dst.size(1)}, src_md.get_data_type(), bia_fmt)
356 : dnnl::memory::desc();
357
358 // create fwd primitive desc hint
359 dnnl::memory::dims _stride = stride.vec();
360 dnnl::memory::dims _padding = padding.vec();
361 dnnl::memory::dims _dilation = deconv_compatible_dilation(dilation);
362 dnnl::primitive_attr pattr;
363
364 #if ONEDNN_SUPPORT_DETERMINISTIC
365 if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
366 pattr.set_deterministic(true);
367 #endif
368 pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
369 auto deconv_fwd_pd = dnnl::deconvolution_forward::primitive_desc(
370 engine,
371 dnnl::prop_kind::forward,
372 dnnl::algorithm::deconvolution_direct,
373 src_md,
374 weight_md,
375 bia_md,
376 dst_md,
377 _stride,
378 _dilation,
379 _padding,
380 _padding,
381 pattr);
382
383 auto deconv_bwd_w_pd = dnnl::deconvolution_backward_weights::primitive_desc(
384 engine,
385 dnnl::algorithm::deconvolution_direct,
386 src_md,
387 weight_md,
388 bia_md,
389 dst_md,
390 _stride,
391 _dilation,
392 _padding,
393 _padding,
394 deconv_fwd_pd,
395 pattr);
396
397 // create bwd dnnl::memory
398 dnnl::memory src_m, diff_dst_m, diff_weight_m;
399
400 src_m = make_onednn_memory(src_md, engine, src.data_ptr());
401 diff_dst_m = make_onednn_memory(dst_md, engine, diff_dst.data_ptr());
402 diff_weight_m = make_onednn_memory(weight_md, engine, diff_weight.data_ptr());
403
404 // insert args
405 std::unordered_map<int, dnnl::memory> args;
406 args.insert({DNNL_ARG_DIFF_DST, diff_dst_m});
407 args.insert({DNNL_ARG_SRC, src_m});
408 args.insert({DNNL_ARG_DIFF_WEIGHTS, diff_weight_m});
409
410 if (diff_bia.defined()) {
411 dnnl::memory diff_bia_m =
412 make_onednn_memory(bia_md, engine, diff_bia.data_ptr());
413 args.insert({DNNL_ARG_DIFF_BIAS, diff_bia_m});
414 }
415
416 size_t scratchpad_size = deconv_bwd_w_pd.scratchpad_desc().get_size();
417 at::Tensor scratchpad_tensor = at::empty(
418 {static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
419 auto scratchpad_m = make_onednn_memory(
420 deconv_bwd_w_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
421 args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
422
423 // execute primitive
424 auto deconv_bwd_w = dnnl::deconvolution_backward_weights(deconv_bwd_w_pd);
425
426 sycl::event deconv_bwd_w_event = dnnl::sycl_interop::execute(deconv_bwd_w, stream, args, deps);
427 return deconv_bwd_w_event;
428
429 }
430
431 } // namespace at::native::onednn
432