xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Resize.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/Resize.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/PeerToPeerAccess.h>
6 #include <ATen/native/ResizeCommon.h>
7 #include <c10/cuda/CUDAGuard.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/resize_native.h>
13 #endif
14 
15 namespace at::native {
16 
resize_bytes_cuda(StorageImpl * storage,size_t size_bytes)17 void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes) {
18   TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
19   auto allocator = storage->allocator();
20   TORCH_CHECK(allocator != nullptr, "Trying to resize storage without an allocator");
21 
22   c10::Device device = storage->device();
23 
24   if (size_bytes == 0) {
25     storage->set_data_ptr_noswap(at::DataPtr(nullptr, device));
26     storage->set_nbytes(0);
27     return;
28   }
29 
30   c10::cuda::CUDAGuard guard(device.index());
31   at::DataPtr data = allocator->allocate(size_bytes);
32   if (storage->data_ptr()) {
33     at::globalContext().lazyInitCUDA();
34 
35     C10_CUDA_CHECK(
36         cudaMemcpyAsync(
37             data.get(),
38             storage->data(),
39             std::min(storage->nbytes(), size_bytes),
40             cudaMemcpyDeviceToDevice,
41             c10::cuda::getCurrentCUDAStream()));
42   }
43 
44   // Destructively overwrite data_ptr
45   storage->set_data_ptr_noswap(std::move(data));
46   storage->set_nbytes(size_bytes);
47 }
48 
resize_cuda_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)49 const Tensor& resize_cuda_(
50     const Tensor& self,
51     IntArrayRef size,
52     std::optional<MemoryFormat> optional_memory_format) {
53   if (self.has_names()) {
54     return resize_named_tensor_(self, size, optional_memory_format);
55   }
56   auto* self_ = self.unsafeGetTensorImpl();
57   int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().nbytes() : 0;
58   resize_impl_cuda_(self_, size, /*strides=*/std::nullopt);
59   if (optional_memory_format.has_value()) {
60     auto memory_format =
61         optional_memory_format.value();
62     TORCH_CHECK(
63         memory_format != MemoryFormat::Preserve,
64         "Unsupported memory format",
65         memory_format);
66     self_->empty_tensor_restride(memory_format);
67   }
68   // See Note [Enabling Deterministic Operations]
69   if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
70     at::native::fill_resize_deterministic_(self, old_storage_nbytes);
71   }
72   return self;
73 }
74 } // namespace at::native
75