1 #include <c10/xpu/XPUFunctions.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/core/grad_mode.h>
5 #include <ATen/record_function.h>
6 #include <c10/core/MemoryFormat.h>
7
8 #include <ATen/native/mkldnn/xpu/detail/Attr.h>
9 #include <ATen/native/mkldnn/xpu/detail/Utils.h>
10
11 #include <oneapi/dnnl/dnnl.hpp>
12
13 namespace at::native::onednn {
14
15 constexpr int src_batch_size_dim = 0;
16 constexpr int weight_dst_channels_dim = 0;
17
conv_dst_size(int64_t ndim,IntArrayRef src_size,IntArrayRef weight_size,IntArrayRef padding_front_top_left,IntArrayRef padding_back_bottom_right,IntArrayRef stride,IntArrayRef dilation)18 dnnl::memory::dims conv_dst_size(
19 int64_t ndim,
20 IntArrayRef src_size,
21 IntArrayRef weight_size,
22 IntArrayRef padding_front_top_left,
23 IntArrayRef padding_back_bottom_right,
24 IntArrayRef stride,
25 IntArrayRef dilation) {
26 bool has_dilation = dilation.size() > 0;
27 dnnl::memory::dims dst_size(ndim);
28 dst_size[0] = src_size[src_batch_size_dim];
29 dst_size[1] = weight_size[weight_dst_channels_dim];
30 for (int d = 2; d < ndim; ++d) {
31 auto dilate = has_dilation ? dilation[d - 2] : 1;
32 auto kernel = dilate * (weight_size[d] - 1) + 1;
33 dst_size[d] =
34 (src_size[d] +
35 (padding_front_top_left[d - 2] + padding_back_bottom_right[d - 2]) -
36 kernel) /
37 stride[d - 2] +
38 1;
39 }
40 return dst_size;
41 }
42
compatible_dilation(IntArrayRef & dilation)43 static inline dnnl::memory::dims compatible_dilation(IntArrayRef& dilation) {
44 dnnl::memory::dims ret = dilation.vec();
45 for (auto it = ret.begin(); it != ret.end(); it++) {
46 *it -= 1;
47 }
48 return ret;
49 }
50
conv_src_fmt(const int64_t ndim,const bool is_channels_last=false)51 static inline dnnl::memory::format_tag conv_src_fmt(
52 const int64_t ndim,
53 const bool is_channels_last = false) {
54 if (!is_channels_last) {
55 return (ndim == 3)
56 ? dnnl::memory::format_tag::ncw
57 : ((ndim == 4) ? dnnl::memory::format_tag::nchw
58 : ((ndim == 5) ? dnnl::memory::format_tag::ncdhw
59 : dnnl::memory::format_tag::undef));
60 } else {
61 return (ndim == 3)
62 ? dnnl::memory::format_tag::nwc
63 : ((ndim == 4) ? dnnl::memory::format_tag::nhwc
64 : ((ndim == 5) ? dnnl::memory::format_tag::ndhwc
65 : dnnl::memory::format_tag::undef));
66 }
67 }
68
conv_weight_fmt(const int64_t ndim,const bool grouped=false,const bool is_channels_last=false)69 static inline dnnl::memory::format_tag conv_weight_fmt(
70 const int64_t ndim,
71 const bool grouped = false,
72 const bool is_channels_last = false) {
73 if (!is_channels_last) {
74 return (ndim == 3)
75 ? (grouped ? dnnl::memory::format_tag::goiw : dnnl::memory::format_tag::oiw)
76 : (ndim == 4)
77 ? (grouped ? dnnl::memory::format_tag::goihw : dnnl::memory::format_tag::oihw)
78 : ((ndim == 5) ? (grouped ? dnnl::memory::format_tag::goidhw
79 : dnnl::memory::format_tag::oidhw)
80 : dnnl::memory::format_tag::undef);
81 } else {
82 return (ndim == 3)
83 ? (grouped ? dnnl::memory::format_tag::gowi : dnnl::memory::format_tag::owi)
84 : (ndim == 4)
85 ? (grouped ? dnnl::memory::format_tag::gohwi : dnnl::memory::format_tag::ohwi)
86 : ((ndim == 5) ? (grouped ? dnnl::memory::format_tag::godhwi
87 : dnnl::memory::format_tag::odhwi)
88 : dnnl::memory::format_tag::undef);
89 }
90 }
91
compatible_weight_dims(const int64_t ndim,const int64_t groups,const int64_t oc,const int64_t ic,const IntArrayRef wsizes)92 static inline dnnl::memory::dims compatible_weight_dims(
93 const int64_t ndim,
94 const int64_t groups,
95 const int64_t oc,
96 const int64_t ic,
97 const IntArrayRef wsizes) {
98 if (ndim == 3) {
99 auto kw = wsizes[2];
100 return (groups != 1) ? dnnl::memory::dims({groups, oc / groups, ic / groups, kw})
101 : dnnl::memory::dims({oc, ic, kw});
102 } else if (ndim == 4) {
103 auto kh = wsizes[2];
104 auto kw = wsizes[3];
105 return (groups != 1)
106 ? dnnl::memory::dims({groups, oc / groups, ic / groups, kh, kw})
107 : dnnl::memory::dims({oc, ic, kh, kw});
108 } else if (ndim == 5) {
109 auto kd = wsizes[2];
110 auto kh = wsizes[3];
111 auto kw = wsizes[4];
112 return (groups != 1)
113 ? dnnl::memory::dims({groups, oc / groups, ic / groups, kd, kh, kw})
114 : dnnl::memory::dims({oc, ic, kd, kh, kw});
115 }
116
117 return {};
118 }
119
120 static std::tuple<
121 dnnl::memory::desc,
122 dnnl::memory::desc,
123 dnnl::memory::desc>
conv_get_md(const at::Tensor & src,const at::Tensor & weight,const at::Tensor & dst,int64_t groups,bool is_channels_last)124 conv_get_md(
125 const at::Tensor& src,
126 const at::Tensor& weight,
127 const at::Tensor& dst,
128 int64_t groups,
129 bool is_channels_last) {
130 // create memory desc from the src/weight/dst tensors
131 dnnl::memory::desc src_usr_md, weight_usr_md, dst_usr_md;
132 auto ndim = src.ndimension();
133 auto fmt_src =
134 conv_src_fmt(ndim, is_channels_last);
135
136 auto src_size = src.sizes().vec();
137 auto src_data_t = get_onednn_dtype_include_double(src);
138 src_usr_md = dnnl::memory::desc(src_size, src_data_t, fmt_src);
139
140 auto dst_size = dst.sizes().vec();
141 auto dst_data_t = get_onednn_dtype_include_double(dst);
142 dst_usr_md = dnnl::memory::desc(dst_size, dst_data_t, fmt_src);
143
144 auto ic = src.size(1);
145 auto oc = dst.size(1);
146 auto wei_data_t = get_onednn_dtype_include_double(weight);
147 dnnl::memory::dims weight_size =
148 compatible_weight_dims(ndim, groups, oc, ic, weight.sizes());
149 auto fmt_weight = conv_weight_fmt(
150 ndim,
151 groups != 1,
152 is_channels_last);
153 weight_usr_md = dnnl::memory::desc(weight_size, wei_data_t, fmt_weight);
154
155 return {src_usr_md, weight_usr_md, dst_usr_md};
156 }
157
convolution(at::Tensor & dst,const at::Tensor & src,const at::Tensor & weight,const at::Tensor & bia,IntArrayRef padding_front_top_left,IntArrayRef padding_back_bottom_right,IntArrayRef stride,IntArrayRef dilation,int64_t groups,Attr & attr,const std::vector<sycl::event> & deps)158 sycl::event convolution(
159 at::Tensor& dst,
160 const at::Tensor& src,
161 const at::Tensor& weight,
162 const at::Tensor& bia,
163 IntArrayRef padding_front_top_left,
164 IntArrayRef padding_back_bottom_right,
165 IntArrayRef stride,
166 IntArrayRef dilation,
167 int64_t groups,
168 Attr& attr,
169 const std::vector<sycl::event>& deps) {
170 auto engine =
171 GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
172 auto stream = GpuStreamManager::Instance().get_stream();
173
174 bool is_channels_last = use_channels_last_for_conv(src, weight, false);
175
176 // create usr_md for tensors, and md for conv primitive
177 auto [src_md, weight_md, dst_md] = conv_get_md(src, weight, dst, groups, is_channels_last);
178
179 auto bia_fmt = dnnl::memory::format_tag::x;
180 auto bia_md = bia.defined()
181 ? dnnl::memory::desc(
182 {dst.size(1)}, get_onednn_dtype_include_double(bia), bia_fmt)
183 : dnnl::memory::desc();
184
185 // create conv primitive descriptor
186 dnnl::memory::dims _stride = stride.vec();
187 dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
188 dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
189 dnnl::memory::dims _dilation = compatible_dilation(dilation);
190
191 // extract post ops
192 dnnl::primitive_attr pattr;
193 dnnl::post_ops po = attr.extract_post_ops(dst);
194 pattr.set_post_ops(po);
195
196 pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
197
198 #if ONEDNN_SUPPORT_DETERMINISTIC
199 if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn()){
200 pattr.set_deterministic(true);
201 }
202 #endif
203
204 auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
205 engine,
206 dnnl::prop_kind::forward,
207 dnnl::algorithm::convolution_direct,
208 src_md,
209 weight_md,
210 bia_md,
211 dst_md,
212 _stride,
213 _dilation,
214 _padding_front_top_left,
215 _padding_back_bottom_right,
216 pattr);
217
218 dnnl::memory src_m, weight_m, dst_m, bia_m;
219 at::Tensor src_blocked, weight_blocked, dst_blocked = dst;
220
221 src_m = make_onednn_memory(src_md, engine, src.data_ptr());
222 weight_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
223 dst_m = make_onednn_memory(dst_md, engine, dst.data_ptr());
224
225
226 std::unordered_map<int, dnnl::memory> args;
227 if (bia.defined()) {
228 bia_m = make_onednn_memory(bia_md, engine, bia.data_ptr());
229 args.insert({DNNL_ARG_BIAS, bia_m});
230 }
231 auto expected_dst_md = conv_fwd_pd.dst_desc();
232 if (attr.with_binary())
233 attr.construct_post_binary(conv_fwd_pd, args);
234
235 args.insert({DNNL_ARG_SRC, src_m});
236 args.insert({DNNL_ARG_WEIGHTS, weight_m});
237 args.insert({DNNL_ARG_DST, dst_m});
238
239 size_t scratchpad_size = conv_fwd_pd.scratchpad_desc().get_size();
240 at::Tensor scratchpad_tensor = at::empty(
241 {static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
242 auto scratchpad_m = make_onednn_memory(
243 conv_fwd_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
244 args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
245
246 auto conv_forward = dnnl::convolution_forward(conv_fwd_pd);
247 auto conv_fwd_event = dnnl::sycl_interop::execute(conv_forward, stream, args, deps);
248
249 return conv_fwd_event;
250 }
251
convolution_backward_weights(at::Tensor & diff_weight,at::Tensor & diff_bia,const at::Tensor & diff_dst,const at::Tensor & src,IntArrayRef diff_weight_aten_size,IntArrayRef padding_front_top_left,IntArrayRef padding_back_bottom_right,IntArrayRef stride,IntArrayRef dilation,int64_t groups,const std::vector<sycl::event> & deps)252 sycl::event convolution_backward_weights(
253 at::Tensor& diff_weight,
254 at::Tensor& diff_bia,
255 const at::Tensor& diff_dst,
256 const at::Tensor& src,
257 IntArrayRef diff_weight_aten_size,
258 IntArrayRef padding_front_top_left,
259 IntArrayRef padding_back_bottom_right,
260 IntArrayRef stride,
261 IntArrayRef dilation,
262 int64_t groups,
263 const std::vector<sycl::event>& deps) {
264 auto engine =
265 GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
266 auto stream = GpuStreamManager::Instance().get_stream();
267
268 bool is_channels_last = use_channels_last_for_conv(src, diff_dst, /*is_transposed=*/false);
269
270 // create dnnl::memory desc
271 auto [src_md, weight_md, dst_md] =
272 conv_get_md(src, diff_weight, diff_dst, groups, is_channels_last);
273 dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
274 auto bia_md = diff_bia.defined()
275 ? dnnl::memory::desc({diff_dst.size(1)}, src_md.get_data_type(), bia_fmt)
276 : dnnl::memory::desc();
277
278 // create fwd primitive hint
279 dnnl::memory::dims _stride = stride.vec();
280 dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
281 dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
282 dnnl::memory::dims _dilation = compatible_dilation(dilation);
283 dnnl::primitive_attr pattr;
284
285 #if ONEDNN_SUPPORT_DETERMINISTIC
286 if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn()){
287 pattr.set_deterministic(true);
288 }
289 #endif
290
291 pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
292 auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
293 engine,
294 dnnl::prop_kind::forward,
295 dnnl::algorithm::convolution_direct,
296 src_md,
297 weight_md,
298 bia_md,
299 dst_md,
300 _stride,
301 _dilation,
302 _padding_front_top_left,
303 _padding_back_bottom_right,
304 pattr);
305
306 // create bwd weight primitive
307 auto conv_bwd_w_pd = dnnl::convolution_backward_weights::primitive_desc(
308 engine,
309 dnnl::algorithm::convolution_direct,
310 src_md,
311 weight_md,
312 bia_md,
313 dst_md,
314 _stride,
315 _dilation,
316 _padding_front_top_left,
317 _padding_back_bottom_right,
318 conv_fwd_pd,
319 pattr);
320
321 // create bwd memory
322 at::Tensor expected_src, expected_diff_dst, expected_diff_weight;
323 dnnl::memory src_m, diff_dst_m, diff_weight_m;
324
325 src_m = make_onednn_memory(src_md, engine, src.data_ptr());
326 diff_dst_m = make_onednn_memory(dst_md, engine, diff_dst.data_ptr());
327 diff_weight_m = make_onednn_memory(weight_md, engine, diff_weight.data_ptr());
328
329 // insert args
330 std::unordered_map<int, dnnl::memory> args;
331 args.insert({DNNL_ARG_DIFF_DST, diff_dst_m});
332 args.insert({DNNL_ARG_SRC, src_m});
333 args.insert({DNNL_ARG_DIFF_WEIGHTS, diff_weight_m});
334 if (diff_bia.defined()) {
335 dnnl::memory diff_bia_m =
336 make_onednn_memory(bia_md, engine, diff_bia.data_ptr());
337 args.insert({DNNL_ARG_DIFF_BIAS, diff_bia_m});
338 }
339
340 size_t scratchpad_size = conv_bwd_w_pd.scratchpad_desc().get_size();
341 at::Tensor scratchpad_tensor = at::empty(
342 {static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
343 auto scratchpad_m = make_onednn_memory(
344 conv_bwd_w_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
345 args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
346
347 // execute primitive
348 auto conv_bwd_w = dnnl::convolution_backward_weights(conv_bwd_w_pd);
349 sycl::event conv_bwd_w_event = dnnl::sycl_interop::execute(conv_bwd_w, stream, args, deps);
350
351 return conv_bwd_w_event;
352 }
353
convolution_backward_data(at::Tensor & diff_src,const at::Tensor & diff_dst,const at::Tensor & weight,IntArrayRef padding_front_top_left,IntArrayRef padding_back_bottom_right,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool bias_defined,const std::vector<sycl::event> & deps)354 sycl::event convolution_backward_data(
355 at::Tensor& diff_src,
356 const at::Tensor& diff_dst,
357 const at::Tensor& weight,
358 IntArrayRef padding_front_top_left,
359 IntArrayRef padding_back_bottom_right,
360 IntArrayRef stride,
361 IntArrayRef dilation,
362 int64_t groups,
363 bool bias_defined,
364 const std::vector<sycl::event>& deps) {
365 auto engine =
366 GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
367 auto stream = GpuStreamManager::Instance().get_stream();
368
369 bool is_channels_last = use_channels_last_for_conv(diff_dst, weight, /*is_transposed=*/false);
370
371 // create memory desc
372 auto [src_md, weight_md, dst_md] =
373 conv_get_md(diff_src, weight, diff_dst, groups, is_channels_last);
374 dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
375 auto bia_md = bias_defined
376 ? dnnl::memory::desc({diff_dst.size(1)}, weight_md.get_data_type(), bia_fmt)
377 : dnnl::memory::desc();
378
379 // create fwd primitive desc hint
380 dnnl::primitive_attr pattr;
381
382 #if ONEDNN_SUPPORT_DETERMINISTIC
383 if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn()){
384 pattr.set_deterministic(true);
385 }
386 #endif
387
388 pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
389 dnnl::memory::dims _stride = stride.vec();
390 dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
391 dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
392 dnnl::memory::dims _dilation = compatible_dilation(dilation);
393 auto conv_forward_pd = dnnl::convolution_forward::primitive_desc(
394 engine,
395 dnnl::prop_kind::forward,
396 dnnl::algorithm::convolution_direct,
397 src_md,
398 weight_md,
399 bia_md,
400 dst_md,
401 _stride,
402 _dilation,
403 _padding_front_top_left,
404 _padding_back_bottom_right,
405 pattr);
406
407 auto conv_backward_data_pd = dnnl::convolution_backward_data::primitive_desc(
408 engine,
409 dnnl::algorithm::convolution_direct,
410 src_md,
411 weight_md,
412 dst_md,
413 _stride,
414 _dilation,
415 _padding_front_top_left,
416 _padding_back_bottom_right,
417 conv_forward_pd,
418 pattr);
419
420 // create memory
421 at::Tensor expected_src, expected_wei, expected_dst;
422 dnnl::memory diff_dst_m, wei_m, diff_src_m;
423
424 diff_src_m = make_onednn_memory(src_md, engine, diff_src.data_ptr());
425 wei_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
426 diff_dst_m = make_onednn_memory(dst_md, engine, diff_dst.data_ptr());
427
428
429 // insert args
430 std::unordered_map<int, dnnl::memory> args;
431 size_t scratchpad_size = conv_backward_data_pd.scratchpad_desc().get_size();
432 at::Tensor scratchpad_tensor = at::empty(
433 {static_cast<int64_t>(scratchpad_size)}, diff_dst.options().dtype(at::kByte), std::nullopt);
434 auto scratchpad_memory = make_onednn_memory(
435 conv_backward_data_pd.scratchpad_desc(),
436 engine,
437 scratchpad_tensor.data_ptr());
438 args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
439 args.insert({DNNL_ARG_DIFF_DST, diff_dst_m});
440 args.insert({DNNL_ARG_WEIGHTS, wei_m});
441 args.insert({DNNL_ARG_DIFF_SRC, diff_src_m});
442
443 // execute primitive
444 auto conv_backward_data =
445 dnnl::convolution_backward_data(conv_backward_data_pd);
446 auto conv_backward_data_event = dnnl::sycl_interop::execute(conv_backward_data, stream, args, deps);
447 return conv_backward_data_event;
448
449 }
450
451 } // namespace at::native::onednn
452