1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef FCP_PROTOCOL_GRPC_CHUNKED_BIDI_STREAM_H_
18 #define FCP_PROTOCOL_GRPC_CHUNKED_BIDI_STREAM_H_
19
20 #include <stddef.h>
21
22 #include <algorithm>
23 #include <deque>
24 #include <memory>
25 #include <string>
26
27 #include "absl/base/attributes.h"
28 #include "absl/status/status.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/protos/federated_api.grpc.pb.h"
31 #include "grpcpp/impl/codegen/call_op_set.h"
32 #include "grpcpp/impl/codegen/sync_stream.h"
33 #include "google/protobuf/io/gzip_stream.h"
34 #include "google/protobuf/io/zero_copy_stream_impl_lite.h"
35
36 namespace fcp {
37 namespace client {
38
39 /**
40 * A class which implements the chunking protocol for the federated learning
41 * API.
42 *
43 * Can be used by both client and server.
44 *
45 * @tparam Outgoing The type of the outgoing protocol buffer message.
46 * @tparam Incoming The type of the incoming protocol buffer message.
47 */
48 template <typename Outgoing, typename Incoming>
49 class GrpcChunkedBidiStream {
50 public:
51 struct GrpcChunkedBidiStreamOptions {
52 int32_t chunk_size_for_upload = -1;
53 int32_t max_pending_chunks = -1;
54 google::internal::federatedml::v2::CompressionLevel compression_level{};
55 };
56 GrpcChunkedBidiStream(
57 grpc::internal::WriterInterface<Outgoing>* writer_interface,
58 grpc::internal::ReaderInterface<Incoming>* reader_interface);
59 GrpcChunkedBidiStream(
60 grpc::internal::WriterInterface<Outgoing>* writer_interface,
61 grpc::internal::ReaderInterface<Incoming>* reader_interface,
62 GrpcChunkedBidiStreamOptions options);
63 virtual ~GrpcChunkedBidiStream() = default;
64
65 // GrpcChunkedBidiStream is neither copyable nor movable.
66 GrpcChunkedBidiStream(const GrpcChunkedBidiStream&) = delete;
67 GrpcChunkedBidiStream& operator=(const GrpcChunkedBidiStream&) = delete;
68
69 ABSL_MUST_USE_RESULT absl::Status Send(Outgoing* message);
70 ABSL_MUST_USE_RESULT absl::Status Receive(Incoming* message);
71 void Close();
72 int64_t ChunkingLayerBytesSent();
73 int64_t ChunkingLayerBytesReceived();
74
75 private:
76 ABSL_MUST_USE_RESULT absl::Status TryDecorateCheckinRequest(
77 Outgoing* message);
78 ABSL_MUST_USE_RESULT absl::Status ChunkMessage(const Outgoing& message);
79 ABSL_MUST_USE_RESULT absl::Status TrySendPending();
80 ABSL_MUST_USE_RESULT absl::Status TrySend(const Outgoing& message);
81 ABSL_MUST_USE_RESULT absl::Status SendAck(int32_t chunk_index);
82 ABSL_MUST_USE_RESULT absl::Status SendRaw(const Outgoing& message,
83 bool disable_compression = false);
84 ABSL_MUST_USE_RESULT absl::Status TrySnoopCheckinResponse(Incoming* message);
85 ABSL_MUST_USE_RESULT absl::Status TryAssemblePending(Incoming* message,
86 bool* message_assembled);
87 ABSL_MUST_USE_RESULT absl::Status AssemblePending(Incoming* message,
88 bool* message_assembled);
89 ABSL_MUST_USE_RESULT absl::Status ReceiveRaw(Incoming* message);
90
91 grpc::internal::WriterInterface<Outgoing>* writer_interface_;
92 grpc::internal::ReaderInterface<Incoming>* reader_interface_;
93
94 struct {
95 int32_t uncompressed_size = -1;
96 google::internal::federatedml::v2::CompressionLevel compression_level{};
97 int32_t blob_size_bytes = -1;
98 std::deque<std::string> deque;
99 std::string composite;
100 int64_t total_bytes_downloaded = 0;
101 } incoming_;
102
103 struct {
104 int32_t chunk_size_for_upload = 0;
105 int32_t max_pending_chunks = 0;
106 int32_t pending_chunks = 0;
107 google::internal::federatedml::v2::CompressionLevel compression_level{};
108 std::deque<std::unique_ptr<Outgoing>> deque;
109 int64_t total_bytes_uploaded = 0;
110
Add__anon6d200f920208111 google::internal::federatedml::v2::ChunkedTransferMessage* Add() {
112 deque.push_back(std::make_unique<Outgoing>());
113 return deque.back()->mutable_chunked_transfer();
114 }
115 } outgoing_;
116 };
117
118 #define COMMON_USING_DIRECTIVES \
119 using google::internal::federatedml::v2::ChunkedTransferMessage; \
120 using google::internal::federatedml::v2::ClientStreamMessage; \
121 using google::internal::federatedml::v2::CompressionLevel; \
122 using google::internal::federatedml::v2::ServerStreamMessage; \
123 using google::protobuf::io::ArrayInputStream; \
124 using google::protobuf::io::StringOutputStream; \
125 using google::protobuf::io::GzipInputStream; \
126 using google::protobuf::io::GzipOutputStream; \
127 using google::protobuf::io::ZeroCopyOutputStream;
128
129 template <typename Outgoing, typename Incoming>
GrpcChunkedBidiStream(grpc::internal::WriterInterface<Outgoing> * writer_interface,grpc::internal::ReaderInterface<Incoming> * reader_interface)130 GrpcChunkedBidiStream<Outgoing, Incoming>::GrpcChunkedBidiStream(
131 grpc::internal::WriterInterface<Outgoing>* writer_interface,
132 grpc::internal::ReaderInterface<Incoming>* reader_interface)
133 : GrpcChunkedBidiStream(writer_interface, reader_interface,
134 GrpcChunkedBidiStreamOptions()) {}
135
136 template <typename Outgoing, typename Incoming>
GrpcChunkedBidiStream(grpc::internal::WriterInterface<Outgoing> * writer_interface,grpc::internal::ReaderInterface<Incoming> * reader_interface,GrpcChunkedBidiStreamOptions options)137 GrpcChunkedBidiStream<Outgoing, Incoming>::GrpcChunkedBidiStream(
138 grpc::internal::WriterInterface<Outgoing>* writer_interface,
139 grpc::internal::ReaderInterface<Incoming>* reader_interface,
140 GrpcChunkedBidiStreamOptions options)
141 : writer_interface_(writer_interface), reader_interface_(reader_interface) {
142 outgoing_.chunk_size_for_upload = options.chunk_size_for_upload;
143 outgoing_.max_pending_chunks = options.max_pending_chunks;
144 outgoing_.compression_level = options.compression_level;
145 }
146
147 template <typename Outgoing, typename Incoming>
Send(Outgoing * message)148 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::Send(
149 Outgoing* message) {
150 COMMON_USING_DIRECTIVES;
151 FCP_RETURN_IF_ERROR(TryDecorateCheckinRequest(message));
152 switch (message->kind_case()) {
153 case Outgoing::KindCase::kChunkedTransfer:
154 Close();
155 return absl::InvalidArgumentError(
156 absl::StrCat("Message is pre-chunked: ", message->DebugString()));
157 default:
158 break;
159 }
160
161 return TrySend(*message);
162 }
163
164 template <typename Outgoing, typename Incoming>
Receive(Incoming * message)165 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::Receive(
166 Incoming* message) {
167 COMMON_USING_DIRECTIVES;
168 Status status;
169 bool message_assembled = false;
170
171 do {
172 FCP_RETURN_IF_ERROR(status = ReceiveRaw(message));
173 switch (message->kind_case()) {
174 case Incoming::KindCase::kChunkedTransfer:
175 if (message->chunked_transfer().kind_case() ==
176 ChunkedTransferMessage::kAck) {
177 --outgoing_.pending_chunks;
178 FCP_RETURN_IF_ERROR(status = TrySendPending());
179 } else {
180 FCP_RETURN_IF_ERROR(
181 status = TryAssemblePending(message, &message_assembled));
182 }
183 break;
184 default:
185 if (incoming_.uncompressed_size != -1)
186 return absl::InvalidArgumentError("Chunk reassembly in progress.");
187 message_assembled = true;
188 break;
189 }
190 } while (!message_assembled);
191
192 FCP_RETURN_IF_ERROR(status = TrySnoopCheckinResponse(message));
193 return status;
194 }
195
196 template <>
197 inline absl::Status
198 GrpcChunkedBidiStream<google::internal::federatedml::v2::ClientStreamMessage,
199 google::internal::federatedml::v2::ServerStreamMessage>::
TryDecorateCheckinRequest(google::internal::federatedml::v2::ClientStreamMessage * message)200 TryDecorateCheckinRequest(
201 google::internal::federatedml::v2::ClientStreamMessage* message) {
202 COMMON_USING_DIRECTIVES;
203 if (message->kind_case() !=
204 ClientStreamMessage::kEligibilityEvalCheckinRequest &&
205 message->kind_case() != ClientStreamMessage::kCheckinRequest)
206 return absl::OkStatus();
207 // Both an EligibilityEvalCheckinRequest or a CheckinRequest message need to
208 // specify a ProtocolOptionsRequest message.
209 auto options = (message->has_eligibility_eval_checkin_request()
210 ? message->mutable_eligibility_eval_checkin_request()
211 ->mutable_protocol_options_request()
212 : message->mutable_checkin_request()
213 ->mutable_protocol_options_request());
214 options->set_supports_chunked_blob_transfer(true);
215 options->add_supported_compression_levels(CompressionLevel::UNCOMPRESSED);
216 options->add_supported_compression_levels(CompressionLevel::ZLIB_DEFAULT);
217 options->add_supported_compression_levels(
218 CompressionLevel::ZLIB_BEST_COMPRESSION);
219 options->add_supported_compression_levels(CompressionLevel::ZLIB_BEST_SPEED);
220 return absl::OkStatus();
221 }
222
223 template <typename Outgoing, typename Incoming>
224 absl::Status
TryDecorateCheckinRequest(Outgoing *)225 GrpcChunkedBidiStream<Outgoing, Incoming>::TryDecorateCheckinRequest(
226 Outgoing*) {
227 return absl::OkStatus();
228 }
229
230 template <typename Outgoing, typename Incoming>
ChunkMessage(const Outgoing & message)231 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::ChunkMessage(
232 const Outgoing& message) {
233 COMMON_USING_DIRECTIVES;
234
235 auto start = outgoing_.Add()->mutable_start();
236 start->set_compression_level(outgoing_.compression_level);
237
238 // TODO(team): Replace with a more efficient serialization mechanism.
239 std::string output;
240 if (outgoing_.compression_level == CompressionLevel::UNCOMPRESSED) {
241 if (!message.AppendToString(&output))
242 return absl::InternalError("Could not append to string.");
243 } else {
244 StringOutputStream string_output_stream(&output);
245 GzipOutputStream::Options options;
246 options.format = GzipOutputStream::ZLIB;
247 switch (outgoing_.compression_level) {
248 case CompressionLevel::ZLIB_DEFAULT:
249 options.compression_level = Z_DEFAULT_COMPRESSION;
250 break;
251 case CompressionLevel::ZLIB_BEST_COMPRESSION:
252 options.compression_level = Z_BEST_COMPRESSION;
253 break;
254 case CompressionLevel::ZLIB_BEST_SPEED:
255 options.compression_level = Z_BEST_SPEED;
256 break;
257 default:
258 Close();
259 return absl::InternalError("Unsupported compression level.");
260 }
261 GzipOutputStream compressed_stream(&string_output_stream, options);
262 if (!message.SerializeToZeroCopyStream(&compressed_stream) ||
263 !compressed_stream.Close())
264 return absl::InvalidArgumentError(
265 absl::StrCat("Failed to serialize message: ",
266 compressed_stream.ZlibErrorMessage()));
267 }
268
269 auto blob_size_bytes = static_cast<int32_t>(output.size());
270 int32_t chunk_index = 0;
271 if (!blob_size_bytes) blob_size_bytes = 1; // Force one empty packet.
272 for (size_t offset = 0; offset < blob_size_bytes;
273 offset += std::min(blob_size_bytes, outgoing_.chunk_size_for_upload),
274 ++chunk_index) {
275 auto data = outgoing_.Add()->mutable_data();
276 data->set_chunk_index(chunk_index);
277 data->set_chunk_bytes(output.substr(
278 offset, static_cast<size_t>(outgoing_.chunk_size_for_upload)));
279 }
280
281 start->set_uncompressed_size(static_cast<int32_t>(message.ByteSizeLong()));
282 start->set_blob_size_bytes(blob_size_bytes);
283
284 auto end = outgoing_.Add()->mutable_end();
285 end->set_chunk_count(chunk_index);
286 return absl::OkStatus();
287 }
288
289 template <typename Outgoing, typename Incoming>
TrySendPending()290 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TrySendPending() {
291 COMMON_USING_DIRECTIVES;
292 auto status = absl::OkStatus();
293 while (!outgoing_.deque.empty() &&
294 outgoing_.pending_chunks < outgoing_.max_pending_chunks) {
295 auto& front = outgoing_.deque.front();
296 FCP_RETURN_IF_ERROR(status =
297 SendRaw(*front, outgoing_.compression_level > 0));
298 if (front->chunked_transfer().kind_case() == ChunkedTransferMessage::kData)
299 ++outgoing_.pending_chunks;
300 outgoing_.deque.pop_front();
301 }
302 return status;
303 }
304
305 template <typename Outgoing, typename Incoming>
TrySend(const Outgoing & message)306 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TrySend(
307 const Outgoing& message) {
308 COMMON_USING_DIRECTIVES;
309 if (outgoing_.chunk_size_for_upload <= 0 || outgoing_.max_pending_chunks <= 0)
310 return SendRaw(message); // No chunking.
311 absl::Status status;
312 if (!(status = ChunkMessage(message)).ok()) {
313 Close();
314 return status;
315 }
316 return TrySendPending();
317 }
318
319 template <typename Outgoing, typename Incoming>
SendAck(int32_t chunk_index)320 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::SendAck(
321 int32_t chunk_index) {
322 Outgoing ack;
323 ack.mutable_chunked_transfer()->mutable_ack()->set_chunk_index(chunk_index);
324 return SendRaw(ack);
325 }
326
327 template <typename Outgoing, typename Incoming>
SendRaw(const Outgoing & message,bool disable_compression)328 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::SendRaw(
329 const Outgoing& message, bool disable_compression) {
330 if (!writer_interface_)
331 return absl::FailedPreconditionError("Send on closed stream.");
332 grpc::WriteOptions write_options;
333 if (disable_compression) write_options.set_no_compression();
334 outgoing_.total_bytes_uploaded += message.ByteSizeLong();
335 if (!writer_interface_->Write(message, write_options)) {
336 Close();
337 return absl::AbortedError("End of stream.");
338 }
339 return absl::OkStatus();
340 }
341
342 // If this class is used on the client side, we need to break the abstraction
343 // that messages are opaque in order to read the chunking parameters sent by the
344 // server to determine how to carry out the remainder of the protocol.
345 // Inspect the checkin response to record these chunking options.
346 template <>
347 inline absl::Status
348 GrpcChunkedBidiStream<google::internal::federatedml::v2::ClientStreamMessage,
349 google::internal::federatedml::v2::ServerStreamMessage>::
TrySnoopCheckinResponse(google::internal::federatedml::v2::ServerStreamMessage * message)350 TrySnoopCheckinResponse(
351 google::internal::federatedml::v2::ServerStreamMessage* message) {
352 COMMON_USING_DIRECTIVES;
353 if (message->kind_case() !=
354 ServerStreamMessage::kEligibilityEvalCheckinResponse &&
355 message->kind_case() != ServerStreamMessage::kCheckinResponse)
356 return absl::OkStatus();
357 if (incoming_.uncompressed_size != -1)
358 return absl::InvalidArgumentError("Chunk reassembly in progress.");
359 // We adopt any new protocol options we may receive, even if we previously
360 // received some options already. I.e. a ProtocolOptionsResponse received in a
361 // CheckinResponse will overwrite any ProtocolOptionsResponse that was
362 // previously received in a EligibilityEvalCheckinResponse.
363 // OTOH, we also don't require that every EligibilityEvalCheckinResponse or
364 // CheckinResponse message actually has a ProtocolOptionsResponse message set
365 // (e.g. CheckinResponse may not have a ProtocolOptionsResponse if one was
366 // already returned inside a prior EligibilityEvalCheckinResponse).
367 if (message->eligibility_eval_checkin_response()
368 .has_protocol_options_response() ||
369 message->checkin_response().has_protocol_options_response()) {
370 auto options =
371 (message->has_eligibility_eval_checkin_response()
372 ? message->eligibility_eval_checkin_response()
373 .protocol_options_response()
374 : message->checkin_response().protocol_options_response());
375 outgoing_.chunk_size_for_upload = options.chunk_size_for_upload();
376 outgoing_.max_pending_chunks = options.max_pending_chunks();
377 outgoing_.compression_level = options.compression_level();
378 }
379 return absl::OkStatus();
380 }
381
382 // If this class is being used by the server, this is a no-op as the server
383 // determines the chunking options.
384 template <typename Outgoing, typename Incoming>
TrySnoopCheckinResponse(Incoming *)385 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TrySnoopCheckinResponse(
386 Incoming*) {
387 return absl::OkStatus();
388 }
389
390 template <typename Outgoing, typename Incoming>
TryAssemblePending(Incoming * message,bool * message_assembled)391 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TryAssemblePending(
392 Incoming* message, bool* message_assembled) {
393 COMMON_USING_DIRECTIVES;
394 *message_assembled = false;
395 auto chunk = message->chunked_transfer();
396 switch (chunk.kind_case()) {
397 case ChunkedTransferMessage::kStart:
398 if (!incoming_.deque.empty() || incoming_.uncompressed_size != -1)
399 return absl::InternalError("Unexpected Start.");
400 incoming_.uncompressed_size = chunk.start().uncompressed_size();
401 incoming_.compression_level = chunk.start().compression_level();
402 incoming_.blob_size_bytes = chunk.start().blob_size_bytes();
403 break;
404 case ChunkedTransferMessage::kData:
405 if (chunk.data().chunk_index() != incoming_.deque.size())
406 return absl::InternalError("Unexpected Data.");
407 incoming_.deque.emplace_back(chunk.data().chunk_bytes());
408 incoming_.composite.append(incoming_.deque.back());
409 return SendAck(static_cast<int32_t>(incoming_.deque.size() - 1));
410 case ChunkedTransferMessage::kEnd:
411 if (incoming_.deque.empty() ||
412 chunk.end().chunk_count() != incoming_.deque.size())
413 return absl::InternalError("Unexpected End.");
414 return AssemblePending(message, message_assembled);
415 case ChunkedTransferMessage::kAck:
416 return absl::InternalError("Unexpected Ack.");
417 default:
418 return absl::InternalError(
419 absl::StrCat("Unexpected message subtype: ",
420 message->chunked_transfer().kind_case()));
421 }
422
423 return absl::OkStatus();
424 }
425
426 template <typename Outgoing, typename Incoming>
AssemblePending(Incoming * message,bool * message_assembled)427 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::AssemblePending(
428 Incoming* message, bool* message_assembled) {
429 COMMON_USING_DIRECTIVES;
430 // TODO(team): Replace with a more efficient deserialization mechanism.
431 if (incoming_.compression_level == CompressionLevel::UNCOMPRESSED) {
432 if (!message->ParseFromString(incoming_.composite))
433 return absl::InternalError(absl::StrCat("Could not parse from string. ",
434 incoming_.composite.size()));
435 } else {
436 ArrayInputStream string_input_stream(
437 incoming_.composite.c_str(),
438 static_cast<int>(incoming_.composite.size()));
439 GzipInputStream compressed_stream(&string_input_stream);
440 if (!message->ParseFromZeroCopyStream(&compressed_stream))
441 return absl::InternalError("Could not parse proto from input stream.");
442 }
443 *message_assembled = true;
444 incoming_.uncompressed_size = -1;
445 incoming_.blob_size_bytes = -1;
446 incoming_.deque.clear();
447 incoming_.composite.clear();
448 return absl::OkStatus();
449 }
450
451 template <typename Outgoing, typename Incoming>
ReceiveRaw(Incoming * message)452 absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::ReceiveRaw(
453 Incoming* message) {
454 if (!reader_interface_)
455 return absl::FailedPreconditionError("Receive on closed stream.");
456 if (!reader_interface_->Read(message)) {
457 Close();
458 return absl::AbortedError("End of stream.");
459 }
460 incoming_.total_bytes_downloaded += message->ByteSizeLong();
461 return absl::OkStatus();
462 }
463
464 template <typename Outgoing, typename Incoming>
Close()465 void GrpcChunkedBidiStream<Outgoing, Incoming>::Close() {
466 writer_interface_ = nullptr;
467 reader_interface_ = nullptr;
468 }
469
470 template <typename Outgoing, typename Incoming>
471 int64_t
ChunkingLayerBytesReceived()472 GrpcChunkedBidiStream<Outgoing, Incoming>::ChunkingLayerBytesReceived() {
473 return incoming_.total_bytes_downloaded;
474 }
475
476 template <typename Outgoing, typename Incoming>
ChunkingLayerBytesSent()477 int64_t GrpcChunkedBidiStream<Outgoing, Incoming>::ChunkingLayerBytesSent() {
478 return outgoing_.total_bytes_uploaded;
479 }
480
481 } // namespace client
482 } // namespace fcp
483
484 #endif // FCP_PROTOCOL_GRPC_CHUNKED_BIDI_STREAM_H_
485