1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/Resize.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/PeerToPeerAccess.h>
6 #include <ATen/native/ResizeCommon.h>
7 #include <c10/cuda/CUDAGuard.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/resize_native.h>
13 #endif
14
15 namespace at::native {
16
resize_bytes_cuda(StorageImpl * storage,size_t size_bytes)17 void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes) {
18 TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
19 auto allocator = storage->allocator();
20 TORCH_CHECK(allocator != nullptr, "Trying to resize storage without an allocator");
21
22 c10::Device device = storage->device();
23
24 if (size_bytes == 0) {
25 storage->set_data_ptr_noswap(at::DataPtr(nullptr, device));
26 storage->set_nbytes(0);
27 return;
28 }
29
30 c10::cuda::CUDAGuard guard(device.index());
31 at::DataPtr data = allocator->allocate(size_bytes);
32 if (storage->data_ptr()) {
33 at::globalContext().lazyInitCUDA();
34
35 C10_CUDA_CHECK(
36 cudaMemcpyAsync(
37 data.get(),
38 storage->data(),
39 std::min(storage->nbytes(), size_bytes),
40 cudaMemcpyDeviceToDevice,
41 c10::cuda::getCurrentCUDAStream()));
42 }
43
44 // Destructively overwrite data_ptr
45 storage->set_data_ptr_noswap(std::move(data));
46 storage->set_nbytes(size_bytes);
47 }
48
resize_cuda_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)49 const Tensor& resize_cuda_(
50 const Tensor& self,
51 IntArrayRef size,
52 std::optional<MemoryFormat> optional_memory_format) {
53 if (self.has_names()) {
54 return resize_named_tensor_(self, size, optional_memory_format);
55 }
56 auto* self_ = self.unsafeGetTensorImpl();
57 int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().nbytes() : 0;
58 resize_impl_cuda_(self_, size, /*strides=*/std::nullopt);
59 if (optional_memory_format.has_value()) {
60 auto memory_format =
61 optional_memory_format.value();
62 TORCH_CHECK(
63 memory_format != MemoryFormat::Preserve,
64 "Unsupported memory format",
65 memory_format);
66 self_->empty_tensor_restride(memory_format);
67 }
68 // See Note [Enabling Deterministic Operations]
69 if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
70 at::native::fill_resize_deterministic_(self, old_storage_nbytes);
71 }
72 return self;
73 }
74 } // namespace at::native
75