1 #include <torch/csrc/jit/runtime/script_profile.h>
2
3 #include <atomic>
4 #include <chrono>
5 #include <mutex>
6 #include <unordered_set>
7
8 #include <c10/util/Exception.h>
9 #include <c10/util/intrusive_ptr.h>
10 #include <torch/csrc/jit/api/function_impl.h>
11
12 namespace torch::jit {
13
14 namespace {
15
16 class ProfilesRegistry {
17 public:
empty()18 bool empty() {
19 return empty_.load(std::memory_order_relaxed);
20 }
21
addProfile(ScriptProfile & p)22 void addProfile(ScriptProfile& p) {
23 std::lock_guard<std::mutex> g(mutex_);
24 enabledProfiles_.emplace(&p);
25 empty_.store(false, std::memory_order_relaxed);
26 }
27
removeProfile(ScriptProfile & p)28 void removeProfile(ScriptProfile& p) {
29 std::lock_guard<std::mutex> g(mutex_);
30 enabledProfiles_.erase(&p);
31 if (enabledProfiles_.empty()) {
32 empty_.store(true, std::memory_order_relaxed);
33 }
34 }
35
send(std::unique_ptr<profiling::Datapoint> datapoint)36 void send(std::unique_ptr<profiling::Datapoint> datapoint) {
37 auto shared = std::shared_ptr<profiling::Datapoint>(std::move(datapoint));
38 std::lock_guard<std::mutex> g(mutex_);
39 for (auto* p : enabledProfiles_) {
40 p->addDatapoint(shared);
41 }
42 }
43
44 private:
45 std::atomic<bool> empty_{true};
46 std::mutex mutex_;
47 std::unordered_set<ScriptProfile*> enabledProfiles_;
48 };
49
getProfilesRegistry()50 ProfilesRegistry& getProfilesRegistry() {
51 static auto registry = std::ref(*new ProfilesRegistry{});
52 return registry;
53 }
54
initBindings()55 auto initBindings() {
56 torch::class_<SourceRef>("profiling", "SourceRef")
57 .def(
58 "starting_lineno",
59 [](const c10::intrusive_ptr<SourceRef>& self) {
60 return static_cast<int64_t>((*self)->starting_line_no());
61 })
62 .def("text", [](const c10::intrusive_ptr<SourceRef>& self) {
63 return (*self)->text_str().str();
64 });
65
66 torch::class_<InstructionStats>("profiling", "InstructionStats")
67 .def(
68 "count",
69 [](const c10::intrusive_ptr<InstructionStats>& self) {
70 return self->count;
71 })
72 .def("duration_ns", [](const c10::intrusive_ptr<InstructionStats>& self) {
73 return static_cast<int64_t>(self->duration.count());
74 });
75
76 torch::class_<SourceStats>("profiling", "SourceStats")
77 .def(
78 "source",
79 [](const c10::intrusive_ptr<SourceStats>& self) {
80 return c10::make_intrusive<SourceRef>(self->getSourceRef());
81 })
82 .def("line_map", &SourceStats::getLineMap);
83
84 torch::class_<ScriptProfile>("profiling", "_ScriptProfile")
85 .def(torch::init<>())
86 .def("enable", &ScriptProfile::enable)
87 .def("disable", &ScriptProfile::disable)
88 .def("_dump_stats", [](const c10::intrusive_ptr<ScriptProfile>& self) {
89 const auto& stats = self->dumpStats();
90 c10::List<c10::intrusive_ptr<SourceStats>> ret;
91 for (const auto& source : stats) {
92 SourceStats::LineMap lineMap;
93 for (const auto& line : source.second) {
94 lineMap.insert(
95 line.first, c10::make_intrusive<InstructionStats>(line.second));
96 }
97 ret.push_back(c10::make_intrusive<SourceStats>(
98 source.first, std::move(lineMap)));
99 }
100 return ret;
101 });
102 return nullptr;
103 }
104
105 const auto C10_UNUSED torchBindInitializer = initBindings();
106
107 } // namespace
108
109 namespace profiling {
110
InstructionSpan(Node & node)111 InstructionSpan::InstructionSpan(Node& node) {
112 datapoint_ = std::make_unique<Datapoint>(node.sourceRange());
113 }
114
~InstructionSpan()115 InstructionSpan::~InstructionSpan() {
116 datapoint_->end = std::chrono::steady_clock::now();
117 getProfilesRegistry().send(std::move(datapoint_));
118 }
119
isProfilingOngoing()120 bool isProfilingOngoing() {
121 return !getProfilesRegistry().empty();
122 }
123
124 } // namespace profiling
125
enable()126 void ScriptProfile::enable() {
127 if (!std::exchange(enabled_, true)) {
128 getProfilesRegistry().addProfile(*this);
129 }
130 }
131
disable()132 void ScriptProfile::disable() {
133 if (std::exchange(enabled_, false)) {
134 getProfilesRegistry().removeProfile(*this);
135 }
136 }
137
addDatapoint(std::shared_ptr<profiling::Datapoint> datapoint)138 void ScriptProfile::addDatapoint(
139 std::shared_ptr<profiling::Datapoint> datapoint) {
140 TORCH_CHECK(enabled_, "Cannot only add datapoint to disabled profilers.");
141 datapoints_.push_back(std::move(datapoint));
142 }
143
dumpStats()144 const ScriptProfile::SourceMap& ScriptProfile::dumpStats() {
145 TORCH_CHECK(!enabled_, "Only disabled profilers are allowed to dump stats.");
146
147 for (const auto& datapoint : datapoints_) {
148 if (const auto& source = datapoint->sourceRange.source()) {
149 if (auto fileLineCol = datapoint->sourceRange.file_line_col()) {
150 auto it = sourceMap_.find(*source);
151 if (it == sourceMap_.end()) {
152 it = sourceMap_.emplace(SourceRef{source}, LineMap{}).first;
153 }
154 auto& stats = it->second[std::get<1>(*fileLineCol)];
155 stats.count++;
156 stats.duration += datapoint->end - datapoint->start;
157 }
158 }
159 }
160 datapoints_.clear();
161
162 return sourceMap_;
163 }
164
~ScriptProfile()165 ScriptProfile::~ScriptProfile() {
166 if (enabled_) {
167 getProfilesRegistry().removeProfile(*this);
168 }
169 }
170
171 } // namespace torch::jit
172