xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/default_comm_hooks.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
4 #include <torch/csrc/distributed/c10d/comm.hpp>
5 
6 namespace c10d {
7 
8 enum class BuiltinCommHookType : uint8_t {
9   ALLREDUCE = 1,
10   FP16_COMPRESS = 2,
11 };
12 
13 class AllReduceCommHook
14     : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
15  public:
AllReduceCommHook(const c10::intrusive_ptr<ProcessGroup> & state)16   explicit AllReduceCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
17       : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
18 
19   ~AllReduceCommHook() override = default;
20 
21   c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
22 };
23 
24 class FP16CompressCommHook
25     : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
26  public:
FP16CompressCommHook(const c10::intrusive_ptr<ProcessGroup> & state)27   explicit FP16CompressCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
28       : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
29 
30   ~FP16CompressCommHook() override = default;
31 
32   c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
33 };
34 
35 // Almost same as AllReduceCommHook, but without division inside the hook.
36 // This enables the optimization of fusing copy and division and saves one scan
37 // over all the input parameters, when no communication hook is provided by the
38 // user. Only used internally and not released as a public built-in
39 // communication hook.
40 class _AllReduceBySumCommHook
41     : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
42  public:
_AllReduceBySumCommHook(const c10::intrusive_ptr<ProcessGroup> & state)43   explicit _AllReduceBySumCommHook(
44       const c10::intrusive_ptr<ProcessGroup>& state)
45       : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
46 
47   ~_AllReduceBySumCommHook() override = default;
48 
49   c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
50 };
51 
52 } // namespace c10d
53