xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Resize.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/ResizeCommon.h>
5 #include <ATen/EmptyTensor.h>
6 #include <ATen/TensorUtils.h>
7 
8 #include <c10/core/CPUAllocator.h>
9 
10 #include <utility>
11 
12 
13 namespace at::native {
14 
15 // TODO: make all operations that resize given outputs use this function
16 //   for consistency and maintainability.
17 //   Some operations like `cat` might not be able to make the use of
18 //   resize_output directly. For more details to understand how it works in `cat`,
19 //   see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
20 // Resizes outputs
21 // Functions accepting output tensors, like with the "out" kwarg, should
22 //   call this function to handle resizing their output tensor.
23 // Issues a warning if the output tensor has one or more elements and
24 //   needs resizing
25 // NOTE: In the future the warning will become an error
26 // Returns a bool saying whether or not the resize actually happened or not
27 TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
28 // WARNING: Do NOT call this directly. If you are resizing an output and want
29 // to support dynamic shapes call at::resize__symint and resize_output_check_symint.
30 // For more details, see: https://github.com/pytorch/pytorch/pull/111530/files#r1365845272
31 TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
32 
33 // Utility for resize_output
34 //  Returns a bool saying resize should happen or not and
35 //  raises a warning if resizing for one or more elements
36 TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
37 TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
38 
39 TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
40 TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
41 TORCH_API void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& size_bytes);
42 
maybe_resize_storage_cpu(TensorImpl * self,size_t new_size_bytes)43 inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
44   // It does not make sense to try to resize a storage
45   // to hold 0 elements, and this can break
46   // if storage_offset is positive but
47   // new_size is 0, so just bail in that case
48   // (same comment is in cuda/Resize.h)
49   if (self->numel() == 0) {
50     return;
51   }
52 
53   const Storage& storage = self->unsafe_storage();
54   if (!storage) {
55     auto new_storage = c10::make_intrusive<StorageImpl>(
56         StorageImpl::use_byte_size_t(),
57         new_size_bytes,
58         c10::GetCPUAllocator(),
59         true);
60     self->set_storage_keep_dtype(std::move(new_storage));
61   } else if (new_size_bytes > storage.nbytes()) {
62     resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
63   }
64 }
65 
66 TORCH_API TensorImpl* resize_impl_cpu_(
67     TensorImpl* self,
68     IntArrayRef size,
69     at::OptionalIntArrayRef stride,
70     bool resize_storage = true);
71 
72 template <typename T>
73 T maybe_convert_symint(c10::SymInt) = delete;
74 
75 template <>
maybe_convert_symint(c10::SymInt x)76 inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
77 
78 template <>
maybe_convert_symint(c10::SymInt x)79 inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
80 
81 template <typename T>
checkInBoundsForStorage(ArrayRef<T> size,ArrayRef<T> stride,T storage_offset,const caffe2::TypeMeta & data_type,const Storage & new_storage)82 inline void checkInBoundsForStorage(
83     ArrayRef<T> size,
84     ArrayRef<T> stride,
85     T storage_offset,
86     const caffe2::TypeMeta& data_type,
87     const Storage& new_storage) {
88   T storage_size_bytes =
89       at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
90   T storage_offset_bytes = storage_offset * data_type.itemsize();
91   if (storage_size_bytes == 0) {
92     // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
93     return;
94   }
95   T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
96   TORCH_CHECK(
97       storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
98       "setStorage: sizes ",
99       size,
100       ", strides ",
101       stride,
102       ","
103       " storage offset ",
104       storage_offset,
105       ", and itemsize ",
106       data_type.itemsize(),
107       " requiring a storage size of ",
108       storage_size_bytes + storage_offset_bytes,
109       " are out of bounds for storage of size ",
110       new_storage_size_bytes);
111 }
112 
113 template <typename T>
checkSetStorage(Tensor & result,Storage storage,T storage_offset,ArrayRef<T> size,ArrayRef<T> stride)114 inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
115                                    ArrayRef<T> size, ArrayRef<T> stride) {
116   // FIXME: stride should be optional
117   if (stride.data()) {
118     TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
119                                               ") and stride length (", stride.size(), ")");
120   }
121 
122 #ifdef DEBUG
123   TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
124 #endif
125 
126   // storage: note this can't be replaced with result.set_(storage) as the semantics of that
127   // function is to set the tensor size to be equal to the size of the storage.
128   if (!result.storage().is_alias_of(storage)) {
129     // Caffe2 might have tensors whose storages are null, but we
130     // don't allow it in PyTorch.
131     TORCH_INTERNAL_ASSERT(storage);
132     TORCH_INTERNAL_ASSERT(result.storage());
133 
134     // We used to allow this, but this breaks device caching.
135     // Let's put an actual error message for this one.
136     TORCH_CHECK(result.storage().device() == storage.device(),
137                 "Attempted to set the storage of a tensor on device \"", result.storage().device(),
138                 "\" to a storage on different device \"", storage.device(),
139                 "\".  This is no longer allowed; the devices must match.");
140     result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
141   }
142 
143   // storageOffset
144   TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
145 }
146 
147 /**
148  * Set self's sizes, strides, and storage_offset.
149  * (size, stride, storage_offset) must be in bounds for self's storage.
150  */
151 template <typename T>
setStrided(const Tensor & self,ArrayRef<T> size,ArrayRef<T> stride,T storage_offset)152 inline void setStrided(
153     const Tensor& self,
154     ArrayRef<T> size,
155     ArrayRef<T> stride,
156     T storage_offset) {
157   TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
158   for (const auto& val : stride) {
159     TORCH_CHECK(val >= 0,
160                 "as_strided: Negative strides are not supported at the moment, "
161                 "got strides: ", stride);
162   }
163 
164   auto* self_ = self.unsafeGetTensorImpl();
165   checkInBoundsForStorage(
166       size, stride, storage_offset, self_->dtype(), self_->storage());
167 
168   /* storage offset */
169   TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
170   self_->set_sizes_and_strides(size, stride, std::make_optional(storage_offset));
171 }
172 
173 } // namespace at::native
174