1// Copyright © 2022 Apple Inc. 2 3#include <c10/util/CallOnce.h> 4 5#include <ATen/mps/IndexKernels.h> 6#include <ATen/mps/MPSAllocatorInterface.h> 7#include <ATen/mps/MPSDevice.h> 8#include <ATen/mps/MPSStream.h> 9#include <ATen/native/mps/MPSGraphSequoiaOps.h> 10 11namespace at::mps { 12 13static std::unique_ptr<MPSDevice> mps_device; 14static c10::once_flag mpsdev_init; 15 16static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) { 17 // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) 18 // host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+ 19 TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2"); 20 return MTLLanguageVersion3_0; 21} 22 23MPSDevice* MPSDevice::getInstance() { 24 c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); }); 25 return mps_device.get(); 26} 27 28id<MTLLibrary> MPSDevice::getMetalIndexingLibrary() { 29 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device); 30 NSError* error = nil; 31 if (!_mtl_indexing_library) { 32 MTLCompileOptions* options = [MTLCompileOptions new]; 33 34 [options setLanguageVersion:getMetalLanguageVersion(_mtl_device)]; 35 36 if (isMacOS13Plus(MacOSVersion::MACOS_VER_15_0_PLUS)) { 37 options.mathMode = MTLMathModeFast; 38 } else { 39 [options setFastMathEnabled:YES]; 40 } 41 _mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders 42 encoding:NSASCIIStringEncoding] 43 options:options 44 error:&error]; 45 TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]); 46 } 47 return _mtl_indexing_library; 48} 49 50id<MTLComputePipelineState> MPSDevice::metalIndexingPSO(const std::string& kernel) { 51 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device); 52 NSError* error = nil; 53 static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache; 54 id<MTLLibrary> indexing_lib = getMetalIndexingLibrary(); 55 id<MTLComputePipelineState> state = psoCache[kernel]; 56 if (state) { 57 return state; 58 } 59 60 id<MTLFunction> indexFunction = 61 [[indexing_lib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease]; 62 TORCH_CHECK(indexFunction, "Can't find function ", kernel); 63 64 state = [_mtl_device newComputePipelineStateWithFunction:indexFunction error:&error]; 65 TORCH_CHECK(state, error.localizedDescription.UTF8String); 66 psoCache[kernel] = state; 67 return state; 68} 69 70MPSDevice::~MPSDevice() { 71 [_mtl_device release]; 72 [_mtl_indexing_library release]; 73 _mtl_device = nil; 74 _mtl_indexing_library = nil; 75} 76 77MPSDevice::MPSDevice() : _mtl_device(nil), _mtl_indexing_library(nil) { 78 // Check that MacOS 13.0+ version of MPS framework is available 79 // Create the MPSGraph and check method introduced in 13.0 80 // which is used by MPS backend. 81 id mpsCD = NSClassFromString(@"MPSGraph"); 82 83 if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) { 84 return; 85 } 86 87 NSArray* devices = [MTLCopyAllDevices() autorelease]; 88 for (unsigned long i = 0; i < [devices count]; i++) { 89 id<MTLDevice> device = devices[i]; 90 if ([device isLowPower]) { // exclude Intel GPUs 91 continue; 92 } 93 if (![device supportsFamily:MTLGPUFamilyMac2]) { 94 // Exclude devices that does not support Metal 2.0 95 // Virtualised MPS device on MacOS 12.6 should fail this check 96 TORCH_WARN("Skipping device ", [[device name] UTF8String], " that does not support Metal 2.0"); 97 continue; 98 } 99 _mtl_device = [device retain]; 100 break; 101 } 102 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device); 103} 104 105bool MPSDevice::isMacOS13Plus(MacOSVersion version) const { 106 auto is_os_version_at_least = [](int major, int minor) { 107 @autoreleasepool { 108 NSProcessInfo* processInfo = [[NSProcessInfo alloc] init]; 109 return [processInfo 110 isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}]; 111 } 112 }; 113 static bool _macos_13_1_plus = is_os_version_at_least(13, 1); 114 static bool _macos_13_2_plus = is_os_version_at_least(13, 2); 115 static bool _macos_13_3_plus = is_os_version_at_least(13, 3); 116 static bool _macos_14_0_plus = is_os_version_at_least(14, 0); 117 static bool _macos_14_4_plus = is_os_version_at_least(14, 4); 118 static bool _macos_15_0_plus = is_os_version_at_least(15, 0); 119 120 switch (version) { 121 case MacOSVersion::MACOS_VER_13_1_PLUS: 122 return _macos_13_1_plus; 123 case MacOSVersion::MACOS_VER_13_2_PLUS: 124 return _macos_13_2_plus; 125 case MacOSVersion::MACOS_VER_13_3_PLUS: 126 return _macos_13_3_plus; 127 case MacOSVersion::MACOS_VER_14_0_PLUS: 128 return _macos_14_0_plus; 129 case MacOSVersion::MACOS_VER_14_4_PLUS: 130 return _macos_14_4_plus; 131 case MacOSVersion::MACOS_VER_15_0_PLUS: 132 return _macos_15_0_plus; 133 default: 134 return false; 135 } 136} 137 138at::Allocator* GetMPSAllocator(bool useSharedAllocator) { 139 return getIMPSAllocator(useSharedAllocator); 140} 141 142bool is_available() { 143 return MPSDevice::getInstance()->device() != nil; 144} 145 146bool is_macos_13_or_newer(MacOSVersion version) { 147 return MPSDevice::getInstance()->isMacOS13Plus(version); 148} 149 150} // namespace at::mps 151