xref: /aosp_15_r20/external/federated-compute/fcp/protocol/grpc_chunked_bidi_stream.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_PROTOCOL_GRPC_CHUNKED_BIDI_STREAM_H_
18 #define FCP_PROTOCOL_GRPC_CHUNKED_BIDI_STREAM_H_
19 
20 #include <stddef.h>
21 
22 #include <algorithm>
23 #include <deque>
24 #include <memory>
25 #include <string>
26 
27 #include "absl/base/attributes.h"
28 #include "absl/status/status.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/protos/federated_api.grpc.pb.h"
31 #include "grpcpp/impl/codegen/call_op_set.h"
32 #include "grpcpp/impl/codegen/sync_stream.h"
33 #include "google/protobuf/io/gzip_stream.h"
34 #include "google/protobuf/io/zero_copy_stream_impl_lite.h"
35 
36 namespace fcp {
37 namespace client {
38 
39 /**
40  * A class which implements the chunking protocol for the federated learning
41  * API.
42  *
43  * Can be used by both client and server.
44  *
45  * @tparam Outgoing The type of the outgoing protocol buffer message.
46  * @tparam Incoming The type of the incoming protocol buffer message.
47  */
48 template <typename Outgoing, typename Incoming>
49 class GrpcChunkedBidiStream {
50  public:
51   struct GrpcChunkedBidiStreamOptions {
52     int32_t chunk_size_for_upload = -1;
53     int32_t max_pending_chunks = -1;
54     google::internal::federatedml::v2::CompressionLevel compression_level{};
55   };
56   GrpcChunkedBidiStream(
57       grpc::internal::WriterInterface<Outgoing>* writer_interface,
58       grpc::internal::ReaderInterface<Incoming>* reader_interface);
59   GrpcChunkedBidiStream(
60       grpc::internal::WriterInterface<Outgoing>* writer_interface,
61       grpc::internal::ReaderInterface<Incoming>* reader_interface,
62       GrpcChunkedBidiStreamOptions options);
63   virtual ~GrpcChunkedBidiStream() = default;
64 
65   // GrpcChunkedBidiStream is neither copyable nor movable.
66   GrpcChunkedBidiStream(const GrpcChunkedBidiStream&) = delete;
67   GrpcChunkedBidiStream& operator=(const GrpcChunkedBidiStream&) = delete;
68 
69   ABSL_MUST_USE_RESULT absl::Status Send(Outgoing* message);
70   ABSL_MUST_USE_RESULT absl::Status Receive(Incoming* message);
71   void Close();
72   int64_t ChunkingLayerBytesSent();
73   int64_t ChunkingLayerBytesReceived();
74 
75  private:
76   ABSL_MUST_USE_RESULT absl::Status TryDecorateCheckinRequest(
77       Outgoing* message);
78   ABSL_MUST_USE_RESULT absl::Status ChunkMessage(const Outgoing& message);
79   ABSL_MUST_USE_RESULT absl::Status TrySendPending();
80   ABSL_MUST_USE_RESULT absl::Status TrySend(const Outgoing& message);
81   ABSL_MUST_USE_RESULT absl::Status SendAck(int32_t chunk_index);
82   ABSL_MUST_USE_RESULT absl::Status SendRaw(const Outgoing& message,
83                                             bool disable_compression = false);
84   ABSL_MUST_USE_RESULT absl::Status TrySnoopCheckinResponse(Incoming* message);
85   ABSL_MUST_USE_RESULT absl::Status TryAssemblePending(Incoming* message,
86                                                        bool* message_assembled);
87   ABSL_MUST_USE_RESULT absl::Status AssemblePending(Incoming* message,
88                                                     bool* message_assembled);
89   ABSL_MUST_USE_RESULT absl::Status ReceiveRaw(Incoming* message);
90 
91   grpc::internal::WriterInterface<Outgoing>* writer_interface_;
92   grpc::internal::ReaderInterface<Incoming>* reader_interface_;
93 
94   struct {
95     int32_t uncompressed_size = -1;
96     google::internal::federatedml::v2::CompressionLevel compression_level{};
97     int32_t blob_size_bytes = -1;
98     std::deque<std::string> deque;
99     std::string composite;
100     int64_t total_bytes_downloaded = 0;
101   } incoming_;
102 
103   struct {
104     int32_t chunk_size_for_upload = 0;
105     int32_t max_pending_chunks = 0;
106     int32_t pending_chunks = 0;
107     google::internal::federatedml::v2::CompressionLevel compression_level{};
108     std::deque<std::unique_ptr<Outgoing>> deque;
109     int64_t total_bytes_uploaded = 0;
110 
Add__anon6d200f920208111     google::internal::federatedml::v2::ChunkedTransferMessage* Add() {
112       deque.push_back(std::make_unique<Outgoing>());
113       return deque.back()->mutable_chunked_transfer();
114     }
115   } outgoing_;
116 };
117 
118 #define COMMON_USING_DIRECTIVES                                    \
119   using google::internal::federatedml::v2::ChunkedTransferMessage; \
120   using google::internal::federatedml::v2::ClientStreamMessage;    \
121   using google::internal::federatedml::v2::CompressionLevel;       \
122   using google::internal::federatedml::v2::ServerStreamMessage;    \
123   using google::protobuf::io::ArrayInputStream;                              \
124   using google::protobuf::io::StringOutputStream;                            \
125   using google::protobuf::io::GzipInputStream;                               \
126   using google::protobuf::io::GzipOutputStream;                              \
127   using google::protobuf::io::ZeroCopyOutputStream;
128 
129 template <typename Outgoing, typename Incoming>
GrpcChunkedBidiStream(grpc::internal::WriterInterface<Outgoing> * writer_interface,grpc::internal::ReaderInterface<Incoming> * reader_interface)130 GrpcChunkedBidiStream<Outgoing, Incoming>::GrpcChunkedBidiStream(
131     grpc::internal::WriterInterface<Outgoing>* writer_interface,
132     grpc::internal::ReaderInterface<Incoming>* reader_interface)
133     : GrpcChunkedBidiStream(writer_interface, reader_interface,
134                             GrpcChunkedBidiStreamOptions()) {}
135 
136 template <typename Outgoing, typename Incoming>
GrpcChunkedBidiStream(grpc::internal::WriterInterface<Outgoing> * writer_interface,grpc::internal::ReaderInterface<Incoming> * reader_interface,GrpcChunkedBidiStreamOptions options)137 GrpcChunkedBidiStream<Outgoing, Incoming>::GrpcChunkedBidiStream(
138     grpc::internal::WriterInterface<Outgoing>* writer_interface,
139     grpc::internal::ReaderInterface<Incoming>* reader_interface,
140     GrpcChunkedBidiStreamOptions options)
141     : writer_interface_(writer_interface), reader_interface_(reader_interface) {
142   outgoing_.chunk_size_for_upload = options.chunk_size_for_upload;
143   outgoing_.max_pending_chunks = options.max_pending_chunks;
144   outgoing_.compression_level = options.compression_level;
145 }
146 
147 template <typename Outgoing, typename Incoming>
Send(Outgoing * message)148 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::Send(
149     Outgoing* message) {
150   COMMON_USING_DIRECTIVES;
151   FCP_RETURN_IF_ERROR(TryDecorateCheckinRequest(message));
152   switch (message->kind_case()) {
153     case Outgoing::KindCase::kChunkedTransfer:
154       Close();
155       return absl::InvalidArgumentError(
156           absl::StrCat("Message is pre-chunked: ", message->DebugString()));
157     default:
158       break;
159   }
160 
161   return TrySend(*message);
162 }
163 
164 template <typename Outgoing, typename Incoming>
Receive(Incoming * message)165 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::Receive(
166     Incoming* message) {
167   COMMON_USING_DIRECTIVES;
168   Status status;
169   bool message_assembled = false;
170 
171   do {
172     FCP_RETURN_IF_ERROR(status = ReceiveRaw(message));
173     switch (message->kind_case()) {
174       case Incoming::KindCase::kChunkedTransfer:
175         if (message->chunked_transfer().kind_case() ==
176             ChunkedTransferMessage::kAck) {
177           --outgoing_.pending_chunks;
178           FCP_RETURN_IF_ERROR(status = TrySendPending());
179         } else {
180           FCP_RETURN_IF_ERROR(
181               status = TryAssemblePending(message, &message_assembled));
182         }
183         break;
184       default:
185         if (incoming_.uncompressed_size != -1)
186           return absl::InvalidArgumentError("Chunk reassembly in progress.");
187         message_assembled = true;
188         break;
189     }
190   } while (!message_assembled);
191 
192   FCP_RETURN_IF_ERROR(status = TrySnoopCheckinResponse(message));
193   return status;
194 }
195 
196 template <>
197 inline absl::Status
198 GrpcChunkedBidiStream<google::internal::federatedml::v2::ClientStreamMessage,
199                       google::internal::federatedml::v2::ServerStreamMessage>::
TryDecorateCheckinRequest(google::internal::federatedml::v2::ClientStreamMessage * message)200     TryDecorateCheckinRequest(
201         google::internal::federatedml::v2::ClientStreamMessage* message) {
202   COMMON_USING_DIRECTIVES;
203   if (message->kind_case() !=
204           ClientStreamMessage::kEligibilityEvalCheckinRequest &&
205       message->kind_case() != ClientStreamMessage::kCheckinRequest)
206     return absl::OkStatus();
207   // Both an EligibilityEvalCheckinRequest or a CheckinRequest message need to
208   // specify a ProtocolOptionsRequest message.
209   auto options = (message->has_eligibility_eval_checkin_request()
210                       ? message->mutable_eligibility_eval_checkin_request()
211                             ->mutable_protocol_options_request()
212                       : message->mutable_checkin_request()
213                             ->mutable_protocol_options_request());
214   options->set_supports_chunked_blob_transfer(true);
215   options->add_supported_compression_levels(CompressionLevel::UNCOMPRESSED);
216   options->add_supported_compression_levels(CompressionLevel::ZLIB_DEFAULT);
217   options->add_supported_compression_levels(
218       CompressionLevel::ZLIB_BEST_COMPRESSION);
219   options->add_supported_compression_levels(CompressionLevel::ZLIB_BEST_SPEED);
220   return absl::OkStatus();
221 }
222 
223 template <typename Outgoing, typename Incoming>
224 absl::Status
TryDecorateCheckinRequest(Outgoing *)225 GrpcChunkedBidiStream<Outgoing, Incoming>::TryDecorateCheckinRequest(
226     Outgoing*) {
227   return absl::OkStatus();
228 }
229 
230 template <typename Outgoing, typename Incoming>
ChunkMessage(const Outgoing & message)231 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::ChunkMessage(
232     const Outgoing& message) {
233   COMMON_USING_DIRECTIVES;
234 
235   auto start = outgoing_.Add()->mutable_start();
236   start->set_compression_level(outgoing_.compression_level);
237 
238   // TODO(team): Replace with a more efficient serialization mechanism.
239   std::string output;
240   if (outgoing_.compression_level == CompressionLevel::UNCOMPRESSED) {
241     if (!message.AppendToString(&output))
242       return absl::InternalError("Could not append to string.");
243   } else {
244     StringOutputStream string_output_stream(&output);
245     GzipOutputStream::Options options;
246     options.format = GzipOutputStream::ZLIB;
247     switch (outgoing_.compression_level) {
248       case CompressionLevel::ZLIB_DEFAULT:
249         options.compression_level = Z_DEFAULT_COMPRESSION;
250         break;
251       case CompressionLevel::ZLIB_BEST_COMPRESSION:
252         options.compression_level = Z_BEST_COMPRESSION;
253         break;
254       case CompressionLevel::ZLIB_BEST_SPEED:
255         options.compression_level = Z_BEST_SPEED;
256         break;
257       default:
258         Close();
259         return absl::InternalError("Unsupported compression level.");
260     }
261     GzipOutputStream compressed_stream(&string_output_stream, options);
262     if (!message.SerializeToZeroCopyStream(&compressed_stream) ||
263         !compressed_stream.Close())
264       return absl::InvalidArgumentError(
265           absl::StrCat("Failed to serialize message: ",
266                        compressed_stream.ZlibErrorMessage()));
267   }
268 
269   auto blob_size_bytes = static_cast<int32_t>(output.size());
270   int32_t chunk_index = 0;
271   if (!blob_size_bytes) blob_size_bytes = 1;  // Force one empty packet.
272   for (size_t offset = 0; offset < blob_size_bytes;
273        offset += std::min(blob_size_bytes, outgoing_.chunk_size_for_upload),
274               ++chunk_index) {
275     auto data = outgoing_.Add()->mutable_data();
276     data->set_chunk_index(chunk_index);
277     data->set_chunk_bytes(output.substr(
278         offset, static_cast<size_t>(outgoing_.chunk_size_for_upload)));
279   }
280 
281   start->set_uncompressed_size(static_cast<int32_t>(message.ByteSizeLong()));
282   start->set_blob_size_bytes(blob_size_bytes);
283 
284   auto end = outgoing_.Add()->mutable_end();
285   end->set_chunk_count(chunk_index);
286   return absl::OkStatus();
287 }
288 
289 template <typename Outgoing, typename Incoming>
TrySendPending()290 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TrySendPending() {
291   COMMON_USING_DIRECTIVES;
292   auto status = absl::OkStatus();
293   while (!outgoing_.deque.empty() &&
294          outgoing_.pending_chunks < outgoing_.max_pending_chunks) {
295     auto& front = outgoing_.deque.front();
296     FCP_RETURN_IF_ERROR(status =
297                             SendRaw(*front, outgoing_.compression_level > 0));
298     if (front->chunked_transfer().kind_case() == ChunkedTransferMessage::kData)
299       ++outgoing_.pending_chunks;
300     outgoing_.deque.pop_front();
301   }
302   return status;
303 }
304 
305 template <typename Outgoing, typename Incoming>
TrySend(const Outgoing & message)306 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TrySend(
307     const Outgoing& message) {
308   COMMON_USING_DIRECTIVES;
309   if (outgoing_.chunk_size_for_upload <= 0 || outgoing_.max_pending_chunks <= 0)
310     return SendRaw(message);  // No chunking.
311   absl::Status status;
312   if (!(status = ChunkMessage(message)).ok()) {
313     Close();
314     return status;
315   }
316   return TrySendPending();
317 }
318 
319 template <typename Outgoing, typename Incoming>
SendAck(int32_t chunk_index)320 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::SendAck(
321     int32_t chunk_index) {
322   Outgoing ack;
323   ack.mutable_chunked_transfer()->mutable_ack()->set_chunk_index(chunk_index);
324   return SendRaw(ack);
325 }
326 
327 template <typename Outgoing, typename Incoming>
SendRaw(const Outgoing & message,bool disable_compression)328 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::SendRaw(
329     const Outgoing& message, bool disable_compression) {
330   if (!writer_interface_)
331     return absl::FailedPreconditionError("Send on closed stream.");
332   grpc::WriteOptions write_options;
333   if (disable_compression) write_options.set_no_compression();
334   outgoing_.total_bytes_uploaded += message.ByteSizeLong();
335   if (!writer_interface_->Write(message, write_options)) {
336     Close();
337     return absl::AbortedError("End of stream.");
338   }
339   return absl::OkStatus();
340 }
341 
342 // If this class is used on the client side, we need to break the abstraction
343 // that messages are opaque in order to read the chunking parameters sent by the
344 // server to determine how to carry out the remainder of the protocol.
345 // Inspect the checkin response to record these chunking options.
346 template <>
347 inline absl::Status
348 GrpcChunkedBidiStream<google::internal::federatedml::v2::ClientStreamMessage,
349                       google::internal::federatedml::v2::ServerStreamMessage>::
TrySnoopCheckinResponse(google::internal::federatedml::v2::ServerStreamMessage * message)350     TrySnoopCheckinResponse(
351         google::internal::federatedml::v2::ServerStreamMessage* message) {
352   COMMON_USING_DIRECTIVES;
353   if (message->kind_case() !=
354           ServerStreamMessage::kEligibilityEvalCheckinResponse &&
355       message->kind_case() != ServerStreamMessage::kCheckinResponse)
356     return absl::OkStatus();
357   if (incoming_.uncompressed_size != -1)
358     return absl::InvalidArgumentError("Chunk reassembly in progress.");
359   // We adopt any new protocol options we may receive, even if we previously
360   // received some options already. I.e. a ProtocolOptionsResponse received in a
361   // CheckinResponse will overwrite any ProtocolOptionsResponse that was
362   // previously received in a EligibilityEvalCheckinResponse.
363   // OTOH, we also don't require that every EligibilityEvalCheckinResponse or
364   // CheckinResponse message actually has a ProtocolOptionsResponse message set
365   // (e.g. CheckinResponse may not have a ProtocolOptionsResponse if one was
366   // already returned inside a prior EligibilityEvalCheckinResponse).
367   if (message->eligibility_eval_checkin_response()
368           .has_protocol_options_response() ||
369       message->checkin_response().has_protocol_options_response()) {
370     auto options =
371         (message->has_eligibility_eval_checkin_response()
372              ? message->eligibility_eval_checkin_response()
373                    .protocol_options_response()
374              : message->checkin_response().protocol_options_response());
375     outgoing_.chunk_size_for_upload = options.chunk_size_for_upload();
376     outgoing_.max_pending_chunks = options.max_pending_chunks();
377     outgoing_.compression_level = options.compression_level();
378   }
379   return absl::OkStatus();
380 }
381 
382 // If this class is being used by the server, this is a no-op as the server
383 // determines the chunking options.
384 template <typename Outgoing, typename Incoming>
TrySnoopCheckinResponse(Incoming *)385 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TrySnoopCheckinResponse(
386     Incoming*) {
387   return absl::OkStatus();
388 }
389 
390 template <typename Outgoing, typename Incoming>
TryAssemblePending(Incoming * message,bool * message_assembled)391 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TryAssemblePending(
392     Incoming* message, bool* message_assembled) {
393   COMMON_USING_DIRECTIVES;
394   *message_assembled = false;
395   auto chunk = message->chunked_transfer();
396   switch (chunk.kind_case()) {
397     case ChunkedTransferMessage::kStart:
398       if (!incoming_.deque.empty() || incoming_.uncompressed_size != -1)
399         return absl::InternalError("Unexpected Start.");
400       incoming_.uncompressed_size = chunk.start().uncompressed_size();
401       incoming_.compression_level = chunk.start().compression_level();
402       incoming_.blob_size_bytes = chunk.start().blob_size_bytes();
403       break;
404     case ChunkedTransferMessage::kData:
405       if (chunk.data().chunk_index() != incoming_.deque.size())
406         return absl::InternalError("Unexpected Data.");
407       incoming_.deque.emplace_back(chunk.data().chunk_bytes());
408       incoming_.composite.append(incoming_.deque.back());
409       return SendAck(static_cast<int32_t>(incoming_.deque.size() - 1));
410     case ChunkedTransferMessage::kEnd:
411       if (incoming_.deque.empty() ||
412           chunk.end().chunk_count() != incoming_.deque.size())
413         return absl::InternalError("Unexpected End.");
414       return AssemblePending(message, message_assembled);
415     case ChunkedTransferMessage::kAck:
416       return absl::InternalError("Unexpected Ack.");
417     default:
418       return absl::InternalError(
419           absl::StrCat("Unexpected message subtype: ",
420                        message->chunked_transfer().kind_case()));
421   }
422 
423   return absl::OkStatus();
424 }
425 
426 template <typename Outgoing, typename Incoming>
AssemblePending(Incoming * message,bool * message_assembled)427 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::AssemblePending(
428     Incoming* message, bool* message_assembled) {
429   COMMON_USING_DIRECTIVES;
430   // TODO(team): Replace with a more efficient deserialization mechanism.
431   if (incoming_.compression_level == CompressionLevel::UNCOMPRESSED) {
432     if (!message->ParseFromString(incoming_.composite))
433       return absl::InternalError(absl::StrCat("Could not parse from string. ",
434                                               incoming_.composite.size()));
435   } else {
436     ArrayInputStream string_input_stream(
437         incoming_.composite.c_str(),
438         static_cast<int>(incoming_.composite.size()));
439     GzipInputStream compressed_stream(&string_input_stream);
440     if (!message->ParseFromZeroCopyStream(&compressed_stream))
441       return absl::InternalError("Could not parse proto from input stream.");
442   }
443   *message_assembled = true;
444   incoming_.uncompressed_size = -1;
445   incoming_.blob_size_bytes = -1;
446   incoming_.deque.clear();
447   incoming_.composite.clear();
448   return absl::OkStatus();
449 }
450 
451 template <typename Outgoing, typename Incoming>
ReceiveRaw(Incoming * message)452 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::ReceiveRaw(
453     Incoming* message) {
454   if (!reader_interface_)
455     return absl::FailedPreconditionError("Receive on closed stream.");
456   if (!reader_interface_->Read(message)) {
457     Close();
458     return absl::AbortedError("End of stream.");
459   }
460   incoming_.total_bytes_downloaded += message->ByteSizeLong();
461   return absl::OkStatus();
462 }
463 
464 template <typename Outgoing, typename Incoming>
Close()465 void GrpcChunkedBidiStream<Outgoing, Incoming>::Close() {
466   writer_interface_ = nullptr;
467   reader_interface_ = nullptr;
468 }
469 
470 template <typename Outgoing, typename Incoming>
471 int64_t
ChunkingLayerBytesReceived()472 GrpcChunkedBidiStream<Outgoing, Incoming>::ChunkingLayerBytesReceived() {
473   return incoming_.total_bytes_downloaded;
474 }
475 
476 template <typename Outgoing, typename Incoming>
ChunkingLayerBytesSent()477 int64_t GrpcChunkedBidiStream<Outgoing, Incoming>::ChunkingLayerBytesSent() {
478   return outgoing_.total_bytes_uploaded;
479 }
480 
481 }  // namespace client
482 }  // namespace fcp
483 
484 #endif  // FCP_PROTOCOL_GRPC_CHUNKED_BIDI_STREAM_H_
485