xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/reducer_timer.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/util/ApproximateClock.h>
3 #include <torch/csrc/autograd/profiler.h>
4 
5 namespace c10d {
6 constexpr int kUnsetTime = -1;
7 
current_time_in_nanos()8 inline int64_t current_time_in_nanos() {
9   return c10::getTime();
10 }
11 
12 class TORCH_API Timer {
13  private:
14   // The timestamp of forward call start time in each iteration.
15   int64_t forward_start_time = kUnsetTime;
16   // The timestamp of backward computation start and end time in each
17   // iteration.
18   int64_t backward_compute_start_time = kUnsetTime;
19   int64_t backward_compute_end_time = kUnsetTime;
20   // The timestamp of first communication call start time in each iteration.
21   int64_t backward_comm_start_time = kUnsetTime;
22   // The timestamp of last communication call end time in each iteration.
23   int64_t backward_comm_end_time = kUnsetTime;
24 
25  public:
26   enum class Event : uint8_t {
27     kForwardStart,
28     kBackwardComputeStart,
29     kBackwardComputeEnd,
30     kBackwardCommStart,
31     kBackwardCommEnd,
32   };
33 
34   // Record the current event, i.e., mark it as having occurred now. Default
35   // CPU implementation.
record(Event event)36   virtual void record(Event event) {
37     getTimeRef(event) = current_time_in_nanos();
38   }
39 
40   // Return the difference between when two events occurred, in nanoseconds.
41   // Or nullopt if one of them hasn't been recorded.
42   virtual std::optional<int64_t> measureDifference(Event start, Event end) = 0;
43 
44   virtual ~Timer() = default;
45 
46   // Return host-side timestamp, or nullopt if it has not yet been recorded.
getTimestamp(Event event)47   std::optional<int64_t> getTimestamp(Event event) {
48     auto time = getTimeRef(event);
49     if (time == kUnsetTime) {
50       return std::nullopt;
51     } else {
52       return time;
53     }
54   }
55 
56   // Return host-side time member variable corresponding to the given event.
getTimeRef(Event event)57   int64_t& getTimeRef(Event event) {
58     switch (event) {
59       case Event::kForwardStart:
60         return forward_start_time;
61       case Event::kBackwardComputeStart:
62         return backward_compute_start_time;
63       case Event::kBackwardComputeEnd:
64         return backward_compute_end_time;
65       case Event::kBackwardCommStart:
66         return backward_comm_start_time;
67       case Event::kBackwardCommEnd:
68         return backward_comm_end_time;
69       default:
70         TORCH_INTERNAL_ASSERT(false);
71     }
72   }
73 };
74 
75 TORCH_DECLARE_TYPED_REGISTRY(
76     TimerRegistry,
77     c10::DeviceType,
78     Timer,
79     std::unique_ptr,
80     c10::Device);
81 } // namespace c10d
82