1 #include <ATen/native/mkldnn/xpu/detail/Utils.h>
2
3 namespace at::native::onednn {
4
make_onednn_memory(dnnl::memory::desc md,dnnl::engine & engine,void * ptr)5 dnnl::memory make_onednn_memory(
6 dnnl::memory::desc md,
7 dnnl::engine& engine,
8 void* ptr){
9 return dnnl::sycl_interop::make_memory(
10 md,
11 engine,
12 dnnl::sycl_interop::memory_kind::usm,
13 ptr == nullptr ? DNNL_MEMORY_ALLOCATE : ptr);
14 }
15
get_dnnl_default_format(int ndims,bool is_channels_last,bool allow_undef)16 dnnl::memory::format_tag get_dnnl_default_format(
17 int ndims,
18 bool is_channels_last,
19 bool allow_undef) {
20 switch (ndims) {
21 case 1:
22 return dnnl::memory::format_tag::a;
23 case 2:
24 return dnnl::memory::format_tag::ab;
25 case 3:
26 return is_channels_last ? dnnl::memory::format_tag::acb
27 : dnnl::memory::format_tag::abc;
28 case 4:
29 return is_channels_last ? dnnl::memory::format_tag::acdb
30 : dnnl::memory::format_tag::abcd;
31 case 5:
32 return is_channels_last ? dnnl::memory::format_tag::acdeb
33 : dnnl::memory::format_tag::abcde;
34 case 6:
35 return dnnl::memory::format_tag::abcdef;
36 case 7:
37 return dnnl::memory::format_tag::abcdefg;
38 case 8:
39 return dnnl::memory::format_tag::abcdefgh;
40 case 9:
41 return dnnl::memory::format_tag::abcdefghi;
42 case 10:
43 return dnnl::memory::format_tag::abcdefghij;
44 case 11:
45 return dnnl::memory::format_tag::abcdefghijk;
46 case 12:
47 return dnnl::memory::format_tag::abcdefghijkl;
48 default:
49 if (!allow_undef) {
50 TORCH_CHECK(false, "oneDNN doesn't support tensor dimension > 12");
51 }
52 return dnnl::memory::format_tag::undef;
53 }
54 }
55
get_onednn_dtype(const at::Tensor & tensor,bool allow_undef)56 dnnl::memory::data_type get_onednn_dtype(
57 const at::Tensor& tensor,
58 bool allow_undef) {
59 switch (tensor.scalar_type()) {
60 case at::ScalarType::Byte:
61 return dnnl::memory::data_type::u8;
62 case at::ScalarType::Char:
63 return dnnl::memory::data_type::s8;
64 case at::ScalarType::QInt8:
65 return dnnl::memory::data_type::s8;
66 case at::ScalarType::QUInt8:
67 return dnnl::memory::data_type::u8;
68 case at::ScalarType::Int:
69 return dnnl::memory::data_type::s32;
70 case at::ScalarType::Half:
71 return dnnl::memory::data_type::f16;
72 case at::ScalarType::Float:
73 return dnnl::memory::data_type::f32;
74 case at::ScalarType::BFloat16:
75 return dnnl::memory::data_type::bf16;
76 default:
77 if (!allow_undef) {
78 TORCH_CHECK(
79 false,
80 c10::toString(tensor.scalar_type()),
81 " is not supported in oneDNN!");
82 }
83 return dnnl::memory::data_type::undef;
84 };
85 }
86
get_onednn_dtype_include_double(const at::Tensor & tensor,bool allow_undef)87 dnnl::memory::data_type get_onednn_dtype_include_double(
88 const at::Tensor& tensor,
89 bool allow_undef) {
90 if (tensor.scalar_type() == at::ScalarType::Double)
91 return dnnl::memory::data_type::f64;
92 return get_onednn_dtype(tensor, allow_undef);
93 }
94
is_supported_onednn_dtype(const at::Tensor & tensor)95 bool is_supported_onednn_dtype(const at::Tensor& tensor) {
96 return get_onednn_dtype(tensor, /*allow_undef*/ true) ==
97 dnnl::memory::data_type::undef
98 ? false
99 : true;
100 }
101
get_onednn_dims(const at::Tensor & tensor)102 dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor) {
103 dnnl::memory::dims dims;
104 for (size_t i = 0; i < tensor.sizes().size(); i++)
105 dims.push_back(tensor.size(i));
106 return dims;
107 }
108
get_onednn_strides(const at::Tensor & tensor)109 dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor) {
110 dnnl::memory::dims strides;
111 for (size_t i = 0; i < tensor.strides().size(); i++)
112 strides.push_back(tensor.stride(i));
113 return strides;
114 }
115
get_onednn_md(const at::Tensor & tensor)116 dnnl::memory::desc get_onednn_md(const at::Tensor& tensor) {
117 return {
118 get_onednn_dims(tensor),
119 get_onednn_dtype(tensor),
120 get_onednn_strides(tensor)};
121 }
122
onednn_strides_check(const at::Tensor & src)123 bool onednn_strides_check(const at::Tensor& src) {
124 auto adims = get_onednn_dims(src);
125 int ndims = (int)adims.size();
126 auto dims = adims.data();
127 auto data_type = static_cast<dnnl_data_type_t>(
128 get_onednn_dtype(src, /*allow_undef*/ true));
129 auto strides_info = get_onednn_strides(src);
130 auto strides = strides_info.empty() ? nullptr : &strides_info[0];
131
132 dnnl_memory_desc_t md;
133 dnnl_memory_desc_create_with_strides(&md, ndims, dims, data_type, strides);
134 dnnl_format_kind_t md_fmt_kind;
135 int md_ndims;
136 int md_inner_nblks;
137 dnnl_dims_t* md_padded_dims = nullptr;
138
139 dnnl_memory_desc_query(md, dnnl_query_inner_nblks_s32, &md_inner_nblks);
140 dnnl_memory_desc_query(md, dnnl_query_format_kind, &md_fmt_kind);
141 dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &md_ndims);
142 dnnl_memory_desc_query(md, dnnl_query_padded_dims, &md_padded_dims);
143 if (strides == nullptr || md_ndims == 0 ||
144 md_fmt_kind != dnnl_format_kind_t::dnnl_blocked)
145 return true;
146
147 dnnl_dims_t blocks = {0};
148 int perm[DNNL_MAX_NDIMS] = {0};
149 for (int d = 0; d < md_ndims; ++d) {
150 // no strides check needed for empty tensor
151 if (md_padded_dims[d] == 0)
152 return true;
153
154 // no strides verification for runtime dims
155 if (strides[d] == DNNL_RUNTIME_DIM_VAL)
156 return true;
157
158 perm[d] = d;
159 blocks[d] = 1;
160 }
161
162 auto block_size = 1;
163 dnnl_dims_t md_inner_blks;
164 dnnl_dims_t md_blk_inner_idxs;
165 dnnl_memory_desc_query(md, dnnl_query_inner_idxs, &md_blk_inner_idxs);
166 dnnl_memory_desc_query(md, dnnl_query_inner_blks, &md_inner_blks);
167 for (int iblk = 0; iblk < md_inner_nblks; ++iblk) {
168 blocks[md_blk_inner_idxs[iblk]] *= md_inner_blks[iblk];
169 block_size *= md_inner_blks[iblk];
170 }
171
172 // A custom comparator to yield linear order on perm
173 auto idx_sorter = [&](const int a, const int b) -> bool {
174 if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b])
175 return a < b;
176 else if (strides[a] == strides[b])
177 return md_padded_dims[a] < md_padded_dims[b];
178 else
179 return strides[a] < strides[b];
180 };
181 std::sort(perm, perm + md_ndims, idx_sorter);
182
183 auto min_stride = block_size;
184 for (int idx = 0; idx < md_ndims; ++idx) {
185 const int d = perm[idx];
186
187 // Make an exception for strides[d] == 0 as it has broadcast semantics
188 // Note: owing to being sorted, these are the initial strides
189 if (strides[d] == 0)
190 continue;
191 else if (strides[d] < min_stride)
192 return false;
193
194 // update min_stride for next iteration
195 const auto padded_dim = *md_padded_dims[d];
196 min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
197 }
198 return true;
199 }
200
is_broadcast(const at::Tensor & t)201 bool is_broadcast(const at::Tensor& t) {
202 for (int i = 0; i < t.dim(); i++) {
203 if (t.stride(i) == 0)
204 return true;
205 }
206 return false;
207 }
208
is_onednn_matmul_strides(const at::Tensor & tensor,bool is_dst)209 bool is_onednn_matmul_strides(
210 const at::Tensor& tensor,
211 bool is_dst) {
212 // https://oneapi-src.github.io/oneDNN/dev_guide_matmul.html
213 // oneDNN matmul only support 2-dim and 3-dim
214 // 2D src(Mxk), wei(KxN), dst(MxN)
215 // 3D src(SxMxK), wei(WxKxN), dst(DxMxN)
216 auto sizes = tensor.sizes();
217 auto tensor_dim = sizes.size();
218 if (tensor_dim != 2 && tensor_dim != 3)
219 return false;
220
221 if (tensor.is_contiguous())
222 return true;
223
224 // the overlaped cases are not supported
225 dnnl::memory::dims strides = get_onednn_strides(tensor);
226 int64_t storage_size = 1;
227 for (size_t dim = 0; dim < tensor_dim; ++dim)
228 storage_size += (sizes[dim] - 1) * strides[dim];
229 if (storage_size < tensor.numel())
230 return false;
231
232 // the broadcast cases are not supported
233 if (is_broadcast(tensor)) {
234 return false;
235 }
236
237 if (is_dst) {
238 // The memory format of the destination tensor should always
239 // be plain with n axis contiguous
240 if (strides[-1] != 1)
241 return false;
242 } else {
243 // the src and weight must have at least one of the axes
244 // m or k and n or k contiguous (i.e., stride=1) respectively.
245 if (strides[tensor_dim - 1] != 1 && strides[tensor_dim - 2] != 1)
246 return false;
247 }
248
249 if (!onednn_strides_check(tensor))
250 return false;
251 return true;
252 }
253
is_broadcast_from_other_to_self(const at::Tensor & self,const at::Tensor & other)254 bool is_broadcast_from_other_to_self(
255 const at::Tensor& self,
256 const at::Tensor& other) {
257 return (
258 self.sizes() != other.sizes() &&
259 at::is_expandable_to(other.sizes(), self.sizes()));
260 }
261
get_cl_tag_by_ndim(const int64_t ndim)262 at::MemoryFormat get_cl_tag_by_ndim(const int64_t ndim) {
263 TORCH_CHECK(
264 3 == ndim || 4 == ndim || 5 == ndim,
265 "ndim must be 3, 4 or 5 when get cl tag");
266 if (3 == ndim) {
267 return at::MemoryFormat::Contiguous;
268 } else if (5 == ndim) {
269 return at::MemoryFormat::ChannelsLast3d;
270 } else {
271 return at::MemoryFormat::ChannelsLast;
272 }
273 }
274
binary_valid(const at::Tensor & self,const at::Tensor & other,bool is_fusion)275 bool binary_valid(
276 const at::Tensor& self,
277 const at::Tensor& other,
278 bool is_fusion) {
279 if (self.sizes() != other.sizes() &&
280 !is_broadcast_from_other_to_self(self, other))
281 return false;
282
283 /* If the following conditions are satisfied, then oneDNN path will be
284 selected:
285 * 1. self and other should be xpu tensor and be defined.
286 * 2. self or other should not be scalar (wrapped tensor).
287 * 3. dim of self and other should be equal and must be larger than 0 and
288 smaller than 7.
289 * 4. the datatype should be supported by oneDNN primitive.
290 * 5. self and other should be in the same datatype.
291 * 6. self and other should be contiguous or channel-last contiguous.*/
292
293
294 // 1. self and other should be xpu tensor and be defined.
295 if ((!self.defined()) || (!other.defined()) || (!self.is_xpu()) ||
296 (!other.is_xpu()))
297 return false;
298
299 // 2. self or other should not be scalar (wrapped tensor).
300 if (self.unsafeGetTensorImpl()->is_wrapped_number() || other.unsafeGetTensorImpl()->is_wrapped_number())
301 return false;
302
303 // 3. dim of self and other should be equal and must be larger than 0 and
304 // smaller than 7.
305 if ((self.dim() <= 0) || (other.dim() <= 0) || (self.dim() != other.dim()) ||
306 (self.dim() > 6) || (other.dim() > 6))
307 return false;
308
309 // 4. the datatype should be supported by oneDNN primitive.
310 switch (self.scalar_type()) {
311 case at::ScalarType::Char:
312 break;
313 case at::ScalarType::Byte:
314 break;
315 case at::ScalarType::Half:
316 break;
317 case at::ScalarType::Float:
318 break;
319 case at::ScalarType::BFloat16:
320 break;
321 default:
322 return false;
323 };
324
325 // 5. datatype check
326 if (is_fusion) {
327 // for fusion case, the fusion can be performed on scalar_type or Float
328 // datatype.
329 if (self.scalar_type() != other.scalar_type() &&
330 other.scalar_type() != at::ScalarType::Float) {
331 return false;
332 }
333 } else {
334 if (self.scalar_type() != other.scalar_type()) {
335 // for non-fusion case: self and other should be in the same datatype.
336 return false;
337 }
338 }
339
340 // 6. self and other should be contiguous or channel-last contiguous.
341 const auto ndim = self.ndimension();
342 auto cl_tag = at::MemoryFormat::ChannelsLast;
343 if (3 == ndim || 4 == ndim || 5 == ndim) {
344 cl_tag = get_cl_tag_by_ndim(ndim);
345 }
346 if ((self.is_contiguous() && other.is_contiguous()) ||
347 (self.is_contiguous(cl_tag) && other.is_contiguous(cl_tag)))
348 return true;
349 return false;
350 }
351
is_channels_last(at::MemoryFormat fmt)352 static inline bool is_channels_last(at::MemoryFormat fmt){
353 return (at::MemoryFormat::ChannelsLast == fmt) || (at::MemoryFormat::ChannelsLast3d == fmt);
354 }
355
is_smf_channels_last(const Tensor & t)356 static inline bool is_smf_channels_last(const Tensor& t){
357 return is_channels_last(t.suggest_memory_format());
358 }
359
use_channels_last_for_conv(const at::Tensor & src,const at::Tensor & weight,bool is_transpose)360 bool use_channels_last_for_conv(
361 const at::Tensor& src,
362 const at::Tensor& weight,
363 bool is_transpose){
364
365 if (!src.defined() || src.is_sparse()) {
366 // suggest channels_first
367 return false;
368 }
369
370 auto suggest_channels_last_format =
371 (is_smf_channels_last(src) || is_smf_channels_last(weight));
372 if (suggest_channels_last_format) {
373 // suggest channels_last
374 return true;
375 }
376
377 return false;
378 }
379
380 }
381