1 // Copyright 2022 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef GRPC_SRC_CORE_LIB_CHANNEL_PROMISE_BASED_FILTER_H
16 #define GRPC_SRC_CORE_LIB_CHANNEL_PROMISE_BASED_FILTER_H
17
18 // Scaffolding to allow the per-call part of a filter to be authored in a
19 // promise-style. Most of this will be removed once the promises conversion is
20 // completed.
21
22 #include <grpc/support/port_platform.h>
23
24 #include <stdint.h>
25 #include <stdlib.h>
26
27 #include <atomic>
28 #include <initializer_list>
29 #include <memory>
30 #include <new>
31 #include <string>
32 #include <type_traits>
33 #include <utility>
34
35 #include "absl/container/inlined_vector.h"
36 #include "absl/functional/function_ref.h"
37 #include "absl/meta/type_traits.h"
38 #include "absl/status/status.h"
39 #include "absl/strings/string_view.h"
40 #include "absl/types/optional.h"
41
42 #include <grpc/event_engine/event_engine.h>
43 #include <grpc/grpc.h>
44 #include <grpc/support/log.h>
45
46 #include "src/core/lib/channel/call_finalization.h"
47 #include "src/core/lib/channel/channel_fwd.h"
48 #include "src/core/lib/channel/channel_stack.h"
49 #include "src/core/lib/channel/context.h"
50 #include "src/core/lib/event_engine/default_event_engine.h" // IWYU pragma: keep
51 #include "src/core/lib/gprpp/crash.h"
52 #include "src/core/lib/gprpp/debug_location.h"
53 #include "src/core/lib/gprpp/time.h"
54 #include "src/core/lib/iomgr/call_combiner.h"
55 #include "src/core/lib/iomgr/closure.h"
56 #include "src/core/lib/iomgr/error.h"
57 #include "src/core/lib/iomgr/exec_ctx.h"
58 #include "src/core/lib/iomgr/polling_entity.h"
59 #include "src/core/lib/promise/activity.h"
60 #include "src/core/lib/promise/arena_promise.h"
61 #include "src/core/lib/promise/context.h"
62 #include "src/core/lib/promise/pipe.h"
63 #include "src/core/lib/promise/poll.h"
64 #include "src/core/lib/resource_quota/arena.h"
65 #include "src/core/lib/slice/slice_buffer.h"
66 #include "src/core/lib/surface/call.h"
67 #include "src/core/lib/transport/error_utils.h"
68 #include "src/core/lib/transport/metadata_batch.h"
69 #include "src/core/lib/transport/transport.h"
70
71 namespace grpc_core {
72
73 class ChannelFilter {
74 public:
75 class Args {
76 public:
Args()77 Args() : Args(nullptr, nullptr) {}
Args(grpc_channel_stack * channel_stack,grpc_channel_element * channel_element)78 explicit Args(grpc_channel_stack* channel_stack,
79 grpc_channel_element* channel_element)
80 : channel_stack_(channel_stack), channel_element_(channel_element) {}
81
channel_stack()82 grpc_channel_stack* channel_stack() const { return channel_stack_; }
uninitialized_channel_element()83 grpc_channel_element* uninitialized_channel_element() {
84 return channel_element_;
85 }
86
87 private:
88 friend class ChannelFilter;
89 grpc_channel_stack* channel_stack_;
90 grpc_channel_element* channel_element_;
91 };
92
93 // Perform post-initialization step (if any).
PostInit()94 virtual void PostInit() {}
95
96 // Construct a promise for one call.
97 virtual ArenaPromise<ServerMetadataHandle> MakeCallPromise(
98 CallArgs call_args, NextPromiseFactory next_promise_factory) = 0;
99
100 // Start a legacy transport op
101 // Return true if the op was handled, false if it should be passed to the
102 // next filter.
103 // TODO(ctiller): design a new API for this - we probably don't want big op
104 // structures going forward.
StartTransportOp(grpc_transport_op *)105 virtual bool StartTransportOp(grpc_transport_op*) { return false; }
106
107 // Perform a legacy get info call
108 // Return true if the op was handled, false if it should be passed to the
109 // next filter.
110 // TODO(ctiller): design a new API for this
GetChannelInfo(const grpc_channel_info *)111 virtual bool GetChannelInfo(const grpc_channel_info*) { return false; }
112
113 virtual ~ChannelFilter() = default;
114
115 grpc_event_engine::experimental::EventEngine*
hack_until_per_channel_stack_event_engines_land_get_event_engine()116 hack_until_per_channel_stack_event_engines_land_get_event_engine() {
117 return event_engine_.get();
118 }
119
120 private:
121 // TODO(ctiller): remove once per-channel-stack EventEngines land
122 std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_ =
123 grpc_event_engine::experimental::GetDefaultEventEngine();
124 };
125
126 // Designator for whether a filter is client side or server side.
127 // Please don't use this outside calls to MakePromiseBasedFilter - it's
128 // intended to be deleted once the promise conversion is complete.
129 enum class FilterEndpoint {
130 kClient,
131 kServer,
132 };
133
134 // Flags for MakePromiseBasedFilter.
135 static constexpr uint8_t kFilterExaminesServerInitialMetadata = 1;
136 static constexpr uint8_t kFilterIsLast = 2;
137 static constexpr uint8_t kFilterExaminesOutboundMessages = 4;
138 static constexpr uint8_t kFilterExaminesInboundMessages = 8;
139 static constexpr uint8_t kFilterExaminesCallContext = 16;
140
141 namespace promise_filter_detail {
142
143 // Proxy channel filter for initialization failure, since we must leave a
144 // valid filter in place.
145 class InvalidChannelFilter : public ChannelFilter {
146 public:
MakeCallPromise(CallArgs,NextPromiseFactory)147 ArenaPromise<ServerMetadataHandle> MakeCallPromise(
148 CallArgs, NextPromiseFactory) override {
149 abort();
150 }
151 };
152
153 // Call data shared between all implementations of promise-based filters.
154 class BaseCallData : public Activity, private Wakeable {
155 protected:
156 // Hook to allow interception of messages on the send/receive path by
157 // PipeSender and PipeReceiver, as appropriate according to whether we're
158 // client or server.
159 class Interceptor {
160 public:
161 virtual PipeSender<MessageHandle>* Push() = 0;
162 virtual PipeReceiver<MessageHandle>* Pull() = 0;
163 virtual PipeReceiver<MessageHandle>* original_receiver() = 0;
164 virtual PipeSender<MessageHandle>* original_sender() = 0;
165 virtual void GotPipe(PipeSender<MessageHandle>*) = 0;
166 virtual void GotPipe(PipeReceiver<MessageHandle>*) = 0;
167 virtual ~Interceptor() = default;
168 };
169
170 BaseCallData(grpc_call_element* elem, const grpc_call_element_args* args,
171 uint8_t flags,
172 absl::FunctionRef<Interceptor*()> make_send_interceptor,
173 absl::FunctionRef<Interceptor*()> make_recv_interceptor);
174
175 public:
176 ~BaseCallData() override;
177
set_pollent(grpc_polling_entity * pollent)178 void set_pollent(grpc_polling_entity* pollent) {
179 GPR_ASSERT(nullptr ==
180 pollent_.exchange(pollent, std::memory_order_release));
181 }
182
183 // Activity implementation (partial).
184 void Orphan() final;
185 Waker MakeNonOwningWaker() final;
186 Waker MakeOwningWaker() final;
187
ActivityDebugTag(WakeupMask)188 std::string ActivityDebugTag(WakeupMask) const override { return DebugTag(); }
189
Finalize(const grpc_call_final_info * final_info)190 void Finalize(const grpc_call_final_info* final_info) {
191 finalization_.Run(final_info);
192 }
193
194 virtual void StartBatch(grpc_transport_stream_op_batch* batch) = 0;
195
196 protected:
197 class ScopedContext
198 : public promise_detail::Context<Arena>,
199 public promise_detail::Context<grpc_call_context_element>,
200 public promise_detail::Context<grpc_polling_entity>,
201 public promise_detail::Context<CallFinalization>,
202 public promise_detail::Context<
203 grpc_event_engine::experimental::EventEngine>,
204 public promise_detail::Context<CallContext> {
205 public:
ScopedContext(BaseCallData * call_data)206 explicit ScopedContext(BaseCallData* call_data)
207 : promise_detail::Context<Arena>(call_data->arena_),
208 promise_detail::Context<grpc_call_context_element>(
209 call_data->context_),
210 promise_detail::Context<grpc_polling_entity>(
211 call_data->pollent_.load(std::memory_order_acquire)),
212 promise_detail::Context<CallFinalization>(&call_data->finalization_),
213 promise_detail::Context<grpc_event_engine::experimental::EventEngine>(
214 call_data->event_engine_),
215 promise_detail::Context<CallContext>(call_data->call_context_) {}
216 };
217
218 class Flusher {
219 public:
220 explicit Flusher(BaseCallData* call);
221 // Calls closures, schedules batches, relinquishes call combiner.
222 ~Flusher();
223
Resume(grpc_transport_stream_op_batch * batch)224 void Resume(grpc_transport_stream_op_batch* batch) {
225 GPR_ASSERT(!call_->is_last());
226 if (batch->HasOp()) {
227 release_.push_back(batch);
228 } else if (batch->on_complete != nullptr) {
229 Complete(batch);
230 }
231 }
232
Cancel(grpc_transport_stream_op_batch * batch,grpc_error_handle error)233 void Cancel(grpc_transport_stream_op_batch* batch,
234 grpc_error_handle error) {
235 grpc_transport_stream_op_batch_queue_finish_with_failure(batch, error,
236 &call_closures_);
237 }
238
Complete(grpc_transport_stream_op_batch * batch)239 void Complete(grpc_transport_stream_op_batch* batch) {
240 call_closures_.Add(batch->on_complete, absl::OkStatus(),
241 "Flusher::Complete");
242 }
243
AddClosure(grpc_closure * closure,grpc_error_handle error,const char * reason)244 void AddClosure(grpc_closure* closure, grpc_error_handle error,
245 const char* reason) {
246 call_closures_.Add(closure, error, reason);
247 }
248
call()249 BaseCallData* call() const { return call_; }
250
251 private:
252 absl::InlinedVector<grpc_transport_stream_op_batch*, 1> release_;
253 CallCombinerClosureList call_closures_;
254 BaseCallData* const call_;
255 };
256
257 // Smart pointer like wrapper around a batch.
258 // Creation makes a ref count of one capture.
259 // Copying increments.
260 // Must be moved from or resumed or cancelled before destruction.
261 class CapturedBatch final {
262 public:
263 CapturedBatch();
264 explicit CapturedBatch(grpc_transport_stream_op_batch* batch);
265 ~CapturedBatch();
266 CapturedBatch(const CapturedBatch&);
267 CapturedBatch& operator=(const CapturedBatch&);
268 CapturedBatch(CapturedBatch&&) noexcept;
269 CapturedBatch& operator=(CapturedBatch&&) noexcept;
270
271 grpc_transport_stream_op_batch* operator->() { return batch_; }
is_captured()272 bool is_captured() const { return batch_ != nullptr; }
273
274 // Resume processing this batch (releases one ref, passes it down the
275 // stack)
276 void ResumeWith(Flusher* releaser);
277 // Cancel this batch immediately (releases all refs)
278 void CancelWith(grpc_error_handle error, Flusher* releaser);
279 // Complete this batch (pass it up) assuming refs drop to zero
280 void CompleteWith(Flusher* releaser);
281
Swap(CapturedBatch * other)282 void Swap(CapturedBatch* other) { std::swap(batch_, other->batch_); }
283
284 private:
285 grpc_transport_stream_op_batch* batch_;
286 };
287
WrapMetadata(grpc_metadata_batch * p)288 static Arena::PoolPtr<grpc_metadata_batch> WrapMetadata(
289 grpc_metadata_batch* p) {
290 return Arena::PoolPtr<grpc_metadata_batch>(p,
291 Arena::PooledDeleter(nullptr));
292 }
293
294 class ReceiveInterceptor final : public Interceptor {
295 public:
ReceiveInterceptor(Arena * arena)296 explicit ReceiveInterceptor(Arena* arena) : pipe_{arena} {}
297
original_receiver()298 PipeReceiver<MessageHandle>* original_receiver() override {
299 return &pipe_.receiver;
300 }
original_sender()301 PipeSender<MessageHandle>* original_sender() override { abort(); }
302
GotPipe(PipeReceiver<MessageHandle> * receiver)303 void GotPipe(PipeReceiver<MessageHandle>* receiver) override {
304 GPR_ASSERT(receiver_ == nullptr);
305 receiver_ = receiver;
306 }
307
GotPipe(PipeSender<MessageHandle> *)308 void GotPipe(PipeSender<MessageHandle>*) override { abort(); }
309
Push()310 PipeSender<MessageHandle>* Push() override { return &pipe_.sender; }
Pull()311 PipeReceiver<MessageHandle>* Pull() override {
312 GPR_ASSERT(receiver_ != nullptr);
313 return receiver_;
314 }
315
316 private:
317 Pipe<MessageHandle> pipe_;
318 PipeReceiver<MessageHandle>* receiver_ = nullptr;
319 };
320
321 class SendInterceptor final : public Interceptor {
322 public:
SendInterceptor(Arena * arena)323 explicit SendInterceptor(Arena* arena) : pipe_{arena} {}
324
original_receiver()325 PipeReceiver<MessageHandle>* original_receiver() override { abort(); }
original_sender()326 PipeSender<MessageHandle>* original_sender() override {
327 return &pipe_.sender;
328 }
329
GotPipe(PipeReceiver<MessageHandle> *)330 void GotPipe(PipeReceiver<MessageHandle>*) override { abort(); }
331
GotPipe(PipeSender<MessageHandle> * sender)332 void GotPipe(PipeSender<MessageHandle>* sender) override {
333 GPR_ASSERT(sender_ == nullptr);
334 sender_ = sender;
335 }
336
Push()337 PipeSender<MessageHandle>* Push() override {
338 GPR_ASSERT(sender_ != nullptr);
339 return sender_;
340 }
Pull()341 PipeReceiver<MessageHandle>* Pull() override { return &pipe_.receiver; }
342
343 private:
344 Pipe<MessageHandle> pipe_;
345 PipeSender<MessageHandle>* sender_ = nullptr;
346 };
347
348 // State machine for sending messages: handles intercepting send_message ops
349 // and forwarding them through pipes to the promise, then getting the result
350 // down the stack.
351 // Split into its own class so that we don't spend the memory instantiating
352 // these members for filters that don't need to intercept sent messages.
353 class SendMessage {
354 public:
SendMessage(BaseCallData * base,Interceptor * interceptor)355 SendMessage(BaseCallData* base, Interceptor* interceptor)
356 : base_(base), interceptor_(interceptor) {}
~SendMessage()357 ~SendMessage() { interceptor_->~Interceptor(); }
358
interceptor()359 Interceptor* interceptor() { return interceptor_; }
360
361 // Start a send_message op.
362 void StartOp(CapturedBatch batch);
363 // Publish the outbound pipe to the filter.
364 // This happens when the promise requests to call the next filter: until
365 // this occurs messages can't be sent as we don't know the pipe that the
366 // promise expects to send on.
367 template <typename T>
368 void GotPipe(T* pipe);
369 // Called from client/server polling to do the send message part of the
370 // work.
371 void WakeInsideCombiner(Flusher* flusher, bool allow_push_to_pipe);
372 // Call is completed, we have trailing metadata. Close things out.
373 void Done(const ServerMetadata& metadata, Flusher* flusher);
374 // Return true if we have a batch captured (for debug logs)
HaveCapturedBatch()375 bool HaveCapturedBatch() const { return batch_.is_captured(); }
376 // Return true if we're not actively sending a message.
377 bool IsIdle() const;
378 // Return true if we've released the message for forwarding down the stack.
IsForwarded()379 bool IsForwarded() const { return state_ == State::kForwardedBatch; }
380
381 private:
382 enum class State : uint8_t {
383 // Starting state: no batch started, no outgoing pipe configured.
384 kInitial,
385 // We have an outgoing pipe, but no batch started.
386 // (this is the steady state).
387 kIdle,
388 // We have a batch started, but no outgoing pipe configured.
389 // Stall until we have one.
390 kGotBatchNoPipe,
391 // We have a batch, and an outgoing pipe. On the next poll we'll push the
392 // message into the pipe to the promise.
393 kGotBatch,
394 // We've pushed a message into the promise, and we're now waiting for it
395 // to pop out the other end so we can forward it down the stack.
396 kPushedToPipe,
397 // We've forwarded a message down the stack, and now we're waiting for
398 // completion.
399 kForwardedBatch,
400 // We've got the completion callback, we'll close things out during poll
401 // and then forward completion callbacks up and transition back to idle.
402 kBatchCompleted,
403 // We're almost done, but need to poll first.
404 kCancelledButNotYetPolled,
405 // We're done.
406 kCancelled,
407 // We're done, but we haven't gotten a status yet
408 kCancelledButNoStatus,
409 };
410 static const char* StateString(State);
411
412 void OnComplete(absl::Status status);
413
414 BaseCallData* const base_;
415 State state_ = State::kInitial;
416 Interceptor* const interceptor_;
417 absl::optional<PipeSender<MessageHandle>::PushType> push_;
418 absl::optional<PipeReceiverNextType<MessageHandle>> next_;
419 CapturedBatch batch_;
420 grpc_closure* intercepted_on_complete_;
421 grpc_closure on_complete_ =
422 MakeMemberClosure<SendMessage, &SendMessage::OnComplete>(this);
423 absl::Status completed_status_;
424 };
425
426 // State machine for receiving messages: handles intercepting recv_message
427 // ops, forwarding them down the stack, and then publishing the result via
428 // pipes to the promise (and ultimately calling the right callbacks for the
429 // batch when our promise has completed processing of them).
430 // Split into its own class so that we don't spend the memory instantiating
431 // these members for filters that don't need to intercept sent messages.
432 class ReceiveMessage {
433 public:
ReceiveMessage(BaseCallData * base,Interceptor * interceptor)434 ReceiveMessage(BaseCallData* base, Interceptor* interceptor)
435 : base_(base), interceptor_(interceptor) {}
~ReceiveMessage()436 ~ReceiveMessage() { interceptor_->~Interceptor(); }
437
interceptor()438 Interceptor* interceptor() { return interceptor_; }
439
440 // Start a recv_message op.
441 void StartOp(CapturedBatch& batch);
442 // Publish the inbound pipe to the filter.
443 // This happens when the promise requests to call the next filter: until
444 // this occurs messages can't be received as we don't know the pipe that the
445 // promise expects to forward them with.
446 template <typename T>
447 void GotPipe(T* pipe);
448 // Called from client/server polling to do the receive message part of the
449 // work.
450 void WakeInsideCombiner(Flusher* flusher, bool allow_push_to_pipe);
451 // Call is completed, we have trailing metadata. Close things out.
452 void Done(const ServerMetadata& metadata, Flusher* flusher);
453
454 private:
455 enum class State : uint8_t {
456 // Starting state: no batch started, no incoming pipe configured.
457 kInitial,
458 // We have an incoming pipe, but no batch started.
459 // (this is the steady state).
460 kIdle,
461 // We received a batch and forwarded it on, but have not got an incoming
462 // pipe configured.
463 kForwardedBatchNoPipe,
464 // We received a batch and forwarded it on.
465 kForwardedBatch,
466 // We got the completion for the recv_message, but we don't yet have a
467 // pipe configured. Stall until this changes.
468 kBatchCompletedNoPipe,
469 // We got the completion for the recv_message, and we have a pipe
470 // configured: next poll will push the message into the pipe for the
471 // filter to process.
472 kBatchCompleted,
473 // We've pushed a message into the promise, and we're now waiting for it
474 // to pop out the other end so we can forward it up the stack.
475 kPushedToPipe,
476 // We've got a message out of the pipe, now we need to wait for processing
477 // to completely quiesce in the promise prior to forwarding the completion
478 // up the stack.
479 kPulledFromPipe,
480 // We're done.
481 kCancelled,
482 // Call got terminated whilst we were idle: we need to close the sender
483 // pipe next poll.
484 kCancelledWhilstIdle,
485 // Call got terminated whilst we had forwarded a recv_message down the
486 // stack: we need to keep track of that until we get the completion so
487 // that we do the right thing in OnComplete.
488 kCancelledWhilstForwarding,
489 // The same, but before we got the pipe
490 kCancelledWhilstForwardingNoPipe,
491 // Call got terminated whilst we had a recv_message batch completed, and
492 // we've now received the completion.
493 // On the next poll we'll close things out and forward on completions,
494 // then transition to cancelled.
495 kBatchCompletedButCancelled,
496 // The same, but before we got the pipe
497 kBatchCompletedButCancelledNoPipe,
498 // Completed successfully while we're processing a recv message - see
499 // kPushedToPipe.
500 kCompletedWhilePushedToPipe,
501 // Completed successfully while we're processing a recv message - see
502 // kPulledFromPipe.
503 kCompletedWhilePulledFromPipe,
504 // Completed successfully while we were waiting to process
505 // kBatchCompleted.
506 kCompletedWhileBatchCompleted,
507 };
508 static const char* StateString(State);
509
510 void OnComplete(absl::Status status);
511
512 BaseCallData* const base_;
513 Interceptor* const interceptor_;
514 State state_ = State::kInitial;
515 uint32_t scratch_flags_;
516 absl::optional<SliceBuffer>* intercepted_slice_buffer_;
517 uint32_t* intercepted_flags_;
518 absl::optional<PipeSender<MessageHandle>::PushType> push_;
519 absl::optional<PipeReceiverNextType<MessageHandle>> next_;
520 absl::Status completed_status_;
521 grpc_closure* intercepted_on_complete_;
522 grpc_closure on_complete_ =
523 MakeMemberClosure<ReceiveMessage, &ReceiveMessage::OnComplete>(this);
524 };
525
arena()526 Arena* arena() { return arena_; }
elem()527 grpc_call_element* elem() const { return elem_; }
call_combiner()528 CallCombiner* call_combiner() const { return call_combiner_; }
deadline()529 Timestamp deadline() const { return deadline_; }
call_stack()530 grpc_call_stack* call_stack() const { return call_stack_; }
server_initial_metadata_pipe()531 Pipe<ServerMetadataHandle>* server_initial_metadata_pipe() const {
532 return server_initial_metadata_pipe_;
533 }
send_message()534 SendMessage* send_message() const { return send_message_; }
receive_message()535 ReceiveMessage* receive_message() const { return receive_message_; }
536
is_last()537 bool is_last() const {
538 return grpc_call_stack_element(call_stack_, call_stack_->count - 1) ==
539 elem_;
540 }
541
542 virtual void WakeInsideCombiner(Flusher* flusher) = 0;
543
544 virtual absl::string_view ClientOrServerString() const = 0;
545 std::string LogTag() const;
546
547 private:
548 // Wakeable implementation.
549 void Wakeup(WakeupMask) final;
WakeupAsync(WakeupMask)550 void WakeupAsync(WakeupMask) final { Crash("not implemented"); }
551 void Drop(WakeupMask) final;
552
553 virtual void OnWakeup() = 0;
554
555 grpc_call_stack* const call_stack_;
556 grpc_call_element* const elem_;
557 Arena* const arena_;
558 CallCombiner* const call_combiner_;
559 const Timestamp deadline_;
560 CallFinalization finalization_;
561 CallContext* call_context_ = nullptr;
562 grpc_call_context_element* const context_;
563 std::atomic<grpc_polling_entity*> pollent_{nullptr};
564 Pipe<ServerMetadataHandle>* const server_initial_metadata_pipe_;
565 SendMessage* const send_message_;
566 ReceiveMessage* const receive_message_;
567 grpc_event_engine::experimental::EventEngine* event_engine_;
568 };
569
570 class ClientCallData : public BaseCallData {
571 public:
572 ClientCallData(grpc_call_element* elem, const grpc_call_element_args* args,
573 uint8_t flags);
574 ~ClientCallData() override;
575
576 // Activity implementation.
577 void ForceImmediateRepoll(WakeupMask) final;
578 // Handle one grpc_transport_stream_op_batch
579 void StartBatch(grpc_transport_stream_op_batch* batch) override;
580
581 std::string DebugTag() const override;
582
583 private:
584 // At what stage is our handling of send initial metadata?
585 enum class SendInitialState {
586 // Start state: no op seen
587 kInitial,
588 // We've seen the op, and started the promise in response to it, but have
589 // not yet sent the op to the next filter.
590 kQueued,
591 // We've sent the op to the next filter.
592 kForwarded,
593 // We were cancelled.
594 kCancelled
595 };
596 // At what stage is our handling of recv trailing metadata?
597 enum class RecvTrailingState {
598 // Start state: no op seen
599 kInitial,
600 // We saw the op, and since it was bundled with send initial metadata, we
601 // queued it until the send initial metadata can be sent to the next
602 // filter.
603 kQueued,
604 // We've forwarded the op to the next filter.
605 kForwarded,
606 // The op has completed from below, but we haven't yet forwarded it up
607 // (the promise gets to interject and mutate it).
608 kComplete,
609 // We've called the recv_metadata_ready callback from the original
610 // recv_trailing_metadata op that was presented to us.
611 kResponded,
612 // We've been cancelled and handled that locally.
613 // (i.e. whilst the recv_trailing_metadata op is queued in this filter).
614 kCancelled
615 };
616
617 static const char* StateString(SendInitialState);
618 static const char* StateString(RecvTrailingState);
619 std::string DebugString() const;
620
621 struct RecvInitialMetadata;
622 class PollContext;
623
624 // Handle cancellation.
625 void Cancel(grpc_error_handle error, Flusher* flusher);
626 // Begin running the promise - which will ultimately take some initial
627 // metadata and return some trailing metadata.
628 void StartPromise(Flusher* flusher);
629 // Interject our callback into the op batch for recv trailing metadata
630 // ready. Stash a pointer to the trailing metadata that will be filled in,
631 // so we can manipulate it later.
632 void HookRecvTrailingMetadata(CapturedBatch batch);
633 // Construct a promise that will "call" the next filter.
634 // Effectively:
635 // - put the modified initial metadata into the batch to be sent down.
636 // - return a wrapper around PollTrailingMetadata as the promise.
637 ArenaPromise<ServerMetadataHandle> MakeNextPromise(CallArgs call_args);
638 // Wrapper to make it look like we're calling the next filter as a promise.
639 // First poll: send the send_initial_metadata op down the stack.
640 // All polls: await receiving the trailing metadata, then return it to the
641 // application.
642 Poll<ServerMetadataHandle> PollTrailingMetadata();
643 static void RecvTrailingMetadataReadyCallback(void* arg,
644 grpc_error_handle error);
645 void RecvTrailingMetadataReady(grpc_error_handle error);
646 void RecvInitialMetadataReady(grpc_error_handle error);
647 // Given an error, fill in ServerMetadataHandle to represent that error.
648 void SetStatusFromError(grpc_metadata_batch* metadata,
649 grpc_error_handle error);
650 // Wakeup and poll the promise if appropriate.
651 void WakeInsideCombiner(Flusher* flusher) override;
652 void OnWakeup() override;
653
ClientOrServerString()654 absl::string_view ClientOrServerString() const override { return "CLI"; }
655
656 // Contained promise
657 ArenaPromise<ServerMetadataHandle> promise_;
658 // Queued batch containing at least a send_initial_metadata op.
659 CapturedBatch send_initial_metadata_batch_;
660 // Pointer to where trailing metadata will be stored.
661 grpc_metadata_batch* recv_trailing_metadata_ = nullptr;
662 // Trailing metadata as returned by the promise, if we hadn't received
663 // trailing metadata from below yet (so we can substitute it in).
664 ServerMetadataHandle cancelling_metadata_;
665 // State tracking recv initial metadata for filters that care about it.
666 RecvInitialMetadata* recv_initial_metadata_ = nullptr;
667 // Closure to call when we're done with the trailing metadata.
668 grpc_closure* original_recv_trailing_metadata_ready_ = nullptr;
669 // Our closure pointing to RecvTrailingMetadataReadyCallback.
670 grpc_closure recv_trailing_metadata_ready_;
671 // Error received during cancellation.
672 grpc_error_handle cancelled_error_;
673 // State of the send_initial_metadata op.
674 SendInitialState send_initial_state_ = SendInitialState::kInitial;
675 // State of the recv_trailing_metadata op.
676 RecvTrailingState recv_trailing_state_ = RecvTrailingState::kInitial;
677 // Polling related data. Non-null if we're actively polling
678 PollContext* poll_ctx_ = nullptr;
679 // Initial metadata outstanding token
680 ClientInitialMetadataOutstandingToken initial_metadata_outstanding_token_;
681 };
682
683 class ServerCallData : public BaseCallData {
684 public:
685 ServerCallData(grpc_call_element* elem, const grpc_call_element_args* args,
686 uint8_t flags);
687 ~ServerCallData() override;
688
689 // Activity implementation.
690 void ForceImmediateRepoll(WakeupMask) final;
691 // Handle one grpc_transport_stream_op_batch
692 void StartBatch(grpc_transport_stream_op_batch* batch) override;
693
694 std::string DebugTag() const override;
695
696 protected:
ClientOrServerString()697 absl::string_view ClientOrServerString() const override { return "SVR"; }
698
699 private:
700 // At what stage is our handling of recv initial metadata?
701 enum class RecvInitialState {
702 // Start state: no op seen
703 kInitial,
704 // Op seen, and forwarded to the next filter.
705 // Now waiting for the callback.
706 kForwarded,
707 // The op has completed from below, but we haven't yet forwarded it up
708 // (the promise gets to interject and mutate it).
709 kComplete,
710 // We've sent the response to the next filter up.
711 kResponded,
712 };
713 // At what stage is our handling of send trailing metadata?
714 enum class SendTrailingState {
715 // Start state: no op seen
716 kInitial,
717 // We saw the op, but it was with a send message op (or one was in progress)
718 // - so we'll wait for that to complete before processing the trailing
719 // metadata.
720 kQueuedBehindSendMessage,
721 // We saw the op, and are waiting for the promise to complete
722 // to forward it. First however we need to close sends.
723 kQueuedButHaventClosedSends,
724 // We saw the op, and are waiting for the promise to complete
725 // to forward it.
726 kQueued,
727 // We've forwarded the op to the next filter.
728 kForwarded,
729 // We were cancelled.
730 kCancelled
731 };
732
733 static const char* StateString(RecvInitialState state);
734 static const char* StateString(SendTrailingState state);
735 std::string DebugString() const;
736
737 class PollContext;
738 struct SendInitialMetadata;
739
740 // Shut things down when the call completes.
741 void Completed(grpc_error_handle error, Flusher* flusher);
742 // Construct a promise that will "call" the next filter.
743 // Effectively:
744 // - put the modified initial metadata into the batch being sent up.
745 // - return a wrapper around PollTrailingMetadata as the promise.
746 ArenaPromise<ServerMetadataHandle> MakeNextPromise(CallArgs call_args);
747 // Wrapper to make it look like we're calling the next filter as a promise.
748 // All polls: await sending the trailing metadata, then foward it down the
749 // stack.
750 Poll<ServerMetadataHandle> PollTrailingMetadata();
751 static void RecvInitialMetadataReadyCallback(void* arg,
752 grpc_error_handle error);
753 void RecvInitialMetadataReady(grpc_error_handle error);
754 static void RecvTrailingMetadataReadyCallback(void* arg,
755 grpc_error_handle error);
756 void RecvTrailingMetadataReady(grpc_error_handle error);
757 // Wakeup and poll the promise if appropriate.
758 void WakeInsideCombiner(Flusher* flusher) override;
759 void OnWakeup() override;
760
761 // Contained promise
762 ArenaPromise<ServerMetadataHandle> promise_;
763 // Pointer to where initial metadata will be stored.
764 grpc_metadata_batch* recv_initial_metadata_ = nullptr;
765 // Pointer to where trailing metadata will be stored.
766 grpc_metadata_batch* recv_trailing_metadata_ = nullptr;
767 // State for sending initial metadata.
768 SendInitialMetadata* send_initial_metadata_ = nullptr;
769 // Closure to call when we're done with the initial metadata.
770 grpc_closure* original_recv_initial_metadata_ready_ = nullptr;
771 // Our closure pointing to RecvInitialMetadataReadyCallback.
772 grpc_closure recv_initial_metadata_ready_;
773 // Closure to call when we're done with the trailing metadata.
774 grpc_closure* original_recv_trailing_metadata_ready_ = nullptr;
775 // Our closure pointing to RecvTrailingMetadataReadyCallback.
776 grpc_closure recv_trailing_metadata_ready_;
777 // Error received during cancellation.
778 grpc_error_handle cancelled_error_;
779 // Trailing metadata batch
780 CapturedBatch send_trailing_metadata_batch_;
781 // State of the send_initial_metadata op.
782 RecvInitialState recv_initial_state_ = RecvInitialState::kInitial;
783 // State of the recv_trailing_metadata op.
784 SendTrailingState send_trailing_state_ = SendTrailingState::kInitial;
785 // Current poll context (or nullptr if not polling).
786 PollContext* poll_ctx_ = nullptr;
787 // Whether to forward the recv_initial_metadata op at the end of promise
788 // wakeup.
789 bool forward_recv_initial_metadata_callback_ = false;
790 };
791
792 // Specific call data per channel filter.
793 // Note that we further specialize for clients and servers since their
794 // implementations are very different.
795 template <FilterEndpoint endpoint>
796 class CallData;
797
798 // Client implementation of call data.
799 template <>
800 class CallData<FilterEndpoint::kClient> : public ClientCallData {
801 public:
802 using ClientCallData::ClientCallData;
803 };
804
805 // Server implementation of call data.
806 template <>
807 class CallData<FilterEndpoint::kServer> : public ServerCallData {
808 public:
809 using ServerCallData::ServerCallData;
810 };
811
812 struct BaseCallDataMethods {
SetPollsetOrPollsetSetBaseCallDataMethods813 static void SetPollsetOrPollsetSet(grpc_call_element* elem,
814 grpc_polling_entity* pollent) {
815 static_cast<BaseCallData*>(elem->call_data)->set_pollent(pollent);
816 }
817
DestructCallDataBaseCallDataMethods818 static void DestructCallData(grpc_call_element* elem,
819 const grpc_call_final_info* final_info) {
820 auto* cd = static_cast<BaseCallData*>(elem->call_data);
821 cd->Finalize(final_info);
822 cd->~BaseCallData();
823 }
824
StartTransportStreamOpBatchBaseCallDataMethods825 static void StartTransportStreamOpBatch(
826 grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
827 static_cast<BaseCallData*>(elem->call_data)->StartBatch(batch);
828 }
829 };
830
831 template <typename CallData, uint8_t kFlags>
832 struct CallDataFilterWithFlagsMethods {
InitCallElemCallDataFilterWithFlagsMethods833 static absl::Status InitCallElem(grpc_call_element* elem,
834 const grpc_call_element_args* args) {
835 new (elem->call_data) CallData(elem, args, kFlags);
836 return absl::OkStatus();
837 }
838
DestroyCallElemCallDataFilterWithFlagsMethods839 static void DestroyCallElem(grpc_call_element* elem,
840 const grpc_call_final_info* final_info,
841 grpc_closure* then_schedule_closure) {
842 BaseCallDataMethods::DestructCallData(elem, final_info);
843 if ((kFlags & kFilterIsLast) != 0) {
844 ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, absl::OkStatus());
845 } else {
846 GPR_ASSERT(then_schedule_closure == nullptr);
847 }
848 }
849 };
850
851 struct ChannelFilterMethods {
MakeCallPromiseChannelFilterMethods852 static ArenaPromise<ServerMetadataHandle> MakeCallPromise(
853 grpc_channel_element* elem, CallArgs call_args,
854 NextPromiseFactory next_promise_factory) {
855 return static_cast<ChannelFilter*>(elem->channel_data)
856 ->MakeCallPromise(std::move(call_args),
857 std::move(next_promise_factory));
858 }
859
StartTransportOpChannelFilterMethods860 static void StartTransportOp(grpc_channel_element* elem,
861 grpc_transport_op* op) {
862 if (!static_cast<ChannelFilter*>(elem->channel_data)
863 ->StartTransportOp(op)) {
864 grpc_channel_next_op(elem, op);
865 }
866 }
867
PostInitChannelElemChannelFilterMethods868 static void PostInitChannelElem(grpc_channel_stack*,
869 grpc_channel_element* elem) {
870 static_cast<ChannelFilter*>(elem->channel_data)->PostInit();
871 }
872
DestroyChannelElemChannelFilterMethods873 static void DestroyChannelElem(grpc_channel_element* elem) {
874 static_cast<ChannelFilter*>(elem->channel_data)->~ChannelFilter();
875 }
876
GetChannelInfoChannelFilterMethods877 static void GetChannelInfo(grpc_channel_element* elem,
878 const grpc_channel_info* info) {
879 if (!static_cast<ChannelFilter*>(elem->channel_data)
880 ->GetChannelInfo(info)) {
881 grpc_channel_next_get_info(elem, info);
882 }
883 }
884 };
885
886 template <typename F, uint8_t kFlags>
887 struct ChannelFilterWithFlagsMethods {
InitChannelElemChannelFilterWithFlagsMethods888 static absl::Status InitChannelElem(grpc_channel_element* elem,
889 grpc_channel_element_args* args) {
890 GPR_ASSERT(args->is_last == ((kFlags & kFilterIsLast) != 0));
891 auto status = F::Create(args->channel_args,
892 ChannelFilter::Args(args->channel_stack, elem));
893 if (!status.ok()) {
894 static_assert(
895 sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F),
896 "InvalidChannelFilter must fit in F");
897 new (elem->channel_data) promise_filter_detail::InvalidChannelFilter();
898 return absl_status_to_grpc_error(status.status());
899 }
900 new (elem->channel_data) F(std::move(*status));
901 return absl::OkStatus();
902 }
903 };
904
905 } // namespace promise_filter_detail
906
907 // F implements ChannelFilter and :
908 // class SomeChannelFilter : public ChannelFilter {
909 // public:
910 // static absl::StatusOr<SomeChannelFilter> Create(
911 // ChannelArgs channel_args, ChannelFilter::Args filter_args);
912 // };
913 template <typename F, FilterEndpoint kEndpoint, uint8_t kFlags = 0>
914 absl::enable_if_t<std::is_base_of<ChannelFilter, F>::value, grpc_channel_filter>
MakePromiseBasedFilter(const char * name)915 MakePromiseBasedFilter(const char* name) {
916 using CallData = promise_filter_detail::CallData<kEndpoint>;
917
918 return grpc_channel_filter{
919 // start_transport_stream_op_batch
920 promise_filter_detail::BaseCallDataMethods::StartTransportStreamOpBatch,
921 // make_call_promise
922 promise_filter_detail::ChannelFilterMethods::MakeCallPromise,
923 // start_transport_op
924 promise_filter_detail::ChannelFilterMethods::StartTransportOp,
925 // sizeof_call_data
926 sizeof(CallData),
927 // init_call_elem
928 promise_filter_detail::CallDataFilterWithFlagsMethods<
929 CallData, kFlags>::InitCallElem,
930 // set_pollset_or_pollset_set
931 promise_filter_detail::BaseCallDataMethods::SetPollsetOrPollsetSet,
932 // destroy_call_elem
933 promise_filter_detail::CallDataFilterWithFlagsMethods<
934 CallData, kFlags>::DestroyCallElem,
935 // sizeof_channel_data
936 sizeof(F),
937 // init_channel_elem
938 promise_filter_detail::ChannelFilterWithFlagsMethods<
939 F, kFlags>::InitChannelElem,
940 // post_init_channel_elem
941 promise_filter_detail::ChannelFilterMethods::PostInitChannelElem,
942 // destroy_channel_elem
943 promise_filter_detail::ChannelFilterMethods::DestroyChannelElem,
944 // get_channel_info
945 promise_filter_detail::ChannelFilterMethods::GetChannelInfo,
946 // name
947 name,
948 };
949 }
950
951 } // namespace grpc_core
952
953 #endif // GRPC_SRC_CORE_LIB_CHANNEL_PROMISE_BASED_FILTER_H
954