xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/request_callback.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/message.h>
4 
5 namespace torch::distributed::rpc {
6 
7 // Functor which is invoked to process an RPC message. This is an abstract class
8 // with some common functionality across all request handlers. Users need to
9 // implement this interface to perform the actual business logic.
10 class TORCH_API RequestCallback {
11  public:
12   // Invoke the callback.
13   c10::intrusive_ptr<JitFuture> operator()(
14       Message& request,
15       std::vector<c10::Stream> streams) const;
16 
17   virtual ~RequestCallback() = default;
18 
19  protected:
20   // RpcAgent implementation should invoke ``RequestCallback`` to process
21   // received requests. There is no restriction on the implementation's
22   // threading model. This function takes an rvalue reference of the Message
23   // object. It is expected to return the future to a response message or
24   // message containing an exception. Different rpc agent implementations are
25   // expected to ensure delivery of the response/exception based on their
26   // implementation specific mechanisms.
27   virtual c10::intrusive_ptr<JitFuture> processMessage(
28       Message& request,
29       std::vector<c10::Stream> streams) const = 0;
30 };
31 
32 } // namespace torch::distributed::rpc
33