1 #pragma once
2
3 #include <ATen/EmptyTensor.h>
4 #include <ATen/native/ResizeCommon.h>
5
6 #include <c10/cuda/CUDAGuard.h>
7
8 namespace at { namespace native {
9
10 TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
11
maybe_resize_storage_cuda(TensorImpl * self,size_t new_size_bytes)12 static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) {
13 // It does not make sense to try to resize a storage
14 // to hold 0 elements, and this can break
15 // if storage_offset is positive but
16 // new_size is 0, so just bail in that case
17 // (same comment is in Resize.h)
18 if (self->numel() == 0) {
19 return;
20 }
21
22 const Storage &storage = self->unsafe_storage();
23 TORCH_CHECK(storage, "Tensor: invalid null storage");
24 if (new_size_bytes > storage.nbytes()) {
25 resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes);
26 }
27 }
28
resize_impl_cuda_(TensorImpl * self,IntArrayRef size,at::OptionalIntArrayRef stride)29 inline TensorImpl* resize_impl_cuda_(
30 TensorImpl* self,
31 IntArrayRef size,
32 at::OptionalIntArrayRef stride) {
33 if (self->sizes() == size && (!stride || self->strides() == stride)) {
34 return self;
35 }
36 const auto itemsize = self->dtype().itemsize();
37 const auto storage_offset = self->storage_offset();
38 size_t storage_size = 1;
39 if (stride) {
40 self->set_sizes_and_strides(size, *stride);
41 storage_size = at::detail::computeStorageNbytes(
42 size, *stride, itemsize, storage_offset);
43 } else {
44 self->set_sizes_contiguous(size);
45 storage_size = at::detail::computeStorageNbytesContiguous(
46 size, itemsize, storage_offset);
47 }
48 maybe_resize_storage_cuda(self, storage_size);
49
50 return self;
51 }
52
53 }}
54