1 // Copyright © 2023 Apple Inc. 2 3 #pragma once 4 5 #include <ATen/mps/MPSStream.h> 6 #include <ctime> 7 #include <stack> 8 9 namespace at::mps { 10 11 // NOTE: don't create instances of this class directly. 12 // Use MPSEventPool to acquire instances of MPSEvent. 13 class MPSEvent { 14 public: 15 explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing); 16 ~MPSEvent(); 17 18 // records an event on the stream 19 void record(bool needsLock, bool syncEvent = false); 20 // makes all future work submitted to the stream wait for this event. 21 bool wait(bool needsLock, bool syncEvent = false); 22 // schedules a notifyListener callback for the event. 23 bool notify(bool needsLock, MTLSharedEventNotificationBlock block); 24 // checks if events are already signaled. 25 bool query() const; 26 // blocks the CPU thread until all the GPU work that were scheduled 27 // prior to recording this event are completed. 28 bool synchronize(); 29 // resets this event with new parameters in case it gets reused from the event pool 30 void reset(MPSStream* stream, bool enable_timing); 31 // returns the unique ID of the event instance getID()32 id_t getID() const { return m_id; } 33 // returns the completion timestamp of the event getCompletionTime()34 uint64_t getCompletionTime() const { return m_completion_time; } 35 // if already recorded, waits for cpu_sync_cv to be signaled 36 void waitForCpuSync(); 37 38 private: 39 id_t m_id; 40 // enables measuring the completion time of the notifyListener of this event 41 bool m_enable_timing; 42 uint64_t m_signalCounter = 0; 43 MPSStream* m_stream = nullptr; 44 MTLSharedEvent_t m_event = nullptr; 45 MTLSharedEventListener* m_listener = nullptr; 46 // used to sync the events created on this Stream with CPU 47 std::mutex m_cpu_sync_mutex{}; 48 std::condition_variable m_cpu_sync_cv{}; 49 // CondVar predicate to sync the events created on this Stream with CPU 50 bool m_cpu_sync_completed = false; 51 // used to compute elapsed time 52 uint64_t m_completion_time = 0; 53 54 void recordLocked(bool syncEvent); 55 bool waitLocked(bool syncEvent); 56 bool notifyLocked(MTLSharedEventNotificationBlock block); 57 void notifyCpuSync(); getTime()58 static uint64_t getTime() { 59 return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW); 60 } 61 }; 62 63 typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr; 64 65 class MPSEventPool { 66 public: 67 explicit MPSEventPool(MPSStream* default_stream); 68 ~MPSEventPool(); 69 70 MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream); 71 void emptyCache(); 72 73 // these are mainly used for MPSHooks and torch.mps.Event() bindings 74 id_t acquireEvent(bool enable_timing); 75 void releaseEvent(id_t event_id); 76 void recordEvent(id_t event_id, bool syncEvent); 77 void waitForEvent(id_t event_id, bool syncEvent); 78 void synchronizeEvent(id_t event_id); 79 bool queryEvent(id_t event_id); 80 // returns elapsed time between two recorded events in milliseconds 81 double elapsedTime(id_t start_event_id, id_t end_event_id); 82 83 private: 84 MPSStream* m_default_stream = nullptr; 85 std::recursive_mutex m_mutex; 86 std::stack<std::unique_ptr<MPSEvent>> m_pool{}; 87 // dictionary to associate event IDs with event objects 88 // used to retain in-use events out of the pool 89 // for torch.mps.Event() bindings. 90 std::unordered_map<id_t, MPSEventPtr> m_in_use_events{}; 91 uint64_t m_event_counter = 0; 92 std::function<void(MPSEvent*)> m_default_deleter; 93 94 MPSEvent* getInUseEvent(id_t event_id, bool locked = true); 95 }; 96 97 // shared_ptr is used to get MPSEventPool destroyed after dependent instances 98 std::shared_ptr<MPSEventPool> getMPSEventPool(); 99 100 } // namespace at::mps 101