xref: /aosp_15_r20/external/pytorch/c10/core/impl/InlineEvent.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/DeviceType.h>
4 #include <c10/core/Stream.h>
5 #include <c10/core/impl/DeviceGuardImplInterface.h>
6 #include <c10/util/Exception.h>
7 
8 namespace c10::impl {
9 
10 template <typename T>
11 struct InlineEvent final {
12   InlineEvent() = delete;
13   InlineEvent(
14       const DeviceType _device_type,
15       const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
16       : backend_{_device_type}, device_type_{_device_type}, flag_{_flag} {}
17 
18   // Copy constructor and copy assignment operator (deleted)
19   InlineEvent(const InlineEvent&) = delete;
20   InlineEvent& operator=(const InlineEvent&) = delete;
21 
22   // Move constructor and move assignment operator
InlineEventfinal23   InlineEvent(InlineEvent&& other) noexcept
24       : event_(other.event_),
25         backend_(std::move(other.backend_)),
26         device_type_(other.device_type_),
27         device_index_(other.device_index_),
28         flag_(other.flag_),
29         was_marked_for_recording_(other.was_marked_for_recording_) {
30     other.event_ = nullptr;
31   }
32   InlineEvent& operator=(InlineEvent&& other) noexcept {
33     swap(other);
34     return *this;
35   }
36 
swapfinal37   void swap(InlineEvent& other) noexcept {
38     std::swap(event_, other.event_);
39     std::swap(backend_, other.backend_);
40     std::swap(device_type_, other.device_type_);
41     std::swap(device_index_, other.device_index_);
42     std::swap(flag_, other.flag_);
43     std::swap(was_marked_for_recording_, other.was_marked_for_recording_);
44   }
45 
~InlineEventfinal46   ~InlineEvent() noexcept {
47     if (event_)
48       backend_.destroyEvent(event_, device_index_);
49   }
50 
device_typefinal51   DeviceType device_type() const noexcept {
52     return device_type_;
53   }
device_indexfinal54   DeviceIndex device_index() const noexcept {
55     return device_index_;
56   }
flagfinal57   EventFlag flag() const noexcept {
58     return flag_;
59   }
was_marked_for_recordingfinal60   bool was_marked_for_recording() const noexcept {
61     return was_marked_for_recording_;
62   }
63 
recordOncefinal64   void recordOnce(const Stream& stream) {
65     if (!was_marked_for_recording_)
66       record(stream);
67   }
68 
recordfinal69   void record(const Stream& stream) {
70     TORCH_CHECK(
71         stream.device_type() == device_type_,
72         "Event device type ",
73         DeviceTypeName(device_type_),
74         " does not match recording stream's device type ",
75         DeviceTypeName(stream.device_type()),
76         ".");
77 
78     backend_.record(&event_, stream, device_index_, flag_);
79     was_marked_for_recording_ = true;
80     device_index_ = stream.device_index();
81   }
82 
blockfinal83   void block(const Stream& stream) const {
84     if (!was_marked_for_recording_)
85       return;
86 
87     TORCH_CHECK(
88         stream.device_type() == device_type_,
89         "Event device type ",
90         DeviceTypeName(device_type_),
91         " does not match blocking stream's device type ",
92         DeviceTypeName(stream.device_type()),
93         ".");
94 
95     backend_.block(event_, stream);
96   }
97 
queryfinal98   bool query() const {
99     if (!was_marked_for_recording_)
100       return true;
101     return backend_.queryEvent(event_);
102   }
103 
eventIdfinal104   void* eventId() const {
105     return event_;
106   }
107 
elapsedTimefinal108   double elapsedTime(const InlineEvent& other) const {
109     TORCH_CHECK(
110         other.was_marked_for_recording(),
111         "other was not marked for recording.");
112     TORCH_CHECK(
113         was_marked_for_recording(), "self was not marked for recording.");
114     TORCH_CHECK(
115         other.device_type() == device_type_,
116         "Event device type ",
117         DeviceTypeName(device_type_),
118         " does not match other's device type ",
119         DeviceTypeName(other.device_type()),
120         ".");
121     return backend_.elapsedTime(event_, other.event_, device_index_);
122   }
123 
synchronizefinal124   void synchronize() const {
125     if (!was_marked_for_recording_)
126       return;
127     backend_.synchronizeEvent(event_);
128   }
129 
130  private:
131   void* event_ = nullptr;
132   T backend_;
133   DeviceType device_type_;
134   DeviceIndex device_index_ = -1;
135   EventFlag flag_ = EventFlag::PYTORCH_DEFAULT;
136   bool was_marked_for_recording_ = false;
137 };
138 
139 } // namespace c10::impl
140