xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/mkldnn/xpu/detail/Utils.h>
2 
3 namespace at::native::onednn {
4 
make_onednn_memory(dnnl::memory::desc md,dnnl::engine & engine,void * ptr)5 dnnl::memory make_onednn_memory(
6     dnnl::memory::desc md,
7     dnnl::engine& engine,
8     void* ptr){
9   return dnnl::sycl_interop::make_memory(
10       md,
11       engine,
12       dnnl::sycl_interop::memory_kind::usm,
13       ptr == nullptr ? DNNL_MEMORY_ALLOCATE : ptr);
14 }
15 
get_dnnl_default_format(int ndims,bool is_channels_last,bool allow_undef)16 dnnl::memory::format_tag get_dnnl_default_format(
17     int ndims,
18     bool is_channels_last,
19     bool allow_undef) {
20   switch (ndims) {
21     case 1:
22       return dnnl::memory::format_tag::a;
23     case 2:
24       return dnnl::memory::format_tag::ab;
25     case 3:
26       return is_channels_last ? dnnl::memory::format_tag::acb
27                               : dnnl::memory::format_tag::abc;
28     case 4:
29       return is_channels_last ? dnnl::memory::format_tag::acdb
30                               : dnnl::memory::format_tag::abcd;
31     case 5:
32       return is_channels_last ? dnnl::memory::format_tag::acdeb
33                               : dnnl::memory::format_tag::abcde;
34     case 6:
35       return dnnl::memory::format_tag::abcdef;
36     case 7:
37       return dnnl::memory::format_tag::abcdefg;
38     case 8:
39       return dnnl::memory::format_tag::abcdefgh;
40     case 9:
41       return dnnl::memory::format_tag::abcdefghi;
42     case 10:
43       return dnnl::memory::format_tag::abcdefghij;
44     case 11:
45       return dnnl::memory::format_tag::abcdefghijk;
46     case 12:
47       return dnnl::memory::format_tag::abcdefghijkl;
48     default:
49       if (!allow_undef) {
50         TORCH_CHECK(false, "oneDNN doesn't support tensor dimension > 12");
51       }
52       return dnnl::memory::format_tag::undef;
53   }
54 }
55 
get_onednn_dtype(const at::Tensor & tensor,bool allow_undef)56 dnnl::memory::data_type get_onednn_dtype(
57     const at::Tensor& tensor,
58     bool allow_undef) {
59   switch (tensor.scalar_type()) {
60     case at::ScalarType::Byte:
61       return dnnl::memory::data_type::u8;
62     case at::ScalarType::Char:
63       return dnnl::memory::data_type::s8;
64     case at::ScalarType::QInt8:
65       return dnnl::memory::data_type::s8;
66     case at::ScalarType::QUInt8:
67       return dnnl::memory::data_type::u8;
68     case at::ScalarType::Int:
69       return dnnl::memory::data_type::s32;
70     case at::ScalarType::Half:
71       return dnnl::memory::data_type::f16;
72     case at::ScalarType::Float:
73       return dnnl::memory::data_type::f32;
74     case at::ScalarType::BFloat16:
75       return dnnl::memory::data_type::bf16;
76     default:
77       if (!allow_undef) {
78         TORCH_CHECK(
79             false,
80             c10::toString(tensor.scalar_type()),
81             " is not supported in oneDNN!");
82       }
83       return dnnl::memory::data_type::undef;
84   };
85 }
86 
get_onednn_dtype_include_double(const at::Tensor & tensor,bool allow_undef)87 dnnl::memory::data_type get_onednn_dtype_include_double(
88     const at::Tensor& tensor,
89     bool allow_undef) {
90   if (tensor.scalar_type() == at::ScalarType::Double)
91     return dnnl::memory::data_type::f64;
92   return get_onednn_dtype(tensor, allow_undef);
93 }
94 
is_supported_onednn_dtype(const at::Tensor & tensor)95 bool is_supported_onednn_dtype(const at::Tensor& tensor) {
96   return get_onednn_dtype(tensor, /*allow_undef*/ true) ==
97           dnnl::memory::data_type::undef
98       ? false
99       : true;
100 }
101 
get_onednn_dims(const at::Tensor & tensor)102 dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor) {
103   dnnl::memory::dims dims;
104   for (size_t i = 0; i < tensor.sizes().size(); i++)
105     dims.push_back(tensor.size(i));
106   return dims;
107 }
108 
get_onednn_strides(const at::Tensor & tensor)109 dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor) {
110   dnnl::memory::dims strides;
111   for (size_t i = 0; i < tensor.strides().size(); i++)
112     strides.push_back(tensor.stride(i));
113   return strides;
114 }
115 
get_onednn_md(const at::Tensor & tensor)116 dnnl::memory::desc get_onednn_md(const at::Tensor& tensor) {
117   return {
118       get_onednn_dims(tensor),
119       get_onednn_dtype(tensor),
120       get_onednn_strides(tensor)};
121 }
122 
onednn_strides_check(const at::Tensor & src)123 bool onednn_strides_check(const at::Tensor& src) {
124   auto adims = get_onednn_dims(src);
125   int ndims = (int)adims.size();
126   auto dims = adims.data();
127   auto data_type = static_cast<dnnl_data_type_t>(
128       get_onednn_dtype(src, /*allow_undef*/ true));
129   auto strides_info = get_onednn_strides(src);
130   auto strides = strides_info.empty() ? nullptr : &strides_info[0];
131 
132   dnnl_memory_desc_t md;
133   dnnl_memory_desc_create_with_strides(&md, ndims, dims, data_type, strides);
134   dnnl_format_kind_t md_fmt_kind;
135   int md_ndims;
136   int md_inner_nblks;
137   dnnl_dims_t* md_padded_dims = nullptr;
138 
139   dnnl_memory_desc_query(md, dnnl_query_inner_nblks_s32, &md_inner_nblks);
140   dnnl_memory_desc_query(md, dnnl_query_format_kind, &md_fmt_kind);
141   dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &md_ndims);
142   dnnl_memory_desc_query(md, dnnl_query_padded_dims, &md_padded_dims);
143   if (strides == nullptr || md_ndims == 0 ||
144       md_fmt_kind != dnnl_format_kind_t::dnnl_blocked)
145     return true;
146 
147   dnnl_dims_t blocks = {0};
148   int perm[DNNL_MAX_NDIMS] = {0};
149   for (int d = 0; d < md_ndims; ++d) {
150     // no strides check needed for empty tensor
151     if (md_padded_dims[d] == 0)
152       return true;
153 
154     // no strides verification for runtime dims
155     if (strides[d] == DNNL_RUNTIME_DIM_VAL)
156       return true;
157 
158     perm[d] = d;
159     blocks[d] = 1;
160   }
161 
162   auto block_size = 1;
163   dnnl_dims_t md_inner_blks;
164   dnnl_dims_t md_blk_inner_idxs;
165   dnnl_memory_desc_query(md, dnnl_query_inner_idxs, &md_blk_inner_idxs);
166   dnnl_memory_desc_query(md, dnnl_query_inner_blks, &md_inner_blks);
167   for (int iblk = 0; iblk < md_inner_nblks; ++iblk) {
168     blocks[md_blk_inner_idxs[iblk]] *= md_inner_blks[iblk];
169     block_size *= md_inner_blks[iblk];
170   }
171 
172   // A custom comparator to yield linear order on perm
173   auto idx_sorter = [&](const int a, const int b) -> bool {
174     if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b])
175       return a < b;
176     else if (strides[a] == strides[b])
177       return md_padded_dims[a] < md_padded_dims[b];
178     else
179       return strides[a] < strides[b];
180   };
181   std::sort(perm, perm + md_ndims, idx_sorter);
182 
183   auto min_stride = block_size;
184   for (int idx = 0; idx < md_ndims; ++idx) {
185     const int d = perm[idx];
186 
187     // Make an exception for strides[d] == 0 as it has broadcast semantics
188     // Note: owing to being sorted, these are the initial strides
189     if (strides[d] == 0)
190       continue;
191     else if (strides[d] < min_stride)
192       return false;
193 
194     // update min_stride for next iteration
195     const auto padded_dim = *md_padded_dims[d];
196     min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
197   }
198   return true;
199 }
200 
is_broadcast(const at::Tensor & t)201 bool is_broadcast(const at::Tensor& t) {
202   for (int i = 0; i < t.dim(); i++) {
203     if (t.stride(i) == 0)
204       return true;
205   }
206   return false;
207 }
208 
is_onednn_matmul_strides(const at::Tensor & tensor,bool is_dst)209 bool is_onednn_matmul_strides(
210     const at::Tensor& tensor,
211     bool is_dst) {
212   // https://oneapi-src.github.io/oneDNN/dev_guide_matmul.html
213   // oneDNN matmul only support 2-dim and 3-dim
214   // 2D src(Mxk), wei(KxN), dst(MxN)
215   // 3D src(SxMxK), wei(WxKxN), dst(DxMxN)
216   auto sizes = tensor.sizes();
217   auto tensor_dim = sizes.size();
218   if (tensor_dim != 2 && tensor_dim != 3)
219     return false;
220 
221   if (tensor.is_contiguous())
222     return true;
223 
224   // the overlaped cases are not supported
225   dnnl::memory::dims strides = get_onednn_strides(tensor);
226   int64_t storage_size = 1;
227   for (size_t dim = 0; dim < tensor_dim; ++dim)
228     storage_size += (sizes[dim] - 1) * strides[dim];
229   if (storage_size < tensor.numel())
230     return false;
231 
232   // the broadcast cases are not supported
233   if (is_broadcast(tensor)) {
234     return false;
235   }
236 
237   if (is_dst) {
238     // The memory format of the destination tensor should always
239     // be plain with n axis contiguous
240     if (strides[-1] != 1)
241       return false;
242   } else {
243     // the src and weight must have at least one of the axes
244     // m or k and n or k contiguous (i.e., stride=1) respectively.
245     if (strides[tensor_dim - 1] != 1 && strides[tensor_dim - 2] != 1)
246       return false;
247   }
248 
249   if (!onednn_strides_check(tensor))
250     return false;
251   return true;
252 }
253 
is_broadcast_from_other_to_self(const at::Tensor & self,const at::Tensor & other)254 bool is_broadcast_from_other_to_self(
255     const at::Tensor& self,
256     const at::Tensor& other) {
257   return (
258       self.sizes() != other.sizes() &&
259       at::is_expandable_to(other.sizes(), self.sizes()));
260 }
261 
get_cl_tag_by_ndim(const int64_t ndim)262 at::MemoryFormat get_cl_tag_by_ndim(const int64_t ndim) {
263   TORCH_CHECK(
264       3 == ndim || 4 == ndim || 5 == ndim,
265       "ndim must be 3, 4 or 5 when get cl tag");
266   if (3 == ndim) {
267     return at::MemoryFormat::Contiguous;
268   } else if (5 == ndim) {
269     return at::MemoryFormat::ChannelsLast3d;
270   } else {
271     return at::MemoryFormat::ChannelsLast;
272   }
273 }
274 
binary_valid(const at::Tensor & self,const at::Tensor & other,bool is_fusion)275 bool binary_valid(
276     const at::Tensor& self,
277     const at::Tensor& other,
278     bool is_fusion) {
279   if (self.sizes() != other.sizes() &&
280       !is_broadcast_from_other_to_self(self, other))
281     return false;
282 
283   /* If the following conditions are satisfied, then oneDNN path will be
284      selected:
285      * 1. self and other should be xpu tensor and be defined.
286      * 2. self or other should not be scalar (wrapped tensor).
287      * 3. dim of self and other should be equal and must be larger than 0 and
288      smaller than 7.
289      * 4. the datatype should be supported by oneDNN primitive.
290      * 5. self and other should be in the same datatype.
291      * 6. self and other should be contiguous or channel-last contiguous.*/
292 
293 
294   // 1. self and other should be xpu tensor and be defined.
295   if ((!self.defined()) || (!other.defined()) || (!self.is_xpu()) ||
296       (!other.is_xpu()))
297     return false;
298 
299   // 2. self or other should not be scalar (wrapped tensor).
300   if (self.unsafeGetTensorImpl()->is_wrapped_number() || other.unsafeGetTensorImpl()->is_wrapped_number())
301     return false;
302 
303   // 3. dim of self and other should be equal and must be larger than 0 and
304   // smaller than 7.
305   if ((self.dim() <= 0) || (other.dim() <= 0) || (self.dim() != other.dim()) ||
306       (self.dim() > 6) || (other.dim() > 6))
307     return false;
308 
309   // 4. the datatype should be supported by oneDNN primitive.
310   switch (self.scalar_type()) {
311     case at::ScalarType::Char:
312       break;
313     case at::ScalarType::Byte:
314       break;
315     case at::ScalarType::Half:
316       break;
317     case at::ScalarType::Float:
318       break;
319     case at::ScalarType::BFloat16:
320       break;
321     default:
322       return false;
323   };
324 
325   // 5. datatype check
326   if (is_fusion) {
327     // for fusion case, the fusion can be performed on scalar_type or Float
328     // datatype.
329     if (self.scalar_type() != other.scalar_type() &&
330         other.scalar_type() != at::ScalarType::Float) {
331       return false;
332     }
333   } else {
334     if (self.scalar_type() != other.scalar_type()) {
335       // for non-fusion case: self and other should be in the same datatype.
336       return false;
337     }
338   }
339 
340   // 6. self and other should be contiguous or channel-last contiguous.
341   const auto ndim = self.ndimension();
342   auto cl_tag = at::MemoryFormat::ChannelsLast;
343   if (3 == ndim || 4 == ndim || 5 == ndim) {
344     cl_tag = get_cl_tag_by_ndim(ndim);
345   }
346   if ((self.is_contiguous() && other.is_contiguous()) ||
347       (self.is_contiguous(cl_tag) && other.is_contiguous(cl_tag)))
348     return true;
349   return false;
350 }
351 
is_channels_last(at::MemoryFormat fmt)352 static inline bool is_channels_last(at::MemoryFormat fmt){
353   return (at::MemoryFormat::ChannelsLast == fmt) || (at::MemoryFormat::ChannelsLast3d == fmt);
354 }
355 
is_smf_channels_last(const Tensor & t)356 static inline bool is_smf_channels_last(const Tensor& t){
357   return is_channels_last(t.suggest_memory_format());
358 }
359 
use_channels_last_for_conv(const at::Tensor & src,const at::Tensor & weight,bool is_transpose)360 bool use_channels_last_for_conv(
361     const at::Tensor& src,
362     const at::Tensor& weight,
363     bool is_transpose){
364 
365   if (!src.defined() || src.is_sparse()) {
366     // suggest channels_first
367     return false;
368   }
369 
370   auto suggest_channels_last_format =
371       (is_smf_channels_last(src) || is_smf_channels_last(weight));
372   if (suggest_channels_last_format) {
373     // suggest channels_last
374     return true;
375   }
376 
377   return false;
378 }
379 
380 }
381