1 // Copyright © 2022 Apple Inc.
2
3 #include <ATen/ATen.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/Utils.h>
6 #include <torch/library.h>
7 #include <ATen/mps/EmptyTensor.h>
8 #include <ATen/mps/MPSDevice.h>
9 #include <ATen/native/Resize.h>
10 #include <ATen/native/TensorFactories.h>
11 #include <ATen/native/mps/Copy.h>
12
13 #define MPS_ERROR_NOT_COMPILED "PyTorch code is not compiled with MPS enabled"
14 #define MPS_ERROR_RUNTIME_TOO_LOW \
15 "The MPS backend is supported on MacOS 12.3+.", \
16 "Current OS version can be queried using `sw_vers`"
17 #define MPS_ERROR_DOUBLE_NOT_SUPPORTED "Cannot convert a MPS Tensor to float64 dtype " \
18 "as the MPS framework doesn't support float64. Please use float32 instead."
19
20 namespace at::detail {
empty_mps(IntArrayRef size,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt,std::optional<c10::MemoryFormat> memory_format_opt)21 TensorBase empty_mps(
22 IntArrayRef size,
23 std::optional<ScalarType> dtype_opt,
24 std::optional<Layout> layout_opt,
25 std::optional<Device> device_opt,
26 std::optional<bool> pin_memory_opt,
27 std::optional<c10::MemoryFormat> memory_format_opt) {
28 #if defined(__APPLE__)
29 #if __is_target_os(macOS)
30 if (at::hasMPS()) {
31 auto device = device_or_default(device_opt);
32 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::MPS);
33
34 TORCH_CHECK_NOT_IMPLEMENTED(
35 layout_or_default(layout_opt) == Layout::Strided,
36 "only strided tensors are supported on MPS");
37
38 TORCH_CHECK(size.size() <= 16, "MPS supports tensors with dimensions <= 16, but got ", size.size(), ".");
39
40 check_size_nonnegative(size);
41
42 auto* allocator = at::mps::GetMPSAllocator();
43 int64_t nelements = c10::multiply_integers(size);
44 auto dtype = dtype_or_default(dtype_opt);
45 TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
46 TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16 || is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_14_0_PLUS), "MPS BFloat16 is only supported on MacOS 14 or newer");
47
48
49 auto dtype_meta = scalarTypeToTypeMeta(dtype);
50 int64_t size_bytes = nelements * dtype_meta.itemsize();
51 auto storage_impl = c10::make_intrusive<StorageImpl>(
52 c10::StorageImpl::use_byte_size_t(),
53 size_bytes,
54 allocator->allocate(size_bytes),
55 allocator,
56 /*resizeable=*/true);
57
58 auto tensor =
59 detail::make_tensor<TensorImpl>(storage_impl, DispatchKey::MPS, dtype_meta);
60 // Default TensorImpl has size [0]
61 if (size.size() != 1 || size[0] != 0) {
62 tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
63 }
64
65 auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
66 tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
67 // See Note [Enabling Deterministic Operations]
68 if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
69 at::native::fill_empty_deterministic_(tensor);
70 }
71 return tensor;
72 } else {
73 TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
74 }
75 #else
76 TORCH_CHECK(false, MPS_ERROR_NOT_COMPILED)
77 #endif
78 #else
79 TORCH_CHECK(false, MPS_ERROR_NOT_COMPILED)
80 #endif
81 }
82
empty_mps(IntArrayRef size,const TensorOptions & options)83 TensorBase empty_mps(
84 IntArrayRef size, const TensorOptions &options) {
85 return at::detail::empty_mps(
86 size,
87 optTypeMetaToScalarType(options.dtype_opt()),
88 options.layout_opt(),
89 options.device_opt(),
90 options.pinned_memory_opt(),
91 options.memory_format_opt());
92 }
93
empty_strided_mps(IntArrayRef size,IntArrayRef stride,ScalarType dtype,std::optional<Device> device_opt)94 TensorBase empty_strided_mps(
95 IntArrayRef size,
96 IntArrayRef stride,
97 ScalarType dtype,
98 std::optional<Device> device_opt) {
99 #if defined(__APPLE__)
100 #if __is_target_os(macOS)
101 if (at::hasMPS()) {
102 auto device = device_or_default(device_opt);
103 TORCH_INTERNAL_ASSERT(device.is_mps());
104 TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
105 const DeviceGuard device_guard(device);
106 auto* allocator = at::mps::GetMPSAllocator();
107 constexpr c10::DispatchKeySet mps_dks(c10::DispatchKey::MPS);
108 Tensor result = at::detail::empty_strided_generic(
109 size, stride, allocator, mps_dks, dtype);
110 // See Note [Enabling Deterministic Operations]
111 if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
112 at::native::fill_empty_deterministic_(result);
113 }
114 return result;
115 } else {
116 TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
117 }
118 #else
119 TORCH_CHECK(false, MPS_ERROR_NOT_COMPILED)
120 #endif
121 #else
122 TORCH_CHECK(false, MPS_ERROR_NOT_COMPILED)
123 #endif
124 }
125
empty_strided_mps(IntArrayRef size,IntArrayRef stride,const TensorOptions & options)126 TensorBase empty_strided_mps(
127 IntArrayRef size,
128 IntArrayRef stride,
129 const TensorOptions &options) {
130 return at::native::empty_strided_mps(
131 size,
132 stride,
133 optTypeMetaToScalarType(options.dtype_opt()),
134 options.layout_opt(),
135 options.device_opt(),
136 options.pinned_memory_opt());
137 }
138
139 } // namespace at::detail
140