xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/tensorpipe_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
2 
3 #ifdef USE_TENSORPIPE
4 
5 #include <c10/util/irange.h>
6 #include <limits>
7 
8 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated")
9 #include <tensorpipe/tensorpipe.h>
10 C10_DIAGNOSTIC_POP()
11 
12 namespace torch::distributed::rpc {
13 namespace {
14 
15 // The TensorPipe agent splits the RPC message's information across multiple
16 // payloads. This allows the agent to provide the data to TensorPipe without
17 // performing a copy into a single contiguous buffer, and without storing it as
18 // metadata, which is less efficient.
19 
20 // First come the rpc::Message::type() and ::id().
21 constexpr int kTpMessageTypeIdx = 0;
22 constexpr int kTpMessageIdIdx = 1;
23 // Then comes the rpc::Message::payload();
24 constexpr int kTpMessagePayloadIdx = 2;
25 // Last comes the pickle of rpc::Message::tensors() (with the tensors themselves
26 // stored as, well, tensors in the tensorpipe::Message).
27 constexpr int kTpMessagePickleIdx = 3;
28 
indexToDevice(c10::DeviceIndex index)29 inline c10::Device indexToDevice(c10::DeviceIndex index) {
30   if (index == -1) {
31     return c10::Device(at::kCPU);
32   } else {
33     return c10::Device(at::kCUDA, index);
34   }
35 }
36 
37 class TensorpipeCpuConverter : public TensorpipeDeviceTypeConverter {
38  public:
prepareTensorForSending(const c10::Storage & storage,const std::vector<c10::Stream> &,tensorpipe::Message & message) const39   std::optional<std::vector<char>> prepareTensorForSending(
40       const c10::Storage& storage,
41       const std::vector<c10::Stream>& /* streams */,
42       tensorpipe::Message& message) const override {
43     // Enforce memory copy if tensor is created from torch::from_blob, means
44     // that the tensor doesn't own the memory.
45     bool storageHasDeleter = storage.data_ptr().get_context() != nullptr;
46     if (!storageHasDeleter) {
47       std::vector<char> storageData(
48           static_cast<const char*>(storage.data()),
49           static_cast<const char*>(storage.data()) + storage.nbytes());
50 
51       tensorpipe::CpuBuffer buffer;
52       buffer.ptr = storageData.data();
53 
54       tensorpipe::Message::Tensor tensor;
55       tensor.buffer = buffer;
56       tensor.length = storageData.size();
57 
58       message.tensors.push_back(std::move(tensor));
59 
60       return std::make_optional(std::move(storageData));
61     } else {
62       tensorpipe::CpuBuffer buffer;
63       buffer.ptr = static_cast<char*>(storage.mutable_data());
64 
65       tensorpipe::Message::Tensor tensor;
66       tensor.buffer = buffer;
67       tensor.length = storage.nbytes();
68 
69       message.tensors.push_back(std::move(tensor));
70 
71       return std::nullopt;
72     }
73   }
74 
allocateTensorForReceiving(c10::DeviceIndex,size_t length,const std::vector<c10::Stream> &,tensorpipe::Allocation & allocation) const75   at::DataPtr allocateTensorForReceiving(
76       c10::DeviceIndex /* deviceIndex */,
77       size_t length,
78       const std::vector<c10::Stream>& /* streams */,
79       tensorpipe::Allocation& allocation) const override {
80     at::DataPtr dataPtr = at::getCPUAllocator()->allocate(length);
81 
82     tensorpipe::CpuBuffer buffer;
83     buffer.ptr = dataPtr.get();
84 
85     tensorpipe::Allocation::Tensor tensor;
86     tensor.buffer = buffer;
87 
88     allocation.tensors.push_back(std::move(tensor));
89 
90     return dataPtr;
91   }
92 };
93 
94 C10_REGISTER_TENSORPIPE_DEVICE_TYPE_CONVERTER(CPU, TensorpipeCpuConverter);
95 
convertDeviceType(const std::string & tpDeviceType)96 c10::DeviceType convertDeviceType(const std::string& tpDeviceType) {
97   if (tpDeviceType == tensorpipe::kCpuDeviceType) {
98     return c10::kCPU;
99   } else if (tpDeviceType == tensorpipe::kCudaDeviceType) {
100     return c10::kCUDA;
101   } else {
102     TORCH_INTERNAL_ASSERT(false, "Unrecognized TensorPipe buffer type.");
103   }
104 }
105 
106 } // namespace
107 
108 // As the vector of streams will typically be very small (1-8 items) we expect
109 // a linear search to be as fast (or faster?) than if we used a hashmap.
getStreamForDevice(const std::vector<c10::Stream> & streams,const c10::Device & device)110 const c10::Stream& getStreamForDevice(
111     const std::vector<c10::Stream>& streams,
112     const c10::Device& device) {
113   for (const c10::Stream& stream : streams) {
114     if (stream.device() == device) {
115       return stream;
116     }
117   }
118   TORCH_INTERNAL_ASSERT(false, "No stream found for device ", device);
119 }
120 
121 std::array<
122     std::atomic<const TensorpipeDeviceTypeConverter*>,
123     static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
124     device_type_converter_registry;
125 
TensorpipeDeviceTypeConverterRegistrar(DeviceType type,const TensorpipeDeviceTypeConverter * impl)126 TensorpipeDeviceTypeConverterRegistrar::TensorpipeDeviceTypeConverterRegistrar(
127     DeviceType type,
128     const TensorpipeDeviceTypeConverter* impl) {
129   device_type_converter_registry[static_cast<size_t>(type)].store(impl);
130 }
131 
tensorpipeSerialize(const c10::intrusive_ptr<Message> & rpcMessage,std::vector<c10::Device> devices,const std::vector<c10::Stream> & streams)132 std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
133     const c10::intrusive_ptr<Message>& rpcMessage,
134     std::vector<c10::Device> devices,
135     const std::vector<c10::Stream>& streams) {
136   tensorpipe::Message tpMessage;
137   TensorpipeWriteBuffers buffers;
138 
139   // Metadata
140   buffers.type = std::make_unique<MessageType>(rpcMessage->type());
141   buffers.id = std::make_unique<int64_t>(rpcMessage->id());
142   // kTpMessageTypeIdx = 0
143   tpMessage.payloads.push_back(
144       tensorpipe::Message::Payload{buffers.type.get(), sizeof(MessageType)});
145   // kTpMessageIdIdx = 1
146   tpMessage.payloads.push_back(
147       tensorpipe::Message::Payload{buffers.id.get(), sizeof(int64_t)});
148 
149   // Payload
150   buffers.payload = std::move(rpcMessage->payload());
151   // TensorPipe uses the same Message class for both reading and writing, thus
152   // it uses non-const pointers even though it doesn't modify them when writing.
153   char* payloadPtr = buffers.payload.data();
154   // kTpMessagePayloadIdx = 2
155   tpMessage.payloads.push_back(
156       tensorpipe::Message::Payload{payloadPtr, buffers.payload.size()});
157 
158   {
159     // The function below might allocate new tensors if there are Tensor views.
160     // Apply stream guard here to include those Tensor allocation operations to
161     // the streams.
162     c10::MultiStreamGuard guard(streams);
163     // Tensors
164     buffers.tensors = cloneSparseTensors(rpcMessage->tensors()).vec();
165   }
166 
167   torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t {
168     buffers.pickle.insert(
169         buffers.pickle.end(),
170         static_cast<const char*>(buf),
171         static_cast<const char*>(buf) + sz);
172     return sz;
173   });
174   pickler.protocol();
175   pickler.pushIValue(buffers.tensors);
176   pickler.stop();
177   // kTpMessagePickleIdx = 3
178   tpMessage.payloads.push_back(tensorpipe::Message::Payload{
179       buffers.pickle.data(), buffers.pickle.size()});
180   const std::vector<torch::Tensor>& tensorDataVec = pickler.tensorData();
181   tpMessage.tensors.reserve(tensorDataVec.size());
182   for (const auto i : c10::irange(tensorDataVec.size())) {
183     const torch::Tensor& tensor = tensorDataVec[i];
184 
185     const TensorpipeDeviceTypeConverter* converter =
186         getDeviceTypeConverter(tensor.device().type());
187     TORCH_CHECK(
188         converter != nullptr,
189         "Attempting to send a Tensor with unexpected device type ",
190         tensor.device());
191 
192     TORCH_INTERNAL_ASSERT(tpMessage.tensors.size() == i);
193     std::optional<std::vector<char>> maybeCopiedTensor =
194         converter->prepareTensorForSending(
195             tensor.storage(), streams, tpMessage);
196     TORCH_INTERNAL_ASSERT(tpMessage.tensors.size() == i + 1);
197 
198     tensorpipe::Device targetDevice = devices.empty() || devices[i].is_cpu()
199         ? tensorpipe::Device{tensorpipe::kCpuDeviceType, 0}
200         : tensorpipe::Device{tensorpipe::kCudaDeviceType, devices[i].index()};
201     tpMessage.tensors.back().targetDevice = std::move(targetDevice);
202 
203     if (maybeCopiedTensor.has_value()) {
204       buffers.copiedTensors.push_back(std::move(maybeCopiedTensor).value());
205     }
206   }
207 
208   return std::make_tuple(std::move(tpMessage), std::move(buffers));
209 }
210 
tensorpipeAllocate(const tensorpipe::Descriptor & tpDescriptor,const std::vector<c10::Stream> & streams)211 std::pair<tensorpipe::Allocation, TensorpipeReadBuffers> tensorpipeAllocate(
212     const tensorpipe::Descriptor& tpDescriptor,
213     const std::vector<c10::Stream>& streams) {
214   tensorpipe::Allocation tpAllocation;
215   TensorpipeReadBuffers buffers;
216 
217   TORCH_INTERNAL_ASSERT(
218       tpDescriptor.payloads.size() == 4,
219       "message expected to contain 4 payloads, whereas it contained ",
220       tpDescriptor.payloads.size(),
221       " payloads");
222   tpAllocation.payloads.resize(tpDescriptor.payloads.size());
223 
224   TORCH_INTERNAL_ASSERT(
225       tpDescriptor.payloads[kTpMessageTypeIdx].length == sizeof(MessageType),
226       "first payload expected to contain ",
227       sizeof(MessageType),
228       " bytes, whereas it contained ",
229       tpDescriptor.payloads[kTpMessageTypeIdx].length,
230       " bytes");
231   buffers.type = std::make_unique<MessageType>();
232   tpAllocation.payloads[kTpMessageTypeIdx].data = buffers.type.get();
233 
234   TORCH_INTERNAL_ASSERT(
235       tpDescriptor.payloads[kTpMessageIdIdx].length == sizeof(int64_t),
236       "second payload expected to contain ",
237       sizeof(int64_t),
238       " bytes, whereas it contained ",
239       tpDescriptor.payloads[kTpMessageIdIdx].length,
240       " bytes");
241   buffers.id = std::make_unique<int64_t>();
242   tpAllocation.payloads[kTpMessageIdIdx].data = buffers.id.get();
243 
244   // FIXME The two resizes below zero out the vectors, which is not needed.
245   buffers.payload.resize(tpDescriptor.payloads[kTpMessagePayloadIdx].length);
246   tpAllocation.payloads[kTpMessagePayloadIdx].data = buffers.payload.data();
247 
248   buffers.pickle.resize(tpDescriptor.payloads[kTpMessagePickleIdx].length);
249   tpAllocation.payloads[kTpMessagePickleIdx].data = buffers.pickle.data();
250 
251   size_t numTensors = tpDescriptor.tensors.size();
252   tpAllocation.tensors.reserve(numTensors);
253   for (const auto tensorIdx : c10::irange(numTensors)) {
254     const tensorpipe::Descriptor::Tensor& tensor =
255         tpDescriptor.tensors[tensorIdx];
256     TORCH_INTERNAL_ASSERT(tensor.targetDevice.has_value());
257     c10::DeviceType targetDeviceType =
258         convertDeviceType(tensor.targetDevice->type);
259 
260     const TensorpipeDeviceTypeConverter* converter =
261         getDeviceTypeConverter(targetDeviceType);
262     TORCH_INTERNAL_ASSERT(
263         converter != nullptr,
264         "Attempting to receive a Tensor with unexpected device type ",
265         targetDeviceType);
266 
267     TORCH_INTERNAL_ASSERT(tpAllocation.tensors.size() == tensorIdx);
268     TORCH_INTERNAL_ASSERT(
269         tensor.targetDevice->index <=
270         std::numeric_limits<c10::DeviceIndex>::max());
271     at::DataPtr dataPtr = converter->allocateTensorForReceiving(
272         static_cast<c10::DeviceIndex>(tensor.targetDevice->index),
273         tensor.length,
274         streams,
275         tpAllocation);
276     TORCH_INTERNAL_ASSERT(tpAllocation.tensors.size() == tensorIdx + 1);
277 
278     buffers.tensors.push_back(std::move(dataPtr));
279   }
280 
281   return {std::move(tpAllocation), std::move(buffers)};
282 }
283 
tensorpipeDeserialize(tensorpipe::Descriptor && tpDescriptor,TensorpipeReadBuffers && buffers)284 c10::intrusive_ptr<Message> tensorpipeDeserialize(
285     tensorpipe::Descriptor&& tpDescriptor,
286     TensorpipeReadBuffers&& buffers) {
287   // Tensors
288   std::vector<at::Tensor> tensors;
289   const char* pickleData = buffers.pickle.data();
290   size_t pickleLen = buffers.pickle.size();
291   size_t picklePos = 0;
292   auto pickleReadFunc = [&](char* buf, size_t n) -> size_t {
293     if (picklePos >= pickleLen || n == 0) {
294       return 0;
295     }
296     size_t toCopy = std::min(picklePos + n, pickleLen) - picklePos;
297     memcpy(buf, pickleData + picklePos, toCopy);
298     picklePos += toCopy;
299     return toCopy;
300   };
301   auto tensorReadFunc = [&](const std::string& ename) -> at::DataPtr {
302     unsigned long index = std::stoul(ename);
303     return std::move(buffers.tensors.at(index));
304   };
305 
306   // No need to pass typeResolver here, as it always processes string and
307   // tensors only
308   torch::jit::Unpickler unpickler(
309       pickleReadFunc,
310       nullptr,
311       nullptr,
312       tensorReadFunc,
313       {},
314       /* use_storage_device*/ true);
315 
316   auto ival = unpickler.parse_ivalue();
317   for (auto&& t : ival.toTensorList()) {
318     tensors.emplace_back(std::move(t));
319   }
320 
321   for (const auto i : c10::irange(tpDescriptor.tensors.size())) {
322     auto& tensor = tpDescriptor.tensors[i];
323     if (tensor.targetDevice.has_value() &&
324         tensor.targetDevice->type == tensorpipe::kCudaDeviceType) {
325       TORCH_INTERNAL_ASSERT(
326           tensors[i].device() == indexToDevice(tensor.targetDevice->index),
327           "Tensor ",
328           i,
329           " in message ",
330           *buffers.id,
331           " was expected to be received on device ",
332           tensor.targetDevice->index,
333           ", but got it on ",
334           tensors[i].device());
335     }
336   }
337 
338   return c10::make_intrusive<Message>(
339       std::move(buffers.payload),
340       std::move(tensors),
341       *buffers.type,
342       *buffers.id);
343 }
344 } // namespace torch::distributed::rpc
345 
346 #endif // USE_TENSORPIPE
347