1 #pragma once 2 3 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> 4 #include <torch/csrc/distributed/c10d/comm.hpp> 5 6 namespace c10d { 7 8 enum class BuiltinCommHookType : uint8_t { 9 ALLREDUCE = 1, 10 FP16_COMPRESS = 2, 11 }; 12 13 class AllReduceCommHook 14 : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> { 15 public: AllReduceCommHook(const c10::intrusive_ptr<ProcessGroup> & state)16 explicit AllReduceCommHook(const c10::intrusive_ptr<ProcessGroup>& state) 17 : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {} 18 19 ~AllReduceCommHook() override = default; 20 21 c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override; 22 }; 23 24 class FP16CompressCommHook 25 : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> { 26 public: FP16CompressCommHook(const c10::intrusive_ptr<ProcessGroup> & state)27 explicit FP16CompressCommHook(const c10::intrusive_ptr<ProcessGroup>& state) 28 : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {} 29 30 ~FP16CompressCommHook() override = default; 31 32 c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override; 33 }; 34 35 // Almost same as AllReduceCommHook, but without division inside the hook. 36 // This enables the optimization of fusing copy and division and saves one scan 37 // over all the input parameters, when no communication hook is provided by the 38 // user. Only used internally and not released as a public built-in 39 // communication hook. 40 class _AllReduceBySumCommHook 41 : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> { 42 public: _AllReduceBySumCommHook(const c10::intrusive_ptr<ProcessGroup> & state)43 explicit _AllReduceBySumCommHook( 44 const c10::intrusive_ptr<ProcessGroup>& state) 45 : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {} 46 47 ~_AllReduceBySumCommHook() override = default; 48 49 c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override; 50 }; 51 52 } // namespace c10d 53