1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/Resize.h>
4 #include <ATen/native/ResizeCommon.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/TensorSubclassLikeUtils.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/resize_as_native.h>
12 #include <ATen/ops/resize_native.h>
13 #include <ATen/ops/resize.h>
14 #include <ATen/ops/_resize_output.h>
15 #include <ATen/ops/_resize_output_native.h>
16 #endif
17
18 namespace at::native {
19
20 // Returns true if resize is necessary
21 template <typename T>
_resize_output_check(const Tensor & output,ArrayRef<T> shape)22 bool _resize_output_check(const Tensor& output, ArrayRef<T> shape) {
23 // Tests for resizing of tensors with one or more elements
24 if (at::symint::sizes<T>(output).equals(shape)) {
25 return false;
26 }
27 if (at::symint::numel<T>(output) != 0) {
28 TORCH_WARN(
29 "An output with one or more elements was resized since it had ",
30 "shape ", at::symint::sizes<T>(output), ", which does not match the required ",
31 "output shape ", shape, ". ",
32 "This behavior is deprecated, and in a future PyTorch release outputs ",
33 "will not be resized unless they have zero elements. You can explicitly ",
34 "reuse an out tensor t by resizing it, inplace, to zero elements with ",
35 "t.resize_(0).");
36 }
37 return true;
38 }
39
resize_output_check(const Tensor & output,IntArrayRef shape)40 bool resize_output_check(const Tensor& output, IntArrayRef shape) {
41 return _resize_output_check(output, shape);
42 }
43
resize_output_check_symint(const Tensor & output,SymIntArrayRef shape)44 bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape) {
45 return _resize_output_check(output, shape);
46 }
47
native_resize_(const Tensor & output,IntArrayRef shape)48 static void native_resize_(const Tensor& output, IntArrayRef shape) {
49 native::resize_(output, shape);
50 }
51
native_resize_(const Tensor & output,SymIntArrayRef shape)52 static void native_resize_(const Tensor& output, SymIntArrayRef shape) {
53 native::resize__symint(output, shape);
54 }
55
56 template <typename T>
_resize_output(const Tensor & output,ArrayRef<T> shape)57 bool _resize_output(const Tensor& output, ArrayRef<T> shape) {
58 if (_resize_output_check<T>(output, shape)) {
59 // avoid a redispatch for cpu and cuda.
60 // TODO: when resize_cuda_ is re-written to be unified with resize_,
61 // we can provide the same benefit for cuda.
62 //
63 // TODO(#61485): functorch wrapped tensors should not go through the
64 // fast path. This is a hack, longer term solutions are in the issue
65 if (output.is_cpu() && !isTensorSubclassLike(output)) {
66 native_resize_(output, shape);
67 } else {
68 at::symint::resize_<T>(output, shape);
69 }
70 return true;
71 } else {
72 return false;
73 }
74 }
75
resize_output(const Tensor & output,IntArrayRef shape)76 bool resize_output(const Tensor& output, IntArrayRef shape) {
77 return _resize_output(output, shape);
78 }
79
resize_output_symint(const Tensor & output,SymIntArrayRef shape)80 bool resize_output_symint(const Tensor& output, SymIntArrayRef shape) {
81 return _resize_output(output, shape);
82 }
83
_resize_output_(const Tensor & self,IntArrayRef shape,c10::Device device)84 const Tensor& _resize_output_(const Tensor& self, IntArrayRef shape, c10::Device device) {
85 TORCH_CHECK(self.device() == device, "out Tensor doesn't have the correct device set");
86 at::native::resize_output(self, shape);
87 return self;
88 }
89
resize_bytes_cpu(StorageImpl * storage,size_t size_bytes)90 void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes) {
91 TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
92
93 at::DataPtr new_data;
94 if (size_bytes != 0) {
95 new_data = storage->allocator()->allocate(size_bytes);
96 }
97 const at::DataPtr& old_data = storage->data_ptr();
98 const auto old_capacity = storage->nbytes();
99 const auto copy_capacity = std::min(size_bytes, old_capacity);
100 if (old_data != nullptr && copy_capacity > 0) {
101 memcpy(new_data.get(), old_data.get(), copy_capacity);
102 }
103 storage->set_data_ptr_noswap(std::move(new_data));
104 storage->set_nbytes(size_bytes);
105 }
106
107 // Call the sparse implementation in SparseTensor.cpp directly.
108 // A dynamic dispatch here is NOT necessary, so I didn't put
109 // this function in native_functions.yaml
110 const Tensor& resize_as_sparse_(const Tensor& self, const Tensor& src);
111
112 // TODO(VitalyFedyunin): Move it to HTML docs.
113 //
114 // Strides of the output tensor of `resize_as_` operator is defined by input
115 // tensor strides and the value of memory_format argument.
116 //
117 // If memory_format is omitted and input tensor have the same shape as output
118 // tensor, strides of the output will remain unchanged. Strides going to be
119 // set to contiguous if shapes are different.
120 //
121 // If memory_format is equals to MemoryFormat::Contiguous (torch.contiguous_format)
122 // output tensor will have contiguous strides.
123 //
124 // If memory_format is equal to MemoryFormat::ChannelsLast (torch.channels_last)
125 // and input tensor is 4D, output tensor will have channels last memory layout.
126 //
127 // If memory_format is equal to MemoryFormat::Preserve (torch.preserve_format)
128 // output tensor will be defined by strides of the input tensor, following
129 // memory format preservation rule:
130 //
131 // - If input tensor strides are in channels last format, output tensor will
132 // have channels last memory layout.
133 //
134 // - Otherwise, output tensor will have contiguous memory layout.
135 //
resize_as_(const Tensor & self,const Tensor & the_template,std::optional<MemoryFormat> optional_memory_format)136 const Tensor& resize_as_(
137 const Tensor& self,
138 const Tensor& the_template,
139 std::optional<MemoryFormat> optional_memory_format) {
140 if (self.is_sparse() && the_template.is_sparse()) {
141 TORCH_CHECK(
142 !optional_memory_format.has_value(),
143 "Unsupported memory format for sparse tensor resize_as_ :",
144 optional_memory_format.value());
145 return at::native::resize_as_sparse_(self, the_template);
146 }
147 const Tensor& result = self.resize_(the_template.sizes());
148 if (optional_memory_format.has_value()) {
149 auto memory_format = optional_memory_format.value();
150 if (memory_format == MemoryFormat::Preserve) {
151 memory_format = the_template.suggest_memory_format();
152 }
153 self.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
154 }
155 namedinference::propagate_names(result, the_template);
156 return result;
157 }
158
159
resize_bytes_meta(StorageImpl * storage,c10::SymInt size_bytes)160 void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes) {
161 TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
162 storage->set_nbytes(std::move(size_bytes));
163 }
164
maybe_resize_storage_meta(TensorImpl * self,c10::SymInt new_size_bytes)165 static void maybe_resize_storage_meta(TensorImpl* self, c10::SymInt new_size_bytes) {
166 // It does not make sense to try to resize a storage
167 // to hold 0 elements, and this can break
168 // if storage_offset is positive but
169 // new_size is 0, so just bail in that case
170 // (same comment is in Resize.h)
171 if (self->sym_numel() == 0) {
172 return;
173 }
174
175 const Storage& storage = self->unsafe_storage();
176 if (!storage) {
177 TORCH_INTERNAL_ASSERT(0, "NYI, this should only be Caffe2");
178 } else if (new_size_bytes > storage.sym_nbytes()) {
179 resize_bytes_meta(storage.unsafeGetStorageImpl(), std::move(new_size_bytes));
180 }
181 }
182
_maybe_resize_storage(TensorImpl * self,int64_t new_size_bytes)183 static void _maybe_resize_storage(TensorImpl* self, int64_t new_size_bytes) {
184 maybe_resize_storage_cpu(self, new_size_bytes);
185 }
186
_maybe_resize_storage(TensorImpl * self,c10::SymInt new_size_bytes)187 static void _maybe_resize_storage(TensorImpl* self, c10::SymInt new_size_bytes) {
188 if (self->is_cpu()) {
189 maybe_resize_storage_cpu(self, new_size_bytes.expect_int());
190 return;
191 }
192 TORCH_INTERNAL_ASSERT(self->is_meta());
193 maybe_resize_storage_meta(self, std::move(new_size_bytes));
194 }
195
196 template <typename T>
_resize_impl_(TensorImpl * self,ArrayRef<T> size,at::OptionalArrayRef<T> stride,bool resize_storage)197 TensorImpl* _resize_impl_(
198 TensorImpl* self,
199 ArrayRef<T> size,
200 at::OptionalArrayRef<T> stride,
201 bool resize_storage) {
202 if (self->generic_sizes<T>() == size && (!stride || self->generic_strides<T>() == stride.value())) {
203 return self;
204 }
205
206 const auto itemsize = self->dtype().itemsize();
207 const auto storage_offset = self->generic_storage_offset<T>();
208 T storage_size = T(1);
209 if (stride) {
210 self->set_sizes_and_strides(size, *stride);
211 storage_size = at::detail::computeStorageNbytes(
212 size, *stride, itemsize, storage_offset);
213 } else {
214 self->generic_set_sizes_contiguous(size);
215 storage_size = at::detail::computeStorageNbytesContiguous(
216 size, itemsize, storage_offset);
217 }
218
219 if (resize_storage) {
220 _maybe_resize_storage(self, std::move(storage_size));
221 }
222
223 return self;
224 }
225
resize_impl_cpu_(TensorImpl * self,IntArrayRef size,at::OptionalIntArrayRef stride,bool resize_storage)226 TensorImpl* resize_impl_cpu_(
227 TensorImpl* self,
228 IntArrayRef size,
229 at::OptionalIntArrayRef stride,
230 bool resize_storage) {
231 return _resize_impl_(self, size, stride, resize_storage);
232 }
233
234 template <typename T>
_resize_(const Tensor & self,ArrayRef<T> size,std::optional<MemoryFormat> optional_memory_format)235 const Tensor& _resize_(
236 const Tensor& self,
237 ArrayRef<T> size,
238 std::optional<MemoryFormat> optional_memory_format) {
239 auto* self_ = self.unsafeGetTensorImpl();
240 int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().sym_nbytes().maybe_as_int().value_or(-1) : 0;
241 // NOLINTNEXTLINE(bugprone-argument-comment)
242 _resize_impl_<T>(self_, size, /*strides=*/std::nullopt, true);
243 if (optional_memory_format.has_value()) {
244 auto memory_format =
245 optional_memory_format.value();
246 TORCH_CHECK(
247 memory_format != MemoryFormat::Preserve,
248 "Unsupported memory format",
249 memory_format);
250 self_->empty_tensor_restride(memory_format);
251 }
252 // See Note [Enabling Deterministic Operations]
253 if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory() && old_storage_nbytes != -1)) {
254 at::native::fill_resize_deterministic_(self, old_storage_nbytes);
255 }
256 return self;
257 }
258
resize_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)259 const Tensor& resize_(
260 const Tensor& self,
261 IntArrayRef size,
262 std::optional<MemoryFormat> optional_memory_format) {
263 if (self.has_names()) {
264 return resize_named_tensor_(self, size, optional_memory_format);
265 }
266 return _resize_(self, size, optional_memory_format);
267 }
268
resize__symint(const Tensor & self,c10::SymIntArrayRef size,std::optional<MemoryFormat> optional_memory_format)269 const Tensor& resize__symint(
270 const Tensor& self,
271 c10::SymIntArrayRef size,
272 std::optional<MemoryFormat> optional_memory_format) {
273 TORCH_INTERNAL_ASSERT(!self.has_names())
274 return _resize_(self, size, optional_memory_format);
275 }
276
resize_bytes_nocuda(const Storage & storage,const c10::SymInt & newsize)277 void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& newsize) {
278 // handles all devices except cuda (which needs to be in a different .so)
279 c10::DeviceType device_type = storage.device_type();
280 if (device_type == at::kCPU) {
281 at::native::resize_bytes_cpu(storage.unsafeGetStorageImpl(), newsize.expect_int());
282 } else if (device_type == at::kMeta) {
283 at::native::resize_bytes_meta(storage.unsafeGetStorageImpl(), newsize);
284 } else if (device_type == at::kPrivateUse1) {
285 at::detail::getPrivateUse1Hooks().resizePrivateUse1Bytes(
286 storage, newsize.expect_int());
287 } else if (device_type == at::kXPU || device_type == at::kHPU || device_type == at::kMTIA) {
288 ptrdiff_t size_bytes_i = newsize.expect_int();
289 TORCH_CHECK(
290 !c10::overflows<int64_t>(size_bytes_i),
291 "Requested storage size (",
292 size_bytes_i,
293 ") cannot be represented as a int64_t");
294 const auto size_bytes = static_cast<int64_t>(size_bytes_i);
295 void* original_data_ptr = storage.data_ptr().get();
296
297 auto src_option =
298 c10::TensorOptions().device(storage.device()).dtype(at::kByte);
299 auto src_tensor = at::empty({0}, src_option).set_(storage);
300 src_tensor.resize_({size_bytes});
301
302 // When using resize_ to replace resize_bytes_xxx, in some cases
303 // the original data_ptr is still returned, which is an inconsistent
304 // behavior when compared to resize_bytes_xxx. For these cases,
305 // an additional memory copy and update for storage are required.
306 if (original_data_ptr == src_tensor.storage().data_ptr().get()) {
307 auto new_tensor = at::empty(src_tensor.sizes(), src_tensor.options());
308 new_tensor.copy_(src_tensor);
309 storage.set_data_ptr_noswap(
310 std::move(new_tensor.storage().mutable_data_ptr()));
311 storage.unsafeGetStorageImpl()->set_allocator(
312 new_tensor.storage().unsafeGetStorageImpl()->allocator());
313 storage.set_nbytes(new_tensor.storage().nbytes());
314 }
315 } else {
316 TORCH_CHECK(
317 false,
318 "UntypedStorage.resize_: got unexpected device type ",
319 device_type);
320 }
321 }
322
323 } // namespace at::native
324