xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSAllocatorInterface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2023 Apple Inc.
2 
3 #pragma once
4 
5 #include <c10/core/Allocator.h>
6 #include <c10/util/Registry.h>
7 #include <ATen/core/ATen_fwd.h>
8 
9 #define MB(x) (x * 1048576UL)
10 
11 namespace at::mps {
12 
13 // this is a public interface to access MPSAllocator.
14 // Do not declare methods that would depend on MPS or Metal frameworks.
15 class IMPSAllocator : public c10::Allocator {
16 public:
17   // see the comments in MPSAllocator.h for the description of these methods.
18   virtual void emptyCache() const = 0;
19   virtual void freeInactiveBuffers() const = 0;
20   virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
21   virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
22   virtual id_t getBufferId(const void* ptr) const = 0;
23   virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
24   virtual bool isSharedBuffer(const void* ptr) const = 0;
25   virtual bool isSharedStorageSupported() const = 0;
26   virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
27   virtual std::string formatSize(size_t size) const = 0;
28   virtual void setLowWatermarkRatio(double ratio) const = 0;
29   virtual void setHighWatermarkRatio(double ratio) const = 0;
30   virtual ssize_t getLowWatermarkValue() const = 0;
31   virtual size_t getLowWatermarkLimit() const = 0;
32   virtual size_t getHighWatermarkLimit() const = 0;
33   virtual size_t getTotalAllocatedMemory() const = 0;
34   virtual size_t getCurrentAllocatedMemory() const = 0;
35   virtual size_t getDriverAllocatedMemory() const = 0;
36   virtual size_t getRecommendedMaxMemory() const = 0;
37   virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
38   virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
39   virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
40 };
41 
42 class IMpsAllocatorCallback {
43  public:
44   enum class EventType {
45     ALLOCATED, // buffer got allocated to be used immediately
46     RECYCLED,  // buffer pulled from free list to be reused
47     FREED,     // buffer put to free list for future recycling
48     RELEASED,  // buffer memory released
49     ALLOCATION_FAILED // buffer allocation failed
50   };
51   virtual ~IMpsAllocatorCallback() = default;
52   virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
53 };
54 
55 // MPS allocator will execute every registered callback when a block of memory is freed.
56 C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
57 #define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
58   C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
59 
60 IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
61 
62 bool isMPSPinnedPtr(const void* data);
63 
64 } // namespace at::mps
65