xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/tensorpipe_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_TENSORPIPE
4 
5 #include <torch/csrc/distributed/rpc/utils.h>
6 
7 namespace tensorpipe {
8 class Message;
9 class Allocation;
10 class Descriptor;
11 } // namespace tensorpipe
12 
13 namespace torch::distributed::rpc {
14 
15 TORCH_API const c10::Stream& getStreamForDevice(
16     const std::vector<c10::Stream>& streams,
17     const c10::Device& device);
18 
19 // Inspired by c10/core/impl/DeviceGuardImplInterface.h.
20 
21 class TensorpipeDeviceTypeConverter {
22  public:
23   // Ideally we'd want this to also return a tensorpipe::Message::Tensor object
24   // but we cannot forward-declare that class (because it's nested), and we
25   // cannot include the TensorPipe headers because it's a private dependency.
26   // Thus we bend over backwards and entrust this method with appending that
27   // object to the `tensors` field of the tensorpipe::Message object we pass.
28   virtual std::optional<std::vector<char>> prepareTensorForSending(
29       const c10::Storage& storage,
30       const std::vector<c10::Stream>& streams,
31       tensorpipe::Message& message) const = 0;
32 
33   // Same as above: this method cannot return a tensorpipe::Allocation::Tensor,
34   // thus it appends it to the `tensors` field of the tensorpipe::Allocation.
35   virtual at::DataPtr allocateTensorForReceiving(
36       c10::DeviceIndex deviceIndex,
37       size_t length,
38       const std::vector<c10::Stream>& streams,
39       tensorpipe::Allocation& allocation) const = 0;
40 
41   virtual ~TensorpipeDeviceTypeConverter() = default;
42 };
43 
44 extern TORCH_API std::array<
45     std::atomic<const TensorpipeDeviceTypeConverter*>,
46     static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
47     device_type_converter_registry;
48 
49 class TORCH_API TensorpipeDeviceTypeConverterRegistrar {
50  public:
51   TensorpipeDeviceTypeConverterRegistrar(
52       DeviceType,
53       const TensorpipeDeviceTypeConverter*);
54 };
55 
56 #define C10_REGISTER_TENSORPIPE_DEVICE_TYPE_CONVERTER(                     \
57     DevType, TensorpipeDeviceTypeConverter)                                \
58   static ::torch::distributed::rpc::TensorpipeDeviceTypeConverterRegistrar \
59       C10_ANONYMOUS_VARIABLE(g_##DeviceType)(                              \
60           ::c10::DeviceType::DevType, new TensorpipeDeviceTypeConverter());
61 
getDeviceTypeConverter(DeviceType type)62 inline const TensorpipeDeviceTypeConverter* getDeviceTypeConverter(
63     DeviceType type) {
64   return device_type_converter_registry[static_cast<size_t>(type)].load();
65 }
66 
67 // A struct that holds pointers that keep alive all the memory that will be
68 // accessed by TensorPipe during a write operation.
69 struct TensorpipeWriteBuffers {
70   // Allocate on heap so pointers stay valid as we move the holder.
71   std::unique_ptr<MessageType> type;
72   std::unique_ptr<int64_t> id;
73   std::vector<char> payload;
74   std::vector<char> pickle;
75   // This contains the original tensors and the clones of the sparse tensors.
76   std::vector<torch::Tensor> tensors;
77   // This contains the copies of the data of the tensors that didn't own their
78   // memory, e.g., the ones created from torch::from_blob() with no deleter.
79   std::vector<std::vector<char>> copiedTensors;
80 };
81 
82 // A struct that holds pointers that keep alive all the memory that will be
83 // accessed by TensorPipe during a read operation.
84 struct TensorpipeReadBuffers {
85   // Allocate on heap so pointers stay valid as we move the holder.
86   std::unique_ptr<MessageType> type;
87   std::unique_ptr<int64_t> id;
88   std::vector<char> payload;
89   std::vector<char> pickle;
90   std::vector<c10::DataPtr> tensors;
91 };
92 
93 // Convert an RPC message into a TensorPipe message, plus a holder to all the
94 // data that must be kept alive while the write is performed asynchronously.
95 TORCH_API std::tuple<tensorpipe::Message, TensorpipeWriteBuffers>
96 tensorpipeSerialize(
97     const c10::intrusive_ptr<Message>& rpcMessage,
98     std::vector<c10::Device> devices,
99     const std::vector<c10::Stream>& streams);
100 
101 // Allocate the buffers that will hold the incoming data. They will be managed
102 // by the returned holder, which must be kept alive until the asynchronous read
103 // has finished. Pointers to these buffers will be stored in the returned
104 // tensorpipe::Allocation struct.
105 TORCH_API std::pair<tensorpipe::Allocation, TensorpipeReadBuffers>
106 tensorpipeAllocate(
107     const tensorpipe::Descriptor& tpDescriptor,
108     const std::vector<c10::Stream>& streams);
109 
110 // Convert a TensorPipe message back into an RPC message. This requires the data
111 // to be available and can thus only be performed once the asynchronous read has
112 // completed. The holder can be destroyed once this function returns.
113 TORCH_API c10::intrusive_ptr<Message> tensorpipeDeserialize(
114     tensorpipe::Descriptor&& tpDescriptor,
115     TensorpipeReadBuffers&& holder);
116 
117 } // namespace torch::distributed::rpc
118 
119 #endif // USE_TENSORPIPE
120