xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Resize.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/EmptyTensor.h>
4 #include <ATen/native/ResizeCommon.h>
5 
6 #include <c10/cuda/CUDAGuard.h>
7 
8 namespace at { namespace native {
9 
10 TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
11 
maybe_resize_storage_cuda(TensorImpl * self,size_t new_size_bytes)12 static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) {
13   // It does not make sense to try to resize a storage
14   // to hold 0 elements, and this can break
15   // if storage_offset is positive but
16   // new_size is 0, so just bail in that case
17   // (same comment is in Resize.h)
18   if (self->numel() == 0) {
19     return;
20   }
21 
22   const Storage &storage = self->unsafe_storage();
23   TORCH_CHECK(storage, "Tensor: invalid null storage");
24   if (new_size_bytes > storage.nbytes()) {
25     resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes);
26   }
27 }
28 
resize_impl_cuda_(TensorImpl * self,IntArrayRef size,at::OptionalIntArrayRef stride)29 inline TensorImpl* resize_impl_cuda_(
30     TensorImpl* self,
31     IntArrayRef size,
32     at::OptionalIntArrayRef stride) {
33   if (self->sizes() == size && (!stride || self->strides() == stride)) {
34     return self;
35   }
36   const auto itemsize = self->dtype().itemsize();
37   const auto storage_offset = self->storage_offset();
38   size_t storage_size = 1;
39   if (stride) {
40     self->set_sizes_and_strides(size, *stride);
41     storage_size = at::detail::computeStorageNbytes(
42         size, *stride, itemsize, storage_offset);
43   } else {
44     self->set_sizes_contiguous(size);
45     storage_size = at::detail::computeStorageNbytesContiguous(
46         size, itemsize, storage_offset);
47   }
48   maybe_resize_storage_cuda(self, storage_size);
49 
50   return self;
51 }
52 
53 }}
54