xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Config.h>
2 
3 #if AT_MKLDNN_ENABLED()
4 #include <c10/core/CPUAllocator.h>
5 #include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
6 
7 namespace torch {
8 namespace jit {
9 namespace fuser {
10 namespace onednn {
11 
12 // Non-default dnnl::graph::allocator needs an allocator.
13 // We would let it use c10::GetCPUAllocator's allocator,
14 // which uses posix_memalign with 64 byte alignment-size.
pytorch_default_allocator(size_t size,size_t alignment)15 static void* pytorch_default_allocator(size_t size, size_t alignment) {
16   static c10::Allocator* c10_allocator = c10::GetCPUAllocator();
17   return c10_allocator->raw_allocate(size);
18 }
19 
20 // Non-default dnnl::graph::allocator needs a deallocator.
21 // We would let it use c10::GetCPUAllocator's deallocator.
pytorch_default_deallocator(void * buf)22 static void pytorch_default_deallocator(void* buf) {
23   static c10::Allocator* c10_allocator = c10::GetCPUAllocator();
24   c10_allocator->raw_deallocate(buf);
25 }
26 
getEngine()27 dnnl::engine& Engine::getEngine() {
28   // Even if the default PyTorch CPU allocator would change, we'd still use the
29   // stale value. In practice, we don't expect users to change the CPU allocator
30   // dynamically anyway, as users preload jemalloc/tcmalloc at runtime, if they
31   // would like to. But this behavior might need to be changed, as some models
32   // work better with tcmalloc, while others work better with jemalloc, so
33   // switching the CPU allocator at runtime can be useful.
34   static dnnl::graph::allocator alloc{
35       pytorch_default_allocator, pytorch_default_deallocator};
36   static dnnl::engine cpu_engine = dnnl::graph::make_engine_with_allocator(
37       dnnl::engine::kind::cpu, /* device_id = */ 0, alloc);
38   return cpu_engine;
39 }
40 
getStream()41 dnnl::stream& Stream::getStream() {
42   static dnnl::stream cpu_stream{Engine::getEngine()};
43   return cpu_stream;
44 }
45 
LlgaTensorImpl(at::Storage && storage,const caffe2::TypeMeta & data_type,const LlgaTensorDesc & desc)46 LlgaTensorImpl::LlgaTensorImpl(
47     at::Storage&& storage,
48     const caffe2::TypeMeta& data_type,
49     const LlgaTensorDesc& desc)
50     : at::TensorImpl(
51           std::move(storage),
52           c10::DispatchKeySet(c10::DispatchKey::MkldnnCPU),
53           data_type),
54       desc_(desc) {
55   set_sizes_and_strides(desc.sizes(), desc.strides());
56   refresh_numel();
57 }
58 
llga_to_aten_tensor(LlgaTensorImpl * llgaImpl)59 at::Tensor LlgaTensorImpl::llga_to_aten_tensor(LlgaTensorImpl* llgaImpl) {
60   auto aten_tensor = at::detail::make_tensor<TensorImpl>(
61       std::move(llgaImpl->storage_),
62       c10::DispatchKeySet(c10::DispatchKey::CPU),
63       llgaImpl->data_type_);
64   auto impl = aten_tensor.unsafeGetTensorImpl();
65   impl->set_storage_offset(llgaImpl->storage_offset_);
66   impl->set_sizes_and_strides(llgaImpl->sizes(), llgaImpl->strides());
67   return aten_tensor;
68 }
69 
empty_llga(const LlgaTensorDesc & desc,const c10::TensorOptions & options)70 at::Tensor empty_llga(
71     const LlgaTensorDesc& desc,
72     const c10::TensorOptions& options) {
73   auto nbytes = desc.storage_size();
74 
75   auto allocator = at::GetCPUAllocator();
76   auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
77       c10::StorageImpl::use_byte_size_t(),
78       nbytes,
79       allocator->allocate(nbytes),
80       allocator,
81       /*resizable=*/false);
82 
83   return at::detail::make_tensor<LlgaTensorImpl>(
84       std::move(storage_impl), options.dtype(), desc);
85 }
86 
get_llga_desc(const at::Tensor & tensor)87 static const LlgaTensorDesc& get_llga_desc(const at::Tensor& tensor) {
88   TORCH_INTERNAL_ASSERT(
89       tensor.is_mkldnn(), "get_llga_desc expects Mkldnn tensor input");
90   return static_cast<LlgaTensorImpl*>(tensor.unsafeGetTensorImpl())->desc();
91 }
92 
llga_from_aten_tensor(const at::Tensor & tensor)93 dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor) {
94   return {
95       get_llga_desc(tensor).logical_tensor(),
96       torch::jit::fuser::onednn::Engine::getEngine(),
97       tensor.data_ptr()};
98 }
99 
100 using data_type = dnnl::graph::logical_tensor::data_type;
101 
getLlgaDataType(at::ScalarType dt) const102 data_type LlgaTensorDesc::getLlgaDataType(at::ScalarType dt) const {
103   switch (dt) {
104     case at::ScalarType::Float:
105       return data_type::f32;
106     case at::ScalarType::BFloat16:
107       return data_type::bf16;
108     case at::kInt:
109       return data_type::s32;
110     case at::ScalarType::QInt8:
111       return data_type::s8;
112     case at::ScalarType::QUInt8:
113       return data_type::u8;
114     default:
115       // If a dtype is unsupported, oneDNN Graph will make that op a wildcard in
116       // the graph construction stage. Then when we would execute oneDNN Graph
117       // kernels pertaining to oneDNN Graph partitions, such an op would not be
118       // inside a oneDNN Graph partition, so we would not encounter inputs with
119       // unsupported dtypes at the time of executing compiled partitions.
120       return data_type::undef;
121   }
122 }
123 
supplementTensorInfo(const at::Tensor & t) const124 LlgaTensorDesc LlgaTensorDesc::supplementTensorInfo(const at::Tensor& t) const {
125   if (t.is_mkldnn()) {
126     // if input tensor is of mkldnn, it's originated from an upstream
127     // LLGA partition which carries opaque layout info
128     return get_llga_desc(t).tid(tid_);
129   } else {
130     // if input tensor is not an mkldnn tensor, use default layout
131     auto sizes = t.sizes().vec();
132     auto strides = t.strides().vec();
133     auto dtype = getLlgaDataType(t.scalar_type());
134     return {tid_, sizes, strides, dtype, property_type_};
135   }
136 }
137 
aten_scalar_type() const138 at::ScalarType LlgaTensorDesc::aten_scalar_type() const {
139   switch (dtype_) {
140     case data_type::f32:
141       return at::ScalarType::Float;
142     case data_type::bf16:
143       return at::ScalarType::BFloat16;
144     case data_type::s32:
145       return at::kInt;
146     case data_type::s8:
147       return at::ScalarType::QInt8;
148     case data_type::u8:
149       return at::ScalarType::QUInt8;
150     default:
151       TORCH_CHECK(false, "Invalid data type ", static_cast<size_t>(dtype_));
152   }
153 }
154 
155 } // namespace onednn
156 } // namespace fuser
157 } // namespace jit
158 } // namespace torch
159 
160 #endif // AT_MKLDNN_ENABLED()
161