xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/EmptyTensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/EmptyTensor.h>
3 #include <ATen/cuda/CUDAContext.h>
4 #include <ATen/EmptyTensor.h>
5 
6 namespace at::detail {
7 
empty_cuda(IntArrayRef size,ScalarType dtype,std::optional<Device> device_opt,std::optional<c10::MemoryFormat> memory_format_opt)8 TensorBase empty_cuda(
9     IntArrayRef size,
10     ScalarType dtype,
11     std::optional<Device> device_opt,
12     std::optional<c10::MemoryFormat> memory_format_opt) {
13   at::globalContext().lazyInitCUDA();
14   const auto device = device_or_default(device_opt);
15   TORCH_INTERNAL_ASSERT(device.is_cuda());
16   const DeviceGuard device_guard(device);
17   auto* allocator = at::cuda::getCUDADeviceAllocator();
18   constexpr c10::DispatchKeySet cuda_dks(c10::DispatchKey::CUDA);
19   return at::detail::empty_generic(
20       size, allocator, cuda_dks, dtype, memory_format_opt);
21 }
22 
empty_cuda(IntArrayRef size,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt,std::optional<c10::MemoryFormat> memory_format_opt)23 TensorBase empty_cuda(
24     IntArrayRef size,
25     std::optional<ScalarType> dtype_opt,
26     std::optional<Layout> layout_opt,
27     std::optional<Device> device_opt,
28     std::optional<bool> pin_memory_opt,
29     std::optional<c10::MemoryFormat> memory_format_opt) {
30   TORCH_CHECK(!pin_memory_opt.has_value() || !*pin_memory_opt, "Only dense CPU tensors can be pinned");
31   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout_or_default(layout_opt) == Layout::Strided);
32 
33   const auto dtype = dtype_or_default(dtype_opt);
34   return at::detail::empty_cuda(size, dtype, device_opt, memory_format_opt);
35 }
36 
empty_cuda(IntArrayRef size,const TensorOptions & options)37 TensorBase empty_cuda(
38     IntArrayRef size, const TensorOptions &options) {
39   return at::detail::empty_cuda(
40       size,
41       optTypeMetaToScalarType(options.dtype_opt()),
42       options.layout_opt(),
43       options.device_opt(),
44       options.pinned_memory_opt(),
45       options.memory_format_opt());
46 }
47 
empty_strided_cuda(IntArrayRef size,IntArrayRef stride,ScalarType dtype,std::optional<Device> device_opt)48 TensorBase empty_strided_cuda(
49     IntArrayRef size,
50     IntArrayRef stride,
51     ScalarType dtype,
52     std::optional<Device> device_opt) {
53   at::globalContext().lazyInitCUDA();
54   const auto device = device_or_default(device_opt);
55   TORCH_INTERNAL_ASSERT(device.is_cuda());
56   const DeviceGuard device_guard(device);
57   auto* allocator = at::cuda::getCUDADeviceAllocator();
58   constexpr c10::DispatchKeySet cuda_dks(c10::DispatchKey::CUDA);
59   return at::detail::empty_strided_generic(
60       size, stride, allocator, cuda_dks, dtype);
61 }
62 
empty_strided_cuda(IntArrayRef size,IntArrayRef stride,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt)63 TensorBase empty_strided_cuda(
64     IntArrayRef size,
65     IntArrayRef stride,
66     std::optional<ScalarType> dtype_opt,
67     std::optional<Layout> layout_opt,
68     std::optional<Device> device_opt,
69     std::optional<bool> pin_memory_opt) {
70   TORCH_CHECK(!pin_memory_opt.has_value() || !*pin_memory_opt, "Only dense CPU tensors can be pinned");
71   // TODO: remove check for jagged, see https://github.com/pytorch/pytorch/issues/130073
72   const auto layout = layout_or_default(layout_opt);
73   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout == Layout::Strided || layout == Layout::Jagged);
74 
75   const auto dtype = dtype_or_default(dtype_opt);
76   return at::detail::empty_strided_cuda(size, stride, dtype, device_opt);
77 }
78 
empty_strided_cuda(IntArrayRef size,IntArrayRef stride,const TensorOptions & options)79 TensorBase empty_strided_cuda(
80     IntArrayRef size,
81     IntArrayRef stride,
82     const TensorOptions &options) {
83   return at::detail::empty_strided_cuda(
84       size,
85       stride,
86       optTypeMetaToScalarType(options.dtype_opt()),
87       options.layout_opt(),
88       options.device_opt(),
89       options.pinned_memory_opt());
90 }
91 
92 }  // namespace at::detail
93