xref: /aosp_15_r20/external/grpc-grpc/src/core/lib/transport/call_spine.h (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
16 #define GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
17 
18 #include <grpc/support/port_platform.h>
19 
20 #include <grpc/support/log.h>
21 
22 #include "src/core/lib/promise/detail/status.h"
23 #include "src/core/lib/promise/for_each.h"
24 #include "src/core/lib/promise/if.h"
25 #include "src/core/lib/promise/latch.h"
26 #include "src/core/lib/promise/party.h"
27 #include "src/core/lib/promise/pipe.h"
28 #include "src/core/lib/promise/prioritized_race.h"
29 #include "src/core/lib/promise/status_flag.h"
30 #include "src/core/lib/promise/try_seq.h"
31 #include "src/core/lib/transport/message.h"
32 #include "src/core/lib/transport/metadata.h"
33 
34 namespace grpc_core {
35 
36 // The common middle part of a call - a reference is held by each of
37 // CallInitiator and CallHandler - which provide interfaces that are appropriate
38 // for each side of a call.
39 // The spine will ultimately host the pipes, filters, and context for one part
40 // of a call: ie top-half client channel, sub channel call, server call.
41 // TODO(ctiller): eventually drop this when we don't need to reference into
42 // legacy promise calls anymore
43 class CallSpineInterface {
44  public:
45   virtual ~CallSpineInterface() = default;
46   virtual Pipe<ClientMetadataHandle>& client_initial_metadata() = 0;
47   virtual Pipe<ServerMetadataHandle>& server_initial_metadata() = 0;
48   virtual Pipe<MessageHandle>& client_to_server_messages() = 0;
49   virtual Pipe<MessageHandle>& server_to_client_messages() = 0;
50   virtual Pipe<ServerMetadataHandle>& server_trailing_metadata() = 0;
51   virtual Latch<ServerMetadataHandle>& cancel_latch() = 0;
52   // Add a callback to be called when server trailing metadata is received.
OnDone(absl::AnyInvocable<void ()> fn)53   void OnDone(absl::AnyInvocable<void()> fn) {
54     if (on_done_ == nullptr) {
55       on_done_ = std::move(fn);
56       return;
57     }
58     on_done_ = [first = std::move(fn), next = std::move(on_done_)]() mutable {
59       first();
60       next();
61     };
62   }
CallOnDone()63   void CallOnDone() {
64     if (on_done_ != nullptr) std::exchange(on_done_, nullptr)();
65   }
66   virtual Party& party() = 0;
67   virtual Arena* arena() = 0;
68   virtual void IncrementRefCount() = 0;
69   virtual void Unref() = 0;
70 
71   // Cancel the call with the given metadata.
72   // Regarding the `MUST_USE_RESULT absl::nullopt_t`:
73   // Most cancellation calls right now happen in pipe interceptors;
74   // there `nullopt` indicates terminate processing of this pipe and close with
75   // error.
76   // It's convenient then to have the Cancel operation (setting the latch to
77   // terminate the call) be the last thing that occurs in a pipe interceptor,
78   // and this construction supports that (and has helped the author not write
79   // some bugs).
Cancel(ServerMetadataHandle metadata)80   GRPC_MUST_USE_RESULT absl::nullopt_t Cancel(ServerMetadataHandle metadata) {
81     GPR_DEBUG_ASSERT(GetContext<Activity>() == &party());
82     auto& c = cancel_latch();
83     if (c.is_set()) return absl::nullopt;
84     c.Set(std::move(metadata));
85     CallOnDone();
86     client_initial_metadata().sender.CloseWithError();
87     server_initial_metadata().sender.CloseWithError();
88     client_to_server_messages().sender.CloseWithError();
89     server_to_client_messages().sender.CloseWithError();
90     server_trailing_metadata().sender.CloseWithError();
91     return absl::nullopt;
92   }
93 
WaitForCancel()94   auto WaitForCancel() {
95     GPR_DEBUG_ASSERT(GetContext<Activity>() == &party());
96     return cancel_latch().Wait();
97   }
98 
99   // Wrap a promise so that if it returns failure it automatically cancels
100   // the rest of the call.
101   // The resulting (returned) promise will resolve to Empty.
102   template <typename Promise>
CancelIfFails(Promise promise)103   auto CancelIfFails(Promise promise) {
104     GPR_DEBUG_ASSERT(GetContext<Activity>() == &party());
105     using P = promise_detail::PromiseLike<Promise>;
106     using ResultType = typename P::Result;
107     return Map(std::move(promise), [this](ResultType r) {
108       if (!IsStatusOk(r)) {
109         std::ignore = Cancel(StatusCast<ServerMetadataHandle>(r));
110       }
111       return r;
112     });
113   }
114 
115   // Spawn a promise that returns Empty{} and save some boilerplate handling
116   // that detail.
117   template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)118   void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
119     party().Spawn(name, std::move(promise_factory), [](Empty) {});
120   }
121 
122   // Spawn a promise that returns some status-like type; if the status
123   // represents failure automatically cancel the rest of the call.
124   template <typename PromiseFactory>
SpawnGuarded(absl::string_view name,PromiseFactory promise_factory)125   void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
126     using FactoryType =
127         promise_detail::OncePromiseFactory<void, PromiseFactory>;
128     using PromiseType = typename FactoryType::Promise;
129     using ResultType = typename PromiseType::Result;
130     static_assert(
131         std::is_same<bool,
132                      decltype(IsStatusOk(std::declval<ResultType>()))>::value,
133         "SpawnGuarded promise must return a status-like object");
134     party().Spawn(name, std::move(promise_factory), [this](ResultType r) {
135       if (!IsStatusOk(r)) {
136         if (grpc_trace_promise_primitives.enabled()) {
137           gpr_log(GPR_DEBUG, "SpawnGuarded sees failure: %s",
138                   r.ToString().c_str());
139         }
140         std::ignore = Cancel(StatusCast<ServerMetadataHandle>(std::move(r)));
141       }
142     });
143   }
144 
145  private:
146   absl::AnyInvocable<void()> on_done_{nullptr};
147 };
148 
149 class CallSpine final : public CallSpineInterface, public Party {
150  public:
Create(grpc_event_engine::experimental::EventEngine * event_engine,Arena * arena)151   static RefCountedPtr<CallSpine> Create(
152       grpc_event_engine::experimental::EventEngine* event_engine,
153       Arena* arena) {
154     return RefCountedPtr<CallSpine>(arena->New<CallSpine>(event_engine, arena));
155   }
156 
client_initial_metadata()157   Pipe<ClientMetadataHandle>& client_initial_metadata() override {
158     return client_initial_metadata_;
159   }
server_initial_metadata()160   Pipe<ServerMetadataHandle>& server_initial_metadata() override {
161     return server_initial_metadata_;
162   }
client_to_server_messages()163   Pipe<MessageHandle>& client_to_server_messages() override {
164     return client_to_server_messages_;
165   }
server_to_client_messages()166   Pipe<MessageHandle>& server_to_client_messages() override {
167     return server_to_client_messages_;
168   }
server_trailing_metadata()169   Pipe<ServerMetadataHandle>& server_trailing_metadata() override {
170     return server_trailing_metadata_;
171   }
cancel_latch()172   Latch<ServerMetadataHandle>& cancel_latch() override { return cancel_latch_; }
party()173   Party& party() override { return *this; }
arena()174   Arena* arena() override { return arena_; }
IncrementRefCount()175   void IncrementRefCount() override { Party::IncrementRefCount(); }
Unref()176   void Unref() override { Party::Unref(); }
177 
178  private:
179   friend class Arena;
CallSpine(grpc_event_engine::experimental::EventEngine * event_engine,Arena * arena)180   CallSpine(grpc_event_engine::experimental::EventEngine* event_engine,
181             Arena* arena)
182       : Party(1), arena_(arena), event_engine_(event_engine) {}
183 
184   class ScopedContext : public ScopedActivity,
185                         public promise_detail::Context<Arena> {
186    public:
ScopedContext(CallSpine * spine)187     explicit ScopedContext(CallSpine* spine)
188         : ScopedActivity(&spine->party()), Context<Arena>(spine->arena()) {}
189   };
190 
RunParty()191   bool RunParty() override {
192     ScopedContext context(this);
193     return Party::RunParty();
194   }
195 
PartyOver()196   void PartyOver() override {
197     Arena* a = arena();
198     {
199       ScopedContext context(this);
200       CancelRemainingParticipants();
201       a->DestroyManagedNewObjects();
202     }
203     this->~CallSpine();
204     a->Destroy();
205   }
206 
event_engine()207   grpc_event_engine::experimental::EventEngine* event_engine() const override {
208     return event_engine_;
209   }
210 
211   Arena* arena_;
212   // Initial metadata from client to server
213   Pipe<ClientMetadataHandle> client_initial_metadata_{arena()};
214   // Initial metadata from server to client
215   Pipe<ServerMetadataHandle> server_initial_metadata_{arena()};
216   // Messages travelling from the application to the transport.
217   Pipe<MessageHandle> client_to_server_messages_{arena()};
218   // Messages travelling from the transport to the application.
219   Pipe<MessageHandle> server_to_client_messages_{arena()};
220   // Trailing metadata from server to client
221   Pipe<ServerMetadataHandle> server_trailing_metadata_{arena()};
222   // Latch that can be set to terminate the call
223   Latch<ServerMetadataHandle> cancel_latch_;
224   // Event engine associated with this call
225   grpc_event_engine::experimental::EventEngine* const event_engine_;
226 };
227 
228 class CallInitiator {
229  public:
CallInitiator(RefCountedPtr<CallSpineInterface> spine)230   explicit CallInitiator(RefCountedPtr<CallSpineInterface> spine)
231       : spine_(std::move(spine)) {}
232 
PushClientInitialMetadata(ClientMetadataHandle md)233   auto PushClientInitialMetadata(ClientMetadataHandle md) {
234     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
235     return Map(spine_->client_initial_metadata().sender.Push(std::move(md)),
236                [](bool ok) { return StatusFlag(ok); });
237   }
238 
PullServerInitialMetadata()239   auto PullServerInitialMetadata() {
240     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
241     return Map(spine_->server_initial_metadata().receiver.Next(),
242                [](NextResult<ServerMetadataHandle> md)
243                    -> ValueOrFailure<absl::optional<ServerMetadataHandle>> {
244                  if (!md.has_value()) {
245                    if (md.cancelled()) return Failure{};
246                    return absl::optional<ServerMetadataHandle>();
247                  }
248                  return absl::optional<ServerMetadataHandle>(std::move(*md));
249                });
250   }
251 
PullServerTrailingMetadata()252   auto PullServerTrailingMetadata() {
253     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
254     return PrioritizedRace(
255         Seq(spine_->server_trailing_metadata().receiver.Next(),
256             [spine = spine_](NextResult<ServerMetadataHandle> md) mutable {
257               return [md = std::move(md),
258                       spine]() mutable -> Poll<ServerMetadataHandle> {
259                 // If the pipe was closed at cancellation time, we'll see no
260                 // value here. Return pending and allow the cancellation to win
261                 // the race.
262                 if (!md.has_value()) return Pending{};
263                 spine->server_trailing_metadata().sender.Close();
264                 return std::move(*md);
265               };
266             }),
267         Map(spine_->WaitForCancel(),
268             [spine = spine_](ServerMetadataHandle md) -> ServerMetadataHandle {
269               spine->server_trailing_metadata().sender.CloseWithError();
270               return md;
271             }));
272   }
273 
PullMessage()274   auto PullMessage() {
275     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
276     return spine_->server_to_client_messages().receiver.Next();
277   }
278 
PushMessage(MessageHandle message)279   auto PushMessage(MessageHandle message) {
280     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
281     return Map(
282         spine_->client_to_server_messages().sender.Push(std::move(message)),
283         [](bool r) { return StatusFlag(r); });
284   }
285 
FinishSends()286   void FinishSends() {
287     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
288     spine_->client_to_server_messages().sender.Close();
289   }
290 
291   template <typename Promise>
CancelIfFails(Promise promise)292   auto CancelIfFails(Promise promise) {
293     return spine_->CancelIfFails(std::move(promise));
294   }
295 
Cancel()296   void Cancel() {
297     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
298     std::ignore =
299         spine_->Cancel(ServerMetadataFromStatus(absl::CancelledError()));
300   }
301 
OnDone(absl::AnyInvocable<void ()> fn)302   void OnDone(absl::AnyInvocable<void()> fn) { spine_->OnDone(std::move(fn)); }
303 
304   template <typename PromiseFactory>
SpawnGuarded(absl::string_view name,PromiseFactory promise_factory)305   void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
306     spine_->SpawnGuarded(name, std::move(promise_factory));
307   }
308 
309   template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)310   void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
311     spine_->SpawnInfallible(name, std::move(promise_factory));
312   }
313 
314   template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)315   auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
316     return spine_->party().SpawnWaitable(name, std::move(promise_factory));
317   }
318 
arena()319   Arena* arena() { return spine_->arena(); }
320 
321  private:
322   RefCountedPtr<CallSpineInterface> spine_;
323 };
324 
325 class CallHandler {
326  public:
CallHandler(RefCountedPtr<CallSpineInterface> spine)327   explicit CallHandler(RefCountedPtr<CallSpineInterface> spine)
328       : spine_(std::move(spine)) {}
329 
PullClientInitialMetadata()330   auto PullClientInitialMetadata() {
331     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
332     return Map(spine_->client_initial_metadata().receiver.Next(),
333                [](NextResult<ClientMetadataHandle> md)
334                    -> ValueOrFailure<ClientMetadataHandle> {
335                  if (!md.has_value()) return Failure{};
336                  return std::move(*md);
337                });
338   }
339 
PushServerInitialMetadata(absl::optional<ServerMetadataHandle> md)340   auto PushServerInitialMetadata(absl::optional<ServerMetadataHandle> md) {
341     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
342     return If(
343         md.has_value(),
344         [&md, this]() {
345           return Map(
346               spine_->server_initial_metadata().sender.Push(std::move(*md)),
347               [](bool ok) { return StatusFlag(ok); });
348         },
349         [this]() {
350           spine_->server_initial_metadata().sender.Close();
351           return []() -> StatusFlag { return Success{}; };
352         });
353   }
354 
PushServerTrailingMetadata(ServerMetadataHandle md)355   auto PushServerTrailingMetadata(ServerMetadataHandle md) {
356     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
357     spine_->server_initial_metadata().sender.Close();
358     spine_->server_to_client_messages().sender.Close();
359     spine_->client_to_server_messages().receiver.CloseWithError();
360     spine_->CallOnDone();
361     return Map(spine_->server_trailing_metadata().sender.Push(std::move(md)),
362                [](bool ok) { return StatusFlag(ok); });
363   }
364 
PullMessage()365   auto PullMessage() {
366     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
367     return spine_->client_to_server_messages().receiver.Next();
368   }
369 
PushMessage(MessageHandle message)370   auto PushMessage(MessageHandle message) {
371     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
372     return Map(
373         spine_->server_to_client_messages().sender.Push(std::move(message)),
374         [](bool ok) { return StatusFlag(ok); });
375   }
376 
Cancel(ServerMetadataHandle status)377   void Cancel(ServerMetadataHandle status) {
378     GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
379     std::ignore = spine_->Cancel(std::move(status));
380   }
381 
OnDone(absl::AnyInvocable<void ()> fn)382   void OnDone(absl::AnyInvocable<void()> fn) { spine_->OnDone(std::move(fn)); }
383 
384   template <typename Promise>
CancelIfFails(Promise promise)385   auto CancelIfFails(Promise promise) {
386     return spine_->CancelIfFails(std::move(promise));
387   }
388 
389   template <typename PromiseFactory>
SpawnGuarded(absl::string_view name,PromiseFactory promise_factory)390   void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
391     spine_->SpawnGuarded(name, std::move(promise_factory));
392   }
393 
394   template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)395   void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
396     spine_->SpawnInfallible(name, std::move(promise_factory));
397   }
398 
399   template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)400   auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
401     return spine_->party().SpawnWaitable(name, std::move(promise_factory));
402   }
403 
arena()404   Arena* arena() { return spine_->arena(); }
405 
406  private:
407   RefCountedPtr<CallSpineInterface> spine_;
408 };
409 
410 struct CallInitiatorAndHandler {
411   CallInitiator initiator;
412   CallHandler handler;
413 };
414 
415 CallInitiatorAndHandler MakeCall(
416     grpc_event_engine::experimental::EventEngine* event_engine, Arena* arena);
417 
418 template <typename CallHalf>
OutgoingMessages(CallHalf h)419 auto OutgoingMessages(CallHalf h) {
420   struct Wrapper {
421     CallHalf h;
422     auto Next() { return h.PullMessage(); }
423   };
424   return Wrapper{std::move(h)};
425 }
426 
427 // Forward a call from `call_handler` to `call_initiator` (with initial metadata
428 // `client_initial_metadata`)
429 void ForwardCall(CallHandler call_handler, CallInitiator call_initiator,
430                  ClientMetadataHandle client_initial_metadata);
431 
432 }  // namespace grpc_core
433 
434 #endif  // GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
435