1 // Copyright © 2023 Apple Inc. 2 3 #pragma once 4 5 #include <c10/core/Allocator.h> 6 #include <c10/util/Registry.h> 7 #include <ATen/core/ATen_fwd.h> 8 9 #define MB(x) (x * 1048576UL) 10 11 namespace at::mps { 12 13 // this is a public interface to access MPSAllocator. 14 // Do not declare methods that would depend on MPS or Metal frameworks. 15 class IMPSAllocator : public c10::Allocator { 16 public: 17 // see the comments in MPSAllocator.h for the description of these methods. 18 virtual void emptyCache() const = 0; 19 virtual void freeInactiveBuffers() const = 0; 20 virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0; 21 virtual IntArrayRef getBufferShape(const void* ptr) const = 0; 22 virtual id_t getBufferId(const void* ptr) const = 0; 23 virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0; 24 virtual bool isSharedBuffer(const void* ptr) const = 0; 25 virtual bool isSharedStorageSupported() const = 0; 26 virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0; 27 virtual std::string formatSize(size_t size) const = 0; 28 virtual void setLowWatermarkRatio(double ratio) const = 0; 29 virtual void setHighWatermarkRatio(double ratio) const = 0; 30 virtual ssize_t getLowWatermarkValue() const = 0; 31 virtual size_t getLowWatermarkLimit() const = 0; 32 virtual size_t getHighWatermarkLimit() const = 0; 33 virtual size_t getTotalAllocatedMemory() const = 0; 34 virtual size_t getCurrentAllocatedMemory() const = 0; 35 virtual size_t getDriverAllocatedMemory() const = 0; 36 virtual size_t getRecommendedMaxMemory() const = 0; 37 virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0; 38 virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0; 39 virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0; 40 }; 41 42 class IMpsAllocatorCallback { 43 public: 44 enum class EventType { 45 ALLOCATED, // buffer got allocated to be used immediately 46 RECYCLED, // buffer pulled from free list to be reused 47 FREED, // buffer put to free list for future recycling 48 RELEASED, // buffer memory released 49 ALLOCATION_FAILED // buffer allocation failed 50 }; 51 virtual ~IMpsAllocatorCallback() = default; 52 virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0; 53 }; 54 55 // MPS allocator will execute every registered callback when a block of memory is freed. 56 C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback); 57 #define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \ 58 C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__); 59 60 IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false); 61 62 bool isMPSPinnedPtr(const void* data); 63 64 } // namespace at::mps 65