xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/python_comm_hook.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/c10d/comm.hpp>
4 
5 #include <ATen/ATen.h>
6 #include <ATen/core/ivalue.h>
7 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
8 #include <torch/csrc/utils/pybind.h>
9 
10 namespace c10d {
11 
12 class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
13  public:
14   // Takes a state and a callable hook. The inputs are Python objects.
15   // The state is passed to the hook in runHook method, and it can be used to
16   // maintain and update any state information during the execution of the hook.
17   // The hook performs user-specified processing and returns a future indicating
18   // asychronous communication of gradients.
PythonCommHook(py::object state,py::object hook)19   PythonCommHook(py::object state, py::object hook)
20       : state_(std::move(state)), hook_(std::move(hook)) {}
21 
22   ~PythonCommHook() override;
23 
24   c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
25 
26   at::Tensor parseHookResult(const c10::IValue& result) override;
27 
28  private:
29   // Only needed for stateful communication.
30   py::object state_;
31   py::object hook_;
32 };
33 
34 } // namespace c10d
35