1 #include <torch/csrc/jit/runtime/logging.h> 2 3 #include <atomic> 4 #include <chrono> 5 #include <mutex> 6 #include <stdexcept> 7 #include <unordered_map> 8 9 namespace torch::jit::logging { 10 11 // TODO: multi-scale histogram for this thing 12 addStatValue(const std::string & stat_name,int64_t val)13void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) { 14 std::unique_lock<std::mutex> lk(m); 15 auto& raw_counter = raw_counters[stat_name]; 16 raw_counter.sum += val; 17 raw_counter.count++; 18 } 19 getCounterValue(const std::string & name) const20int64_t LockingLogger::getCounterValue(const std::string& name) const { 21 std::unique_lock<std::mutex> lk(m); 22 if (!raw_counters.count(name)) { 23 return 0; 24 } 25 AggregationType type = 26 agg_types.count(name) ? agg_types.at(name) : AggregationType::SUM; 27 const auto& raw_counter = raw_counters.at(name); 28 switch (type) { 29 case AggregationType::SUM: { 30 return raw_counter.sum; 31 } break; 32 case AggregationType::AVG: { 33 return raw_counter.sum / raw_counter.count; 34 } break; 35 } 36 throw std::runtime_error("Unknown aggregation type!"); 37 } 38 setAggregationType(const std::string & stat_name,AggregationType type)39void LockingLogger::setAggregationType( 40 const std::string& stat_name, 41 AggregationType type) { 42 agg_types[stat_name] = type; 43 } 44 45 std::atomic<LoggerBase*> global_logger{new NoopLogger()}; 46 getLogger()47LoggerBase* getLogger() { 48 return global_logger.load(); 49 } 50 setLogger(LoggerBase * logger)51LoggerBase* setLogger(LoggerBase* logger) { 52 LoggerBase* previous = global_logger.load(); 53 while (!global_logger.compare_exchange_strong(previous, logger)) { 54 previous = global_logger.load(); 55 } 56 return previous; 57 } 58 timePoint()59JITTimePoint timePoint() { 60 return JITTimePoint{std::chrono::high_resolution_clock::now()}; 61 } 62 recordDurationSince(const std::string & name,const JITTimePoint & tp)63void recordDurationSince(const std::string& name, const JITTimePoint& tp) { 64 auto end = std::chrono::high_resolution_clock::now(); 65 // Measurement in microseconds. 66 auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e9; 67 logging::getLogger()->addStatValue(name, seconds); 68 } 69 70 } // namespace torch::jit::logging 71