xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/profiler_legacy.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/profiler_legacy.h>
2 
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/jit/frontend/tracer.h>
5 #include <torch/csrc/jit/runtime/interpreter.h>
6 #include <torch/csrc/jit/runtime/operator.h>
7 
8 #include <ATen/code_template.h>
9 #include <ATen/core/op_registration/op_registration.h>
10 #include <torch/library.h>
11 
12 #include <fstream>
13 #include <mutex>
14 #include <string>
15 #include <vector>
16 
17 #include <ATen/record_function.h>
18 #include <c10/core/Allocator.h>
19 #include <c10/util/ApproximateClock.h>
20 #include <c10/util/ThreadLocalDebugInfo.h>
21 #include <c10/util/irange.h>
22 
23 #include <iostream>
24 
25 namespace torch::autograd::profiler {
26 
27 // We decompose the profiler logic into the following components:
28 //
29 // ThreadLocalDebugInfo:
30 //
31 // ThreadLocalDebugInfo is a thread local mapping from slots into
32 // the debug information structs.
33 // ThreadLocalDebugInfo is automatically propagated across thread
34 // boundaries, including the cases of:
35 //  - launching async jobs with at::launch
36 //  - executing JIT continuations
37 //  - moving from the forward threads into autograd (backward) threads
38 //
39 // Entries in ThreadLocalDebugInfo are managed by DebugInfoGuard
40 // which can be used to add or overwrite an entry in the thread local
41 // mapping. A corresponding entry is removed when the guard is destroyed,
42 // potentially revealing the previously set value for the same slot.
43 //
44 // For the async tasks, slots previously set in the main thread before
45 // launching of an async task are shared and visible in the async task.
46 //
47 // On the other hand, any adding or overwriting of the mapping by the
48 // async task is not visible to the main thread and any modification
49 // (including removal of the entries) in the main thread is not visible
50 // to the async task if it happens after launching the task.
51 //
52 // We use ThreadLocalDebugInfo (slot PROFILER_STATE) to store profiler config,
53 // as well as a list of events that happen during profiling.
54 // An instance of ThreadLocalDebugInfo is created each time we enter
55 // profiler (i.e. enter profiling context manager/call enableConfig) and
56 // uniquely identifies a profiling run.
57 //
58 // We automatically propagate ThreadLocalDebugInfo into async tasks,
59 // as well as across JIT continuations and autograd thread, so all
60 // the operations that happen between profiling start and end
61 // (not necessarily within the same thread) are recorded.
62 // Unless the profiling slot is overwritten as in the case of nested
63 // profiling ranges (in this case events for the subrange are handled
64 // by the nested profiler)
65 //
66 // When we exit a profiling range (either by exiting profiling context
67 // manager or by calling disableProfiler), we remove the previously set
68 // profiling entry for the given thread local mapping, and consolidate
69 // events in the profiling result
70 //
71 //
72 // ThreadLocalState:
73 //
74 // ThreadLocalState takes a 'snapshot' of thread local variables
75 // using provided getters. It is used together with ThreadLocalStateGuard
76 // to transfer the snapshot across thread boundary and set the thread local
77 // values as in the parent task.
78 //
79 // Profiler uses ThreadLocalState to propagate profiler's thread local state.
80 // ThreadLocalState also automatically propagates profiler callbacks.
81 //
82 //
83 // at::RecordFunction and observers
84 //
85 // Profiler uses observers mechanism to add a pair of thread local callbacks
86 // that are executed on a number of predetermined ranges, including:
87 //  - c10/ATen ops
88 //  - TorchScript functions/methods
89 //  - user defined named ranges (see `record_function` python context manager)
90 //
91 // Profiler setups a pair of callbacks that record profiling events and save
92 // them into the thread local profiler struct (ThreadLocalDebugInfo,
93 // PROFILER_STATE slot)
94 //
95 //
96 // Thus, the overall logic is:
97 //
98 // enableProfiler:
99 //  - checks that profiler is not enabled (otherwise throws)
100 //  - pushes new ThreadLocalDebugInfo (slot PROFILER_STATE) as the profiler
101 //    config for the current thread
102 //  - pushes profiling callbacks for the current thread
103 //
104 // disableProfiler:
105 //  - pops PROFILER_STATE slot from the current ThreadLocalDebugInfo and
106 //    consolidates events
107 //  - removes profiling callbacks
108 //
109 // ThreadLocalState:
110 //  - propagates ThreadLocalDebugInfo across threads
111 //  - propagates profiler callbacks across threads
112 //
113 // Profiler callbacks:
114 //  - get the current profiling state (PROFILER slot in ThreadLocalDebugInfo)
115 //  - save profiling events into the profiling state
116 //
117 
118 namespace {
119 using torch::profiler::impl::ActiveProfilerType;
120 using torch::profiler::impl::ProfilerStateBase;
121 
122 struct ProfilerLegacyThreadLocalState : public ProfilerStateBase {
ProfilerLegacyThreadLocalStatetorch::autograd::profiler::__anon4c5fe8160111::ProfilerLegacyThreadLocalState123   explicit ProfilerLegacyThreadLocalState(
124       const torch::profiler::impl::ProfilerConfig& config)
125       : ProfilerStateBase(config), remoteProfiledEvents_{std::nullopt} {}
126   ~ProfilerLegacyThreadLocalState() override = default;
127 
getTLStorch::autograd::profiler::__anon4c5fe8160111::ProfilerLegacyThreadLocalState128   static ProfilerLegacyThreadLocalState* getTLS() {
129     auto tls = ProfilerStateBase::get(/*global=*/false);
130     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
131         tls == nullptr || tls->profilerType() == ActiveProfilerType::LEGACY);
132     return static_cast<ProfilerLegacyThreadLocalState*>(tls);
133   }
134 
135   thread_event_lists consolidate();
136 
137   void mark(std::string name, bool include_cuda = true);
138 
139   void setOrAddRemoteProfiledEvents(
140       std::vector<LegacyEvent>&& remoteProfiledEvents);
141 
142   void pushRange(
143       const at::RecordFunction& fn,
144       const bool record_cuda,
145       std::vector<std::vector<int64_t>>&& shapes = {});
146 
147   void popRange(const at::RecordFunction& fn, const bool record_cuda);
148 
149   void reportMemoryUsage(
150       void* /* unused */,
151       int64_t alloc_size,
152       size_t /* total_allocated, unused for legacy */,
153       size_t /* total_reserved, unused for legacy */,
154       c10::Device device) override;
155 
profilerTypetorch::autograd::profiler::__anon4c5fe8160111::ProfilerLegacyThreadLocalState156   ActiveProfilerType profilerType() override {
157     return ActiveProfilerType::LEGACY;
158   }
159 
leakHandletorch::autograd::profiler::__anon4c5fe8160111::ProfilerLegacyThreadLocalState160   void leakHandle() {
161     handle_ = 0;
162   }
163 
164  protected:
165   RangeEventList& getEventList(
166       std::optional<uint64_t> thread_id = std::nullopt);
167 
168   std::mutex state_mutex_;
169   std::unordered_map<uint64_t, std::shared_ptr<RangeEventList>>
170       event_lists_map_;
171 
172   std::optional<std::vector<std::vector<LegacyEvent>>> remoteProfiledEvents_;
173 };
174 
consolidate()175 thread_event_lists ProfilerLegacyThreadLocalState::consolidate() {
176   std::lock_guard<std::mutex> g(state_mutex_);
177   thread_event_lists result;
178   for (auto& kv : event_lists_map_) {
179     auto& list = kv.second;
180     result.emplace_back(list->consolidate());
181   }
182   // Consolidate remote events if applicable as well.
183   if (remoteProfiledEvents_) {
184     result.insert(
185         result.end(),
186         std::make_move_iterator(remoteProfiledEvents_->begin()),
187         std::make_move_iterator(remoteProfiledEvents_->end()));
188   }
189   return result;
190 }
191 
mark(std::string name,bool include_cuda)192 void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) {
193   if (config_.disabled()) {
194     return;
195   }
196   if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
197     torch::profiler::impl::cudaStubs()->mark(name.c_str());
198   } else {
199     LegacyEvent evt(
200         EventKind::Mark,
201         at::StringView(std::move(name)),
202         at::RecordFunction::currentThreadId(),
203         include_cuda &&
204             config_.state == torch::profiler::impl::ProfilerState::CUDA);
205     evt.setNodeId(at::RecordFunction::getDefaultNodeId());
206     getEventList().record(std::move(evt));
207   }
208 }
209 
setOrAddRemoteProfiledEvents(std::vector<LegacyEvent> && remoteProfiledEvents)210 void ProfilerLegacyThreadLocalState::setOrAddRemoteProfiledEvents(
211     std::vector<LegacyEvent>&& remoteProfiledEvents) {
212   // Lock to serialize access from multiple callback threads.
213   std::lock_guard<std::mutex> guard(state_mutex_);
214   if (remoteProfiledEvents_) {
215     (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents);
216   } else {
217     remoteProfiledEvents_ = {std::move(remoteProfiledEvents)};
218   }
219 }
220 
pushRange(const at::RecordFunction & fn,const bool record_cuda,std::vector<std::vector<int64_t>> && shapes)221 void ProfilerLegacyThreadLocalState::pushRange(
222     const at::RecordFunction& fn,
223     const bool record_cuda,
224     std::vector<std::vector<int64_t>>&& shapes) {
225   if (config_.disabled()) {
226     return;
227   }
228   if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
229     torch::profiler::impl::cudaStubs()->rangePush(
230         torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes)
231             .c_str());
232   } else {
233     LegacyEvent evt(
234         EventKind::PushRange,
235         at::StringView(std::string(fn.name())),
236         at::RecordFunction::currentThreadId(),
237         record_cuda,
238         fn.handle(),
239         std::move(shapes),
240         at::RecordFunction::getDefaultNodeId(),
241         fn.isAsync());
242     evt.setSequenceNr(fn.seqNr());
243     evt.setFwdThreadId(fn.forwardThreadId());
244     evt.setScope((uint8_t)fn.scope());
245     if (config_.with_flops) {
246       evt.setExtraArgs(torch::profiler::impl::saveExtraArgs(fn));
247       evt.setFlops(torch::profiler::impl::computeFlops(
248           std::string(fn.name()), evt.extraArgs()));
249     }
250 
251 // TODO: will unify the two macros BUILD_LITE_INTERPRETER and C10_MOBILE soon.
252 #if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE
253     // backward nodes source range corresponds to the forward node
254     // TODO: consider using C++ stack trace
255     if (config_.with_stack &&
256         fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
257       auto cs =
258           torch::profiler::impl::prepareCallstack(jit::currentCallstack());
259       if (cs.empty()) {
260         cs = torch::profiler::impl::prepareCallstack(
261             jit::tracer::pythonCallstack());
262       }
263       evt.setStack(callstackStr(cs));
264     }
265 #endif
266     getEventList().record(std::move(evt));
267   }
268 }
269 
popRange(const at::RecordFunction & fn,const bool record_cuda)270 void ProfilerLegacyThreadLocalState::popRange(
271     const at::RecordFunction& fn,
272     const bool record_cuda) {
273   if (config_.disabled()) {
274     return;
275   }
276   if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
277     torch::profiler::impl::cudaStubs()->rangePop();
278   } else {
279     // In some cases RecordFunction (and popRange) may be
280     // called on a different thread than pushRange
281     // As a convention, we put the async pop on the original
282     // thread and save current thread id in pop event
283     LegacyEvent evt(
284         EventKind::PopRange,
285         at::StringView(""),
286         at::RecordFunction::currentThreadId(),
287         record_cuda,
288         fn.handle());
289     evt.setNodeId(at::RecordFunction::getDefaultNodeId());
290     getEventList(fn.threadId()).record(std::move(evt));
291   }
292 }
293 
reportMemoryUsage(void *,int64_t alloc_size,size_t,size_t,c10::Device device)294 void ProfilerLegacyThreadLocalState::reportMemoryUsage(
295     void* /* unused */,
296     int64_t alloc_size,
297     size_t /* total_allocated, unused for legacy */,
298     size_t /* total_reserved, unused for legacy */,
299     c10::Device device) {
300   if (config_.profile_memory && !config_.disabled()) {
301     uint64_t thread_id = at::RecordFunction::currentThreadId();
302     LegacyEvent evt(
303         EventKind::MemoryAlloc,
304         at::StringView(""),
305         thread_id,
306         config_.state == torch::profiler::impl::ProfilerState::CUDA);
307     evt.updateMemoryStats(alloc_size, device);
308     getEventList(thread_id).record(std::move(evt));
309   }
310 }
311 
getEventList(std::optional<uint64_t> thread_id)312 RangeEventList& ProfilerLegacyThreadLocalState::getEventList(
313     std::optional<uint64_t> thread_id) {
314   if (!thread_id.has_value()) {
315     thread_id = at::RecordFunction::currentThreadId();
316   }
317   RangeEventList* list_ptr = nullptr;
318   std::lock_guard<std::mutex> guard(state_mutex_);
319   auto it = event_lists_map_.find(thread_id.value());
320   if (it != event_lists_map_.end()) {
321     list_ptr = it->second.get();
322   } else {
323     auto event_list = std::make_shared<RangeEventList>();
324     event_lists_map_[thread_id.value()] = event_list;
325     list_ptr = event_list.get();
326   }
327   return *list_ptr;
328 }
329 
330 enum EventIValueIdx {
331   KIND = 0,
332   NAME,
333   THREAD_ID,
334   HANDLE,
335   NODE_ID,
336   CPU_MEM_USAGE,
337   CPU_NS,
338   CUDA_RECORDED,
339   CUDA_MEM_USAGE,
340   CUDA_DEVICE,
341   CUDA_US,
342   SHAPES,
343   NUM_EVENT_IVALUE_IDX // must be last in list
344 };
345 
346 const std::unordered_set<std::string> disable_cuda_profiling = {
347     "aten::view",
348     "aten::t",
349     "aten::transpose",
350     "aten::stride",
351     "aten::empty",
352     "aten::empty_like",
353     "aten::empty_strided",
354     "aten::as_strided",
355     "aten::expand",
356     "aten::resize_",
357     "aten::squeeze",
358     "aten::unsqueeze",
359     "aten::slice",
360     "aten::_unsafe_view",
361     "aten::size"};
362 
pushProfilingCallbacksLegacy()363 void pushProfilingCallbacksLegacy() {
364   auto registration_state_ptr = ProfilerLegacyThreadLocalState::getTLS();
365   TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set");
366   auto handle = at::addThreadLocalCallback(
367       at::RecordFunctionCallback(
368           [](const at::RecordFunction& fn)
369               -> std::unique_ptr<at::ObserverContext> {
370             auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
371             if (!state_ptr || state_ptr->config().disabled()) {
372               return nullptr;
373             }
374             bool record_cuda = state_ptr->config().state ==
375                 torch::profiler::impl::ProfilerState::CUDA;
376             if (record_cuda &&
377                 disable_cuda_profiling.find(fn.name()) !=
378                     disable_cuda_profiling.end()) {
379               record_cuda = false;
380             }
381 
382             if (state_ptr->config().report_input_shapes) {
383               auto sizes = torch::profiler::impl::inputSizes(fn);
384               state_ptr->pushRange(fn, record_cuda, std::move(sizes));
385             } else {
386               state_ptr->pushRange(fn, record_cuda);
387             }
388 
389             return nullptr;
390           },
391           [](const at::RecordFunction& fn, at::ObserverContext*) {
392             auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
393             if (!state_ptr || state_ptr->config().disabled()) {
394               return;
395             }
396             bool record_cuda = state_ptr->config().state ==
397                 torch::profiler::impl::ProfilerState::CUDA;
398             if (record_cuda &&
399                 disable_cuda_profiling.find(fn.name()) !=
400                     disable_cuda_profiling.end()) {
401               record_cuda = false;
402             }
403             state_ptr->popRange(fn, record_cuda);
404           })
405           .needsInputs(registration_state_ptr->config().report_input_shapes)
406           .needsIds(true));
407   registration_state_ptr->setCallbackHandle(handle);
408 }
409 
410 } // namespace
411 
enableProfilerLegacy(const torch::profiler::impl::ProfilerConfig & new_config)412 void enableProfilerLegacy(
413     const torch::profiler::impl::ProfilerConfig& new_config) {
414   TORCH_CHECK(
415       new_config.state != torch::profiler::impl::ProfilerState::NVTX ||
416           torch::profiler::impl::cudaStubs()->enabled(),
417       "Can't use NVTX profiler - PyTorch was compiled without CUDA");
418 
419   TORCH_CHECK(new_config.state != torch::profiler::impl::ProfilerState::KINETO);
420 
421   auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
422   TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread");
423   auto state = std::make_shared<ProfilerLegacyThreadLocalState>(new_config);
424   c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
425 
426   pushProfilingCallbacksLegacy();
427 
428   state->mark("__start_profile", false);
429 }
430 
disableProfilerLegacy(std::optional<ProfilerDisableOptions> profilerDisableOptions)431 thread_event_lists disableProfilerLegacy(
432     std::optional<ProfilerDisableOptions> profilerDisableOptions) {
433   auto cleanupTLSState =
434       profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true;
435   auto consolidate =
436       profilerDisableOptions ? profilerDisableOptions->consolidate : true;
437   // all the DebugInfoBase objects are scope based and supposed to use
438   // DebugInfoGuard
439   std::shared_ptr<c10::DebugInfoBase> state;
440   if (cleanupTLSState) {
441     state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE);
442   } else {
443     state =
444         c10::ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PROFILER_STATE);
445   }
446 
447   auto state_ptr = static_cast<ProfilerLegacyThreadLocalState*>(state.get());
448   TORCH_CHECK(
449       state_ptr && !state_ptr->config().disabled(),
450       "Can't disable profiler when it's not running");
451 
452   cleanupTLSState ? state_ptr->removeCallback() : state_ptr->leakHandle();
453   if (!consolidate ||
454       state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX) {
455     return thread_event_lists();
456   }
457 
458   state_ptr->mark("__stop_profile", false);
459   // Note that this will erase the underlying events.
460   return state_ptr->consolidate();
461 }
462 
addEventList(std::vector<LegacyEvent> && profiledEvents)463 void addEventList(std::vector<LegacyEvent>&& profiledEvents) {
464   auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
465   TORCH_CHECK(state_ptr, "Profiler must be enabled.");
466   state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents));
467 }
468 
record(bool record_cuda)469 void LegacyEvent::record(bool record_cuda) {
470   if (record_cuda) {
471     torch::profiler::impl::cudaStubs()->record(&device_, &cuda_event, &cpu_ns_);
472     return;
473   }
474   cpu_ns_ = c10::getTime();
475 }
476 
fromIValue(const at::IValue & eventIValue)477 /* static */ LegacyEvent LegacyEvent::fromIValue(
478     const at::IValue& eventIValue) {
479   TORCH_INTERNAL_ASSERT(
480       eventIValue.isList(),
481       "Expected IValue to contain type c10::impl::GenericList");
482   auto ivalues = eventIValue.toList();
483   TORCH_INTERNAL_ASSERT(
484       ivalues.size() >= NUM_EVENT_IVALUE_IDX,
485       "Expected at least ",
486       NUM_EVENT_IVALUE_IDX,
487       " elements to reconstruct LegacyEvent.");
488 
489   // Reconstruct input shapes from ivalues.
490   const auto& shapeListIValue = ivalues.get(EventIValueIdx::SHAPES);
491   TORCH_INTERNAL_ASSERT(
492       shapeListIValue.isList(),
493       "Expected profiler shapes IValue to contain type c10::impl::GenericList.");
494 
495   auto shapeList = shapeListIValue.toList();
496   std::vector<std::vector<int64_t>> shapes;
497   shapes.reserve(shapeList.size());
498   for (const auto i : c10::irange(shapeList.size())) {
499     std::vector<int64_t> s;
500     const auto& shapeIValue = shapeList.get(i);
501     TORCH_INTERNAL_ASSERT(
502         shapeIValue.isList(),
503         "Expected each profiler shape element to contain shapes of type c10::impl::GenericList.")
504     auto curShapesList = shapeIValue.toList();
505     s.reserve(curShapesList.size());
506     for (const auto j : c10::irange(curShapesList.size())) {
507       s.emplace_back(curShapesList.get(j).toInt());
508     }
509     shapes.emplace_back(s);
510   }
511 
512   LegacyEvent evt(
513       static_cast<EventKind>(
514           ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind
515       at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name
516       ivalues.get(EventIValueIdx::THREAD_ID).toInt(), // thread_id
517       static_cast<at::RecordFunctionHandle>(
518           ivalues.get(EventIValueIdx::HANDLE).toDouble()), // handle
519       std::move(shapes), // input shapes
520       ivalues.get(EventIValueIdx::NODE_ID).toInt(), // node id
521       true, // is remote
522       ivalues.get(EventIValueIdx::CPU_MEM_USAGE).toInt(), // cpu_mem_usage
523       ivalues.get(EventIValueIdx::CPU_NS).toInt(), // cpu_ns
524       ivalues.get(EventIValueIdx::CUDA_RECORDED).toBool(), // was cuda recorded
525       ivalues.get(EventIValueIdx::CUDA_MEM_USAGE).toInt(), // cuda memory usage
526       c10::DeviceIndex(
527           ivalues.get(EventIValueIdx::CUDA_DEVICE).toInt()), // device
528       static_cast<double>(
529           ivalues.get(EventIValueIdx::CUDA_US).toInt()) // cuda_us
530   );
531   return evt;
532 }
533 
toIValue() const534 at::IValue LegacyEvent::toIValue() const {
535   c10::impl::GenericList eventIValueList(at::AnyType::get());
536   eventIValueList.reserve(NUM_EVENT_IVALUE_IDX);
537   eventIValueList.emplace_back(static_cast<int64_t>(kind_));
538   eventIValueList.emplace_back(std::string(name_.str()));
539   eventIValueList.emplace_back(static_cast<int64_t>(thread_id_));
540   eventIValueList.emplace_back(static_cast<double>(handle_));
541   eventIValueList.emplace_back(node_id_);
542   eventIValueList.emplace_back(cpu_memory_usage_);
543   eventIValueList.emplace_back(cpu_ns_);
544   // CUDA event information
545   bool cuda_profiling_enabled = hasCuda();
546   eventIValueList.emplace_back(cuda_profiling_enabled);
547   eventIValueList.emplace_back(static_cast<int64_t>(cuda_memory_usage_));
548   eventIValueList.emplace_back(device_);
549   eventIValueList.emplace_back(cuda_us_);
550   // Shapes
551   c10::impl::GenericList shapesList =
552       c10::impl::GenericList(at::ListType::create(at::IntType::get()));
553   shapesList.reserve(shapes_.size());
554   for (const auto& shape : shapes_) {
555     c10::impl::GenericList s = c10::impl::GenericList(at::IntType::get());
556     s.reserve(shape.size());
557     for (const auto& k : shape) {
558       s.emplace_back(k);
559     }
560     shapesList.emplace_back(s);
561   }
562   eventIValueList.emplace_back(shapesList);
563   return at::IValue(eventIValueList);
564 }
565 
cudaElapsedUs(const LegacyEvent & e) const566 double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const {
567   TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA");
568   TORCH_CHECK(
569       e.device() == device(),
570       c10::str(
571           "Events are not on the same device: ", e.device(), " vs ", device()));
572   if (isRemote() && e.isRemote()) {
573     // validate that cuda_us_ has been set properly.
574     TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0);
575     return static_cast<double>(e.cuda_us_ - cuda_us_);
576   }
577   return torch::profiler::impl::cudaStubs()->elapsed(
578       &cuda_event, &e.cuda_event);
579 }
580 
581 static const at::jit::CodeTemplate event_template(R"(
582 {
583   "name": "${name}",
584   "ph": "X",
585   "ts": ${ts},
586   "dur": ${dur},
587   "tid": ${tid},
588   "pid": "CPU Functions",
589   "args": {}
590 })");
591 
writeProfilerEventsToStream(std::ostream & out,const std::vector<LegacyEvent * > & events)592 void writeProfilerEventsToStream(
593     std::ostream& out,
594     const std::vector<LegacyEvent*>& events) {
595   TORCH_CHECK(out, "Could not open file");
596   LegacyEvent* profiler_start = nullptr;
597   for (LegacyEvent* e : events) {
598     if (0 == strcmp(e->name(), "__start_profile")) {
599       profiler_start = e;
600       break;
601     }
602   }
603   TORCH_CHECK(profiler_start, "Could not find __start_profile mark");
604 
605   struct PairHash {
606     size_t operator()(
607         std::pair<at::RecordFunctionHandle, int> p) const noexcept {
608       return std::hash<at::RecordFunctionHandle>()(p.first) ^
609           std::hash<int64_t>()(p.second);
610     }
611   };
612   std::unordered_map<
613       std::pair<at::RecordFunctionHandle, int64_t>,
614       LegacyEvent*,
615       PairHash>
616       events_map;
617   out << "[\n";
618   bool first = true;
619   for (LegacyEvent* evt : events) {
620     if (evt->kindStr() == "push") {
621       events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt;
622     } else if (evt->kindStr() == "pop") {
623       if (!first) {
624         out << ",\n";
625       }
626       first = false;
627       auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId()));
628       TORCH_CHECK(it != events_map.end(), "Unmatched pop event");
629       LegacyEvent* evt_start = it->second;
630       events_map.erase(it);
631 
632       at::jit::TemplateEnv env;
633       env.s("name", evt_start->name());
634       env.d("ts", profiler_start->cpuElapsedUs(*evt_start));
635       env.d("dur", evt_start->cpuElapsedUs(*evt));
636       env.d("tid", evt_start->threadId());
637       out << event_template.format(env);
638     }
639   }
640   out << "]\n";
641 }
642 
RecordProfile(std::ostream & out)643 RecordProfile::RecordProfile(std::ostream& out) : out_(out) {
644   init();
645 }
646 
RecordProfile(const std::string & filename)647 RecordProfile::RecordProfile(const std::string& filename)
648     : file_(std::make_unique<std::ofstream>(filename)), out_(*file_) {
649   init();
650 }
651 
init()652 void RecordProfile::init() {
653   enableProfilerLegacy(torch::profiler::impl::ProfilerConfig(
654       torch::profiler::impl::ProfilerState::CPU));
655 }
656 
~RecordProfile()657 RecordProfile::~RecordProfile() {
658   try {
659     thread_event_lists event_lists = disableProfilerLegacy();
660     std::vector<LegacyEvent*> events;
661     for (auto& l : event_lists) {
662       for (auto& e : l) {
663         events.push_back(&e);
664       }
665     }
666     processEvents(events);
667   } catch (const std::exception& e) {
668     LOG(ERROR) << e.what() << '\n';
669   } catch (...) {
670     LOG(ERROR) << "Unknown error" << '\n';
671   }
672 }
673 
processEvents(const std::vector<LegacyEvent * > & events)674 void RecordProfile::processEvents(const std::vector<LegacyEvent*>& events) {
675   writeProfilerEventsToStream(out_, events);
676 }
677 
678 } // namespace torch::autograd::profiler
679