1 #include <torch/csrc/monitor/counters.h> 2 3 #include <unordered_set> 4 5 namespace torch { 6 namespace monitor { 7 aggregationName(Aggregation agg)8const char* aggregationName(Aggregation agg) { 9 switch (agg) { 10 case Aggregation::NONE: 11 return "none"; 12 case Aggregation::VALUE: 13 return "value"; 14 case Aggregation::MEAN: 15 return "mean"; 16 case Aggregation::COUNT: 17 return "count"; 18 case Aggregation::SUM: 19 return "sum"; 20 case Aggregation::MAX: 21 return "max"; 22 case Aggregation::MIN: 23 return "min"; 24 default: 25 throw std::runtime_error( 26 "unknown aggregation: " + std::to_string(static_cast<int>(agg))); 27 } 28 } 29 30 namespace { 31 struct Stats { 32 std::mutex mu; 33 34 std::unordered_set<Stat<double>*> doubles; 35 std::unordered_set<Stat<int64_t>*> int64s; 36 }; 37 stats()38Stats& stats() { 39 static Stats stats; 40 return stats; 41 } 42 } // namespace 43 44 namespace detail { registerStat(Stat<double> * stat)45void registerStat(Stat<double>* stat) { 46 std::lock_guard<std::mutex> guard(stats().mu); 47 48 stats().doubles.insert(stat); 49 } registerStat(Stat<int64_t> * stat)50void registerStat(Stat<int64_t>* stat) { 51 std::lock_guard<std::mutex> guard(stats().mu); 52 53 stats().int64s.insert(stat); 54 } unregisterStat(Stat<double> * stat)55void unregisterStat(Stat<double>* stat) { 56 std::lock_guard<std::mutex> guard(stats().mu); 57 58 stats().doubles.erase(stat); 59 } unregisterStat(Stat<int64_t> * stat)60void unregisterStat(Stat<int64_t>* stat) { 61 std::lock_guard<std::mutex> guard(stats().mu); 62 63 stats().int64s.erase(stat); 64 } 65 } // namespace detail 66 67 } // namespace monitor 68 } // namespace torch 69