xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/Vitals.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Vitals.h>
2 #include <cstdlib>
3 #include <iostream>
4 
5 namespace at::vitals {
6 
7 APIVitals VitalsAPI;
8 
operator <<(std::ostream & os,TorchVital const & tv)9 std::ostream& operator<<(std::ostream& os, TorchVital const& tv) {
10   for (const auto& m : tv.attrs) {
11     os << "[TORCH_VITAL] " << tv.name << "." << m.first << "\t\t "
12        << m.second.value << "\n";
13   }
14   return os;
15 }
16 
~TorchVital()17 TorchVital::~TorchVital() {
18   if (torchVitalEnabled()) {
19     std::cout << *this;
20   }
21 }
22 
create(const std::string & attr)23 TorchVitalAttr& TorchVital::create(const std::string& attr) {
24   return create(attr, /* force = */ false);
25 }
26 
create(const std::string & attr,bool force)27 TorchVitalAttr& TorchVital::create(const std::string& attr, bool force) {
28   if (!(torchVitalEnabled() || force)) {
29     static TorchVitalAttr disabled;
30     return disabled;
31   }
32   auto iter = attrs.find(attr);
33   if (iter == attrs.end()) {
34     auto r = attrs.emplace(attr, TorchVitalAttr());
35     return r.first->second;
36   }
37   return iter->second;
38 }
39 
torchVitalEnabled()40 bool torchVitalEnabled() {
41   // If this is a performance hit, make `enabled` variable static
42   // and return `const bool&` instead
43   bool enabled = []() {
44     auto e = getenv("TORCH_VITAL");
45     if (e != nullptr) {
46       return e[0] != '\0';
47     }
48     return false;
49   }();
50   if (enabled) {
51     VitalsAPI.vitals_enabled = true;
52   }
53   return VitalsAPI.vitals_enabled;
54 }
55 
readVitals()56 std::string APIVitals::readVitals() {
57   if (!torchVitalEnabled()) {
58     return "";
59   }
60 
61   std::stringstream buf;
62   for (const auto& x : name_map_) {
63     buf << x.second;
64   }
65   return buf.str();
66 }
67 
setVital(const std::string & vital_name,const std::string & attr_name,const std::string & value,bool force)68 bool APIVitals::setVital(
69     const std::string& vital_name,
70     const std::string& attr_name,
71     const std::string& value,
72     bool force) {
73   if (!(torchVitalEnabled() || force)) {
74     return false;
75   }
76 
77   auto iter = name_map_.find(vital_name);
78   TorchVital* vital = nullptr;
79   if (iter == name_map_.end()) {
80     auto r = name_map_.emplace(vital_name, TorchVital(vital_name));
81     vital = &r.first->second;
82   } else {
83     vital = &iter->second;
84   }
85 
86   vital->create(attr_name, force).write(value, force);
87   return true;
88 }
89 
APIVitals()90 APIVitals::APIVitals() : vitals_enabled(false), name_map_() {
91   // Set default values, force is necessary because in unit tests the env
92   // variable may not be set when global APIVitals are constructed.
93   setVital("CUDA", "used", "False", /* force = */ true);
94 }
95 
96 } // namespace at::vitals
97