xref: /aosp_15_r20/external/pytorch/torch/csrc/monitor/counters.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/monitor/counters.h>
2 
3 #include <unordered_set>
4 
5 namespace torch {
6 namespace monitor {
7 
aggregationName(Aggregation agg)8 const char* aggregationName(Aggregation agg) {
9   switch (agg) {
10     case Aggregation::NONE:
11       return "none";
12     case Aggregation::VALUE:
13       return "value";
14     case Aggregation::MEAN:
15       return "mean";
16     case Aggregation::COUNT:
17       return "count";
18     case Aggregation::SUM:
19       return "sum";
20     case Aggregation::MAX:
21       return "max";
22     case Aggregation::MIN:
23       return "min";
24     default:
25       throw std::runtime_error(
26           "unknown aggregation: " + std::to_string(static_cast<int>(agg)));
27   }
28 }
29 
30 namespace {
31 struct Stats {
32   std::mutex mu;
33 
34   std::unordered_set<Stat<double>*> doubles;
35   std::unordered_set<Stat<int64_t>*> int64s;
36 };
37 
stats()38 Stats& stats() {
39   static Stats stats;
40   return stats;
41 }
42 } // namespace
43 
44 namespace detail {
registerStat(Stat<double> * stat)45 void registerStat(Stat<double>* stat) {
46   std::lock_guard<std::mutex> guard(stats().mu);
47 
48   stats().doubles.insert(stat);
49 }
registerStat(Stat<int64_t> * stat)50 void registerStat(Stat<int64_t>* stat) {
51   std::lock_guard<std::mutex> guard(stats().mu);
52 
53   stats().int64s.insert(stat);
54 }
unregisterStat(Stat<double> * stat)55 void unregisterStat(Stat<double>* stat) {
56   std::lock_guard<std::mutex> guard(stats().mu);
57 
58   stats().doubles.erase(stat);
59 }
unregisterStat(Stat<int64_t> * stat)60 void unregisterStat(Stat<int64_t>* stat) {
61   std::lock_guard<std::mutex> guard(stats().mu);
62 
63   stats().int64s.erase(stat);
64 }
65 } // namespace detail
66 
67 } // namespace monitor
68 } // namespace torch
69