xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/resize_storage_bytes.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/library.h>
2 
3 #include <ATen/FunctionalTensorWrapper.h>
4 #include <ATen/native/Resize.h>
5 
6 #ifdef USE_CUDA
7 #include <ATen/native/cuda/Resize.h>
8 #endif
9 
10 namespace torch::inductor {
11 using namespace at;
12 
13 // NOLINTNEXTLINE(performance-unnecessary-value-param)
resize_storage_bytes_(const Tensor & variable,SymInt new_size)14 static void resize_storage_bytes_(const Tensor& variable, SymInt new_size) {
15   // similar to THPStorage_resize_ in StorageMethods.cpp, but is traceable
16   if (variable.storage().device_type() == at::kCUDA) {
17     // rocm build has undefined reference to resize_bytes_cuda
18 #if defined(USE_CUDA) && !defined(USE_ROCM)
19     at::native::resize_bytes_cuda(
20         variable.storage().unsafeGetStorageImpl(), new_size.expect_int());
21 #else
22     TORCH_CHECK(false, "built without cuda");
23 #endif
24   } else {
25     at::native::resize_bytes_nocuda(variable.storage(), new_size);
26   }
27 }
28 
resize_storage_bytes__functionalize(const Tensor & variable,SymInt new_size)29 static void resize_storage_bytes__functionalize(
30     const Tensor& variable,
31     // NOLINTNEXTLINE(performance-unnecessary-value-param)
32     SymInt new_size) {
33   static auto op = c10::Dispatcher::singleton()
34                        .findSchemaOrThrow("inductor::resize_storage_bytes_", "")
35                        .typed<void(const Tensor&, SymInt)>();
36   if (!at::functionalization::impl::isFunctionalTensor(variable)) {
37     // Functionalization not active: nop
38     at::AutoDispatchSkipFunctionalize guard;
39     op.call(variable, new_size);
40     return;
41   }
42   // Don't functionalize, call the mutable op on the inner tensor.
43   auto functional_impl =
44       at::functionalization::impl::unsafeGetFunctionalWrapper(variable);
45   {
46     at::AutoDispatchSkipFunctionalize guard;
47     op.call(functional_impl->value(), new_size);
48     return;
49   }
50 }
51 
TORCH_LIBRARY_FRAGMENT(inductor,m)52 TORCH_LIBRARY_FRAGMENT(inductor, m) {
53   m.def(
54       "resize_storage_bytes_(Tensor variable, SymInt new_size) -> ()",
55       dispatch(
56           c10::DispatchKey::CompositeExplicitAutograd, resize_storage_bytes_),
57       {at::Tag::pt2_compliant_tag});
58 }
59 
TORCH_LIBRARY_IMPL(inductor,Functionalize,m)60 TORCH_LIBRARY_IMPL(inductor, Functionalize, m) {
61   m.impl(
62       "resize_storage_bytes_", TORCH_FN(resize_storage_bytes__functionalize));
63 }
64 
65 } // namespace torch::inductor
66