1 // Copyright 2022 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_CHANNEL_PROMISE_BASED_FILTER_H
16 #define GRPC_SRC_CORE_LIB_CHANNEL_PROMISE_BASED_FILTER_H
17 
18 // Scaffolding to allow the per-call part of a filter to be authored in a
19 // promise-style. Most of this will be removed once the promises conversion is
20 // completed.
21 
22 #include <grpc/support/port_platform.h>
23 
24 #include <stdint.h>
25 #include <stdlib.h>
26 
27 #include <atomic>
28 #include <initializer_list>
29 #include <memory>
30 #include <new>
31 #include <string>
32 #include <type_traits>
33 #include <utility>
34 
35 #include "absl/container/inlined_vector.h"
36 #include "absl/functional/function_ref.h"
37 #include "absl/meta/type_traits.h"
38 #include "absl/status/status.h"
39 #include "absl/strings/string_view.h"
40 #include "absl/types/optional.h"
41 
42 #include <grpc/event_engine/event_engine.h>
43 #include <grpc/grpc.h>
44 #include <grpc/support/log.h>
45 
46 #include "src/core/lib/channel/call_finalization.h"
47 #include "src/core/lib/channel/channel_fwd.h"
48 #include "src/core/lib/channel/channel_stack.h"
49 #include "src/core/lib/channel/context.h"
50 #include "src/core/lib/event_engine/default_event_engine.h"  // IWYU pragma: keep
51 #include "src/core/lib/gprpp/crash.h"
52 #include "src/core/lib/gprpp/debug_location.h"
53 #include "src/core/lib/gprpp/time.h"
54 #include "src/core/lib/iomgr/call_combiner.h"
55 #include "src/core/lib/iomgr/closure.h"
56 #include "src/core/lib/iomgr/error.h"
57 #include "src/core/lib/iomgr/exec_ctx.h"
58 #include "src/core/lib/iomgr/polling_entity.h"
59 #include "src/core/lib/promise/activity.h"
60 #include "src/core/lib/promise/arena_promise.h"
61 #include "src/core/lib/promise/context.h"
62 #include "src/core/lib/promise/pipe.h"
63 #include "src/core/lib/promise/poll.h"
64 #include "src/core/lib/resource_quota/arena.h"
65 #include "src/core/lib/slice/slice_buffer.h"
66 #include "src/core/lib/surface/call.h"
67 #include "src/core/lib/transport/error_utils.h"
68 #include "src/core/lib/transport/metadata_batch.h"
69 #include "src/core/lib/transport/transport.h"
70 
71 namespace grpc_core {
72 
73 class ChannelFilter {
74  public:
75   class Args {
76    public:
Args()77     Args() : Args(nullptr, nullptr) {}
Args(grpc_channel_stack * channel_stack,grpc_channel_element * channel_element)78     explicit Args(grpc_channel_stack* channel_stack,
79                   grpc_channel_element* channel_element)
80         : channel_stack_(channel_stack), channel_element_(channel_element) {}
81 
channel_stack()82     grpc_channel_stack* channel_stack() const { return channel_stack_; }
uninitialized_channel_element()83     grpc_channel_element* uninitialized_channel_element() {
84       return channel_element_;
85     }
86 
87    private:
88     friend class ChannelFilter;
89     grpc_channel_stack* channel_stack_;
90     grpc_channel_element* channel_element_;
91   };
92 
93   // Perform post-initialization step (if any).
PostInit()94   virtual void PostInit() {}
95 
96   // Construct a promise for one call.
97   virtual ArenaPromise<ServerMetadataHandle> MakeCallPromise(
98       CallArgs call_args, NextPromiseFactory next_promise_factory) = 0;
99 
100   // Start a legacy transport op
101   // Return true if the op was handled, false if it should be passed to the
102   // next filter.
103   // TODO(ctiller): design a new API for this - we probably don't want big op
104   // structures going forward.
StartTransportOp(grpc_transport_op *)105   virtual bool StartTransportOp(grpc_transport_op*) { return false; }
106 
107   // Perform a legacy get info call
108   // Return true if the op was handled, false if it should be passed to the
109   // next filter.
110   // TODO(ctiller): design a new API for this
GetChannelInfo(const grpc_channel_info *)111   virtual bool GetChannelInfo(const grpc_channel_info*) { return false; }
112 
113   virtual ~ChannelFilter() = default;
114 
115   grpc_event_engine::experimental::EventEngine*
hack_until_per_channel_stack_event_engines_land_get_event_engine()116   hack_until_per_channel_stack_event_engines_land_get_event_engine() {
117     return event_engine_.get();
118   }
119 
120  private:
121   // TODO(ctiller): remove once per-channel-stack EventEngines land
122   std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_ =
123       grpc_event_engine::experimental::GetDefaultEventEngine();
124 };
125 
126 // Designator for whether a filter is client side or server side.
127 // Please don't use this outside calls to MakePromiseBasedFilter - it's
128 // intended to be deleted once the promise conversion is complete.
129 enum class FilterEndpoint {
130   kClient,
131   kServer,
132 };
133 
134 // Flags for MakePromiseBasedFilter.
135 static constexpr uint8_t kFilterExaminesServerInitialMetadata = 1;
136 static constexpr uint8_t kFilterIsLast = 2;
137 static constexpr uint8_t kFilterExaminesOutboundMessages = 4;
138 static constexpr uint8_t kFilterExaminesInboundMessages = 8;
139 static constexpr uint8_t kFilterExaminesCallContext = 16;
140 
141 namespace promise_filter_detail {
142 
143 // Proxy channel filter for initialization failure, since we must leave a
144 // valid filter in place.
145 class InvalidChannelFilter : public ChannelFilter {
146  public:
MakeCallPromise(CallArgs,NextPromiseFactory)147   ArenaPromise<ServerMetadataHandle> MakeCallPromise(
148       CallArgs, NextPromiseFactory) override {
149     abort();
150   }
151 };
152 
153 // Call data shared between all implementations of promise-based filters.
154 class BaseCallData : public Activity, private Wakeable {
155  protected:
156   // Hook to allow interception of messages on the send/receive path by
157   // PipeSender and PipeReceiver, as appropriate according to whether we're
158   // client or server.
159   class Interceptor {
160    public:
161     virtual PipeSender<MessageHandle>* Push() = 0;
162     virtual PipeReceiver<MessageHandle>* Pull() = 0;
163     virtual PipeReceiver<MessageHandle>* original_receiver() = 0;
164     virtual PipeSender<MessageHandle>* original_sender() = 0;
165     virtual void GotPipe(PipeSender<MessageHandle>*) = 0;
166     virtual void GotPipe(PipeReceiver<MessageHandle>*) = 0;
167     virtual ~Interceptor() = default;
168   };
169 
170   BaseCallData(grpc_call_element* elem, const grpc_call_element_args* args,
171                uint8_t flags,
172                absl::FunctionRef<Interceptor*()> make_send_interceptor,
173                absl::FunctionRef<Interceptor*()> make_recv_interceptor);
174 
175  public:
176   ~BaseCallData() override;
177 
set_pollent(grpc_polling_entity * pollent)178   void set_pollent(grpc_polling_entity* pollent) {
179     GPR_ASSERT(nullptr ==
180                pollent_.exchange(pollent, std::memory_order_release));
181   }
182 
183   // Activity implementation (partial).
184   void Orphan() final;
185   Waker MakeNonOwningWaker() final;
186   Waker MakeOwningWaker() final;
187 
ActivityDebugTag(WakeupMask)188   std::string ActivityDebugTag(WakeupMask) const override { return DebugTag(); }
189 
Finalize(const grpc_call_final_info * final_info)190   void Finalize(const grpc_call_final_info* final_info) {
191     finalization_.Run(final_info);
192   }
193 
194   virtual void StartBatch(grpc_transport_stream_op_batch* batch) = 0;
195 
196  protected:
197   class ScopedContext
198       : public promise_detail::Context<Arena>,
199         public promise_detail::Context<grpc_call_context_element>,
200         public promise_detail::Context<grpc_polling_entity>,
201         public promise_detail::Context<CallFinalization>,
202         public promise_detail::Context<
203             grpc_event_engine::experimental::EventEngine>,
204         public promise_detail::Context<CallContext> {
205    public:
ScopedContext(BaseCallData * call_data)206     explicit ScopedContext(BaseCallData* call_data)
207         : promise_detail::Context<Arena>(call_data->arena_),
208           promise_detail::Context<grpc_call_context_element>(
209               call_data->context_),
210           promise_detail::Context<grpc_polling_entity>(
211               call_data->pollent_.load(std::memory_order_acquire)),
212           promise_detail::Context<CallFinalization>(&call_data->finalization_),
213           promise_detail::Context<grpc_event_engine::experimental::EventEngine>(
214               call_data->event_engine_),
215           promise_detail::Context<CallContext>(call_data->call_context_) {}
216   };
217 
218   class Flusher {
219    public:
220     explicit Flusher(BaseCallData* call);
221     // Calls closures, schedules batches, relinquishes call combiner.
222     ~Flusher();
223 
Resume(grpc_transport_stream_op_batch * batch)224     void Resume(grpc_transport_stream_op_batch* batch) {
225       GPR_ASSERT(!call_->is_last());
226       if (batch->HasOp()) {
227         release_.push_back(batch);
228       } else if (batch->on_complete != nullptr) {
229         Complete(batch);
230       }
231     }
232 
Cancel(grpc_transport_stream_op_batch * batch,grpc_error_handle error)233     void Cancel(grpc_transport_stream_op_batch* batch,
234                 grpc_error_handle error) {
235       grpc_transport_stream_op_batch_queue_finish_with_failure(batch, error,
236                                                                &call_closures_);
237     }
238 
Complete(grpc_transport_stream_op_batch * batch)239     void Complete(grpc_transport_stream_op_batch* batch) {
240       call_closures_.Add(batch->on_complete, absl::OkStatus(),
241                          "Flusher::Complete");
242     }
243 
AddClosure(grpc_closure * closure,grpc_error_handle error,const char * reason)244     void AddClosure(grpc_closure* closure, grpc_error_handle error,
245                     const char* reason) {
246       call_closures_.Add(closure, error, reason);
247     }
248 
call()249     BaseCallData* call() const { return call_; }
250 
251    private:
252     absl::InlinedVector<grpc_transport_stream_op_batch*, 1> release_;
253     CallCombinerClosureList call_closures_;
254     BaseCallData* const call_;
255   };
256 
257   // Smart pointer like wrapper around a batch.
258   // Creation makes a ref count of one capture.
259   // Copying increments.
260   // Must be moved from or resumed or cancelled before destruction.
261   class CapturedBatch final {
262    public:
263     CapturedBatch();
264     explicit CapturedBatch(grpc_transport_stream_op_batch* batch);
265     ~CapturedBatch();
266     CapturedBatch(const CapturedBatch&);
267     CapturedBatch& operator=(const CapturedBatch&);
268     CapturedBatch(CapturedBatch&&) noexcept;
269     CapturedBatch& operator=(CapturedBatch&&) noexcept;
270 
271     grpc_transport_stream_op_batch* operator->() { return batch_; }
is_captured()272     bool is_captured() const { return batch_ != nullptr; }
273 
274     // Resume processing this batch (releases one ref, passes it down the
275     // stack)
276     void ResumeWith(Flusher* releaser);
277     // Cancel this batch immediately (releases all refs)
278     void CancelWith(grpc_error_handle error, Flusher* releaser);
279     // Complete this batch (pass it up) assuming refs drop to zero
280     void CompleteWith(Flusher* releaser);
281 
Swap(CapturedBatch * other)282     void Swap(CapturedBatch* other) { std::swap(batch_, other->batch_); }
283 
284    private:
285     grpc_transport_stream_op_batch* batch_;
286   };
287 
WrapMetadata(grpc_metadata_batch * p)288   static Arena::PoolPtr<grpc_metadata_batch> WrapMetadata(
289       grpc_metadata_batch* p) {
290     return Arena::PoolPtr<grpc_metadata_batch>(p,
291                                                Arena::PooledDeleter(nullptr));
292   }
293 
294   class ReceiveInterceptor final : public Interceptor {
295    public:
ReceiveInterceptor(Arena * arena)296     explicit ReceiveInterceptor(Arena* arena) : pipe_{arena} {}
297 
original_receiver()298     PipeReceiver<MessageHandle>* original_receiver() override {
299       return &pipe_.receiver;
300     }
original_sender()301     PipeSender<MessageHandle>* original_sender() override { abort(); }
302 
GotPipe(PipeReceiver<MessageHandle> * receiver)303     void GotPipe(PipeReceiver<MessageHandle>* receiver) override {
304       GPR_ASSERT(receiver_ == nullptr);
305       receiver_ = receiver;
306     }
307 
GotPipe(PipeSender<MessageHandle> *)308     void GotPipe(PipeSender<MessageHandle>*) override { abort(); }
309 
Push()310     PipeSender<MessageHandle>* Push() override { return &pipe_.sender; }
Pull()311     PipeReceiver<MessageHandle>* Pull() override {
312       GPR_ASSERT(receiver_ != nullptr);
313       return receiver_;
314     }
315 
316    private:
317     Pipe<MessageHandle> pipe_;
318     PipeReceiver<MessageHandle>* receiver_ = nullptr;
319   };
320 
321   class SendInterceptor final : public Interceptor {
322    public:
SendInterceptor(Arena * arena)323     explicit SendInterceptor(Arena* arena) : pipe_{arena} {}
324 
original_receiver()325     PipeReceiver<MessageHandle>* original_receiver() override { abort(); }
original_sender()326     PipeSender<MessageHandle>* original_sender() override {
327       return &pipe_.sender;
328     }
329 
GotPipe(PipeReceiver<MessageHandle> *)330     void GotPipe(PipeReceiver<MessageHandle>*) override { abort(); }
331 
GotPipe(PipeSender<MessageHandle> * sender)332     void GotPipe(PipeSender<MessageHandle>* sender) override {
333       GPR_ASSERT(sender_ == nullptr);
334       sender_ = sender;
335     }
336 
Push()337     PipeSender<MessageHandle>* Push() override {
338       GPR_ASSERT(sender_ != nullptr);
339       return sender_;
340     }
Pull()341     PipeReceiver<MessageHandle>* Pull() override { return &pipe_.receiver; }
342 
343    private:
344     Pipe<MessageHandle> pipe_;
345     PipeSender<MessageHandle>* sender_ = nullptr;
346   };
347 
348   // State machine for sending messages: handles intercepting send_message ops
349   // and forwarding them through pipes to the promise, then getting the result
350   // down the stack.
351   // Split into its own class so that we don't spend the memory instantiating
352   // these members for filters that don't need to intercept sent messages.
353   class SendMessage {
354    public:
SendMessage(BaseCallData * base,Interceptor * interceptor)355     SendMessage(BaseCallData* base, Interceptor* interceptor)
356         : base_(base), interceptor_(interceptor) {}
~SendMessage()357     ~SendMessage() { interceptor_->~Interceptor(); }
358 
interceptor()359     Interceptor* interceptor() { return interceptor_; }
360 
361     // Start a send_message op.
362     void StartOp(CapturedBatch batch);
363     // Publish the outbound pipe to the filter.
364     // This happens when the promise requests to call the next filter: until
365     // this occurs messages can't be sent as we don't know the pipe that the
366     // promise expects to send on.
367     template <typename T>
368     void GotPipe(T* pipe);
369     // Called from client/server polling to do the send message part of the
370     // work.
371     void WakeInsideCombiner(Flusher* flusher, bool allow_push_to_pipe);
372     // Call is completed, we have trailing metadata. Close things out.
373     void Done(const ServerMetadata& metadata, Flusher* flusher);
374     // Return true if we have a batch captured (for debug logs)
HaveCapturedBatch()375     bool HaveCapturedBatch() const { return batch_.is_captured(); }
376     // Return true if we're not actively sending a message.
377     bool IsIdle() const;
378     // Return true if we've released the message for forwarding down the stack.
IsForwarded()379     bool IsForwarded() const { return state_ == State::kForwardedBatch; }
380 
381    private:
382     enum class State : uint8_t {
383       // Starting state: no batch started, no outgoing pipe configured.
384       kInitial,
385       // We have an outgoing pipe, but no batch started.
386       // (this is the steady state).
387       kIdle,
388       // We have a batch started, but no outgoing pipe configured.
389       // Stall until we have one.
390       kGotBatchNoPipe,
391       // We have a batch, and an outgoing pipe. On the next poll we'll push the
392       // message into the pipe to the promise.
393       kGotBatch,
394       // We've pushed a message into the promise, and we're now waiting for it
395       // to pop out the other end so we can forward it down the stack.
396       kPushedToPipe,
397       // We've forwarded a message down the stack, and now we're waiting for
398       // completion.
399       kForwardedBatch,
400       // We've got the completion callback, we'll close things out during poll
401       // and then forward completion callbacks up and transition back to idle.
402       kBatchCompleted,
403       // We're almost done, but need to poll first.
404       kCancelledButNotYetPolled,
405       // We're done.
406       kCancelled,
407       // We're done, but we haven't gotten a status yet
408       kCancelledButNoStatus,
409     };
410     static const char* StateString(State);
411 
412     void OnComplete(absl::Status status);
413 
414     BaseCallData* const base_;
415     State state_ = State::kInitial;
416     Interceptor* const interceptor_;
417     absl::optional<PipeSender<MessageHandle>::PushType> push_;
418     absl::optional<PipeReceiverNextType<MessageHandle>> next_;
419     CapturedBatch batch_;
420     grpc_closure* intercepted_on_complete_;
421     grpc_closure on_complete_ =
422         MakeMemberClosure<SendMessage, &SendMessage::OnComplete>(this);
423     absl::Status completed_status_;
424   };
425 
426   // State machine for receiving messages: handles intercepting recv_message
427   // ops, forwarding them down the stack, and then publishing the result via
428   // pipes to the promise (and ultimately calling the right callbacks for the
429   // batch when our promise has completed processing of them).
430   // Split into its own class so that we don't spend the memory instantiating
431   // these members for filters that don't need to intercept sent messages.
432   class ReceiveMessage {
433    public:
ReceiveMessage(BaseCallData * base,Interceptor * interceptor)434     ReceiveMessage(BaseCallData* base, Interceptor* interceptor)
435         : base_(base), interceptor_(interceptor) {}
~ReceiveMessage()436     ~ReceiveMessage() { interceptor_->~Interceptor(); }
437 
interceptor()438     Interceptor* interceptor() { return interceptor_; }
439 
440     // Start a recv_message op.
441     void StartOp(CapturedBatch& batch);
442     // Publish the inbound pipe to the filter.
443     // This happens when the promise requests to call the next filter: until
444     // this occurs messages can't be received as we don't know the pipe that the
445     // promise expects to forward them with.
446     template <typename T>
447     void GotPipe(T* pipe);
448     // Called from client/server polling to do the receive message part of the
449     // work.
450     void WakeInsideCombiner(Flusher* flusher, bool allow_push_to_pipe);
451     // Call is completed, we have trailing metadata. Close things out.
452     void Done(const ServerMetadata& metadata, Flusher* flusher);
453 
454    private:
455     enum class State : uint8_t {
456       // Starting state: no batch started, no incoming pipe configured.
457       kInitial,
458       // We have an incoming pipe, but no batch started.
459       // (this is the steady state).
460       kIdle,
461       // We received a batch and forwarded it on, but have not got an incoming
462       // pipe configured.
463       kForwardedBatchNoPipe,
464       // We received a batch and forwarded it on.
465       kForwardedBatch,
466       // We got the completion for the recv_message, but we don't yet have a
467       // pipe configured. Stall until this changes.
468       kBatchCompletedNoPipe,
469       // We got the completion for the recv_message, and we have a pipe
470       // configured: next poll will push the message into the pipe for the
471       // filter to process.
472       kBatchCompleted,
473       // We've pushed a message into the promise, and we're now waiting for it
474       // to pop out the other end so we can forward it up the stack.
475       kPushedToPipe,
476       // We've got a message out of the pipe, now we need to wait for processing
477       // to completely quiesce in the promise prior to forwarding the completion
478       // up the stack.
479       kPulledFromPipe,
480       // We're done.
481       kCancelled,
482       // Call got terminated whilst we were idle: we need to close the sender
483       // pipe next poll.
484       kCancelledWhilstIdle,
485       // Call got terminated whilst we had forwarded a recv_message down the
486       // stack: we need to keep track of that until we get the completion so
487       // that we do the right thing in OnComplete.
488       kCancelledWhilstForwarding,
489       // The same, but before we got the pipe
490       kCancelledWhilstForwardingNoPipe,
491       // Call got terminated whilst we had a recv_message batch completed, and
492       // we've now received the completion.
493       // On the next poll we'll close things out and forward on completions,
494       // then transition to cancelled.
495       kBatchCompletedButCancelled,
496       // The same, but before we got the pipe
497       kBatchCompletedButCancelledNoPipe,
498       // Completed successfully while we're processing a recv message - see
499       // kPushedToPipe.
500       kCompletedWhilePushedToPipe,
501       // Completed successfully while we're processing a recv message - see
502       // kPulledFromPipe.
503       kCompletedWhilePulledFromPipe,
504       // Completed successfully while we were waiting to process
505       // kBatchCompleted.
506       kCompletedWhileBatchCompleted,
507     };
508     static const char* StateString(State);
509 
510     void OnComplete(absl::Status status);
511 
512     BaseCallData* const base_;
513     Interceptor* const interceptor_;
514     State state_ = State::kInitial;
515     uint32_t scratch_flags_;
516     absl::optional<SliceBuffer>* intercepted_slice_buffer_;
517     uint32_t* intercepted_flags_;
518     absl::optional<PipeSender<MessageHandle>::PushType> push_;
519     absl::optional<PipeReceiverNextType<MessageHandle>> next_;
520     absl::Status completed_status_;
521     grpc_closure* intercepted_on_complete_;
522     grpc_closure on_complete_ =
523         MakeMemberClosure<ReceiveMessage, &ReceiveMessage::OnComplete>(this);
524   };
525 
arena()526   Arena* arena() { return arena_; }
elem()527   grpc_call_element* elem() const { return elem_; }
call_combiner()528   CallCombiner* call_combiner() const { return call_combiner_; }
deadline()529   Timestamp deadline() const { return deadline_; }
call_stack()530   grpc_call_stack* call_stack() const { return call_stack_; }
server_initial_metadata_pipe()531   Pipe<ServerMetadataHandle>* server_initial_metadata_pipe() const {
532     return server_initial_metadata_pipe_;
533   }
send_message()534   SendMessage* send_message() const { return send_message_; }
receive_message()535   ReceiveMessage* receive_message() const { return receive_message_; }
536 
is_last()537   bool is_last() const {
538     return grpc_call_stack_element(call_stack_, call_stack_->count - 1) ==
539            elem_;
540   }
541 
542   virtual void WakeInsideCombiner(Flusher* flusher) = 0;
543 
544   virtual absl::string_view ClientOrServerString() const = 0;
545   std::string LogTag() const;
546 
547  private:
548   // Wakeable implementation.
549   void Wakeup(WakeupMask) final;
WakeupAsync(WakeupMask)550   void WakeupAsync(WakeupMask) final { Crash("not implemented"); }
551   void Drop(WakeupMask) final;
552 
553   virtual void OnWakeup() = 0;
554 
555   grpc_call_stack* const call_stack_;
556   grpc_call_element* const elem_;
557   Arena* const arena_;
558   CallCombiner* const call_combiner_;
559   const Timestamp deadline_;
560   CallFinalization finalization_;
561   CallContext* call_context_ = nullptr;
562   grpc_call_context_element* const context_;
563   std::atomic<grpc_polling_entity*> pollent_{nullptr};
564   Pipe<ServerMetadataHandle>* const server_initial_metadata_pipe_;
565   SendMessage* const send_message_;
566   ReceiveMessage* const receive_message_;
567   grpc_event_engine::experimental::EventEngine* event_engine_;
568 };
569 
570 class ClientCallData : public BaseCallData {
571  public:
572   ClientCallData(grpc_call_element* elem, const grpc_call_element_args* args,
573                  uint8_t flags);
574   ~ClientCallData() override;
575 
576   // Activity implementation.
577   void ForceImmediateRepoll(WakeupMask) final;
578   // Handle one grpc_transport_stream_op_batch
579   void StartBatch(grpc_transport_stream_op_batch* batch) override;
580 
581   std::string DebugTag() const override;
582 
583  private:
584   // At what stage is our handling of send initial metadata?
585   enum class SendInitialState {
586     // Start state: no op seen
587     kInitial,
588     // We've seen the op, and started the promise in response to it, but have
589     // not yet sent the op to the next filter.
590     kQueued,
591     // We've sent the op to the next filter.
592     kForwarded,
593     // We were cancelled.
594     kCancelled
595   };
596   // At what stage is our handling of recv trailing metadata?
597   enum class RecvTrailingState {
598     // Start state: no op seen
599     kInitial,
600     // We saw the op, and since it was bundled with send initial metadata, we
601     // queued it until the send initial metadata can be sent to the next
602     // filter.
603     kQueued,
604     // We've forwarded the op to the next filter.
605     kForwarded,
606     // The op has completed from below, but we haven't yet forwarded it up
607     // (the promise gets to interject and mutate it).
608     kComplete,
609     // We've called the recv_metadata_ready callback from the original
610     // recv_trailing_metadata op that was presented to us.
611     kResponded,
612     // We've been cancelled and handled that locally.
613     // (i.e. whilst the recv_trailing_metadata op is queued in this filter).
614     kCancelled
615   };
616 
617   static const char* StateString(SendInitialState);
618   static const char* StateString(RecvTrailingState);
619   std::string DebugString() const;
620 
621   struct RecvInitialMetadata;
622   class PollContext;
623 
624   // Handle cancellation.
625   void Cancel(grpc_error_handle error, Flusher* flusher);
626   // Begin running the promise - which will ultimately take some initial
627   // metadata and return some trailing metadata.
628   void StartPromise(Flusher* flusher);
629   // Interject our callback into the op batch for recv trailing metadata
630   // ready. Stash a pointer to the trailing metadata that will be filled in,
631   // so we can manipulate it later.
632   void HookRecvTrailingMetadata(CapturedBatch batch);
633   // Construct a promise that will "call" the next filter.
634   // Effectively:
635   //   - put the modified initial metadata into the batch to be sent down.
636   //   - return a wrapper around PollTrailingMetadata as the promise.
637   ArenaPromise<ServerMetadataHandle> MakeNextPromise(CallArgs call_args);
638   // Wrapper to make it look like we're calling the next filter as a promise.
639   // First poll: send the send_initial_metadata op down the stack.
640   // All polls: await receiving the trailing metadata, then return it to the
641   // application.
642   Poll<ServerMetadataHandle> PollTrailingMetadata();
643   static void RecvTrailingMetadataReadyCallback(void* arg,
644                                                 grpc_error_handle error);
645   void RecvTrailingMetadataReady(grpc_error_handle error);
646   void RecvInitialMetadataReady(grpc_error_handle error);
647   // Given an error, fill in ServerMetadataHandle to represent that error.
648   void SetStatusFromError(grpc_metadata_batch* metadata,
649                           grpc_error_handle error);
650   // Wakeup and poll the promise if appropriate.
651   void WakeInsideCombiner(Flusher* flusher) override;
652   void OnWakeup() override;
653 
ClientOrServerString()654   absl::string_view ClientOrServerString() const override { return "CLI"; }
655 
656   // Contained promise
657   ArenaPromise<ServerMetadataHandle> promise_;
658   // Queued batch containing at least a send_initial_metadata op.
659   CapturedBatch send_initial_metadata_batch_;
660   // Pointer to where trailing metadata will be stored.
661   grpc_metadata_batch* recv_trailing_metadata_ = nullptr;
662   // Trailing metadata as returned by the promise, if we hadn't received
663   // trailing metadata from below yet (so we can substitute it in).
664   ServerMetadataHandle cancelling_metadata_;
665   // State tracking recv initial metadata for filters that care about it.
666   RecvInitialMetadata* recv_initial_metadata_ = nullptr;
667   // Closure to call when we're done with the trailing metadata.
668   grpc_closure* original_recv_trailing_metadata_ready_ = nullptr;
669   // Our closure pointing to RecvTrailingMetadataReadyCallback.
670   grpc_closure recv_trailing_metadata_ready_;
671   // Error received during cancellation.
672   grpc_error_handle cancelled_error_;
673   // State of the send_initial_metadata op.
674   SendInitialState send_initial_state_ = SendInitialState::kInitial;
675   // State of the recv_trailing_metadata op.
676   RecvTrailingState recv_trailing_state_ = RecvTrailingState::kInitial;
677   // Polling related data. Non-null if we're actively polling
678   PollContext* poll_ctx_ = nullptr;
679   // Initial metadata outstanding token
680   ClientInitialMetadataOutstandingToken initial_metadata_outstanding_token_;
681 };
682 
683 class ServerCallData : public BaseCallData {
684  public:
685   ServerCallData(grpc_call_element* elem, const grpc_call_element_args* args,
686                  uint8_t flags);
687   ~ServerCallData() override;
688 
689   // Activity implementation.
690   void ForceImmediateRepoll(WakeupMask) final;
691   // Handle one grpc_transport_stream_op_batch
692   void StartBatch(grpc_transport_stream_op_batch* batch) override;
693 
694   std::string DebugTag() const override;
695 
696  protected:
ClientOrServerString()697   absl::string_view ClientOrServerString() const override { return "SVR"; }
698 
699  private:
700   // At what stage is our handling of recv initial metadata?
701   enum class RecvInitialState {
702     // Start state: no op seen
703     kInitial,
704     // Op seen, and forwarded to the next filter.
705     // Now waiting for the callback.
706     kForwarded,
707     // The op has completed from below, but we haven't yet forwarded it up
708     // (the promise gets to interject and mutate it).
709     kComplete,
710     // We've sent the response to the next filter up.
711     kResponded,
712   };
713   // At what stage is our handling of send trailing metadata?
714   enum class SendTrailingState {
715     // Start state: no op seen
716     kInitial,
717     // We saw the op, but it was with a send message op (or one was in progress)
718     // - so we'll wait for that to complete before processing the trailing
719     // metadata.
720     kQueuedBehindSendMessage,
721     // We saw the op, and are waiting for the promise to complete
722     // to forward it. First however we need to close sends.
723     kQueuedButHaventClosedSends,
724     // We saw the op, and are waiting for the promise to complete
725     // to forward it.
726     kQueued,
727     // We've forwarded the op to the next filter.
728     kForwarded,
729     // We were cancelled.
730     kCancelled
731   };
732 
733   static const char* StateString(RecvInitialState state);
734   static const char* StateString(SendTrailingState state);
735   std::string DebugString() const;
736 
737   class PollContext;
738   struct SendInitialMetadata;
739 
740   // Shut things down when the call completes.
741   void Completed(grpc_error_handle error, Flusher* flusher);
742   // Construct a promise that will "call" the next filter.
743   // Effectively:
744   //   - put the modified initial metadata into the batch being sent up.
745   //   - return a wrapper around PollTrailingMetadata as the promise.
746   ArenaPromise<ServerMetadataHandle> MakeNextPromise(CallArgs call_args);
747   // Wrapper to make it look like we're calling the next filter as a promise.
748   // All polls: await sending the trailing metadata, then foward it down the
749   // stack.
750   Poll<ServerMetadataHandle> PollTrailingMetadata();
751   static void RecvInitialMetadataReadyCallback(void* arg,
752                                                grpc_error_handle error);
753   void RecvInitialMetadataReady(grpc_error_handle error);
754   static void RecvTrailingMetadataReadyCallback(void* arg,
755                                                 grpc_error_handle error);
756   void RecvTrailingMetadataReady(grpc_error_handle error);
757   // Wakeup and poll the promise if appropriate.
758   void WakeInsideCombiner(Flusher* flusher) override;
759   void OnWakeup() override;
760 
761   // Contained promise
762   ArenaPromise<ServerMetadataHandle> promise_;
763   // Pointer to where initial metadata will be stored.
764   grpc_metadata_batch* recv_initial_metadata_ = nullptr;
765   // Pointer to where trailing metadata will be stored.
766   grpc_metadata_batch* recv_trailing_metadata_ = nullptr;
767   // State for sending initial metadata.
768   SendInitialMetadata* send_initial_metadata_ = nullptr;
769   // Closure to call when we're done with the initial metadata.
770   grpc_closure* original_recv_initial_metadata_ready_ = nullptr;
771   // Our closure pointing to RecvInitialMetadataReadyCallback.
772   grpc_closure recv_initial_metadata_ready_;
773   // Closure to call when we're done with the trailing metadata.
774   grpc_closure* original_recv_trailing_metadata_ready_ = nullptr;
775   // Our closure pointing to RecvTrailingMetadataReadyCallback.
776   grpc_closure recv_trailing_metadata_ready_;
777   // Error received during cancellation.
778   grpc_error_handle cancelled_error_;
779   // Trailing metadata batch
780   CapturedBatch send_trailing_metadata_batch_;
781   // State of the send_initial_metadata op.
782   RecvInitialState recv_initial_state_ = RecvInitialState::kInitial;
783   // State of the recv_trailing_metadata op.
784   SendTrailingState send_trailing_state_ = SendTrailingState::kInitial;
785   // Current poll context (or nullptr if not polling).
786   PollContext* poll_ctx_ = nullptr;
787   // Whether to forward the recv_initial_metadata op at the end of promise
788   // wakeup.
789   bool forward_recv_initial_metadata_callback_ = false;
790 };
791 
792 // Specific call data per channel filter.
793 // Note that we further specialize for clients and servers since their
794 // implementations are very different.
795 template <FilterEndpoint endpoint>
796 class CallData;
797 
798 // Client implementation of call data.
799 template <>
800 class CallData<FilterEndpoint::kClient> : public ClientCallData {
801  public:
802   using ClientCallData::ClientCallData;
803 };
804 
805 // Server implementation of call data.
806 template <>
807 class CallData<FilterEndpoint::kServer> : public ServerCallData {
808  public:
809   using ServerCallData::ServerCallData;
810 };
811 
812 struct BaseCallDataMethods {
SetPollsetOrPollsetSetBaseCallDataMethods813   static void SetPollsetOrPollsetSet(grpc_call_element* elem,
814                                      grpc_polling_entity* pollent) {
815     static_cast<BaseCallData*>(elem->call_data)->set_pollent(pollent);
816   }
817 
DestructCallDataBaseCallDataMethods818   static void DestructCallData(grpc_call_element* elem,
819                                const grpc_call_final_info* final_info) {
820     auto* cd = static_cast<BaseCallData*>(elem->call_data);
821     cd->Finalize(final_info);
822     cd->~BaseCallData();
823   }
824 
StartTransportStreamOpBatchBaseCallDataMethods825   static void StartTransportStreamOpBatch(
826       grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
827     static_cast<BaseCallData*>(elem->call_data)->StartBatch(batch);
828   }
829 };
830 
831 template <typename CallData, uint8_t kFlags>
832 struct CallDataFilterWithFlagsMethods {
InitCallElemCallDataFilterWithFlagsMethods833   static absl::Status InitCallElem(grpc_call_element* elem,
834                                    const grpc_call_element_args* args) {
835     new (elem->call_data) CallData(elem, args, kFlags);
836     return absl::OkStatus();
837   }
838 
DestroyCallElemCallDataFilterWithFlagsMethods839   static void DestroyCallElem(grpc_call_element* elem,
840                               const grpc_call_final_info* final_info,
841                               grpc_closure* then_schedule_closure) {
842     BaseCallDataMethods::DestructCallData(elem, final_info);
843     if ((kFlags & kFilterIsLast) != 0) {
844       ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, absl::OkStatus());
845     } else {
846       GPR_ASSERT(then_schedule_closure == nullptr);
847     }
848   }
849 };
850 
851 struct ChannelFilterMethods {
MakeCallPromiseChannelFilterMethods852   static ArenaPromise<ServerMetadataHandle> MakeCallPromise(
853       grpc_channel_element* elem, CallArgs call_args,
854       NextPromiseFactory next_promise_factory) {
855     return static_cast<ChannelFilter*>(elem->channel_data)
856         ->MakeCallPromise(std::move(call_args),
857                           std::move(next_promise_factory));
858   }
859 
StartTransportOpChannelFilterMethods860   static void StartTransportOp(grpc_channel_element* elem,
861                                grpc_transport_op* op) {
862     if (!static_cast<ChannelFilter*>(elem->channel_data)
863              ->StartTransportOp(op)) {
864       grpc_channel_next_op(elem, op);
865     }
866   }
867 
PostInitChannelElemChannelFilterMethods868   static void PostInitChannelElem(grpc_channel_stack*,
869                                   grpc_channel_element* elem) {
870     static_cast<ChannelFilter*>(elem->channel_data)->PostInit();
871   }
872 
DestroyChannelElemChannelFilterMethods873   static void DestroyChannelElem(grpc_channel_element* elem) {
874     static_cast<ChannelFilter*>(elem->channel_data)->~ChannelFilter();
875   }
876 
GetChannelInfoChannelFilterMethods877   static void GetChannelInfo(grpc_channel_element* elem,
878                              const grpc_channel_info* info) {
879     if (!static_cast<ChannelFilter*>(elem->channel_data)
880              ->GetChannelInfo(info)) {
881       grpc_channel_next_get_info(elem, info);
882     }
883   }
884 };
885 
886 template <typename F, uint8_t kFlags>
887 struct ChannelFilterWithFlagsMethods {
InitChannelElemChannelFilterWithFlagsMethods888   static absl::Status InitChannelElem(grpc_channel_element* elem,
889                                       grpc_channel_element_args* args) {
890     GPR_ASSERT(args->is_last == ((kFlags & kFilterIsLast) != 0));
891     auto status = F::Create(args->channel_args,
892                             ChannelFilter::Args(args->channel_stack, elem));
893     if (!status.ok()) {
894       static_assert(
895           sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F),
896           "InvalidChannelFilter must fit in F");
897       new (elem->channel_data) promise_filter_detail::InvalidChannelFilter();
898       return absl_status_to_grpc_error(status.status());
899     }
900     new (elem->channel_data) F(std::move(*status));
901     return absl::OkStatus();
902   }
903 };
904 
905 }  // namespace promise_filter_detail
906 
907 // F implements ChannelFilter and :
908 // class SomeChannelFilter : public ChannelFilter {
909 //  public:
910 //   static absl::StatusOr<SomeChannelFilter> Create(
911 //       ChannelArgs channel_args, ChannelFilter::Args filter_args);
912 // };
913 template <typename F, FilterEndpoint kEndpoint, uint8_t kFlags = 0>
914 absl::enable_if_t<std::is_base_of<ChannelFilter, F>::value, grpc_channel_filter>
MakePromiseBasedFilter(const char * name)915 MakePromiseBasedFilter(const char* name) {
916   using CallData = promise_filter_detail::CallData<kEndpoint>;
917 
918   return grpc_channel_filter{
919       // start_transport_stream_op_batch
920       promise_filter_detail::BaseCallDataMethods::StartTransportStreamOpBatch,
921       // make_call_promise
922       promise_filter_detail::ChannelFilterMethods::MakeCallPromise,
923       // start_transport_op
924       promise_filter_detail::ChannelFilterMethods::StartTransportOp,
925       // sizeof_call_data
926       sizeof(CallData),
927       // init_call_elem
928       promise_filter_detail::CallDataFilterWithFlagsMethods<
929           CallData, kFlags>::InitCallElem,
930       // set_pollset_or_pollset_set
931       promise_filter_detail::BaseCallDataMethods::SetPollsetOrPollsetSet,
932       // destroy_call_elem
933       promise_filter_detail::CallDataFilterWithFlagsMethods<
934           CallData, kFlags>::DestroyCallElem,
935       // sizeof_channel_data
936       sizeof(F),
937       // init_channel_elem
938       promise_filter_detail::ChannelFilterWithFlagsMethods<
939           F, kFlags>::InitChannelElem,
940       // post_init_channel_elem
941       promise_filter_detail::ChannelFilterMethods::PostInitChannelElem,
942       // destroy_channel_elem
943       promise_filter_detail::ChannelFilterMethods::DestroyChannelElem,
944       // get_channel_info
945       promise_filter_detail::ChannelFilterMethods::GetChannelInfo,
946       // name
947       name,
948   };
949 }
950 
951 }  // namespace grpc_core
952 
953 #endif  // GRPC_SRC_CORE_LIB_CHANNEL_PROMISE_BASED_FILTER_H
954