xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/reportMemoryUsage.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 
5 #include <c10/core/Allocator.h>
6 #include <c10/util/ThreadLocalDebugInfo.h>
7 
8 class TestMemoryReportingInfo : public c10::MemoryReportingInfoBase {
9  public:
10   struct Record {
11     void* ptr;
12     int64_t alloc_size;
13     size_t total_allocated;
14     size_t total_reserved;
15     c10::Device device;
16   };
17 
18   std::vector<Record> records;
19 
20   TestMemoryReportingInfo() = default;
21   ~TestMemoryReportingInfo() override = default;
22 
reportMemoryUsage(void * ptr,int64_t alloc_size,size_t total_allocated,size_t total_reserved,c10::Device device)23   void reportMemoryUsage(
24       void* ptr,
25       int64_t alloc_size,
26       size_t total_allocated,
27       size_t total_reserved,
28       c10::Device device) override {
29     records.emplace_back(
30         Record{ptr, alloc_size, total_allocated, total_reserved, device});
31   }
32 
memoryProfilingEnabled()33   bool memoryProfilingEnabled() const override {
34     return true;
35   }
36 
getLatestRecord()37   Record getLatestRecord() {
38     return records.back();
39   }
40 };
41