xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSEvent.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2023 Apple Inc.
2 
3 #pragma once
4 
5 #include <ATen/mps/MPSStream.h>
6 #include <ctime>
7 #include <stack>
8 
9 namespace at::mps {
10 
11 // NOTE: don't create instances of this class directly.
12 // Use MPSEventPool to acquire instances of MPSEvent.
13 class MPSEvent {
14 public:
15   explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
16   ~MPSEvent();
17 
18   // records an event on the stream
19   void record(bool needsLock, bool syncEvent = false);
20   // makes all future work submitted to the stream wait for this event.
21   bool wait(bool needsLock, bool syncEvent = false);
22   // schedules a notifyListener callback for the event.
23   bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
24   // checks if events are already signaled.
25   bool query() const;
26   // blocks the CPU thread until all the GPU work that were scheduled
27   // prior to recording this event are completed.
28   bool synchronize();
29   // resets this event with new parameters in case it gets reused from the event pool
30   void reset(MPSStream* stream, bool enable_timing);
31   // returns the unique ID of the event instance
getID()32   id_t getID() const { return m_id; }
33   // returns the completion timestamp of the event
getCompletionTime()34   uint64_t getCompletionTime() const { return m_completion_time; }
35   // if already recorded, waits for cpu_sync_cv to be signaled
36   void waitForCpuSync();
37 
38 private:
39   id_t m_id;
40   // enables measuring the completion time of the notifyListener of this event
41   bool m_enable_timing;
42   uint64_t m_signalCounter = 0;
43   MPSStream* m_stream = nullptr;
44   MTLSharedEvent_t m_event = nullptr;
45   MTLSharedEventListener* m_listener = nullptr;
46   // used to sync the events created on this Stream with CPU
47   std::mutex m_cpu_sync_mutex{};
48   std::condition_variable m_cpu_sync_cv{};
49   // CondVar predicate to sync the events created on this Stream with CPU
50   bool m_cpu_sync_completed = false;
51   // used to compute elapsed time
52   uint64_t m_completion_time = 0;
53 
54   void recordLocked(bool syncEvent);
55   bool waitLocked(bool syncEvent);
56   bool notifyLocked(MTLSharedEventNotificationBlock block);
57   void notifyCpuSync();
getTime()58   static uint64_t getTime() {
59     return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
60   }
61 };
62 
63 typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
64 
65 class MPSEventPool {
66 public:
67   explicit MPSEventPool(MPSStream* default_stream);
68   ~MPSEventPool();
69 
70   MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
71   void emptyCache();
72 
73   // these are mainly used for MPSHooks and torch.mps.Event() bindings
74   id_t acquireEvent(bool enable_timing);
75   void releaseEvent(id_t event_id);
76   void recordEvent(id_t event_id, bool syncEvent);
77   void waitForEvent(id_t event_id, bool syncEvent);
78   void synchronizeEvent(id_t event_id);
79   bool queryEvent(id_t event_id);
80   // returns elapsed time between two recorded events in milliseconds
81   double elapsedTime(id_t start_event_id, id_t end_event_id);
82 
83 private:
84   MPSStream* m_default_stream = nullptr;
85   std::recursive_mutex m_mutex;
86   std::stack<std::unique_ptr<MPSEvent>> m_pool{};
87   // dictionary to associate event IDs with event objects
88   // used to retain in-use events out of the pool
89   // for torch.mps.Event() bindings.
90   std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
91   uint64_t m_event_counter = 0;
92   std::function<void(MPSEvent*)> m_default_deleter;
93 
94   MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
95 };
96 
97 // shared_ptr is used to get MPSEventPool destroyed after dependent instances
98 std::shared_ptr<MPSEventPool> getMPSEventPool();
99 
100 } // namespace at::mps
101