1#import <ATen/native/metal/MetalDevice.h> 2#import <ATen/native/metal/MetalShaders.h> 3#import <ATen/native/metal/MetalContext.h> 4 5#include <c10/util/Exception.h> 6 7#include <mutex> 8#include <unordered_map> 9 10#if C10_IOS 11#import <UIKit/UIKit.h> 12#elif TARGET_OS_MAC 13#import <Foundation/NSProcessInfo.h> 14#endif 15 16using namespace at::native::metal; 17@implementation MetalContext { 18 std::mutex _pipelineCacheMutex; 19 MetalDeviceInfo _deviceInfo; 20 std::unordered_map<std::string, id<MTLComputePipelineState>> _pipelineCache; 21} 22 23+ (instancetype)sharedInstance { 24 static dispatch_once_t onceToken; 25 static MetalContext* instance = nil; 26 dispatch_once(&onceToken, ^{ 27 instance = [[MetalContext alloc] init]; 28 id<MTLDevice> device = MTLCreateSystemDefaultDevice(); 29 instance->_device = device; 30 instance->_deviceInfo = createDeviceInfo(device); 31 instance->_library = nil; 32 instance->_commandQueue = [instance.device newCommandQueue]; 33 }); 34 return instance; 35} 36 37- (BOOL)available { 38#if !defined(__APPLE__) 39 return false; 40#elif TARGET_OS_IPHONE 41 if (!MPSSupportsMTLDevice(_device)) { 42 return false; 43 } 44 if ([UIDevice currentDevice].systemVersion.floatValue < 11.0) { 45 return false; 46 } 47#elif TARGET_OS_MAC 48 if (!MPSSupportsMTLDevice(_device)) { 49 return false; 50 } 51 NSOperatingSystemVersion supportedVer = {10, 13, 0}; 52 if (![[NSProcessInfo processInfo] 53 isOperatingSystemAtLeastVersion:supportedVer]) { 54 return false; 55 } 56C10_CLANG_DIAGNOSTIC_PUSH() 57C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-declarations") 58 if (![_device supportsFeatureSet:MTLFeatureSet_macOS_GPUFamily1_v3]) { 59 return false; 60 } 61C10_CLANG_DIAGNOSTIC_POP() 62#else 63 return false; 64#endif 65 NSError* error = [self _compileProgram]; 66 if (error) { 67 std::string compilationError = error.localizedDescription.UTF8String; 68 std::string deviceInfo = self.description.UTF8String; 69 TORCH_CHECK(false, compilationError + "\n" + deviceInfo); 70 } 71 return _device && _library && _commandQueue; 72} 73 74- (id<MTLComputePipelineState>)pipelineState:(const std::string&)kernel { 75 TORCH_CHECK(_library, "Failed to load Metal shaders"); 76 std::lock_guard<std::mutex> g(_pipelineCacheMutex); 77 id<MTLComputePipelineState> state = _pipelineCache[kernel]; 78 if (state) { 79 return state; 80 } 81 id<MTLFunction> func = [_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()] ?: @""]; 82 TORCH_CHECK(func, "Failed to load the Metal Shader function: ", kernel); 83 NSError* errors = nil; 84 state = [_device newComputePipelineStateWithFunction:func error:&errors]; 85 TORCH_CHECK(state, errors.localizedDescription.UTF8String); 86 _pipelineCache[kernel] = state; 87 return state; 88} 89 90- (id<MTLComputePipelineState>)specializedPipelineState:(const std::string&)kernel 91 Constants:(NSArray<NSNumber*>*) 92 constants { 93 TORCH_CHECK(_library, "Failed to load Metal shaders"); 94 std::string kernelStr = kernel; 95 for (NSUInteger i = 0; i < constants.count; ++i) { 96 kernelStr += "_" + std::string([constants[i] stringValue].UTF8String); 97 } 98 std::lock_guard<std::mutex> g(_pipelineCacheMutex); 99 id<MTLComputePipelineState> state = _pipelineCache[kernelStr]; 100 if (state) { 101 return state; 102 } 103 MTLFunctionConstantValues* constantValues = [MTLFunctionConstantValues new]; 104 NSUInteger ushortArgIndex = 0; 105 NSUInteger floatArgIndex = 12; 106 for (NSUInteger i = 0; i < constants.count; ++i) { 107 NSNumber* constant = constants[i]; 108 const char* type = constant.objCType; 109 if (strcmp(type, @encode(NSUInteger)) == 0 || 110 strcmp(type, @encode(NSInteger)) == 0) { 111 TORCH_CHECK(ushortArgIndex <= 12); 112 ushort value = ushort([constant unsignedIntegerValue]); 113 [constantValues setConstantValue:&value 114 type:MTLDataTypeUShort 115 atIndex:ushortArgIndex]; 116 ushortArgIndex++; 117 } 118 if (strcmp(type, @encode(float)) == 0 || 119 strcmp(type, @encode(double)) == 0) { 120 TORCH_CHECK(floatArgIndex <= 14); 121 float value = [constant floatValue]; 122 [constantValues setConstantValue:&value 123 type:MTLDataTypeFloat 124 atIndex:floatArgIndex]; 125 floatArgIndex++; 126 } 127 } 128 NSError* errors = nil; 129 id<MTLFunction> func = [_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()] ?: @"" 130 constantValues:constantValues 131 error:&errors]; 132 TORCH_CHECK(func, errors.localizedDescription.UTF8String); 133 state = [_device newComputePipelineStateWithFunction:func error:&errors]; 134 TORCH_CHECK(state, errors.localizedDescription.UTF8String); 135 _pipelineCache[kernelStr] = state; 136 return state; 137} 138 139- (id<MTLBuffer>)emptyMTLBuffer:(int64_t) size { 140 TORCH_CHECK(_device); 141 id<MTLBuffer> buffer = [_device newBufferWithLength:size 142 options:MTLResourceCPUCacheModeWriteCombined]; 143 return buffer; 144} 145 146- (NSString*)description { 147 NSString* desc = 148 [NSString stringWithFormat:@"DeviceName: %s, LanguageVersion: %lu", 149 _deviceInfo.name.c_str(), 150 (unsigned long)_deviceInfo.languageVersion]; 151 return desc; 152} 153 154- (NSError*)_compileProgram { 155 __block NSError* compilationError = nil; 156 static dispatch_once_t onceToken; 157 dispatch_once(&onceToken, ^{ 158 NSError* localError = nil; 159 MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; 160 [options setLanguageVersion:_deviceInfo.languageVersion]; 161 [options setFastMathEnabled:YES]; 162 _library = [_device 163 newLibraryWithSource:[NSString stringWithUTF8String:PT_METAL_SHADERS] ?: @"" 164 options:options 165 error:&localError]; 166 compilationError = localError; 167 }); 168 return compilationError; 169} 170 171 172 173@end 174