1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <chrono> 5 #include <mutex> 6 #include <vector> 7 8 constexpr auto kNoTimeout = std::chrono::milliseconds(0); 9 10 namespace c10d { 11 12 constexpr const char* const kSeqNumStoreKey = "SEQ_NUM_STORE_KEY"; 13 14 enum class OpType : std::uint8_t { 15 BROADCAST = 0, 16 ALLREDUCE = 1, 17 ALLREDUCE_COALESCED = 2, 18 REDUCE = 3, 19 ALLGATHER = 4, 20 _ALLGATHER_BASE = 5, 21 ALLGATHER_COALESCED = 6, 22 GATHER = 7, 23 SCATTER = 8, 24 REDUCE_SCATTER = 9, 25 ALLTOALL_BASE = 10, 26 ALLTOALL = 11, 27 SEND = 12, 28 RECV = 13, 29 RECVANYSOURCE = 14, 30 BARRIER = 15, 31 _REDUCE_SCATTER_BASE = 16, 32 COALESCED = 17, 33 _ALLREDUCE_SPARSE = 18, 34 UNKNOWN = 100, 35 }; 36 37 // Converts OpType to human readable string. 38 TORCH_API std::string opTypeToString(OpType opType); 39 40 // Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE) 41 TORCH_API bool isP2POp(OpType opType, bool batchP2P = false); 42 43 // Please do not use Work API, it is going away, to be 44 // replaced by ivalue::Future. 45 // Python binding for this class might change, please do not assume 46 // this will be bound using pybind. 47 class TORCH_API Work : public torch::CustomClassHolder { 48 public: 49 Work( 50 int rank = -1, 51 OpType opType = OpType::UNKNOWN, 52 const char* profilingTitle = nullptr, 53 const std::optional<std::vector<at::Tensor>>& inputTensors = 54 std::nullopt); 55 56 ~Work() override; 57 58 // Checks if request has completed. Non-blocking operation. 59 virtual bool isCompleted(); 60 61 // Returns if the work completed successfully. 62 // If false, the exception function can be called to get details. 63 virtual bool isSuccess() const; 64 65 // Returns exception if isSuccess() returned false. 66 virtual std::exception_ptr exception() const; 67 68 // Returns source rank if this objects represents a recv-from-any. 69 virtual int sourceRank() const; 70 71 // Returns result tensors, if applicable. 72 // If work is not supposed to have result, we return empty list. 73 virtual std::vector<at::Tensor> result(); 74 75 // Ensures that operations on the output tensors that are invoked 76 // after this function returns are correctly sequenced after the 77 // asynchronous completion of this work. 78 // 79 // For CUDA tensors, it inserts stream synchronization such that 80 // the streams of the caller wait for completion of the 81 // asynchronous operations on the destination tensors. 82 // 83 // For CPU tensors, it is currently a nop. 84 // 85 // This function should only be used if the caller polls for 86 // completion through the `isCompleted` function, it has returned 87 // true, and the `isSuccess` function also has returned true. 88 // 89 virtual void synchronize(); 90 91 // Waits until request completes. Blocking operation. 92 // Throws if the work completed with an exception. 93 // Returns false if the work is aborted. 94 // Otherwise, it always returns true, indicating the work is completed. 95 // 96 // Functionally equivalent to: 97 // 98 // while (!isCompleted()) { /* nop */ } 99 // auto success = isSuccess(); 100 // if (!success) { std::rethrow_exception(exception()); } 101 // return success; 102 // 103 virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout); 104 105 virtual void abort(); 106 107 // Returns a Future object that will be associated with the completion of 108 // work. Only NCCL backend is currently supported. 109 virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture(); 110 111 virtual float getDuration() const; 112 113 virtual uint64_t getSequencenumber() const; 114 115 OpType retrieveOpType() const; 116 117 static c10::intrusive_ptr<Work> create_from_future( 118 const c10::intrusive_ptr<c10::ivalue::Future>&); 119 120 protected: 121 // Completes the work object and optionally sets the exception in a 122 // thread-safe manner. Notifies all waiting condition variables as well. 123 void finish(std::exception_ptr exception = nullptr); 124 125 // Similar to finish, but throws an exception if one is already set or 126 // provided by the user. 127 void finishAndThrow(std::exception_ptr exception); 128 129 mutable std::mutex mutex_; 130 std::condition_variable cv_; 131 bool completed_ = false; 132 std::exception_ptr exception_; 133 134 // Current rank of the node. 135 const int rank_; 136 137 // Operation type that this work object refers to. 138 OpType opType_; 139 140 // When profiling, the callback to record end of operation event. This 141 // callback needs to be called when collective operation is complete. 142 std::function<void()> recordFunctionEndCallback_; 143 }; 144 145 struct TORCH_API WorkInfo { WorkInfoc10d::WorkInfo146 WorkInfo( 147 const OpType& opType, 148 const uint64_t seq, 149 const std::chrono::time_point<std::chrono::system_clock>& timeStarted, 150 const std::chrono::time_point<std::chrono::system_clock>& timeFinished, 151 const std::chrono::duration<float>& activeDuration) 152 : opType(opType), 153 seq(seq), 154 timeStarted(timeStarted), 155 timeFinished(timeFinished), 156 activeDuration(activeDuration) {} 157 158 OpType opType; 159 uint64_t seq; 160 std::chrono::time_point<std::chrono::system_clock> timeStarted; 161 std::chrono::time_point<std::chrono::system_clock> timeFinished; 162 std::chrono::duration<float> activeDuration; 163 }; 164 165 } // namespace c10d 166