1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2 #include <ATen/core/Tensor.h> 3 #include <ATen/Config.h> 4 #include <ATen/Parallel.h> 5 #include <ATen/cpu/vec/functional.h> 6 #include <ATen/cpu/vec/vec.h> 7 8 #ifndef AT_PER_OPERATOR_HEADERS 9 #include <ATen/NativeFunctions.h> 10 #else 11 #include <ATen/ops/zero_native.h> 12 #endif 13 14 #if !AT_MKLDNN_ENABLED() 15 16 namespace at { 17 namespace native { 18 mkldnn_zero_(Tensor & self)19Tensor& mkldnn_zero_(Tensor& self) { 20 TORCH_CHECK(false, "mkldnn_zero_: ATen not compiled with MKLDNN support"); 21 } 22 23 } // namespace native 24 } // namespace at 25 26 #else // AT_MKLDNN_ENABLED 27 28 #include <ATen/native/mkldnn/MKLDNNCommon.h> 29 30 namespace at { 31 namespace native { 32 mkldnn_zero_(Tensor & self)33Tensor& mkldnn_zero_(Tensor& self) { 34 using Vec = vec::Vectorized<float>; 35 36 ideep::tensor& x = itensor_from_mkldnn(self); 37 38 auto n = x.get_nelems(); 39 auto* x_ = static_cast<float*>(x.get_data_handle()); 40 parallel_for(0, n, 2048, [x_](int64_t begin, int64_t end) { 41 vec::map( 42 [](Vec /* unused */) { return 0.0; }, 43 x_ + begin, 44 x_ + begin, 45 end - begin); 46 }); 47 48 return self; 49 } 50 51 } // namespace native 52 } // namespace at 53 54 #endif // AT_MKLDNN_ENABLED 55