xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalContext.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalDevice.h>
2#import <ATen/native/metal/MetalShaders.h>
3#import <ATen/native/metal/MetalContext.h>
4
5#include <c10/util/Exception.h>
6
7#include <mutex>
8#include <unordered_map>
9
10#if C10_IOS
11#import <UIKit/UIKit.h>
12#elif TARGET_OS_MAC
13#import <Foundation/NSProcessInfo.h>
14#endif
15
16using namespace at::native::metal;
17@implementation MetalContext {
18  std::mutex _pipelineCacheMutex;
19  MetalDeviceInfo _deviceInfo;
20  std::unordered_map<std::string, id<MTLComputePipelineState>> _pipelineCache;
21}
22
23+ (instancetype)sharedInstance {
24  static dispatch_once_t onceToken;
25  static MetalContext* instance = nil;
26  dispatch_once(&onceToken, ^{
27    instance = [[MetalContext alloc] init];
28    id<MTLDevice> device = MTLCreateSystemDefaultDevice();
29    instance->_device = device;
30    instance->_deviceInfo = createDeviceInfo(device);
31    instance->_library = nil;
32    instance->_commandQueue = [instance.device newCommandQueue];
33  });
34  return instance;
35}
36
37- (BOOL)available {
38#if !defined(__APPLE__)
39  return false;
40#elif TARGET_OS_IPHONE
41  if (!MPSSupportsMTLDevice(_device)) {
42    return false;
43  }
44  if ([UIDevice currentDevice].systemVersion.floatValue < 11.0) {
45    return false;
46  }
47#elif TARGET_OS_MAC
48  if (!MPSSupportsMTLDevice(_device)) {
49    return false;
50  }
51  NSOperatingSystemVersion supportedVer = {10, 13, 0};
52  if (![[NSProcessInfo processInfo]
53          isOperatingSystemAtLeastVersion:supportedVer]) {
54    return false;
55  }
56C10_CLANG_DIAGNOSTIC_PUSH()
57C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-declarations")
58  if (![_device supportsFeatureSet:MTLFeatureSet_macOS_GPUFamily1_v3]) {
59    return false;
60  }
61C10_CLANG_DIAGNOSTIC_POP()
62#else
63  return false;
64#endif
65  NSError* error = [self _compileProgram];
66  if (error) {
67    std::string compilationError = error.localizedDescription.UTF8String;
68    std::string deviceInfo = self.description.UTF8String;
69    TORCH_CHECK(false, compilationError + "\n" + deviceInfo);
70  }
71  return _device && _library && _commandQueue;
72}
73
74- (id<MTLComputePipelineState>)pipelineState:(const std::string&)kernel {
75  TORCH_CHECK(_library, "Failed to load Metal shaders");
76  std::lock_guard<std::mutex> g(_pipelineCacheMutex);
77  id<MTLComputePipelineState> state = _pipelineCache[kernel];
78  if (state) {
79    return state;
80  }
81  id<MTLFunction> func = [_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()] ?: @""];
82  TORCH_CHECK(func, "Failed to load the Metal Shader function: ", kernel);
83  NSError* errors = nil;
84  state = [_device newComputePipelineStateWithFunction:func error:&errors];
85  TORCH_CHECK(state, errors.localizedDescription.UTF8String);
86  _pipelineCache[kernel] = state;
87  return state;
88}
89
90- (id<MTLComputePipelineState>)specializedPipelineState:(const std::string&)kernel
91                                              Constants:(NSArray<NSNumber*>*)
92                                                            constants {
93  TORCH_CHECK(_library, "Failed to load Metal shaders");
94  std::string kernelStr = kernel;
95  for (NSUInteger i = 0; i < constants.count; ++i) {
96    kernelStr += "_" + std::string([constants[i] stringValue].UTF8String);
97  }
98  std::lock_guard<std::mutex> g(_pipelineCacheMutex);
99  id<MTLComputePipelineState> state = _pipelineCache[kernelStr];
100  if (state) {
101    return state;
102  }
103  MTLFunctionConstantValues* constantValues = [MTLFunctionConstantValues new];
104  NSUInteger ushortArgIndex = 0;
105  NSUInteger floatArgIndex = 12;
106  for (NSUInteger i = 0; i < constants.count; ++i) {
107    NSNumber* constant = constants[i];
108    const char* type = constant.objCType;
109    if (strcmp(type, @encode(NSUInteger)) == 0 ||
110        strcmp(type, @encode(NSInteger)) == 0) {
111      TORCH_CHECK(ushortArgIndex <= 12);
112      ushort value = ushort([constant unsignedIntegerValue]);
113      [constantValues setConstantValue:&value
114                                  type:MTLDataTypeUShort
115                               atIndex:ushortArgIndex];
116      ushortArgIndex++;
117    }
118    if (strcmp(type, @encode(float)) == 0 ||
119        strcmp(type, @encode(double)) == 0) {
120      TORCH_CHECK(floatArgIndex <= 14);
121      float value = [constant floatValue];
122      [constantValues setConstantValue:&value
123                                  type:MTLDataTypeFloat
124                               atIndex:floatArgIndex];
125      floatArgIndex++;
126    }
127  }
128  NSError* errors = nil;
129  id<MTLFunction> func = [_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()] ?: @""
130                                        constantValues:constantValues
131                                                 error:&errors];
132  TORCH_CHECK(func, errors.localizedDescription.UTF8String);
133  state = [_device newComputePipelineStateWithFunction:func error:&errors];
134  TORCH_CHECK(state, errors.localizedDescription.UTF8String);
135  _pipelineCache[kernelStr] = state;
136  return state;
137}
138
139- (id<MTLBuffer>)emptyMTLBuffer:(int64_t) size {
140    TORCH_CHECK(_device);
141    id<MTLBuffer> buffer = [_device newBufferWithLength:size
142                      options:MTLResourceCPUCacheModeWriteCombined];
143    return buffer;
144}
145
146- (NSString*)description {
147  NSString* desc =
148      [NSString stringWithFormat:@"DeviceName: %s, LanguageVersion: %lu",
149                                 _deviceInfo.name.c_str(),
150                                 (unsigned long)_deviceInfo.languageVersion];
151  return desc;
152}
153
154- (NSError*)_compileProgram {
155  __block NSError* compilationError = nil;
156  static dispatch_once_t onceToken;
157  dispatch_once(&onceToken, ^{
158    NSError* localError = nil;
159    MTLCompileOptions* options = [[MTLCompileOptions alloc] init];
160    [options setLanguageVersion:_deviceInfo.languageVersion];
161    [options setFastMathEnabled:YES];
162    _library = [_device
163        newLibraryWithSource:[NSString stringWithUTF8String:PT_METAL_SHADERS] ?: @""
164                     options:options
165                       error:&localError];
166    compilationError = localError;
167  });
168  return compilationError;
169}
170
171
172
173@end
174