xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSDevice.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2
3#include <c10/util/CallOnce.h>
4
5#include <ATen/mps/IndexKernels.h>
6#include <ATen/mps/MPSAllocatorInterface.h>
7#include <ATen/mps/MPSDevice.h>
8#include <ATen/mps/MPSStream.h>
9#include <ATen/native/mps/MPSGraphSequoiaOps.h>
10
11namespace at::mps {
12
13static std::unique_ptr<MPSDevice> mps_device;
14static c10::once_flag mpsdev_init;
15
16static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
17  // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
18  // host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+
19  TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
20  return MTLLanguageVersion3_0;
21}
22
23MPSDevice* MPSDevice::getInstance() {
24  c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); });
25  return mps_device.get();
26}
27
28id<MTLLibrary> MPSDevice::getMetalIndexingLibrary() {
29  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
30  NSError* error = nil;
31  if (!_mtl_indexing_library) {
32    MTLCompileOptions* options = [MTLCompileOptions new];
33
34    [options setLanguageVersion:getMetalLanguageVersion(_mtl_device)];
35
36    if (isMacOS13Plus(MacOSVersion::MACOS_VER_15_0_PLUS)) {
37      options.mathMode = MTLMathModeFast;
38    } else {
39      [options setFastMathEnabled:YES];
40    }
41    _mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
42                                                                                 encoding:NSASCIIStringEncoding]
43                                                      options:options
44                                                        error:&error];
45    TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
46  }
47  return _mtl_indexing_library;
48}
49
50id<MTLComputePipelineState> MPSDevice::metalIndexingPSO(const std::string& kernel) {
51  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
52  NSError* error = nil;
53  static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
54  id<MTLLibrary> indexing_lib = getMetalIndexingLibrary();
55  id<MTLComputePipelineState> state = psoCache[kernel];
56  if (state) {
57    return state;
58  }
59
60  id<MTLFunction> indexFunction =
61      [[indexing_lib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
62  TORCH_CHECK(indexFunction, "Can't find function ", kernel);
63
64  state = [_mtl_device newComputePipelineStateWithFunction:indexFunction error:&error];
65  TORCH_CHECK(state, error.localizedDescription.UTF8String);
66  psoCache[kernel] = state;
67  return state;
68}
69
70MPSDevice::~MPSDevice() {
71  [_mtl_device release];
72  [_mtl_indexing_library release];
73  _mtl_device = nil;
74  _mtl_indexing_library = nil;
75}
76
77MPSDevice::MPSDevice() : _mtl_device(nil), _mtl_indexing_library(nil) {
78  // Check that MacOS 13.0+ version of MPS framework is available
79  // Create the MPSGraph and check method introduced in 13.0
80  // which is used by MPS backend.
81  id mpsCD = NSClassFromString(@"MPSGraph");
82
83  if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) {
84    return;
85  }
86
87  NSArray* devices = [MTLCopyAllDevices() autorelease];
88  for (unsigned long i = 0; i < [devices count]; i++) {
89    id<MTLDevice> device = devices[i];
90    if ([device isLowPower]) { // exclude Intel GPUs
91      continue;
92    }
93    if (![device supportsFamily:MTLGPUFamilyMac2]) {
94      // Exclude devices that does not support Metal 2.0
95      // Virtualised MPS device on MacOS 12.6 should fail this check
96      TORCH_WARN("Skipping device ", [[device name] UTF8String], " that does not support Metal 2.0");
97      continue;
98    }
99    _mtl_device = [device retain];
100    break;
101  }
102  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
103}
104
105bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
106  auto is_os_version_at_least = [](int major, int minor) {
107    @autoreleasepool {
108      NSProcessInfo* processInfo = [[NSProcessInfo alloc] init];
109      return [processInfo
110          isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}];
111    }
112  };
113  static bool _macos_13_1_plus = is_os_version_at_least(13, 1);
114  static bool _macos_13_2_plus = is_os_version_at_least(13, 2);
115  static bool _macos_13_3_plus = is_os_version_at_least(13, 3);
116  static bool _macos_14_0_plus = is_os_version_at_least(14, 0);
117  static bool _macos_14_4_plus = is_os_version_at_least(14, 4);
118  static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
119
120  switch (version) {
121    case MacOSVersion::MACOS_VER_13_1_PLUS:
122      return _macos_13_1_plus;
123    case MacOSVersion::MACOS_VER_13_2_PLUS:
124      return _macos_13_2_plus;
125    case MacOSVersion::MACOS_VER_13_3_PLUS:
126      return _macos_13_3_plus;
127    case MacOSVersion::MACOS_VER_14_0_PLUS:
128      return _macos_14_0_plus;
129    case MacOSVersion::MACOS_VER_14_4_PLUS:
130      return _macos_14_4_plus;
131    case MacOSVersion::MACOS_VER_15_0_PLUS:
132      return _macos_15_0_plus;
133    default:
134      return false;
135  }
136}
137
138at::Allocator* GetMPSAllocator(bool useSharedAllocator) {
139  return getIMPSAllocator(useSharedAllocator);
140}
141
142bool is_available() {
143  return MPSDevice::getInstance()->device() != nil;
144}
145
146bool is_macos_13_or_newer(MacOSVersion version) {
147  return MPSDevice::getInstance()->isMacOS13Plus(version);
148}
149
150} // namespace at::mps
151