xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/logging.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/logging.h>
2 
3 #include <atomic>
4 #include <chrono>
5 #include <mutex>
6 #include <stdexcept>
7 #include <unordered_map>
8 
9 namespace torch::jit::logging {
10 
11 // TODO: multi-scale histogram for this thing
12 
addStatValue(const std::string & stat_name,int64_t val)13 void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
14   std::unique_lock<std::mutex> lk(m);
15   auto& raw_counter = raw_counters[stat_name];
16   raw_counter.sum += val;
17   raw_counter.count++;
18 }
19 
getCounterValue(const std::string & name) const20 int64_t LockingLogger::getCounterValue(const std::string& name) const {
21   std::unique_lock<std::mutex> lk(m);
22   if (!raw_counters.count(name)) {
23     return 0;
24   }
25   AggregationType type =
26       agg_types.count(name) ? agg_types.at(name) : AggregationType::SUM;
27   const auto& raw_counter = raw_counters.at(name);
28   switch (type) {
29     case AggregationType::SUM: {
30       return raw_counter.sum;
31     } break;
32     case AggregationType::AVG: {
33       return raw_counter.sum / raw_counter.count;
34     } break;
35   }
36   throw std::runtime_error("Unknown aggregation type!");
37 }
38 
setAggregationType(const std::string & stat_name,AggregationType type)39 void LockingLogger::setAggregationType(
40     const std::string& stat_name,
41     AggregationType type) {
42   agg_types[stat_name] = type;
43 }
44 
45 std::atomic<LoggerBase*> global_logger{new NoopLogger()};
46 
getLogger()47 LoggerBase* getLogger() {
48   return global_logger.load();
49 }
50 
setLogger(LoggerBase * logger)51 LoggerBase* setLogger(LoggerBase* logger) {
52   LoggerBase* previous = global_logger.load();
53   while (!global_logger.compare_exchange_strong(previous, logger)) {
54     previous = global_logger.load();
55   }
56   return previous;
57 }
58 
timePoint()59 JITTimePoint timePoint() {
60   return JITTimePoint{std::chrono::high_resolution_clock::now()};
61 }
62 
recordDurationSince(const std::string & name,const JITTimePoint & tp)63 void recordDurationSince(const std::string& name, const JITTimePoint& tp) {
64   auto end = std::chrono::high_resolution_clock::now();
65   // Measurement in microseconds.
66   auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e9;
67   logging::getLogger()->addStatValue(name, seconds);
68 }
69 
70 } // namespace torch::jit::logging
71