1 #pragma once 2 3 #include <ATen/ATen.h> 4 5 #include <c10/core/Allocator.h> 6 #include <c10/util/ThreadLocalDebugInfo.h> 7 8 class TestMemoryReportingInfo : public c10::MemoryReportingInfoBase { 9 public: 10 struct Record { 11 void* ptr; 12 int64_t alloc_size; 13 size_t total_allocated; 14 size_t total_reserved; 15 c10::Device device; 16 }; 17 18 std::vector<Record> records; 19 20 TestMemoryReportingInfo() = default; 21 ~TestMemoryReportingInfo() override = default; 22 reportMemoryUsage(void * ptr,int64_t alloc_size,size_t total_allocated,size_t total_reserved,c10::Device device)23 void reportMemoryUsage( 24 void* ptr, 25 int64_t alloc_size, 26 size_t total_allocated, 27 size_t total_reserved, 28 c10::Device device) override { 29 records.emplace_back( 30 Record{ptr, alloc_size, total_allocated, total_reserved, device}); 31 } 32 memoryProfilingEnabled()33 bool memoryProfilingEnabled() const override { 34 return true; 35 } 36 getLatestRecord()37 Record getLatestRecord() { 38 return records.back(); 39 } 40 }; 41