1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
16 #define GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
17
18 #include <grpc/support/port_platform.h>
19
20 #include <grpc/support/log.h>
21
22 #include "src/core/lib/promise/detail/status.h"
23 #include "src/core/lib/promise/for_each.h"
24 #include "src/core/lib/promise/if.h"
25 #include "src/core/lib/promise/latch.h"
26 #include "src/core/lib/promise/party.h"
27 #include "src/core/lib/promise/pipe.h"
28 #include "src/core/lib/promise/prioritized_race.h"
29 #include "src/core/lib/promise/status_flag.h"
30 #include "src/core/lib/promise/try_seq.h"
31 #include "src/core/lib/transport/message.h"
32 #include "src/core/lib/transport/metadata.h"
33
34 namespace grpc_core {
35
36 // The common middle part of a call - a reference is held by each of
37 // CallInitiator and CallHandler - which provide interfaces that are appropriate
38 // for each side of a call.
39 // The spine will ultimately host the pipes, filters, and context for one part
40 // of a call: ie top-half client channel, sub channel call, server call.
41 // TODO(ctiller): eventually drop this when we don't need to reference into
42 // legacy promise calls anymore
43 class CallSpineInterface {
44 public:
45 virtual ~CallSpineInterface() = default;
46 virtual Pipe<ClientMetadataHandle>& client_initial_metadata() = 0;
47 virtual Pipe<ServerMetadataHandle>& server_initial_metadata() = 0;
48 virtual Pipe<MessageHandle>& client_to_server_messages() = 0;
49 virtual Pipe<MessageHandle>& server_to_client_messages() = 0;
50 virtual Pipe<ServerMetadataHandle>& server_trailing_metadata() = 0;
51 virtual Latch<ServerMetadataHandle>& cancel_latch() = 0;
52 // Add a callback to be called when server trailing metadata is received.
OnDone(absl::AnyInvocable<void ()> fn)53 void OnDone(absl::AnyInvocable<void()> fn) {
54 if (on_done_ == nullptr) {
55 on_done_ = std::move(fn);
56 return;
57 }
58 on_done_ = [first = std::move(fn), next = std::move(on_done_)]() mutable {
59 first();
60 next();
61 };
62 }
CallOnDone()63 void CallOnDone() {
64 if (on_done_ != nullptr) std::exchange(on_done_, nullptr)();
65 }
66 virtual Party& party() = 0;
67 virtual Arena* arena() = 0;
68 virtual void IncrementRefCount() = 0;
69 virtual void Unref() = 0;
70
71 // Cancel the call with the given metadata.
72 // Regarding the `MUST_USE_RESULT absl::nullopt_t`:
73 // Most cancellation calls right now happen in pipe interceptors;
74 // there `nullopt` indicates terminate processing of this pipe and close with
75 // error.
76 // It's convenient then to have the Cancel operation (setting the latch to
77 // terminate the call) be the last thing that occurs in a pipe interceptor,
78 // and this construction supports that (and has helped the author not write
79 // some bugs).
Cancel(ServerMetadataHandle metadata)80 GRPC_MUST_USE_RESULT absl::nullopt_t Cancel(ServerMetadataHandle metadata) {
81 GPR_DEBUG_ASSERT(GetContext<Activity>() == &party());
82 auto& c = cancel_latch();
83 if (c.is_set()) return absl::nullopt;
84 c.Set(std::move(metadata));
85 CallOnDone();
86 client_initial_metadata().sender.CloseWithError();
87 server_initial_metadata().sender.CloseWithError();
88 client_to_server_messages().sender.CloseWithError();
89 server_to_client_messages().sender.CloseWithError();
90 server_trailing_metadata().sender.CloseWithError();
91 return absl::nullopt;
92 }
93
WaitForCancel()94 auto WaitForCancel() {
95 GPR_DEBUG_ASSERT(GetContext<Activity>() == &party());
96 return cancel_latch().Wait();
97 }
98
99 // Wrap a promise so that if it returns failure it automatically cancels
100 // the rest of the call.
101 // The resulting (returned) promise will resolve to Empty.
102 template <typename Promise>
CancelIfFails(Promise promise)103 auto CancelIfFails(Promise promise) {
104 GPR_DEBUG_ASSERT(GetContext<Activity>() == &party());
105 using P = promise_detail::PromiseLike<Promise>;
106 using ResultType = typename P::Result;
107 return Map(std::move(promise), [this](ResultType r) {
108 if (!IsStatusOk(r)) {
109 std::ignore = Cancel(StatusCast<ServerMetadataHandle>(r));
110 }
111 return r;
112 });
113 }
114
115 // Spawn a promise that returns Empty{} and save some boilerplate handling
116 // that detail.
117 template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)118 void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
119 party().Spawn(name, std::move(promise_factory), [](Empty) {});
120 }
121
122 // Spawn a promise that returns some status-like type; if the status
123 // represents failure automatically cancel the rest of the call.
124 template <typename PromiseFactory>
SpawnGuarded(absl::string_view name,PromiseFactory promise_factory)125 void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
126 using FactoryType =
127 promise_detail::OncePromiseFactory<void, PromiseFactory>;
128 using PromiseType = typename FactoryType::Promise;
129 using ResultType = typename PromiseType::Result;
130 static_assert(
131 std::is_same<bool,
132 decltype(IsStatusOk(std::declval<ResultType>()))>::value,
133 "SpawnGuarded promise must return a status-like object");
134 party().Spawn(name, std::move(promise_factory), [this](ResultType r) {
135 if (!IsStatusOk(r)) {
136 if (grpc_trace_promise_primitives.enabled()) {
137 gpr_log(GPR_DEBUG, "SpawnGuarded sees failure: %s",
138 r.ToString().c_str());
139 }
140 std::ignore = Cancel(StatusCast<ServerMetadataHandle>(std::move(r)));
141 }
142 });
143 }
144
145 private:
146 absl::AnyInvocable<void()> on_done_{nullptr};
147 };
148
149 class CallSpine final : public CallSpineInterface, public Party {
150 public:
Create(grpc_event_engine::experimental::EventEngine * event_engine,Arena * arena)151 static RefCountedPtr<CallSpine> Create(
152 grpc_event_engine::experimental::EventEngine* event_engine,
153 Arena* arena) {
154 return RefCountedPtr<CallSpine>(arena->New<CallSpine>(event_engine, arena));
155 }
156
client_initial_metadata()157 Pipe<ClientMetadataHandle>& client_initial_metadata() override {
158 return client_initial_metadata_;
159 }
server_initial_metadata()160 Pipe<ServerMetadataHandle>& server_initial_metadata() override {
161 return server_initial_metadata_;
162 }
client_to_server_messages()163 Pipe<MessageHandle>& client_to_server_messages() override {
164 return client_to_server_messages_;
165 }
server_to_client_messages()166 Pipe<MessageHandle>& server_to_client_messages() override {
167 return server_to_client_messages_;
168 }
server_trailing_metadata()169 Pipe<ServerMetadataHandle>& server_trailing_metadata() override {
170 return server_trailing_metadata_;
171 }
cancel_latch()172 Latch<ServerMetadataHandle>& cancel_latch() override { return cancel_latch_; }
party()173 Party& party() override { return *this; }
arena()174 Arena* arena() override { return arena_; }
IncrementRefCount()175 void IncrementRefCount() override { Party::IncrementRefCount(); }
Unref()176 void Unref() override { Party::Unref(); }
177
178 private:
179 friend class Arena;
CallSpine(grpc_event_engine::experimental::EventEngine * event_engine,Arena * arena)180 CallSpine(grpc_event_engine::experimental::EventEngine* event_engine,
181 Arena* arena)
182 : Party(1), arena_(arena), event_engine_(event_engine) {}
183
184 class ScopedContext : public ScopedActivity,
185 public promise_detail::Context<Arena> {
186 public:
ScopedContext(CallSpine * spine)187 explicit ScopedContext(CallSpine* spine)
188 : ScopedActivity(&spine->party()), Context<Arena>(spine->arena()) {}
189 };
190
RunParty()191 bool RunParty() override {
192 ScopedContext context(this);
193 return Party::RunParty();
194 }
195
PartyOver()196 void PartyOver() override {
197 Arena* a = arena();
198 {
199 ScopedContext context(this);
200 CancelRemainingParticipants();
201 a->DestroyManagedNewObjects();
202 }
203 this->~CallSpine();
204 a->Destroy();
205 }
206
event_engine()207 grpc_event_engine::experimental::EventEngine* event_engine() const override {
208 return event_engine_;
209 }
210
211 Arena* arena_;
212 // Initial metadata from client to server
213 Pipe<ClientMetadataHandle> client_initial_metadata_{arena()};
214 // Initial metadata from server to client
215 Pipe<ServerMetadataHandle> server_initial_metadata_{arena()};
216 // Messages travelling from the application to the transport.
217 Pipe<MessageHandle> client_to_server_messages_{arena()};
218 // Messages travelling from the transport to the application.
219 Pipe<MessageHandle> server_to_client_messages_{arena()};
220 // Trailing metadata from server to client
221 Pipe<ServerMetadataHandle> server_trailing_metadata_{arena()};
222 // Latch that can be set to terminate the call
223 Latch<ServerMetadataHandle> cancel_latch_;
224 // Event engine associated with this call
225 grpc_event_engine::experimental::EventEngine* const event_engine_;
226 };
227
228 class CallInitiator {
229 public:
CallInitiator(RefCountedPtr<CallSpineInterface> spine)230 explicit CallInitiator(RefCountedPtr<CallSpineInterface> spine)
231 : spine_(std::move(spine)) {}
232
PushClientInitialMetadata(ClientMetadataHandle md)233 auto PushClientInitialMetadata(ClientMetadataHandle md) {
234 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
235 return Map(spine_->client_initial_metadata().sender.Push(std::move(md)),
236 [](bool ok) { return StatusFlag(ok); });
237 }
238
PullServerInitialMetadata()239 auto PullServerInitialMetadata() {
240 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
241 return Map(spine_->server_initial_metadata().receiver.Next(),
242 [](NextResult<ServerMetadataHandle> md)
243 -> ValueOrFailure<absl::optional<ServerMetadataHandle>> {
244 if (!md.has_value()) {
245 if (md.cancelled()) return Failure{};
246 return absl::optional<ServerMetadataHandle>();
247 }
248 return absl::optional<ServerMetadataHandle>(std::move(*md));
249 });
250 }
251
PullServerTrailingMetadata()252 auto PullServerTrailingMetadata() {
253 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
254 return PrioritizedRace(
255 Seq(spine_->server_trailing_metadata().receiver.Next(),
256 [spine = spine_](NextResult<ServerMetadataHandle> md) mutable {
257 return [md = std::move(md),
258 spine]() mutable -> Poll<ServerMetadataHandle> {
259 // If the pipe was closed at cancellation time, we'll see no
260 // value here. Return pending and allow the cancellation to win
261 // the race.
262 if (!md.has_value()) return Pending{};
263 spine->server_trailing_metadata().sender.Close();
264 return std::move(*md);
265 };
266 }),
267 Map(spine_->WaitForCancel(),
268 [spine = spine_](ServerMetadataHandle md) -> ServerMetadataHandle {
269 spine->server_trailing_metadata().sender.CloseWithError();
270 return md;
271 }));
272 }
273
PullMessage()274 auto PullMessage() {
275 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
276 return spine_->server_to_client_messages().receiver.Next();
277 }
278
PushMessage(MessageHandle message)279 auto PushMessage(MessageHandle message) {
280 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
281 return Map(
282 spine_->client_to_server_messages().sender.Push(std::move(message)),
283 [](bool r) { return StatusFlag(r); });
284 }
285
FinishSends()286 void FinishSends() {
287 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
288 spine_->client_to_server_messages().sender.Close();
289 }
290
291 template <typename Promise>
CancelIfFails(Promise promise)292 auto CancelIfFails(Promise promise) {
293 return spine_->CancelIfFails(std::move(promise));
294 }
295
Cancel()296 void Cancel() {
297 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
298 std::ignore =
299 spine_->Cancel(ServerMetadataFromStatus(absl::CancelledError()));
300 }
301
OnDone(absl::AnyInvocable<void ()> fn)302 void OnDone(absl::AnyInvocable<void()> fn) { spine_->OnDone(std::move(fn)); }
303
304 template <typename PromiseFactory>
SpawnGuarded(absl::string_view name,PromiseFactory promise_factory)305 void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
306 spine_->SpawnGuarded(name, std::move(promise_factory));
307 }
308
309 template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)310 void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
311 spine_->SpawnInfallible(name, std::move(promise_factory));
312 }
313
314 template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)315 auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
316 return spine_->party().SpawnWaitable(name, std::move(promise_factory));
317 }
318
arena()319 Arena* arena() { return spine_->arena(); }
320
321 private:
322 RefCountedPtr<CallSpineInterface> spine_;
323 };
324
325 class CallHandler {
326 public:
CallHandler(RefCountedPtr<CallSpineInterface> spine)327 explicit CallHandler(RefCountedPtr<CallSpineInterface> spine)
328 : spine_(std::move(spine)) {}
329
PullClientInitialMetadata()330 auto PullClientInitialMetadata() {
331 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
332 return Map(spine_->client_initial_metadata().receiver.Next(),
333 [](NextResult<ClientMetadataHandle> md)
334 -> ValueOrFailure<ClientMetadataHandle> {
335 if (!md.has_value()) return Failure{};
336 return std::move(*md);
337 });
338 }
339
PushServerInitialMetadata(absl::optional<ServerMetadataHandle> md)340 auto PushServerInitialMetadata(absl::optional<ServerMetadataHandle> md) {
341 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
342 return If(
343 md.has_value(),
344 [&md, this]() {
345 return Map(
346 spine_->server_initial_metadata().sender.Push(std::move(*md)),
347 [](bool ok) { return StatusFlag(ok); });
348 },
349 [this]() {
350 spine_->server_initial_metadata().sender.Close();
351 return []() -> StatusFlag { return Success{}; };
352 });
353 }
354
PushServerTrailingMetadata(ServerMetadataHandle md)355 auto PushServerTrailingMetadata(ServerMetadataHandle md) {
356 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
357 spine_->server_initial_metadata().sender.Close();
358 spine_->server_to_client_messages().sender.Close();
359 spine_->client_to_server_messages().receiver.CloseWithError();
360 spine_->CallOnDone();
361 return Map(spine_->server_trailing_metadata().sender.Push(std::move(md)),
362 [](bool ok) { return StatusFlag(ok); });
363 }
364
PullMessage()365 auto PullMessage() {
366 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
367 return spine_->client_to_server_messages().receiver.Next();
368 }
369
PushMessage(MessageHandle message)370 auto PushMessage(MessageHandle message) {
371 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
372 return Map(
373 spine_->server_to_client_messages().sender.Push(std::move(message)),
374 [](bool ok) { return StatusFlag(ok); });
375 }
376
Cancel(ServerMetadataHandle status)377 void Cancel(ServerMetadataHandle status) {
378 GPR_DEBUG_ASSERT(GetContext<Activity>() == &spine_->party());
379 std::ignore = spine_->Cancel(std::move(status));
380 }
381
OnDone(absl::AnyInvocable<void ()> fn)382 void OnDone(absl::AnyInvocable<void()> fn) { spine_->OnDone(std::move(fn)); }
383
384 template <typename Promise>
CancelIfFails(Promise promise)385 auto CancelIfFails(Promise promise) {
386 return spine_->CancelIfFails(std::move(promise));
387 }
388
389 template <typename PromiseFactory>
SpawnGuarded(absl::string_view name,PromiseFactory promise_factory)390 void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
391 spine_->SpawnGuarded(name, std::move(promise_factory));
392 }
393
394 template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)395 void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
396 spine_->SpawnInfallible(name, std::move(promise_factory));
397 }
398
399 template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)400 auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
401 return spine_->party().SpawnWaitable(name, std::move(promise_factory));
402 }
403
arena()404 Arena* arena() { return spine_->arena(); }
405
406 private:
407 RefCountedPtr<CallSpineInterface> spine_;
408 };
409
410 struct CallInitiatorAndHandler {
411 CallInitiator initiator;
412 CallHandler handler;
413 };
414
415 CallInitiatorAndHandler MakeCall(
416 grpc_event_engine::experimental::EventEngine* event_engine, Arena* arena);
417
418 template <typename CallHalf>
OutgoingMessages(CallHalf h)419 auto OutgoingMessages(CallHalf h) {
420 struct Wrapper {
421 CallHalf h;
422 auto Next() { return h.PullMessage(); }
423 };
424 return Wrapper{std::move(h)};
425 }
426
427 // Forward a call from `call_handler` to `call_initiator` (with initial metadata
428 // `client_initial_metadata`)
429 void ForwardCall(CallHandler call_handler, CallInitiator call_initiator,
430 ClientMetadataHandle client_initial_metadata);
431
432 } // namespace grpc_core
433
434 #endif // GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
435