xref: /aosp_15_r20/external/pytorch/c10/core/CachingDeviceAllocator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Allocator.h>
4 #include <c10/util/irange.h>
5 
6 #include <array>
7 
8 namespace c10::CachingDeviceAllocator {
9 
10 struct Stat {
increaseStat11   void increase(size_t amount) {
12     current += static_cast<int64_t>(amount);
13     peak = std::max(current, peak);
14     allocated += static_cast<int64_t>(amount);
15   }
16 
decreaseStat17   void decrease(size_t amount) {
18     current -= static_cast<int64_t>(amount);
19     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
20         current >= 0,
21         "Negative tracked stat in device allocator (likely logic error).");
22     freed += static_cast<int64_t>(amount);
23   }
24 
reset_accumulatedStat25   void reset_accumulated() {
26     allocated = 0;
27     freed = 0;
28   }
29 
reset_peakStat30   void reset_peak() {
31     peak = current;
32   }
33 
34   int64_t current = 0;
35   int64_t peak = 0;
36   int64_t allocated = 0;
37   int64_t freed = 0;
38 };
39 
40 enum struct StatType : uint64_t {
41   AGGREGATE = 0,
42   SMALL_POOL = 1,
43   LARGE_POOL = 2,
44   NUM_TYPES = 3 // remember to update this whenever a new stat type is added
45 };
46 
47 using StatArray = std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)>;
48 using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
49 
50 template <typename Func>
for_each_selected_stat_type(const StatTypes & stat_types,Func f)51 void for_each_selected_stat_type(const StatTypes& stat_types, Func f) {
52   for (const auto stat_type : c10::irange(stat_types.size())) {
53     if (stat_types[stat_type]) {
54       f(stat_type);
55     }
56   }
57 }
58 
59 // Struct containing memory allocator summary statistics for a device.
60 struct DeviceStats {
61   // COUNT: allocations requested by client code
62   StatArray allocation;
63   // COUNT: number of allocated segments from device memory allocation.
64   StatArray segment;
65   // COUNT: number of active memory blocks (allocated or used by stream)
66   StatArray active;
67   // COUNT: number of inactive, split memory blocks (unallocated but can't be
68   // released via device memory deallocation)
69   StatArray inactive_split;
70 
71   // SUM: bytes allocated by this memory alocator
72   StatArray allocated_bytes;
73   // SUM: bytes reserved by this memory allocator (both free and used)
74   StatArray reserved_bytes;
75   // SUM: bytes within active memory blocks
76   StatArray active_bytes;
77   // SUM: bytes within inactive, split memory blocks
78   StatArray inactive_split_bytes;
79   // SUM: bytes requested by client code
80   StatArray requested_bytes;
81 
82   // COUNT: total number of failed calls to device malloc necessitating cache
83   // flushes.
84   int64_t num_alloc_retries = 0;
85 
86   // COUNT: total number of OOMs (i.e. failed calls to device memory allocation
87   // after cache flush)
88   int64_t num_ooms = 0;
89 
90   // COUNT: total number of oversize blocks allocated from pool
91   Stat oversize_allocations;
92 
93   // COUNT: total number of oversize blocks requiring malloc
94   Stat oversize_segments;
95 
96   // COUNT: total number of synchronize_and_free_events() calls
97   int64_t num_sync_all_streams = 0;
98 
99   // COUNT: total number of device memory allocation calls. This includes both
100   // mapped and malloced memory.
101   int64_t num_device_alloc = 0;
102 
103   // COUNT: total number of device memory deallocation calls. This includes both
104   // un-mapped and free memory.
105   int64_t num_device_free = 0;
106 
107   // SIZE: maximum block size that is allowed to be split.
108   int64_t max_split_size = 0;
109 };
110 
111 // Size pretty-printer
format_size(uint64_t size)112 inline std::string format_size(uint64_t size) {
113   std::ostringstream os;
114   os.precision(2);
115   os << std::fixed;
116   if (size <= 1024) {
117     os << size << " bytes";
118   } else if (size <= 1048576) {
119     os << (static_cast<double>(size) / 1024.0);
120     os << " KiB";
121   } else if (size <= 1073741824ULL) {
122     os << static_cast<double>(size) / 1048576.0;
123     os << " MiB";
124   } else {
125     os << static_cast<double>(size) / 1073741824.0;
126     os << " GiB";
127   }
128   return os.str();
129 }
130 
131 } // namespace c10::CachingDeviceAllocator
132