1 #pragma once 2 3 #include <c10/core/DeviceType.h> 4 #include <c10/core/Stream.h> 5 #include <c10/core/impl/DeviceGuardImplInterface.h> 6 #include <c10/util/Exception.h> 7 8 namespace c10::impl { 9 10 template <typename T> 11 struct InlineEvent final { 12 InlineEvent() = delete; 13 InlineEvent( 14 const DeviceType _device_type, 15 const EventFlag _flag = EventFlag::PYTORCH_DEFAULT) 16 : backend_{_device_type}, device_type_{_device_type}, flag_{_flag} {} 17 18 // Copy constructor and copy assignment operator (deleted) 19 InlineEvent(const InlineEvent&) = delete; 20 InlineEvent& operator=(const InlineEvent&) = delete; 21 22 // Move constructor and move assignment operator InlineEventfinal23 InlineEvent(InlineEvent&& other) noexcept 24 : event_(other.event_), 25 backend_(std::move(other.backend_)), 26 device_type_(other.device_type_), 27 device_index_(other.device_index_), 28 flag_(other.flag_), 29 was_marked_for_recording_(other.was_marked_for_recording_) { 30 other.event_ = nullptr; 31 } 32 InlineEvent& operator=(InlineEvent&& other) noexcept { 33 swap(other); 34 return *this; 35 } 36 swapfinal37 void swap(InlineEvent& other) noexcept { 38 std::swap(event_, other.event_); 39 std::swap(backend_, other.backend_); 40 std::swap(device_type_, other.device_type_); 41 std::swap(device_index_, other.device_index_); 42 std::swap(flag_, other.flag_); 43 std::swap(was_marked_for_recording_, other.was_marked_for_recording_); 44 } 45 ~InlineEventfinal46 ~InlineEvent() noexcept { 47 if (event_) 48 backend_.destroyEvent(event_, device_index_); 49 } 50 device_typefinal51 DeviceType device_type() const noexcept { 52 return device_type_; 53 } device_indexfinal54 DeviceIndex device_index() const noexcept { 55 return device_index_; 56 } flagfinal57 EventFlag flag() const noexcept { 58 return flag_; 59 } was_marked_for_recordingfinal60 bool was_marked_for_recording() const noexcept { 61 return was_marked_for_recording_; 62 } 63 recordOncefinal64 void recordOnce(const Stream& stream) { 65 if (!was_marked_for_recording_) 66 record(stream); 67 } 68 recordfinal69 void record(const Stream& stream) { 70 TORCH_CHECK( 71 stream.device_type() == device_type_, 72 "Event device type ", 73 DeviceTypeName(device_type_), 74 " does not match recording stream's device type ", 75 DeviceTypeName(stream.device_type()), 76 "."); 77 78 backend_.record(&event_, stream, device_index_, flag_); 79 was_marked_for_recording_ = true; 80 device_index_ = stream.device_index(); 81 } 82 blockfinal83 void block(const Stream& stream) const { 84 if (!was_marked_for_recording_) 85 return; 86 87 TORCH_CHECK( 88 stream.device_type() == device_type_, 89 "Event device type ", 90 DeviceTypeName(device_type_), 91 " does not match blocking stream's device type ", 92 DeviceTypeName(stream.device_type()), 93 "."); 94 95 backend_.block(event_, stream); 96 } 97 queryfinal98 bool query() const { 99 if (!was_marked_for_recording_) 100 return true; 101 return backend_.queryEvent(event_); 102 } 103 eventIdfinal104 void* eventId() const { 105 return event_; 106 } 107 elapsedTimefinal108 double elapsedTime(const InlineEvent& other) const { 109 TORCH_CHECK( 110 other.was_marked_for_recording(), 111 "other was not marked for recording."); 112 TORCH_CHECK( 113 was_marked_for_recording(), "self was not marked for recording."); 114 TORCH_CHECK( 115 other.device_type() == device_type_, 116 "Event device type ", 117 DeviceTypeName(device_type_), 118 " does not match other's device type ", 119 DeviceTypeName(other.device_type()), 120 "."); 121 return backend_.elapsedTime(event_, other.event_, device_index_); 122 } 123 synchronizefinal124 void synchronize() const { 125 if (!was_marked_for_recording_) 126 return; 127 backend_.synchronizeEvent(event_); 128 } 129 130 private: 131 void* event_ = nullptr; 132 T backend_; 133 DeviceType device_type_; 134 DeviceIndex device_index_ = -1; 135 EventFlag flag_ = EventFlag::PYTORCH_DEFAULT; 136 bool was_marked_for_recording_ = false; 137 }; 138 139 } // namespace c10::impl 140