1 #include <torch/library.h>
2
3 #include <ATen/FunctionalTensorWrapper.h>
4 #include <ATen/native/Resize.h>
5
6 #ifdef USE_CUDA
7 #include <ATen/native/cuda/Resize.h>
8 #endif
9
10 namespace torch::inductor {
11 using namespace at;
12
13 // NOLINTNEXTLINE(performance-unnecessary-value-param)
resize_storage_bytes_(const Tensor & variable,SymInt new_size)14 static void resize_storage_bytes_(const Tensor& variable, SymInt new_size) {
15 // similar to THPStorage_resize_ in StorageMethods.cpp, but is traceable
16 if (variable.storage().device_type() == at::kCUDA) {
17 // rocm build has undefined reference to resize_bytes_cuda
18 #if defined(USE_CUDA) && !defined(USE_ROCM)
19 at::native::resize_bytes_cuda(
20 variable.storage().unsafeGetStorageImpl(), new_size.expect_int());
21 #else
22 TORCH_CHECK(false, "built without cuda");
23 #endif
24 } else {
25 at::native::resize_bytes_nocuda(variable.storage(), new_size);
26 }
27 }
28
resize_storage_bytes__functionalize(const Tensor & variable,SymInt new_size)29 static void resize_storage_bytes__functionalize(
30 const Tensor& variable,
31 // NOLINTNEXTLINE(performance-unnecessary-value-param)
32 SymInt new_size) {
33 static auto op = c10::Dispatcher::singleton()
34 .findSchemaOrThrow("inductor::resize_storage_bytes_", "")
35 .typed<void(const Tensor&, SymInt)>();
36 if (!at::functionalization::impl::isFunctionalTensor(variable)) {
37 // Functionalization not active: nop
38 at::AutoDispatchSkipFunctionalize guard;
39 op.call(variable, new_size);
40 return;
41 }
42 // Don't functionalize, call the mutable op on the inner tensor.
43 auto functional_impl =
44 at::functionalization::impl::unsafeGetFunctionalWrapper(variable);
45 {
46 at::AutoDispatchSkipFunctionalize guard;
47 op.call(functional_impl->value(), new_size);
48 return;
49 }
50 }
51
TORCH_LIBRARY_FRAGMENT(inductor,m)52 TORCH_LIBRARY_FRAGMENT(inductor, m) {
53 m.def(
54 "resize_storage_bytes_(Tensor variable, SymInt new_size) -> ()",
55 dispatch(
56 c10::DispatchKey::CompositeExplicitAutograd, resize_storage_bytes_),
57 {at::Tag::pt2_compliant_tag});
58 }
59
TORCH_LIBRARY_IMPL(inductor,Functionalize,m)60 TORCH_LIBRARY_IMPL(inductor, Functionalize, m) {
61 m.impl(
62 "resize_storage_bytes_", TORCH_FN(resize_storage_bytes__functionalize));
63 }
64
65 } // namespace torch::inductor
66