1 // Copyright © 2022 Apple Inc. 2 3 #pragma once 4 5 #include <c10/core/Allocator.h> 6 #include <ATen/core/Generator.h> 7 #include <ATen/detail/AcceleratorHooksInterface.h> 8 #include <c10/util/Exception.h> 9 #include <c10/util/Registry.h> 10 11 #include <cstddef> 12 13 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") 14 namespace at { 15 16 struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { 17 // this fails the implementation if MPSHooks functions are called, but 18 // MPS backend is not present. 19 #define FAIL_MPSHOOKS_FUNC(func) \ 20 TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend."); 21 22 ~MPSHooksInterface() override = default; 23 24 // Initialize the MPS library state initMPSMPSHooksInterface25 virtual void initMPS() const { 26 FAIL_MPSHOOKS_FUNC(__func__); 27 } hasMPSMPSHooksInterface28 virtual bool hasMPS() const { 29 return false; 30 } 31 virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const { 32 FAIL_MPSHOOKS_FUNC(__func__); 33 } getDefaultMPSGeneratorMPSHooksInterface34 virtual const Generator& getDefaultMPSGenerator() const { 35 FAIL_MPSHOOKS_FUNC(__func__); 36 } getMPSDeviceAllocatorMPSHooksInterface37 virtual Allocator* getMPSDeviceAllocator() const { 38 FAIL_MPSHOOKS_FUNC(__func__); 39 } deviceSynchronizeMPSHooksInterface40 virtual void deviceSynchronize() const { 41 FAIL_MPSHOOKS_FUNC(__func__); 42 } commitStreamMPSHooksInterface43 virtual void commitStream() const { 44 FAIL_MPSHOOKS_FUNC(__func__); 45 } getCommandBufferMPSHooksInterface46 virtual void* getCommandBuffer() const { 47 FAIL_MPSHOOKS_FUNC(__func__); 48 } getDispatchQueueMPSHooksInterface49 virtual void* getDispatchQueue() const { 50 FAIL_MPSHOOKS_FUNC(__func__); 51 } emptyCacheMPSHooksInterface52 virtual void emptyCache() const { 53 FAIL_MPSHOOKS_FUNC(__func__); 54 } getCurrentAllocatedMemoryMPSHooksInterface55 virtual size_t getCurrentAllocatedMemory() const { 56 FAIL_MPSHOOKS_FUNC(__func__); 57 } getDriverAllocatedMemoryMPSHooksInterface58 virtual size_t getDriverAllocatedMemory() const { 59 FAIL_MPSHOOKS_FUNC(__func__); 60 } getRecommendedMaxMemoryMPSHooksInterface61 virtual size_t getRecommendedMaxMemory() const { 62 FAIL_MPSHOOKS_FUNC(__func__); 63 } setMemoryFractionMPSHooksInterface64 virtual void setMemoryFraction(double /*ratio*/) const { 65 FAIL_MPSHOOKS_FUNC(__func__); 66 } profilerStartTraceMPSHooksInterface67 virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const { 68 FAIL_MPSHOOKS_FUNC(__func__); 69 } profilerStopTraceMPSHooksInterface70 virtual void profilerStopTrace() const { 71 FAIL_MPSHOOKS_FUNC(__func__); 72 } acquireEventMPSHooksInterface73 virtual uint32_t acquireEvent(bool enable_timing) const { 74 FAIL_MPSHOOKS_FUNC(__func__); 75 } releaseEventMPSHooksInterface76 virtual void releaseEvent(uint32_t event_id) const { 77 FAIL_MPSHOOKS_FUNC(__func__); 78 } recordEventMPSHooksInterface79 virtual void recordEvent(uint32_t event_id) const { 80 FAIL_MPSHOOKS_FUNC(__func__); 81 } waitForEventMPSHooksInterface82 virtual void waitForEvent(uint32_t event_id) const { 83 FAIL_MPSHOOKS_FUNC(__func__); 84 } synchronizeEventMPSHooksInterface85 virtual void synchronizeEvent(uint32_t event_id) const { 86 FAIL_MPSHOOKS_FUNC(__func__); 87 } queryEventMPSHooksInterface88 virtual bool queryEvent(uint32_t event_id) const { 89 FAIL_MPSHOOKS_FUNC(__func__); 90 } elapsedTimeOfEventsMPSHooksInterface91 virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const { 92 FAIL_MPSHOOKS_FUNC(__func__); 93 } hasPrimaryContextMPSHooksInterface94 bool hasPrimaryContext(DeviceIndex device_index) const override { 95 FAIL_MPSHOOKS_FUNC(__func__); 96 } isPinnedPtrMPSHooksInterface97 bool isPinnedPtr(const void* data) const override { 98 return false; 99 } getPinnedMemoryAllocatorMPSHooksInterface100 Allocator* getPinnedMemoryAllocator() const override { 101 FAIL_MPSHOOKS_FUNC(__func__); 102 } 103 #undef FAIL_MPSHOOKS_FUNC 104 }; 105 106 struct TORCH_API MPSHooksArgs {}; 107 108 TORCH_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs); 109 #define REGISTER_MPS_HOOKS(clsname) \ 110 C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname) 111 112 namespace detail { 113 TORCH_API const MPSHooksInterface& getMPSHooks(); 114 115 } // namespace detail 116 } // namespace at 117 C10_DIAGNOSTIC_POP() 118