xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/cpu/vec/functional.h>
6 #include <ATen/cpu/vec/vec.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/zero_native.h>
12 #endif
13 
14 #if !AT_MKLDNN_ENABLED()
15 
16 namespace at {
17 namespace native {
18 
mkldnn_zero_(Tensor & self)19 Tensor& mkldnn_zero_(Tensor& self) {
20   TORCH_CHECK(false, "mkldnn_zero_: ATen not compiled with MKLDNN support");
21 }
22 
23 } // namespace native
24 } // namespace at
25 
26 #else // AT_MKLDNN_ENABLED
27 
28 #include <ATen/native/mkldnn/MKLDNNCommon.h>
29 
30 namespace at {
31 namespace native {
32 
mkldnn_zero_(Tensor & self)33 Tensor& mkldnn_zero_(Tensor& self) {
34   using Vec = vec::Vectorized<float>;
35 
36   ideep::tensor& x = itensor_from_mkldnn(self);
37 
38   auto n = x.get_nelems();
39   auto* x_ = static_cast<float*>(x.get_data_handle());
40   parallel_for(0, n, 2048, [x_](int64_t begin, int64_t end) {
41     vec::map(
42         [](Vec /* unused */) { return 0.0; },
43         x_ + begin,
44         x_ + begin,
45         end - begin);
46   });
47 
48   return self;
49 }
50 
51 } // namespace native
52 } // namespace at
53 
54 #endif // AT_MKLDNN_ENABLED
55