xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/TensorShape.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/InferSize.h>
4 #include <ATen/core/Tensor.h>
5 #include <c10/core/SymIntArrayRef.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_mkldnn_reshape_native.h>
11 #include <ATen/ops/_mkldnn_transpose_native.h>
12 #include <ATen/ops/clone_native.h>
13 #include <ATen/ops/view_native.h>
14 #endif
15 
16 #if !AT_MKLDNN_ENABLED()
17 
18 namespace at {
19 namespace native {
20 
mkldnn_view(const Tensor & self,IntArrayRef size)21 Tensor mkldnn_view(const Tensor& self, IntArrayRef size) {
22   TORCH_CHECK(false, "mkldnn_reshape: ATen not compiled with MKLDNN support");
23 }
24 
mkldnn_reshape(const Tensor & self,IntArrayRef size)25 Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
26   TORCH_CHECK(false, "mkldnn_reshape: ATen not compiled with MKLDNN support");
27 }
28 
mkldnn_clone(const Tensor & self,std::optional<c10::MemoryFormat> optional_memory_format)29 Tensor mkldnn_clone(const Tensor& self, std::optional<c10::MemoryFormat> optional_memory_format) {
30   TORCH_CHECK(false, "mkldnn_clone: ATen not compiled with MKLDNN support");
31 }
32 
mkldnn_transpose(const Tensor & self,int64_t dim0,int64_t dim1)33 Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
34   TORCH_CHECK(false, "mkldnn_transpose: ATen not compiled with MKLDNN support");
35 }
36 
mkldnn_transpose_(Tensor & self,int64_t dim0,int64_t dim1)37 Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) {
38   TORCH_CHECK(false, "mkldnn_transpose_: ATen not compiled with MKLDNN support");
39 }
40 
41 } // namespace native
42 } // namespace at
43 
44 #else // AT_MKLDNN_ENABLED
45 
46 #include <ATen/native/mkldnn/MKLDNNCommon.h>
47 
48 namespace at {
49 namespace native {
50 
mkldnn_view(const Tensor & self,IntArrayRef size)51 Tensor mkldnn_view(const Tensor& self, IntArrayRef size) {
52   TORCH_CHECK(false,
53       "Currently Mkldnn tensor does not support view. Change to use reshape instead");
54 }
55 
mkldnn_reshape(const Tensor & self,IntArrayRef size)56 Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
57   auto inferred_size = at::infer_size(size, self.numel());
58   if (self.sizes() == inferred_size) {
59     return self;
60   }
61   const ideep::tensor& x = itensor_from_mkldnn(self);
62   ideep::tensor y{x};
63   y.reshape(inferred_size);
64   return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
65                                  self.options().device_opt());
66 }
67 
mkldnn_clone(const Tensor & self,std::optional<c10::MemoryFormat> optional_memory_format)68 Tensor mkldnn_clone(const Tensor& self, std::optional<c10::MemoryFormat> optional_memory_format) {
69   TORCH_CHECK(
70       !optional_memory_format.has_value(),
71       "unsupported memory format option ",
72       optional_memory_format.value());
73   ideep::tensor& src = itensor_from_mkldnn(self);
74   ideep::tensor dst;
75   ideep::direct_copy::compute(src, dst);
76   return new_with_itensor_mkldnn(std::move(dst), optTypeMetaToScalarType(self.options().dtype_opt()),
77                                  self.options().device_opt());
78 }
79 
mkldnn_transpose(const Tensor & self,int64_t dim0,int64_t dim1)80 Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
81   auto ndims = self.dim();
82   dim0 = maybe_wrap_dim(dim0, ndims);
83   dim1 = maybe_wrap_dim(dim1, ndims);
84   const ideep::tensor& x = itensor_from_mkldnn(self);
85   ideep::tensor y;
86   std::vector<int> axes(x.ndims());
87   std::iota(axes.begin(), axes.end(), 0);
88   std::swap(axes[dim0], axes[dim1]);
89   y.transpose_from(x, axes);
90   return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
91                                  self.options().device_opt());
92 }
93 
mkldnn_transpose_(Tensor & self,int64_t dim0,int64_t dim1)94 Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) {
95   TORCH_CHECK(false, "mkldnn_transpose_: in-place mkldnn operations are not supported yet");
96 }
97 
98 } // namespace native
99 } // namespace at
100 
101 #endif // AT_MKLDNN_ENABLED
102