1 // 2 // Copyright (c) 2023 Apple Inc. All rights reserved. 3 // Provided subject to the LICENSE file in the top level directory. 4 // 5 6 #pragma once 7 8 // Obj-C headers 9 #include <Foundation/Foundation.h> 10 #include <Metal/Metal.h> 11 12 // Runtime headers 13 #include <executorch/runtime/backend/interface.h> 14 15 // MPS headers 16 #include <MetalPerformanceShaders/MetalPerformanceShaders.h> 17 18 #include <unordered_map> 19 #include <vector> 20 21 #define MB(x) (x * 1048576UL) 22 23 namespace executorch { 24 namespace backends { 25 namespace mps { 26 namespace delegate { 27 28 // Helper enum to check if a MPSGraph op is supported in a given macOS version 29 enum class MacOSVersion : uint32_t { 30 MACOS_VER_13_0_PLUS = 0, 31 MACOS_VER_13_1_PLUS, 32 MACOS_VER_13_2_PLUS, 33 MACOS_VER_13_3_PLUS, 34 MACOS_VER_14_0_PLUS, 35 MACOS_VER_15_0_PLUS, 36 }; 37 38 enum class LibraryType : uint32_t { 39 INDEXING_KERNELS = 0, 40 MAX = INDEXING_KERNELS, 41 }; 42 43 class MPSDevice { 44 public: 45 /** 46 * MPSDevice should not be cloneable. 47 */ 48 MPSDevice(MPSDevice& other) = delete; 49 /** 50 * MPSDevice should not be assignable. 51 */ 52 void operator=(const MPSDevice&) = delete; 53 /** 54 * Gets single instance of the Device. 55 */ 56 static MPSDevice* getInstance(); 57 /** 58 * Returns the single device. 59 */ device()60 id<MTLDevice> device() { 61 return _mtl_device; 62 } 63 64 /** 65 * Returns whether running on Ventura or newer 66 */ 67 bool isMacOS13Plus(MacOSVersion version) const; 68 69 ~MPSDevice(); 70 71 /** 72 * Compile a PSO for a given library type. 73 * Once compiled, the library and PSOs are cached. 74 */ 75 executorch::runtime::Error compilePSO( 76 LibraryType libraryType, 77 const char* kernelName); 78 executorch::runtime::Error compileLibrary(LibraryType); 79 80 private: 81 static MPSDevice* _device; 82 id<MTLDevice> _mtl_device; 83 std::unordered_map<LibraryType, id<MTLLibrary>> _m_library_cache; 84 std::unordered_map<std::string, id<MTLComputePipelineState>> _m_pso_cache; 85 MPSDevice(); 86 }; 87 88 bool is_macos_13_or_newer( 89 MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS); 90 91 } // namespace delegate 92 } // namespace mps 93 } // namespace backends 94 } // namespace executorch 95