1 #include <torch/csrc/autograd/profiler_legacy.h>
2
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/jit/frontend/tracer.h>
5 #include <torch/csrc/jit/runtime/interpreter.h>
6 #include <torch/csrc/jit/runtime/operator.h>
7
8 #include <ATen/code_template.h>
9 #include <ATen/core/op_registration/op_registration.h>
10 #include <torch/library.h>
11
12 #include <fstream>
13 #include <mutex>
14 #include <string>
15 #include <vector>
16
17 #include <ATen/record_function.h>
18 #include <c10/core/Allocator.h>
19 #include <c10/util/ApproximateClock.h>
20 #include <c10/util/ThreadLocalDebugInfo.h>
21 #include <c10/util/irange.h>
22
23 #include <iostream>
24
25 namespace torch::autograd::profiler {
26
27 // We decompose the profiler logic into the following components:
28 //
29 // ThreadLocalDebugInfo:
30 //
31 // ThreadLocalDebugInfo is a thread local mapping from slots into
32 // the debug information structs.
33 // ThreadLocalDebugInfo is automatically propagated across thread
34 // boundaries, including the cases of:
35 // - launching async jobs with at::launch
36 // - executing JIT continuations
37 // - moving from the forward threads into autograd (backward) threads
38 //
39 // Entries in ThreadLocalDebugInfo are managed by DebugInfoGuard
40 // which can be used to add or overwrite an entry in the thread local
41 // mapping. A corresponding entry is removed when the guard is destroyed,
42 // potentially revealing the previously set value for the same slot.
43 //
44 // For the async tasks, slots previously set in the main thread before
45 // launching of an async task are shared and visible in the async task.
46 //
47 // On the other hand, any adding or overwriting of the mapping by the
48 // async task is not visible to the main thread and any modification
49 // (including removal of the entries) in the main thread is not visible
50 // to the async task if it happens after launching the task.
51 //
52 // We use ThreadLocalDebugInfo (slot PROFILER_STATE) to store profiler config,
53 // as well as a list of events that happen during profiling.
54 // An instance of ThreadLocalDebugInfo is created each time we enter
55 // profiler (i.e. enter profiling context manager/call enableConfig) and
56 // uniquely identifies a profiling run.
57 //
58 // We automatically propagate ThreadLocalDebugInfo into async tasks,
59 // as well as across JIT continuations and autograd thread, so all
60 // the operations that happen between profiling start and end
61 // (not necessarily within the same thread) are recorded.
62 // Unless the profiling slot is overwritten as in the case of nested
63 // profiling ranges (in this case events for the subrange are handled
64 // by the nested profiler)
65 //
66 // When we exit a profiling range (either by exiting profiling context
67 // manager or by calling disableProfiler), we remove the previously set
68 // profiling entry for the given thread local mapping, and consolidate
69 // events in the profiling result
70 //
71 //
72 // ThreadLocalState:
73 //
74 // ThreadLocalState takes a 'snapshot' of thread local variables
75 // using provided getters. It is used together with ThreadLocalStateGuard
76 // to transfer the snapshot across thread boundary and set the thread local
77 // values as in the parent task.
78 //
79 // Profiler uses ThreadLocalState to propagate profiler's thread local state.
80 // ThreadLocalState also automatically propagates profiler callbacks.
81 //
82 //
83 // at::RecordFunction and observers
84 //
85 // Profiler uses observers mechanism to add a pair of thread local callbacks
86 // that are executed on a number of predetermined ranges, including:
87 // - c10/ATen ops
88 // - TorchScript functions/methods
89 // - user defined named ranges (see `record_function` python context manager)
90 //
91 // Profiler setups a pair of callbacks that record profiling events and save
92 // them into the thread local profiler struct (ThreadLocalDebugInfo,
93 // PROFILER_STATE slot)
94 //
95 //
96 // Thus, the overall logic is:
97 //
98 // enableProfiler:
99 // - checks that profiler is not enabled (otherwise throws)
100 // - pushes new ThreadLocalDebugInfo (slot PROFILER_STATE) as the profiler
101 // config for the current thread
102 // - pushes profiling callbacks for the current thread
103 //
104 // disableProfiler:
105 // - pops PROFILER_STATE slot from the current ThreadLocalDebugInfo and
106 // consolidates events
107 // - removes profiling callbacks
108 //
109 // ThreadLocalState:
110 // - propagates ThreadLocalDebugInfo across threads
111 // - propagates profiler callbacks across threads
112 //
113 // Profiler callbacks:
114 // - get the current profiling state (PROFILER slot in ThreadLocalDebugInfo)
115 // - save profiling events into the profiling state
116 //
117
118 namespace {
119 using torch::profiler::impl::ActiveProfilerType;
120 using torch::profiler::impl::ProfilerStateBase;
121
122 struct ProfilerLegacyThreadLocalState : public ProfilerStateBase {
ProfilerLegacyThreadLocalStatetorch::autograd::profiler::__anon4c5fe8160111::ProfilerLegacyThreadLocalState123 explicit ProfilerLegacyThreadLocalState(
124 const torch::profiler::impl::ProfilerConfig& config)
125 : ProfilerStateBase(config), remoteProfiledEvents_{std::nullopt} {}
126 ~ProfilerLegacyThreadLocalState() override = default;
127
getTLStorch::autograd::profiler::__anon4c5fe8160111::ProfilerLegacyThreadLocalState128 static ProfilerLegacyThreadLocalState* getTLS() {
129 auto tls = ProfilerStateBase::get(/*global=*/false);
130 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
131 tls == nullptr || tls->profilerType() == ActiveProfilerType::LEGACY);
132 return static_cast<ProfilerLegacyThreadLocalState*>(tls);
133 }
134
135 thread_event_lists consolidate();
136
137 void mark(std::string name, bool include_cuda = true);
138
139 void setOrAddRemoteProfiledEvents(
140 std::vector<LegacyEvent>&& remoteProfiledEvents);
141
142 void pushRange(
143 const at::RecordFunction& fn,
144 const bool record_cuda,
145 std::vector<std::vector<int64_t>>&& shapes = {});
146
147 void popRange(const at::RecordFunction& fn, const bool record_cuda);
148
149 void reportMemoryUsage(
150 void* /* unused */,
151 int64_t alloc_size,
152 size_t /* total_allocated, unused for legacy */,
153 size_t /* total_reserved, unused for legacy */,
154 c10::Device device) override;
155
profilerTypetorch::autograd::profiler::__anon4c5fe8160111::ProfilerLegacyThreadLocalState156 ActiveProfilerType profilerType() override {
157 return ActiveProfilerType::LEGACY;
158 }
159
leakHandletorch::autograd::profiler::__anon4c5fe8160111::ProfilerLegacyThreadLocalState160 void leakHandle() {
161 handle_ = 0;
162 }
163
164 protected:
165 RangeEventList& getEventList(
166 std::optional<uint64_t> thread_id = std::nullopt);
167
168 std::mutex state_mutex_;
169 std::unordered_map<uint64_t, std::shared_ptr<RangeEventList>>
170 event_lists_map_;
171
172 std::optional<std::vector<std::vector<LegacyEvent>>> remoteProfiledEvents_;
173 };
174
consolidate()175 thread_event_lists ProfilerLegacyThreadLocalState::consolidate() {
176 std::lock_guard<std::mutex> g(state_mutex_);
177 thread_event_lists result;
178 for (auto& kv : event_lists_map_) {
179 auto& list = kv.second;
180 result.emplace_back(list->consolidate());
181 }
182 // Consolidate remote events if applicable as well.
183 if (remoteProfiledEvents_) {
184 result.insert(
185 result.end(),
186 std::make_move_iterator(remoteProfiledEvents_->begin()),
187 std::make_move_iterator(remoteProfiledEvents_->end()));
188 }
189 return result;
190 }
191
mark(std::string name,bool include_cuda)192 void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) {
193 if (config_.disabled()) {
194 return;
195 }
196 if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
197 torch::profiler::impl::cudaStubs()->mark(name.c_str());
198 } else {
199 LegacyEvent evt(
200 EventKind::Mark,
201 at::StringView(std::move(name)),
202 at::RecordFunction::currentThreadId(),
203 include_cuda &&
204 config_.state == torch::profiler::impl::ProfilerState::CUDA);
205 evt.setNodeId(at::RecordFunction::getDefaultNodeId());
206 getEventList().record(std::move(evt));
207 }
208 }
209
setOrAddRemoteProfiledEvents(std::vector<LegacyEvent> && remoteProfiledEvents)210 void ProfilerLegacyThreadLocalState::setOrAddRemoteProfiledEvents(
211 std::vector<LegacyEvent>&& remoteProfiledEvents) {
212 // Lock to serialize access from multiple callback threads.
213 std::lock_guard<std::mutex> guard(state_mutex_);
214 if (remoteProfiledEvents_) {
215 (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents);
216 } else {
217 remoteProfiledEvents_ = {std::move(remoteProfiledEvents)};
218 }
219 }
220
pushRange(const at::RecordFunction & fn,const bool record_cuda,std::vector<std::vector<int64_t>> && shapes)221 void ProfilerLegacyThreadLocalState::pushRange(
222 const at::RecordFunction& fn,
223 const bool record_cuda,
224 std::vector<std::vector<int64_t>>&& shapes) {
225 if (config_.disabled()) {
226 return;
227 }
228 if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
229 torch::profiler::impl::cudaStubs()->rangePush(
230 torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes)
231 .c_str());
232 } else {
233 LegacyEvent evt(
234 EventKind::PushRange,
235 at::StringView(std::string(fn.name())),
236 at::RecordFunction::currentThreadId(),
237 record_cuda,
238 fn.handle(),
239 std::move(shapes),
240 at::RecordFunction::getDefaultNodeId(),
241 fn.isAsync());
242 evt.setSequenceNr(fn.seqNr());
243 evt.setFwdThreadId(fn.forwardThreadId());
244 evt.setScope((uint8_t)fn.scope());
245 if (config_.with_flops) {
246 evt.setExtraArgs(torch::profiler::impl::saveExtraArgs(fn));
247 evt.setFlops(torch::profiler::impl::computeFlops(
248 std::string(fn.name()), evt.extraArgs()));
249 }
250
251 // TODO: will unify the two macros BUILD_LITE_INTERPRETER and C10_MOBILE soon.
252 #if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE
253 // backward nodes source range corresponds to the forward node
254 // TODO: consider using C++ stack trace
255 if (config_.with_stack &&
256 fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
257 auto cs =
258 torch::profiler::impl::prepareCallstack(jit::currentCallstack());
259 if (cs.empty()) {
260 cs = torch::profiler::impl::prepareCallstack(
261 jit::tracer::pythonCallstack());
262 }
263 evt.setStack(callstackStr(cs));
264 }
265 #endif
266 getEventList().record(std::move(evt));
267 }
268 }
269
popRange(const at::RecordFunction & fn,const bool record_cuda)270 void ProfilerLegacyThreadLocalState::popRange(
271 const at::RecordFunction& fn,
272 const bool record_cuda) {
273 if (config_.disabled()) {
274 return;
275 }
276 if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
277 torch::profiler::impl::cudaStubs()->rangePop();
278 } else {
279 // In some cases RecordFunction (and popRange) may be
280 // called on a different thread than pushRange
281 // As a convention, we put the async pop on the original
282 // thread and save current thread id in pop event
283 LegacyEvent evt(
284 EventKind::PopRange,
285 at::StringView(""),
286 at::RecordFunction::currentThreadId(),
287 record_cuda,
288 fn.handle());
289 evt.setNodeId(at::RecordFunction::getDefaultNodeId());
290 getEventList(fn.threadId()).record(std::move(evt));
291 }
292 }
293
reportMemoryUsage(void *,int64_t alloc_size,size_t,size_t,c10::Device device)294 void ProfilerLegacyThreadLocalState::reportMemoryUsage(
295 void* /* unused */,
296 int64_t alloc_size,
297 size_t /* total_allocated, unused for legacy */,
298 size_t /* total_reserved, unused for legacy */,
299 c10::Device device) {
300 if (config_.profile_memory && !config_.disabled()) {
301 uint64_t thread_id = at::RecordFunction::currentThreadId();
302 LegacyEvent evt(
303 EventKind::MemoryAlloc,
304 at::StringView(""),
305 thread_id,
306 config_.state == torch::profiler::impl::ProfilerState::CUDA);
307 evt.updateMemoryStats(alloc_size, device);
308 getEventList(thread_id).record(std::move(evt));
309 }
310 }
311
getEventList(std::optional<uint64_t> thread_id)312 RangeEventList& ProfilerLegacyThreadLocalState::getEventList(
313 std::optional<uint64_t> thread_id) {
314 if (!thread_id.has_value()) {
315 thread_id = at::RecordFunction::currentThreadId();
316 }
317 RangeEventList* list_ptr = nullptr;
318 std::lock_guard<std::mutex> guard(state_mutex_);
319 auto it = event_lists_map_.find(thread_id.value());
320 if (it != event_lists_map_.end()) {
321 list_ptr = it->second.get();
322 } else {
323 auto event_list = std::make_shared<RangeEventList>();
324 event_lists_map_[thread_id.value()] = event_list;
325 list_ptr = event_list.get();
326 }
327 return *list_ptr;
328 }
329
330 enum EventIValueIdx {
331 KIND = 0,
332 NAME,
333 THREAD_ID,
334 HANDLE,
335 NODE_ID,
336 CPU_MEM_USAGE,
337 CPU_NS,
338 CUDA_RECORDED,
339 CUDA_MEM_USAGE,
340 CUDA_DEVICE,
341 CUDA_US,
342 SHAPES,
343 NUM_EVENT_IVALUE_IDX // must be last in list
344 };
345
346 const std::unordered_set<std::string> disable_cuda_profiling = {
347 "aten::view",
348 "aten::t",
349 "aten::transpose",
350 "aten::stride",
351 "aten::empty",
352 "aten::empty_like",
353 "aten::empty_strided",
354 "aten::as_strided",
355 "aten::expand",
356 "aten::resize_",
357 "aten::squeeze",
358 "aten::unsqueeze",
359 "aten::slice",
360 "aten::_unsafe_view",
361 "aten::size"};
362
pushProfilingCallbacksLegacy()363 void pushProfilingCallbacksLegacy() {
364 auto registration_state_ptr = ProfilerLegacyThreadLocalState::getTLS();
365 TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set");
366 auto handle = at::addThreadLocalCallback(
367 at::RecordFunctionCallback(
368 [](const at::RecordFunction& fn)
369 -> std::unique_ptr<at::ObserverContext> {
370 auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
371 if (!state_ptr || state_ptr->config().disabled()) {
372 return nullptr;
373 }
374 bool record_cuda = state_ptr->config().state ==
375 torch::profiler::impl::ProfilerState::CUDA;
376 if (record_cuda &&
377 disable_cuda_profiling.find(fn.name()) !=
378 disable_cuda_profiling.end()) {
379 record_cuda = false;
380 }
381
382 if (state_ptr->config().report_input_shapes) {
383 auto sizes = torch::profiler::impl::inputSizes(fn);
384 state_ptr->pushRange(fn, record_cuda, std::move(sizes));
385 } else {
386 state_ptr->pushRange(fn, record_cuda);
387 }
388
389 return nullptr;
390 },
391 [](const at::RecordFunction& fn, at::ObserverContext*) {
392 auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
393 if (!state_ptr || state_ptr->config().disabled()) {
394 return;
395 }
396 bool record_cuda = state_ptr->config().state ==
397 torch::profiler::impl::ProfilerState::CUDA;
398 if (record_cuda &&
399 disable_cuda_profiling.find(fn.name()) !=
400 disable_cuda_profiling.end()) {
401 record_cuda = false;
402 }
403 state_ptr->popRange(fn, record_cuda);
404 })
405 .needsInputs(registration_state_ptr->config().report_input_shapes)
406 .needsIds(true));
407 registration_state_ptr->setCallbackHandle(handle);
408 }
409
410 } // namespace
411
enableProfilerLegacy(const torch::profiler::impl::ProfilerConfig & new_config)412 void enableProfilerLegacy(
413 const torch::profiler::impl::ProfilerConfig& new_config) {
414 TORCH_CHECK(
415 new_config.state != torch::profiler::impl::ProfilerState::NVTX ||
416 torch::profiler::impl::cudaStubs()->enabled(),
417 "Can't use NVTX profiler - PyTorch was compiled without CUDA");
418
419 TORCH_CHECK(new_config.state != torch::profiler::impl::ProfilerState::KINETO);
420
421 auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
422 TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread");
423 auto state = std::make_shared<ProfilerLegacyThreadLocalState>(new_config);
424 c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
425
426 pushProfilingCallbacksLegacy();
427
428 state->mark("__start_profile", false);
429 }
430
disableProfilerLegacy(std::optional<ProfilerDisableOptions> profilerDisableOptions)431 thread_event_lists disableProfilerLegacy(
432 std::optional<ProfilerDisableOptions> profilerDisableOptions) {
433 auto cleanupTLSState =
434 profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true;
435 auto consolidate =
436 profilerDisableOptions ? profilerDisableOptions->consolidate : true;
437 // all the DebugInfoBase objects are scope based and supposed to use
438 // DebugInfoGuard
439 std::shared_ptr<c10::DebugInfoBase> state;
440 if (cleanupTLSState) {
441 state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE);
442 } else {
443 state =
444 c10::ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PROFILER_STATE);
445 }
446
447 auto state_ptr = static_cast<ProfilerLegacyThreadLocalState*>(state.get());
448 TORCH_CHECK(
449 state_ptr && !state_ptr->config().disabled(),
450 "Can't disable profiler when it's not running");
451
452 cleanupTLSState ? state_ptr->removeCallback() : state_ptr->leakHandle();
453 if (!consolidate ||
454 state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX) {
455 return thread_event_lists();
456 }
457
458 state_ptr->mark("__stop_profile", false);
459 // Note that this will erase the underlying events.
460 return state_ptr->consolidate();
461 }
462
addEventList(std::vector<LegacyEvent> && profiledEvents)463 void addEventList(std::vector<LegacyEvent>&& profiledEvents) {
464 auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
465 TORCH_CHECK(state_ptr, "Profiler must be enabled.");
466 state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents));
467 }
468
record(bool record_cuda)469 void LegacyEvent::record(bool record_cuda) {
470 if (record_cuda) {
471 torch::profiler::impl::cudaStubs()->record(&device_, &cuda_event, &cpu_ns_);
472 return;
473 }
474 cpu_ns_ = c10::getTime();
475 }
476
fromIValue(const at::IValue & eventIValue)477 /* static */ LegacyEvent LegacyEvent::fromIValue(
478 const at::IValue& eventIValue) {
479 TORCH_INTERNAL_ASSERT(
480 eventIValue.isList(),
481 "Expected IValue to contain type c10::impl::GenericList");
482 auto ivalues = eventIValue.toList();
483 TORCH_INTERNAL_ASSERT(
484 ivalues.size() >= NUM_EVENT_IVALUE_IDX,
485 "Expected at least ",
486 NUM_EVENT_IVALUE_IDX,
487 " elements to reconstruct LegacyEvent.");
488
489 // Reconstruct input shapes from ivalues.
490 const auto& shapeListIValue = ivalues.get(EventIValueIdx::SHAPES);
491 TORCH_INTERNAL_ASSERT(
492 shapeListIValue.isList(),
493 "Expected profiler shapes IValue to contain type c10::impl::GenericList.");
494
495 auto shapeList = shapeListIValue.toList();
496 std::vector<std::vector<int64_t>> shapes;
497 shapes.reserve(shapeList.size());
498 for (const auto i : c10::irange(shapeList.size())) {
499 std::vector<int64_t> s;
500 const auto& shapeIValue = shapeList.get(i);
501 TORCH_INTERNAL_ASSERT(
502 shapeIValue.isList(),
503 "Expected each profiler shape element to contain shapes of type c10::impl::GenericList.")
504 auto curShapesList = shapeIValue.toList();
505 s.reserve(curShapesList.size());
506 for (const auto j : c10::irange(curShapesList.size())) {
507 s.emplace_back(curShapesList.get(j).toInt());
508 }
509 shapes.emplace_back(s);
510 }
511
512 LegacyEvent evt(
513 static_cast<EventKind>(
514 ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind
515 at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name
516 ivalues.get(EventIValueIdx::THREAD_ID).toInt(), // thread_id
517 static_cast<at::RecordFunctionHandle>(
518 ivalues.get(EventIValueIdx::HANDLE).toDouble()), // handle
519 std::move(shapes), // input shapes
520 ivalues.get(EventIValueIdx::NODE_ID).toInt(), // node id
521 true, // is remote
522 ivalues.get(EventIValueIdx::CPU_MEM_USAGE).toInt(), // cpu_mem_usage
523 ivalues.get(EventIValueIdx::CPU_NS).toInt(), // cpu_ns
524 ivalues.get(EventIValueIdx::CUDA_RECORDED).toBool(), // was cuda recorded
525 ivalues.get(EventIValueIdx::CUDA_MEM_USAGE).toInt(), // cuda memory usage
526 c10::DeviceIndex(
527 ivalues.get(EventIValueIdx::CUDA_DEVICE).toInt()), // device
528 static_cast<double>(
529 ivalues.get(EventIValueIdx::CUDA_US).toInt()) // cuda_us
530 );
531 return evt;
532 }
533
toIValue() const534 at::IValue LegacyEvent::toIValue() const {
535 c10::impl::GenericList eventIValueList(at::AnyType::get());
536 eventIValueList.reserve(NUM_EVENT_IVALUE_IDX);
537 eventIValueList.emplace_back(static_cast<int64_t>(kind_));
538 eventIValueList.emplace_back(std::string(name_.str()));
539 eventIValueList.emplace_back(static_cast<int64_t>(thread_id_));
540 eventIValueList.emplace_back(static_cast<double>(handle_));
541 eventIValueList.emplace_back(node_id_);
542 eventIValueList.emplace_back(cpu_memory_usage_);
543 eventIValueList.emplace_back(cpu_ns_);
544 // CUDA event information
545 bool cuda_profiling_enabled = hasCuda();
546 eventIValueList.emplace_back(cuda_profiling_enabled);
547 eventIValueList.emplace_back(static_cast<int64_t>(cuda_memory_usage_));
548 eventIValueList.emplace_back(device_);
549 eventIValueList.emplace_back(cuda_us_);
550 // Shapes
551 c10::impl::GenericList shapesList =
552 c10::impl::GenericList(at::ListType::create(at::IntType::get()));
553 shapesList.reserve(shapes_.size());
554 for (const auto& shape : shapes_) {
555 c10::impl::GenericList s = c10::impl::GenericList(at::IntType::get());
556 s.reserve(shape.size());
557 for (const auto& k : shape) {
558 s.emplace_back(k);
559 }
560 shapesList.emplace_back(s);
561 }
562 eventIValueList.emplace_back(shapesList);
563 return at::IValue(eventIValueList);
564 }
565
cudaElapsedUs(const LegacyEvent & e) const566 double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const {
567 TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA");
568 TORCH_CHECK(
569 e.device() == device(),
570 c10::str(
571 "Events are not on the same device: ", e.device(), " vs ", device()));
572 if (isRemote() && e.isRemote()) {
573 // validate that cuda_us_ has been set properly.
574 TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0);
575 return static_cast<double>(e.cuda_us_ - cuda_us_);
576 }
577 return torch::profiler::impl::cudaStubs()->elapsed(
578 &cuda_event, &e.cuda_event);
579 }
580
581 static const at::jit::CodeTemplate event_template(R"(
582 {
583 "name": "${name}",
584 "ph": "X",
585 "ts": ${ts},
586 "dur": ${dur},
587 "tid": ${tid},
588 "pid": "CPU Functions",
589 "args": {}
590 })");
591
writeProfilerEventsToStream(std::ostream & out,const std::vector<LegacyEvent * > & events)592 void writeProfilerEventsToStream(
593 std::ostream& out,
594 const std::vector<LegacyEvent*>& events) {
595 TORCH_CHECK(out, "Could not open file");
596 LegacyEvent* profiler_start = nullptr;
597 for (LegacyEvent* e : events) {
598 if (0 == strcmp(e->name(), "__start_profile")) {
599 profiler_start = e;
600 break;
601 }
602 }
603 TORCH_CHECK(profiler_start, "Could not find __start_profile mark");
604
605 struct PairHash {
606 size_t operator()(
607 std::pair<at::RecordFunctionHandle, int> p) const noexcept {
608 return std::hash<at::RecordFunctionHandle>()(p.first) ^
609 std::hash<int64_t>()(p.second);
610 }
611 };
612 std::unordered_map<
613 std::pair<at::RecordFunctionHandle, int64_t>,
614 LegacyEvent*,
615 PairHash>
616 events_map;
617 out << "[\n";
618 bool first = true;
619 for (LegacyEvent* evt : events) {
620 if (evt->kindStr() == "push") {
621 events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt;
622 } else if (evt->kindStr() == "pop") {
623 if (!first) {
624 out << ",\n";
625 }
626 first = false;
627 auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId()));
628 TORCH_CHECK(it != events_map.end(), "Unmatched pop event");
629 LegacyEvent* evt_start = it->second;
630 events_map.erase(it);
631
632 at::jit::TemplateEnv env;
633 env.s("name", evt_start->name());
634 env.d("ts", profiler_start->cpuElapsedUs(*evt_start));
635 env.d("dur", evt_start->cpuElapsedUs(*evt));
636 env.d("tid", evt_start->threadId());
637 out << event_template.format(env);
638 }
639 }
640 out << "]\n";
641 }
642
RecordProfile(std::ostream & out)643 RecordProfile::RecordProfile(std::ostream& out) : out_(out) {
644 init();
645 }
646
RecordProfile(const std::string & filename)647 RecordProfile::RecordProfile(const std::string& filename)
648 : file_(std::make_unique<std::ofstream>(filename)), out_(*file_) {
649 init();
650 }
651
init()652 void RecordProfile::init() {
653 enableProfilerLegacy(torch::profiler::impl::ProfilerConfig(
654 torch::profiler::impl::ProfilerState::CPU));
655 }
656
~RecordProfile()657 RecordProfile::~RecordProfile() {
658 try {
659 thread_event_lists event_lists = disableProfilerLegacy();
660 std::vector<LegacyEvent*> events;
661 for (auto& l : event_lists) {
662 for (auto& e : l) {
663 events.push_back(&e);
664 }
665 }
666 processEvents(events);
667 } catch (const std::exception& e) {
668 LOG(ERROR) << e.what() << '\n';
669 } catch (...) {
670 LOG(ERROR) << "Unknown error" << '\n';
671 }
672 }
673
processEvents(const std::vector<LegacyEvent * > & events)674 void RecordProfile::processEvents(const std::vector<LegacyEvent*>& events) {
675 writeProfilerEventsToStream(out_, events);
676 }
677
678 } // namespace torch::autograd::profiler
679