xref: /aosp_15_r20/external/pytorch/aten/src/ATen/detail/MPSHooksInterface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2022 Apple Inc.
2 
3 #pragma once
4 
5 #include <c10/core/Allocator.h>
6 #include <ATen/core/Generator.h>
7 #include <ATen/detail/AcceleratorHooksInterface.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/Registry.h>
10 
11 #include <cstddef>
12 
13 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
14 namespace at {
15 
16 struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
17   // this fails the implementation if MPSHooks functions are called, but
18   // MPS backend is not present.
19   #define FAIL_MPSHOOKS_FUNC(func) \
20     TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend.");
21 
22   ~MPSHooksInterface() override = default;
23 
24   // Initialize the MPS library state
initMPSMPSHooksInterface25   virtual void initMPS() const {
26     FAIL_MPSHOOKS_FUNC(__func__);
27   }
hasMPSMPSHooksInterface28   virtual bool hasMPS() const {
29     return false;
30   }
31   virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const {
32     FAIL_MPSHOOKS_FUNC(__func__);
33   }
getDefaultMPSGeneratorMPSHooksInterface34   virtual const Generator& getDefaultMPSGenerator() const {
35     FAIL_MPSHOOKS_FUNC(__func__);
36   }
getMPSDeviceAllocatorMPSHooksInterface37   virtual Allocator* getMPSDeviceAllocator() const {
38     FAIL_MPSHOOKS_FUNC(__func__);
39   }
deviceSynchronizeMPSHooksInterface40   virtual void deviceSynchronize() const {
41     FAIL_MPSHOOKS_FUNC(__func__);
42   }
commitStreamMPSHooksInterface43   virtual void commitStream() const {
44     FAIL_MPSHOOKS_FUNC(__func__);
45   }
getCommandBufferMPSHooksInterface46   virtual void* getCommandBuffer() const {
47     FAIL_MPSHOOKS_FUNC(__func__);
48   }
getDispatchQueueMPSHooksInterface49   virtual void* getDispatchQueue() const {
50     FAIL_MPSHOOKS_FUNC(__func__);
51   }
emptyCacheMPSHooksInterface52   virtual void emptyCache() const {
53     FAIL_MPSHOOKS_FUNC(__func__);
54   }
getCurrentAllocatedMemoryMPSHooksInterface55   virtual size_t getCurrentAllocatedMemory() const {
56     FAIL_MPSHOOKS_FUNC(__func__);
57   }
getDriverAllocatedMemoryMPSHooksInterface58   virtual size_t getDriverAllocatedMemory() const {
59     FAIL_MPSHOOKS_FUNC(__func__);
60   }
getRecommendedMaxMemoryMPSHooksInterface61   virtual size_t getRecommendedMaxMemory() const {
62     FAIL_MPSHOOKS_FUNC(__func__);
63   }
setMemoryFractionMPSHooksInterface64   virtual void setMemoryFraction(double /*ratio*/) const {
65     FAIL_MPSHOOKS_FUNC(__func__);
66   }
profilerStartTraceMPSHooksInterface67   virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
68     FAIL_MPSHOOKS_FUNC(__func__);
69   }
profilerStopTraceMPSHooksInterface70   virtual void profilerStopTrace() const {
71     FAIL_MPSHOOKS_FUNC(__func__);
72   }
acquireEventMPSHooksInterface73   virtual uint32_t acquireEvent(bool enable_timing) const {
74     FAIL_MPSHOOKS_FUNC(__func__);
75   }
releaseEventMPSHooksInterface76   virtual void releaseEvent(uint32_t event_id) const {
77     FAIL_MPSHOOKS_FUNC(__func__);
78   }
recordEventMPSHooksInterface79   virtual void recordEvent(uint32_t event_id) const {
80     FAIL_MPSHOOKS_FUNC(__func__);
81   }
waitForEventMPSHooksInterface82   virtual void waitForEvent(uint32_t event_id) const {
83     FAIL_MPSHOOKS_FUNC(__func__);
84   }
synchronizeEventMPSHooksInterface85   virtual void synchronizeEvent(uint32_t event_id) const {
86     FAIL_MPSHOOKS_FUNC(__func__);
87   }
queryEventMPSHooksInterface88   virtual bool queryEvent(uint32_t event_id) const {
89     FAIL_MPSHOOKS_FUNC(__func__);
90   }
elapsedTimeOfEventsMPSHooksInterface91   virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
92     FAIL_MPSHOOKS_FUNC(__func__);
93   }
hasPrimaryContextMPSHooksInterface94   bool hasPrimaryContext(DeviceIndex device_index) const override {
95     FAIL_MPSHOOKS_FUNC(__func__);
96   }
isPinnedPtrMPSHooksInterface97   bool isPinnedPtr(const void* data) const override {
98     return false;
99   }
getPinnedMemoryAllocatorMPSHooksInterface100   Allocator* getPinnedMemoryAllocator() const override {
101     FAIL_MPSHOOKS_FUNC(__func__);
102   }
103   #undef FAIL_MPSHOOKS_FUNC
104 };
105 
106 struct TORCH_API MPSHooksArgs {};
107 
108 TORCH_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs);
109 #define REGISTER_MPS_HOOKS(clsname) \
110   C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname)
111 
112 namespace detail {
113 TORCH_API const MPSHooksInterface& getMPSHooks();
114 
115 } // namespace detail
116 } // namespace at
117 C10_DIAGNOSTIC_POP()
118