xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/mkldnn/MKLDNNCommon.h>
2 #include <ATen/OpaqueTensorImpl.h>
3 #include <c10/core/Allocator.h>
4 #include <torch/library.h>
5 
6 #if AT_MKLDNN_ENABLED()
7 
8 #include <ideep.hpp>
9 
10 namespace at { namespace native {
11 
12 /**
13  * `IntrusivePtrTargetWrapper` wraps a custom storage handle  of a tensor
14 *  (as template param) and inherits `c10::intrusive_ptr_target` so that it
15 *  can be used with `c10::intrusive_ptr`.
16  *
17  * It currently only supports wrapping the custom handle by:
18  * - Constructing with an existing custom handle by copy/move constructor.
19  *
20  * See `OpaqueTensorImpl::opaque_handle_`.
21  *
22  * NOTE: if this is generally useful we may want to move this to its own header.
23  */
24 template <typename T>
25 struct TORCH_API IntrusivePtrTargetWrapper : c10::intrusive_ptr_target {
26 private:
27   T target_;
28 
29 public:
30   IntrusivePtrTargetWrapper() = delete;
IntrusivePtrTargetWrapperat::native::IntrusivePtrTargetWrapper31   IntrusivePtrTargetWrapper(const T& target): target_(target) {}
IntrusivePtrTargetWrapperat::native::IntrusivePtrTargetWrapper32   IntrusivePtrTargetWrapper(T&& target): target_(std::move(target)) {}
33 
get_targetat::native::IntrusivePtrTargetWrapper34   T& get_target() {
35     return target_;
36   }
37 };
38 
39 using IDeepTensorWrapper = IntrusivePtrTargetWrapper<ideep::tensor>;
40 using IDeepTensorWrapperPtr = c10::intrusive_ptr<IDeepTensorWrapper>;
41 using MKLDNNTensorImpl = OpaqueTensorImpl<IDeepTensorWrapperPtr>;
42 using MKLDNNTensor = Tensor;
43 
get_mkldnn_dtype(ScalarType type)44 ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
45   switch (type) {
46     case ScalarType::Float:
47       return ideep::tensor::data_type::f32;
48     case ScalarType::QInt32:
49       return ideep::tensor::data_type::s32;
50     case ScalarType::QInt8:
51     case ScalarType::Char:
52       return ideep::tensor::data_type::s8;
53     case ScalarType::QUInt8:
54     case ScalarType::Byte:
55       return ideep::tensor::data_type::u8;
56     case ScalarType::BFloat16:
57       return ideep::tensor::data_type::bf16;
58     case ScalarType::Half:
59       return ideep::tensor::data_type::f16;
60     default:
61       TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
62   }
63 }
64 
data_ptr_from_mkldnn(const Tensor & mkldnn_tensor)65 int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor) {
66   MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
67   void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle();
68   return reinterpret_cast<int64_t>(data_ptr);
69 }
70 
mkldnn_tensor_from_data_ptr(void * data_ptr,at::IntArrayRef dims,at::ScalarType dtype,at::Device device,const uint8_t * opaque_metadata,int64_t opaque_metadata_size)71 at::Tensor mkldnn_tensor_from_data_ptr(
72     void* data_ptr,
73     at::IntArrayRef dims,
74     at::ScalarType dtype,
75     at::Device device,
76     const uint8_t* opaque_metadata,
77     int64_t opaque_metadata_size) {
78   std::vector<uint8_t> vector_serialized_md{
79       opaque_metadata, opaque_metadata + opaque_metadata_size};
80   ideep::tensor::desc deserialized_ideep_desc;
81 #if IDEEP_PREREQ(3, 4, 1, 2)
82   // groups is needed for grouped conv
83   deserialized_ideep_desc = ideep::tensor::desc(vector_serialized_md);
84 #else
85   TORCH_CHECK(false, "Unexpected IDeep version to do weight deserialization.");
86 #endif
87 
88   auto a = ideep::tensor(deserialized_ideep_desc, data_ptr);
89   return at::native::new_with_itensor_mkldnn(std::move(a), dtype, device);
90 }
91 
new_with_itensor_mkldnn(ideep::tensor && it,std::optional<ScalarType> dtype,std::optional<Device> device)92 Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device) {
93   // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
94   // TODO: support int64_t dims in ideep::tensor to avoid extra conversion
95   auto dims = it.get_dims();
96   IDeepTensorWrapperPtr handle = c10::make_intrusive<IDeepTensorWrapper>(std::move(it));
97   caffe2::TypeMeta dtype_ = scalarTypeToTypeMeta(dtype_or_default(dtype));
98   Device device_ = device_or_default(device);
99   return detail::make_tensor<MKLDNNTensorImpl>(
100     DispatchKeySet(DispatchKey::MkldnnCPU),
101     dtype_, device_, handle,
102     std::vector<int64_t>(dims.begin(), dims.end()));
103 }
104 
itensor_from_mkldnn(const MKLDNNTensor & mkldnn_tensor)105 ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
106   TORCH_CHECK(mkldnn_tensor.is_mkldnn(),
107              "itensor_from_mkldnn expects MKL-DNN tensor input");
108   MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
109   return mklimpl->unsafe_opaque_handle()->get_target();
110 }
111 
nbytes_from_mkldnn(const Tensor & mkldnn_tensor)112 int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor) {
113   ideep::tensor t = itensor_from_mkldnn(mkldnn_tensor);
114   return t.get_desc().get_size();
115 }
116 
itensor_view_from_dense(const Tensor & tensor,bool from_const_data_ptr)117 ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr) {
118   TORCH_CHECK(
119       tensor.device().is_cpu(),
120       "itensor_view_from_dense expects CPU tensor input");
121   TORCH_CHECK(
122       tensor.layout() == Layout::Strided,
123       "itensor_view_from_dense expects dense tensor input");
124   if (tensor.scalar_type() == ScalarType::Float) {
125     return {{tensor.sizes().vec(),
126             ideep::tensor::data_type::f32,
127             tensor.strides().vec()},
128             from_const_data_ptr ?
129               const_cast<float*>(tensor.template const_data_ptr<float>()) :
130               tensor.template data_ptr<float>()};
131   }
132   else if (tensor.scalar_type() == ScalarType::BFloat16) {
133     return {{tensor.sizes().vec(),
134             ideep::tensor::data_type::bf16,
135             tensor.strides().vec()},
136             from_const_data_ptr ?
137               const_cast<BFloat16*>(tensor.template const_data_ptr<BFloat16>()) :
138               tensor.template data_ptr<BFloat16>()};
139   }
140   else if (tensor.scalar_type() == ScalarType::Half) {
141     return {{tensor.sizes().vec(),
142             ideep::tensor::data_type::f16,
143             tensor.strides().vec()},
144             from_const_data_ptr ?
145               const_cast<Half*>(tensor.template const_data_ptr<Half>()) :
146               tensor.template data_ptr<Half>()};
147   }
148   else if (tensor.scalar_type() == ScalarType::Byte) {
149     return {{tensor.sizes().vec(),
150             ideep::tensor::data_type::u8,
151             tensor.strides().vec()},
152             from_const_data_ptr ?
153               const_cast<void*>(tensor.const_data_ptr()) :
154               tensor.data_ptr()};
155   }
156   else if (tensor.scalar_type() == ScalarType::Char) {
157     return {{tensor.sizes().vec(),
158             ideep::tensor::data_type::s8,
159             tensor.strides().vec()},
160             from_const_data_ptr ?
161               const_cast<void*>(tensor.const_data_ptr()) :
162               tensor.data_ptr()};
163   }
164   else {
165     TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
166   }
167 }
168 
itensor_view_from_dense(const at::Tensor & tensor,const ideep::tensor::desc & desc)169 ideep::tensor itensor_view_from_dense(
170     const at::Tensor& tensor,
171     const ideep::tensor::desc& desc) {
172   TORCH_CHECK(
173       tensor.device().is_cpu(),
174       "itensor_view_from_dense expects CPU tensor input");
175   TORCH_CHECK(
176       tensor.layout() == at::Layout::Strided,
177       "itensor_view_from_dense expects dense tensor input");
178   TORCH_CHECK(
179       tensor.scalar_type() == at::ScalarType::Float ||
180           tensor.scalar_type() == at::ScalarType::BFloat16 ||
181           tensor.scalar_type() == at::ScalarType::Half,
182       "itensor_view_from_dense expects float, bfloat16 or half tensor input");
183   return {desc, tensor.data_ptr()};
184 }
185 
186 // Helper function for getting an ideep tensor out of an aten Tensor.
187 // Note in case the aten Tensor is a dense tensor, the returned ideep
188 // tensor is just a view of the storage of the aten dense tensor, so
189 // caller needs to make sure the aten dense tensor's lifetime is
190 // longer than the ideep tensor.
itensor_from_tensor(const Tensor & tensor,bool from_const_data_ptr)191 ideep::tensor itensor_from_tensor(const Tensor& tensor, bool from_const_data_ptr) {
192   if (tensor.is_mkldnn()) {
193     return itensor_from_mkldnn(tensor);
194   } else {
195     return itensor_view_from_dense(tensor, from_const_data_ptr);
196   }
197 }
198 
set_verbose(int level)199 int set_verbose(int level) {
200     return ideep::utils::set_verbose(level);
201 }
202 
TORCH_LIBRARY_IMPL(mkldnn,MkldnnCPU,m)203 TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
204   m.impl(
205       TORCH_SELECTIVE_NAME("mkldnn::data_ptr"),
206       TORCH_FN(data_ptr_from_mkldnn));
207   m.impl(
208       TORCH_SELECTIVE_NAME("mkldnn::_nbytes"),
209       TORCH_FN(nbytes_from_mkldnn));
210 }
211 
212 }}
213 
214 #endif // AT_MKLDNN_ENABLED()
215