1 #include <ATen/native/mkldnn/MKLDNNCommon.h>
2 #include <ATen/OpaqueTensorImpl.h>
3 #include <c10/core/Allocator.h>
4 #include <torch/library.h>
5
6 #if AT_MKLDNN_ENABLED()
7
8 #include <ideep.hpp>
9
10 namespace at { namespace native {
11
12 /**
13 * `IntrusivePtrTargetWrapper` wraps a custom storage handle of a tensor
14 * (as template param) and inherits `c10::intrusive_ptr_target` so that it
15 * can be used with `c10::intrusive_ptr`.
16 *
17 * It currently only supports wrapping the custom handle by:
18 * - Constructing with an existing custom handle by copy/move constructor.
19 *
20 * See `OpaqueTensorImpl::opaque_handle_`.
21 *
22 * NOTE: if this is generally useful we may want to move this to its own header.
23 */
24 template <typename T>
25 struct TORCH_API IntrusivePtrTargetWrapper : c10::intrusive_ptr_target {
26 private:
27 T target_;
28
29 public:
30 IntrusivePtrTargetWrapper() = delete;
IntrusivePtrTargetWrapperat::native::IntrusivePtrTargetWrapper31 IntrusivePtrTargetWrapper(const T& target): target_(target) {}
IntrusivePtrTargetWrapperat::native::IntrusivePtrTargetWrapper32 IntrusivePtrTargetWrapper(T&& target): target_(std::move(target)) {}
33
get_targetat::native::IntrusivePtrTargetWrapper34 T& get_target() {
35 return target_;
36 }
37 };
38
39 using IDeepTensorWrapper = IntrusivePtrTargetWrapper<ideep::tensor>;
40 using IDeepTensorWrapperPtr = c10::intrusive_ptr<IDeepTensorWrapper>;
41 using MKLDNNTensorImpl = OpaqueTensorImpl<IDeepTensorWrapperPtr>;
42 using MKLDNNTensor = Tensor;
43
get_mkldnn_dtype(ScalarType type)44 ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
45 switch (type) {
46 case ScalarType::Float:
47 return ideep::tensor::data_type::f32;
48 case ScalarType::QInt32:
49 return ideep::tensor::data_type::s32;
50 case ScalarType::QInt8:
51 case ScalarType::Char:
52 return ideep::tensor::data_type::s8;
53 case ScalarType::QUInt8:
54 case ScalarType::Byte:
55 return ideep::tensor::data_type::u8;
56 case ScalarType::BFloat16:
57 return ideep::tensor::data_type::bf16;
58 case ScalarType::Half:
59 return ideep::tensor::data_type::f16;
60 default:
61 TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
62 }
63 }
64
data_ptr_from_mkldnn(const Tensor & mkldnn_tensor)65 int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor) {
66 MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
67 void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle();
68 return reinterpret_cast<int64_t>(data_ptr);
69 }
70
mkldnn_tensor_from_data_ptr(void * data_ptr,at::IntArrayRef dims,at::ScalarType dtype,at::Device device,const uint8_t * opaque_metadata,int64_t opaque_metadata_size)71 at::Tensor mkldnn_tensor_from_data_ptr(
72 void* data_ptr,
73 at::IntArrayRef dims,
74 at::ScalarType dtype,
75 at::Device device,
76 const uint8_t* opaque_metadata,
77 int64_t opaque_metadata_size) {
78 std::vector<uint8_t> vector_serialized_md{
79 opaque_metadata, opaque_metadata + opaque_metadata_size};
80 ideep::tensor::desc deserialized_ideep_desc;
81 #if IDEEP_PREREQ(3, 4, 1, 2)
82 // groups is needed for grouped conv
83 deserialized_ideep_desc = ideep::tensor::desc(vector_serialized_md);
84 #else
85 TORCH_CHECK(false, "Unexpected IDeep version to do weight deserialization.");
86 #endif
87
88 auto a = ideep::tensor(deserialized_ideep_desc, data_ptr);
89 return at::native::new_with_itensor_mkldnn(std::move(a), dtype, device);
90 }
91
new_with_itensor_mkldnn(ideep::tensor && it,std::optional<ScalarType> dtype,std::optional<Device> device)92 Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device) {
93 // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
94 // TODO: support int64_t dims in ideep::tensor to avoid extra conversion
95 auto dims = it.get_dims();
96 IDeepTensorWrapperPtr handle = c10::make_intrusive<IDeepTensorWrapper>(std::move(it));
97 caffe2::TypeMeta dtype_ = scalarTypeToTypeMeta(dtype_or_default(dtype));
98 Device device_ = device_or_default(device);
99 return detail::make_tensor<MKLDNNTensorImpl>(
100 DispatchKeySet(DispatchKey::MkldnnCPU),
101 dtype_, device_, handle,
102 std::vector<int64_t>(dims.begin(), dims.end()));
103 }
104
itensor_from_mkldnn(const MKLDNNTensor & mkldnn_tensor)105 ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
106 TORCH_CHECK(mkldnn_tensor.is_mkldnn(),
107 "itensor_from_mkldnn expects MKL-DNN tensor input");
108 MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
109 return mklimpl->unsafe_opaque_handle()->get_target();
110 }
111
nbytes_from_mkldnn(const Tensor & mkldnn_tensor)112 int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor) {
113 ideep::tensor t = itensor_from_mkldnn(mkldnn_tensor);
114 return t.get_desc().get_size();
115 }
116
itensor_view_from_dense(const Tensor & tensor,bool from_const_data_ptr)117 ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr) {
118 TORCH_CHECK(
119 tensor.device().is_cpu(),
120 "itensor_view_from_dense expects CPU tensor input");
121 TORCH_CHECK(
122 tensor.layout() == Layout::Strided,
123 "itensor_view_from_dense expects dense tensor input");
124 if (tensor.scalar_type() == ScalarType::Float) {
125 return {{tensor.sizes().vec(),
126 ideep::tensor::data_type::f32,
127 tensor.strides().vec()},
128 from_const_data_ptr ?
129 const_cast<float*>(tensor.template const_data_ptr<float>()) :
130 tensor.template data_ptr<float>()};
131 }
132 else if (tensor.scalar_type() == ScalarType::BFloat16) {
133 return {{tensor.sizes().vec(),
134 ideep::tensor::data_type::bf16,
135 tensor.strides().vec()},
136 from_const_data_ptr ?
137 const_cast<BFloat16*>(tensor.template const_data_ptr<BFloat16>()) :
138 tensor.template data_ptr<BFloat16>()};
139 }
140 else if (tensor.scalar_type() == ScalarType::Half) {
141 return {{tensor.sizes().vec(),
142 ideep::tensor::data_type::f16,
143 tensor.strides().vec()},
144 from_const_data_ptr ?
145 const_cast<Half*>(tensor.template const_data_ptr<Half>()) :
146 tensor.template data_ptr<Half>()};
147 }
148 else if (tensor.scalar_type() == ScalarType::Byte) {
149 return {{tensor.sizes().vec(),
150 ideep::tensor::data_type::u8,
151 tensor.strides().vec()},
152 from_const_data_ptr ?
153 const_cast<void*>(tensor.const_data_ptr()) :
154 tensor.data_ptr()};
155 }
156 else if (tensor.scalar_type() == ScalarType::Char) {
157 return {{tensor.sizes().vec(),
158 ideep::tensor::data_type::s8,
159 tensor.strides().vec()},
160 from_const_data_ptr ?
161 const_cast<void*>(tensor.const_data_ptr()) :
162 tensor.data_ptr()};
163 }
164 else {
165 TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
166 }
167 }
168
itensor_view_from_dense(const at::Tensor & tensor,const ideep::tensor::desc & desc)169 ideep::tensor itensor_view_from_dense(
170 const at::Tensor& tensor,
171 const ideep::tensor::desc& desc) {
172 TORCH_CHECK(
173 tensor.device().is_cpu(),
174 "itensor_view_from_dense expects CPU tensor input");
175 TORCH_CHECK(
176 tensor.layout() == at::Layout::Strided,
177 "itensor_view_from_dense expects dense tensor input");
178 TORCH_CHECK(
179 tensor.scalar_type() == at::ScalarType::Float ||
180 tensor.scalar_type() == at::ScalarType::BFloat16 ||
181 tensor.scalar_type() == at::ScalarType::Half,
182 "itensor_view_from_dense expects float, bfloat16 or half tensor input");
183 return {desc, tensor.data_ptr()};
184 }
185
186 // Helper function for getting an ideep tensor out of an aten Tensor.
187 // Note in case the aten Tensor is a dense tensor, the returned ideep
188 // tensor is just a view of the storage of the aten dense tensor, so
189 // caller needs to make sure the aten dense tensor's lifetime is
190 // longer than the ideep tensor.
itensor_from_tensor(const Tensor & tensor,bool from_const_data_ptr)191 ideep::tensor itensor_from_tensor(const Tensor& tensor, bool from_const_data_ptr) {
192 if (tensor.is_mkldnn()) {
193 return itensor_from_mkldnn(tensor);
194 } else {
195 return itensor_view_from_dense(tensor, from_const_data_ptr);
196 }
197 }
198
set_verbose(int level)199 int set_verbose(int level) {
200 return ideep::utils::set_verbose(level);
201 }
202
TORCH_LIBRARY_IMPL(mkldnn,MkldnnCPU,m)203 TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
204 m.impl(
205 TORCH_SELECTIVE_NAME("mkldnn::data_ptr"),
206 TORCH_FN(data_ptr_from_mkldnn));
207 m.impl(
208 TORCH_SELECTIVE_NAME("mkldnn::_nbytes"),
209 TORCH_FN(nbytes_from_mkldnn));
210 }
211
212 }}
213
214 #endif // AT_MKLDNN_ENABLED()
215