xref: /aosp_15_r20/external/pytorch/aten/src/ATen/StorageUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Functions.h>
2 #include <ATen/MapAllocator.h>
3 #include <ATen/StorageUtils.h>
4 #include <c10/core/TensorOptions.h>
5 
6 namespace at {
7 
new_shm_fd_storage(size_t size)8 C10_EXPORT c10::intrusive_ptr<c10::StorageImpl> new_shm_fd_storage(
9     size_t size) {
10   int flags = ALLOCATOR_MAPPED_SHAREDMEM | ALLOCATOR_MAPPED_EXCLUSIVE |
11       ALLOCATOR_MAPPED_KEEPFD | ALLOCATOR_MAPPED_UNLINK;
12   std::string handle = NewProcessWideShmHandle();
13   auto sptr = MapAllocator::makeDataPtr(
14       handle.c_str(), flags, size * sizeof(uint8_t), nullptr);
15   return c10::make_intrusive<StorageImpl>(
16       c10::StorageImpl::use_byte_size_t(),
17       size,
18       std::move(sptr),
19       /*allocator=*/nullptr,
20       /*resizable=*/false);
21 }
22 
storage_copy(c10::Storage & dst,const c10::Storage & src,bool non_blocking)23 C10_EXPORT void storage_copy(
24     c10::Storage& dst,
25     const c10::Storage& src,
26     bool non_blocking) {
27   auto dst_options = c10::TensorOptions().device(dst.device()).dtype(at::kByte);
28   auto dst_t = at::empty({0}, dst_options).set_(dst);
29 
30   auto src_options = c10::TensorOptions().device(src.device()).dtype(at::kByte);
31   auto src_t = at::empty({0}, src_options).set_(src);
32   dst_t.copy_(src_t, non_blocking);
33 }
34 
share_memory_(TensorBase & t)35 C10_EXPORT void share_memory_(TensorBase& t) {
36   if (t.device() != at::kCPU) {
37     return;
38   }
39 
40   const at::Storage& origStorage = t.storage();
41 
42   if (MapAllocator::fromDataPtr(origStorage.data_ptr()) != nullptr) {
43     // already shared
44     return;
45   }
46   at::Storage newStorage(new_shm_fd_storage(origStorage.nbytes()));
47   storage_copy(newStorage, origStorage);
48 
49   // Replace the old data_ptr and allocator with the new ones
50   c10::StorageImpl* origStorageImpl = origStorage.unsafeGetStorageImpl();
51   c10::StorageImpl* newStorageImpl = newStorage.unsafeGetStorageImpl();
52   origStorageImpl->set_data_ptr(std::move(newStorageImpl->mutable_data_ptr()));
53   origStorageImpl->set_allocator(newStorageImpl->allocator());
54 }
55 
56 } // namespace at
57