xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Work.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <chrono>
5 #include <mutex>
6 #include <vector>
7 
8 constexpr auto kNoTimeout = std::chrono::milliseconds(0);
9 
10 namespace c10d {
11 
12 constexpr const char* const kSeqNumStoreKey = "SEQ_NUM_STORE_KEY";
13 
14 enum class OpType : std::uint8_t {
15   BROADCAST = 0,
16   ALLREDUCE = 1,
17   ALLREDUCE_COALESCED = 2,
18   REDUCE = 3,
19   ALLGATHER = 4,
20   _ALLGATHER_BASE = 5,
21   ALLGATHER_COALESCED = 6,
22   GATHER = 7,
23   SCATTER = 8,
24   REDUCE_SCATTER = 9,
25   ALLTOALL_BASE = 10,
26   ALLTOALL = 11,
27   SEND = 12,
28   RECV = 13,
29   RECVANYSOURCE = 14,
30   BARRIER = 15,
31   _REDUCE_SCATTER_BASE = 16,
32   COALESCED = 17,
33   _ALLREDUCE_SPARSE = 18,
34   UNKNOWN = 100,
35 };
36 
37 // Converts OpType to human readable string.
38 TORCH_API std::string opTypeToString(OpType opType);
39 
40 // Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE)
41 TORCH_API bool isP2POp(OpType opType, bool batchP2P = false);
42 
43 // Please do not use Work API, it is going away, to be
44 // replaced by ivalue::Future.
45 // Python binding for this class might change, please do not assume
46 // this will be bound using pybind.
47 class TORCH_API Work : public torch::CustomClassHolder {
48  public:
49   Work(
50       int rank = -1,
51       OpType opType = OpType::UNKNOWN,
52       const char* profilingTitle = nullptr,
53       const std::optional<std::vector<at::Tensor>>& inputTensors =
54           std::nullopt);
55 
56   ~Work() override;
57 
58   // Checks if request has completed. Non-blocking operation.
59   virtual bool isCompleted();
60 
61   // Returns if the work completed successfully.
62   // If false, the exception function can be called to get details.
63   virtual bool isSuccess() const;
64 
65   // Returns exception if isSuccess() returned false.
66   virtual std::exception_ptr exception() const;
67 
68   // Returns source rank if this objects represents a recv-from-any.
69   virtual int sourceRank() const;
70 
71   // Returns result tensors, if applicable.
72   // If work is not supposed to have result, we return empty list.
73   virtual std::vector<at::Tensor> result();
74 
75   // Ensures that operations on the output tensors that are invoked
76   // after this function returns are correctly sequenced after the
77   // asynchronous completion of this work.
78   //
79   // For CUDA tensors, it inserts stream synchronization such that
80   // the streams of the caller wait for completion of the
81   // asynchronous operations on the destination tensors.
82   //
83   // For CPU tensors, it is currently a nop.
84   //
85   // This function should only be used if the caller polls for
86   // completion through the `isCompleted` function, it has returned
87   // true, and the `isSuccess` function also has returned true.
88   //
89   virtual void synchronize();
90 
91   // Waits until request completes. Blocking operation.
92   // Throws if the work completed with an exception.
93   // Returns false if the work is aborted.
94   // Otherwise, it always returns true, indicating the work is completed.
95   //
96   // Functionally equivalent to:
97   //
98   //   while (!isCompleted()) { /* nop */ }
99   //   auto success = isSuccess();
100   //   if (!success) { std::rethrow_exception(exception()); }
101   //   return success;
102   //
103   virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout);
104 
105   virtual void abort();
106 
107   // Returns a Future object that will be associated with the completion of
108   // work. Only NCCL backend is currently supported.
109   virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
110 
111   virtual float getDuration() const;
112 
113   virtual uint64_t getSequencenumber() const;
114 
115   OpType retrieveOpType() const;
116 
117   static c10::intrusive_ptr<Work> create_from_future(
118       const c10::intrusive_ptr<c10::ivalue::Future>&);
119 
120  protected:
121   // Completes the work object and optionally sets the exception in a
122   // thread-safe manner. Notifies all waiting condition variables as well.
123   void finish(std::exception_ptr exception = nullptr);
124 
125   // Similar to finish, but throws an exception if one is already set or
126   // provided by the user.
127   void finishAndThrow(std::exception_ptr exception);
128 
129   mutable std::mutex mutex_;
130   std::condition_variable cv_;
131   bool completed_ = false;
132   std::exception_ptr exception_;
133 
134   // Current rank of the node.
135   const int rank_;
136 
137   // Operation type that this work object refers to.
138   OpType opType_;
139 
140   // When profiling, the callback to record end of operation event. This
141   // callback needs to be called when collective operation is complete.
142   std::function<void()> recordFunctionEndCallback_;
143 };
144 
145 struct TORCH_API WorkInfo {
WorkInfoc10d::WorkInfo146   WorkInfo(
147       const OpType& opType,
148       const uint64_t seq,
149       const std::chrono::time_point<std::chrono::system_clock>& timeStarted,
150       const std::chrono::time_point<std::chrono::system_clock>& timeFinished,
151       const std::chrono::duration<float>& activeDuration)
152       : opType(opType),
153         seq(seq),
154         timeStarted(timeStarted),
155         timeFinished(timeFinished),
156         activeDuration(activeDuration) {}
157 
158   OpType opType;
159   uint64_t seq;
160   std::chrono::time_point<std::chrono::system_clock> timeStarted;
161   std::chrono::time_point<std::chrono::system_clock> timeFinished;
162   std::chrono::duration<float> activeDuration;
163 };
164 
165 } // namespace c10d
166