xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/EmptyTensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2022 Apple Inc.
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/Utils.h>
6 #include <torch/library.h>
7 #include <ATen/mps/EmptyTensor.h>
8 #include <ATen/mps/MPSDevice.h>
9 #include <ATen/native/Resize.h>
10 #include <ATen/native/TensorFactories.h>
11 #include <ATen/native/mps/Copy.h>
12 
13 #define MPS_ERROR_NOT_COMPILED "PyTorch code is not compiled with MPS enabled"
14 #define MPS_ERROR_RUNTIME_TOO_LOW \
15   "The MPS backend is supported on MacOS 12.3+.", \
16   "Current OS version can be queried using `sw_vers`"
17 #define MPS_ERROR_DOUBLE_NOT_SUPPORTED "Cannot convert a MPS Tensor to float64 dtype " \
18   "as the MPS framework doesn't support float64. Please use float32 instead."
19 
20 namespace at::detail {
empty_mps(IntArrayRef size,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt,std::optional<c10::MemoryFormat> memory_format_opt)21 TensorBase empty_mps(
22     IntArrayRef size,
23     std::optional<ScalarType> dtype_opt,
24     std::optional<Layout> layout_opt,
25     std::optional<Device> device_opt,
26     std::optional<bool> pin_memory_opt,
27     std::optional<c10::MemoryFormat> memory_format_opt) {
28 #if defined(__APPLE__)
29 #if __is_target_os(macOS)
30   if (at::hasMPS()) {
31     auto device = device_or_default(device_opt);
32     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::MPS);
33 
34     TORCH_CHECK_NOT_IMPLEMENTED(
35         layout_or_default(layout_opt) == Layout::Strided,
36         "only strided tensors are supported on MPS");
37 
38     TORCH_CHECK(size.size() <= 16, "MPS supports tensors with dimensions <= 16, but got ", size.size(), ".");
39 
40     check_size_nonnegative(size);
41 
42     auto* allocator = at::mps::GetMPSAllocator();
43     int64_t nelements = c10::multiply_integers(size);
44     auto dtype = dtype_or_default(dtype_opt);
45     TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
46     TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16 || is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_14_0_PLUS), "MPS BFloat16 is only supported on MacOS 14 or newer");
47 
48 
49     auto dtype_meta = scalarTypeToTypeMeta(dtype);
50     int64_t size_bytes = nelements * dtype_meta.itemsize();
51     auto storage_impl = c10::make_intrusive<StorageImpl>(
52         c10::StorageImpl::use_byte_size_t(),
53         size_bytes,
54         allocator->allocate(size_bytes),
55         allocator,
56         /*resizeable=*/true);
57 
58     auto tensor =
59         detail::make_tensor<TensorImpl>(storage_impl, DispatchKey::MPS, dtype_meta);
60     // Default TensorImpl has size [0]
61     if (size.size() != 1 || size[0] != 0) {
62       tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
63     }
64 
65     auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
66     tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
67     // See Note [Enabling Deterministic Operations]
68     if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
69       at::native::fill_empty_deterministic_(tensor);
70     }
71     return tensor;
72   } else {
73     TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
74   }
75 #else
76   TORCH_CHECK(false, MPS_ERROR_NOT_COMPILED)
77 #endif
78 #else
79   TORCH_CHECK(false, MPS_ERROR_NOT_COMPILED)
80 #endif
81 }
82 
empty_mps(IntArrayRef size,const TensorOptions & options)83 TensorBase empty_mps(
84     IntArrayRef size, const TensorOptions &options) {
85   return at::detail::empty_mps(
86       size,
87       optTypeMetaToScalarType(options.dtype_opt()),
88       options.layout_opt(),
89       options.device_opt(),
90       options.pinned_memory_opt(),
91       options.memory_format_opt());
92 }
93 
empty_strided_mps(IntArrayRef size,IntArrayRef stride,ScalarType dtype,std::optional<Device> device_opt)94 TensorBase empty_strided_mps(
95     IntArrayRef size,
96     IntArrayRef stride,
97     ScalarType dtype,
98     std::optional<Device> device_opt) {
99 #if defined(__APPLE__)
100 #if __is_target_os(macOS)
101   if (at::hasMPS()) {
102     auto device = device_or_default(device_opt);
103     TORCH_INTERNAL_ASSERT(device.is_mps());
104     TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
105     const DeviceGuard device_guard(device);
106     auto* allocator = at::mps::GetMPSAllocator();
107     constexpr c10::DispatchKeySet mps_dks(c10::DispatchKey::MPS);
108     Tensor result = at::detail::empty_strided_generic(
109         size, stride, allocator, mps_dks, dtype);
110     // See Note [Enabling Deterministic Operations]
111     if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
112       at::native::fill_empty_deterministic_(result);
113     }
114     return result;
115   } else {
116     TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
117   }
118 #else
119   TORCH_CHECK(false, MPS_ERROR_NOT_COMPILED)
120 #endif
121 #else
122   TORCH_CHECK(false, MPS_ERROR_NOT_COMPILED)
123 #endif
124 }
125 
empty_strided_mps(IntArrayRef size,IntArrayRef stride,const TensorOptions & options)126 TensorBase empty_strided_mps(
127     IntArrayRef size,
128     IntArrayRef stride,
129     const TensorOptions &options) {
130   return at::native::empty_strided_mps(
131       size,
132       stride,
133       optTypeMetaToScalarType(options.dtype_opt()),
134       options.layout_opt(),
135       options.device_opt(),
136       options.pinned_memory_opt());
137 }
138 
139 } // namespace at::detail
140