1 #include <sstream>
2
3 #ifdef TORCH_CUDA_USE_NVTX3
4 #include <nvtx3/nvtx3.hpp>
5 #else
6 #include <nvToolsExt.h>
7 #endif
8
9 #include <c10/cuda/CUDAGuard.h>
10 #include <c10/util/ApproximateClock.h>
11 #include <c10/util/irange.h>
12 #include <torch/csrc/profiler/stubs/base.h>
13 #include <torch/csrc/profiler/util.h>
14
15 namespace torch {
16 namespace profiler {
17 namespace impl {
18 namespace {
19
cudaCheck(cudaError_t result,const char * file,int line)20 static inline void cudaCheck(cudaError_t result, const char* file, int line) {
21 if (result != cudaSuccess) {
22 std::stringstream ss;
23 ss << file << ":" << line << ": ";
24 if (result == cudaErrorInitializationError) {
25 // It is common for users to use DataLoader with multiple workers
26 // and the autograd profiler. Throw a nice error message here.
27 ss << "CUDA initialization error. "
28 << "This can occur if one runs the profiler in CUDA mode on code "
29 << "that creates a DataLoader with num_workers > 0. This operation "
30 << "is currently unsupported; potential workarounds are: "
31 << "(1) don't use the profiler in CUDA mode or (2) use num_workers=0 "
32 << "in the DataLoader or (3) Don't profile the data loading portion "
33 << "of your code. https://github.com/pytorch/pytorch/issues/6313 "
34 << "tracks profiler support for multi-worker DataLoader.";
35 } else {
36 ss << cudaGetErrorString(result);
37 }
38 throw std::runtime_error(ss.str());
39 }
40 }
41 #define TORCH_CUDA_CHECK(result) cudaCheck(result, __FILE__, __LINE__);
42
43 struct CUDAMethods : public ProfilerStubs {
recordtorch::profiler::impl::__anon3eda17280111::CUDAMethods44 void record(
45 c10::DeviceIndex* device,
46 ProfilerVoidEventStub* event,
47 int64_t* cpu_ns) const override {
48 if (device) {
49 TORCH_CUDA_CHECK(c10::cuda::GetDevice(device));
50 }
51 CUevent_st* cuda_event_ptr{nullptr};
52 TORCH_CUDA_CHECK(cudaEventCreate(&cuda_event_ptr));
53 *event = std::shared_ptr<CUevent_st>(cuda_event_ptr, [](CUevent_st* ptr) {
54 TORCH_CUDA_CHECK(cudaEventDestroy(ptr));
55 });
56 auto stream = at::cuda::getCurrentCUDAStream();
57 if (cpu_ns) {
58 *cpu_ns = c10::getTime();
59 }
60 TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream));
61 }
62
elapsedtorch::profiler::impl::__anon3eda17280111::CUDAMethods63 float elapsed(
64 const ProfilerVoidEventStub* event_,
65 const ProfilerVoidEventStub* event2_) const override {
66 auto event = (const ProfilerEventStub*)(event_);
67 auto event2 = (const ProfilerEventStub*)(event2_);
68 TORCH_CUDA_CHECK(cudaEventSynchronize(event->get()));
69 TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get()));
70 float ms = 0;
71 TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event->get(), event2->get()));
72 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
73 return ms * 1000.0;
74 }
75
marktorch::profiler::impl::__anon3eda17280111::CUDAMethods76 void mark(const char* name) const override {
77 ::nvtxMark(name);
78 }
79
rangePushtorch::profiler::impl::__anon3eda17280111::CUDAMethods80 void rangePush(const char* name) const override {
81 ::nvtxRangePushA(name);
82 }
83
rangePoptorch::profiler::impl::__anon3eda17280111::CUDAMethods84 void rangePop() const override {
85 ::nvtxRangePop();
86 }
87
onEachDevicetorch::profiler::impl::__anon3eda17280111::CUDAMethods88 void onEachDevice(std::function<void(int)> op) const override {
89 at::cuda::OptionalCUDAGuard device_guard;
90 for (const auto i : c10::irange(at::cuda::device_count())) {
91 device_guard.set_index(i);
92 op(i);
93 }
94 }
95
synchronizetorch::profiler::impl::__anon3eda17280111::CUDAMethods96 void synchronize() const override {
97 TORCH_CUDA_CHECK(cudaDeviceSynchronize());
98 }
99
enabledtorch::profiler::impl::__anon3eda17280111::CUDAMethods100 bool enabled() const override {
101 return true;
102 }
103 };
104
105 struct RegisterCUDAMethods {
RegisterCUDAMethodstorch::profiler::impl::__anon3eda17280111::RegisterCUDAMethods106 RegisterCUDAMethods() {
107 static CUDAMethods methods;
108 registerCUDAMethods(&methods);
109 }
110 };
111 RegisterCUDAMethods reg;
112
113 } // namespace
114 } // namespace impl
115 } // namespace profiler
116 } // namespace torch
117