1 #pragma once
2 #include <c10/util/ApproximateClock.h>
3 #include <torch/csrc/autograd/profiler.h>
4
5 namespace c10d {
6 constexpr int kUnsetTime = -1;
7
current_time_in_nanos()8 inline int64_t current_time_in_nanos() {
9 return c10::getTime();
10 }
11
12 class TORCH_API Timer {
13 private:
14 // The timestamp of forward call start time in each iteration.
15 int64_t forward_start_time = kUnsetTime;
16 // The timestamp of backward computation start and end time in each
17 // iteration.
18 int64_t backward_compute_start_time = kUnsetTime;
19 int64_t backward_compute_end_time = kUnsetTime;
20 // The timestamp of first communication call start time in each iteration.
21 int64_t backward_comm_start_time = kUnsetTime;
22 // The timestamp of last communication call end time in each iteration.
23 int64_t backward_comm_end_time = kUnsetTime;
24
25 public:
26 enum class Event : uint8_t {
27 kForwardStart,
28 kBackwardComputeStart,
29 kBackwardComputeEnd,
30 kBackwardCommStart,
31 kBackwardCommEnd,
32 };
33
34 // Record the current event, i.e., mark it as having occurred now. Default
35 // CPU implementation.
record(Event event)36 virtual void record(Event event) {
37 getTimeRef(event) = current_time_in_nanos();
38 }
39
40 // Return the difference between when two events occurred, in nanoseconds.
41 // Or nullopt if one of them hasn't been recorded.
42 virtual std::optional<int64_t> measureDifference(Event start, Event end) = 0;
43
44 virtual ~Timer() = default;
45
46 // Return host-side timestamp, or nullopt if it has not yet been recorded.
getTimestamp(Event event)47 std::optional<int64_t> getTimestamp(Event event) {
48 auto time = getTimeRef(event);
49 if (time == kUnsetTime) {
50 return std::nullopt;
51 } else {
52 return time;
53 }
54 }
55
56 // Return host-side time member variable corresponding to the given event.
getTimeRef(Event event)57 int64_t& getTimeRef(Event event) {
58 switch (event) {
59 case Event::kForwardStart:
60 return forward_start_time;
61 case Event::kBackwardComputeStart:
62 return backward_compute_start_time;
63 case Event::kBackwardComputeEnd:
64 return backward_compute_end_time;
65 case Event::kBackwardCommStart:
66 return backward_comm_start_time;
67 case Event::kBackwardCommEnd:
68 return backward_comm_end_time;
69 default:
70 TORCH_INTERNAL_ASSERT(false);
71 }
72 }
73 };
74
75 TORCH_DECLARE_TYPED_REGISTRY(
76 TimerRegistry,
77 c10::DeviceType,
78 Timer,
79 std::unique_ptr,
80 c10::Device);
81 } // namespace c10d
82