xref: /aosp_15_r20/external/grpc-grpc/include/grpcpp/impl/interceptor_common.h (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #ifndef GRPCPP_IMPL_INTERCEPTOR_COMMON_H
20 #define GRPCPP_IMPL_INTERCEPTOR_COMMON_H
21 
22 #include <array>
23 #include <functional>
24 
25 #include <grpc/impl/grpc_types.h>
26 #include <grpc/support/log.h>
27 #include <grpcpp/impl/call.h>
28 #include <grpcpp/impl/call_op_set_interface.h>
29 #include <grpcpp/impl/intercepted_channel.h>
30 #include <grpcpp/support/client_interceptor.h>
31 #include <grpcpp/support/server_interceptor.h>
32 
33 namespace grpc {
34 namespace internal {
35 
36 class InterceptorBatchMethodsImpl
37     : public experimental::InterceptorBatchMethods {
38  public:
InterceptorBatchMethodsImpl()39   InterceptorBatchMethodsImpl() {
40     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
41          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
42          i = static_cast<experimental::InterceptionHookPoints>(
43              static_cast<size_t>(i) + 1)) {
44       hooks_[static_cast<size_t>(i)] = false;
45     }
46   }
47 
~InterceptorBatchMethodsImpl()48   ~InterceptorBatchMethodsImpl() override {}
49 
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)50   bool QueryInterceptionHookPoint(
51       experimental::InterceptionHookPoints type) override {
52     return hooks_[static_cast<size_t>(type)];
53   }
54 
Proceed()55   void Proceed() override {
56     if (call_->client_rpc_info() != nullptr) {
57       return ProceedClient();
58     }
59     GPR_ASSERT(call_->server_rpc_info() != nullptr);
60     ProceedServer();
61   }
62 
Hijack()63   void Hijack() override {
64     // Only the client can hijack when sending down initial metadata
65     GPR_ASSERT(!reverse_ && ops_ != nullptr &&
66                call_->client_rpc_info() != nullptr);
67     // It is illegal to call Hijack twice
68     GPR_ASSERT(!ran_hijacking_interceptor_);
69     auto* rpc_info = call_->client_rpc_info();
70     rpc_info->hijacked_ = true;
71     rpc_info->hijacked_interceptor_ = current_interceptor_index_;
72     ClearHookPoints();
73     ops_->SetHijackingState();
74     ran_hijacking_interceptor_ = true;
75     rpc_info->RunInterceptor(this, current_interceptor_index_);
76   }
77 
AddInterceptionHookPoint(experimental::InterceptionHookPoints type)78   void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) {
79     hooks_[static_cast<size_t>(type)] = true;
80   }
81 
GetSerializedSendMessage()82   ByteBuffer* GetSerializedSendMessage() override {
83     GPR_ASSERT(orig_send_message_ != nullptr);
84     if (*orig_send_message_ != nullptr) {
85       GPR_ASSERT(serializer_(*orig_send_message_).ok());
86       *orig_send_message_ = nullptr;
87     }
88     return send_message_;
89   }
90 
GetSendMessage()91   const void* GetSendMessage() override {
92     GPR_ASSERT(orig_send_message_ != nullptr);
93     return *orig_send_message_;
94   }
95 
ModifySendMessage(const void * message)96   void ModifySendMessage(const void* message) override {
97     GPR_ASSERT(orig_send_message_ != nullptr);
98     *orig_send_message_ = message;
99   }
100 
GetSendMessageStatus()101   bool GetSendMessageStatus() override { return !*fail_send_message_; }
102 
GetSendInitialMetadata()103   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
104     return send_initial_metadata_;
105   }
106 
GetSendStatus()107   Status GetSendStatus() override {
108     return Status(static_cast<StatusCode>(*code_), *error_message_,
109                   *error_details_);
110   }
111 
ModifySendStatus(const Status & status)112   void ModifySendStatus(const Status& status) override {
113     *code_ = static_cast<grpc_status_code>(status.error_code());
114     *error_details_ = status.error_details();
115     *error_message_ = status.error_message();
116   }
117 
GetSendTrailingMetadata()118   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
119     return send_trailing_metadata_;
120   }
121 
GetRecvMessage()122   void* GetRecvMessage() override { return recv_message_; }
123 
GetRecvInitialMetadata()124   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
125       override {
126     return recv_initial_metadata_->map();
127   }
128 
GetRecvStatus()129   Status* GetRecvStatus() override { return recv_status_; }
130 
FailHijackedSendMessage()131   void FailHijackedSendMessage() override {
132     GPR_ASSERT(hooks_[static_cast<size_t>(
133         experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
134     *fail_send_message_ = true;
135   }
136 
GetRecvTrailingMetadata()137   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
138       override {
139     return recv_trailing_metadata_->map();
140   }
141 
SetSendMessage(ByteBuffer * buf,const void ** msg,bool * fail_send_message,std::function<Status (const void *)> serializer)142   void SetSendMessage(ByteBuffer* buf, const void** msg,
143                       bool* fail_send_message,
144                       std::function<Status(const void*)> serializer) {
145     send_message_ = buf;
146     orig_send_message_ = msg;
147     fail_send_message_ = fail_send_message;
148     serializer_ = serializer;
149   }
150 
SetSendInitialMetadata(std::multimap<std::string,std::string> * metadata)151   void SetSendInitialMetadata(
152       std::multimap<std::string, std::string>* metadata) {
153     send_initial_metadata_ = metadata;
154   }
155 
SetSendStatus(grpc_status_code * code,std::string * error_details,std::string * error_message)156   void SetSendStatus(grpc_status_code* code, std::string* error_details,
157                      std::string* error_message) {
158     code_ = code;
159     error_details_ = error_details;
160     error_message_ = error_message;
161   }
162 
SetSendTrailingMetadata(std::multimap<std::string,std::string> * metadata)163   void SetSendTrailingMetadata(
164       std::multimap<std::string, std::string>* metadata) {
165     send_trailing_metadata_ = metadata;
166   }
167 
SetRecvMessage(void * message,bool * hijacked_recv_message_failed)168   void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) {
169     recv_message_ = message;
170     hijacked_recv_message_failed_ = hijacked_recv_message_failed;
171   }
172 
SetRecvInitialMetadata(MetadataMap * map)173   void SetRecvInitialMetadata(MetadataMap* map) {
174     recv_initial_metadata_ = map;
175   }
176 
SetRecvStatus(Status * status)177   void SetRecvStatus(Status* status) { recv_status_ = status; }
178 
SetRecvTrailingMetadata(MetadataMap * map)179   void SetRecvTrailingMetadata(MetadataMap* map) {
180     recv_trailing_metadata_ = map;
181   }
182 
GetInterceptedChannel()183   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
184     auto* info = call_->client_rpc_info();
185     if (info == nullptr) {
186       return std::unique_ptr<ChannelInterface>(nullptr);
187     }
188     // The intercepted channel starts from the interceptor just after the
189     // current interceptor
190     return std::unique_ptr<ChannelInterface>(new InterceptedChannel(
191         info->channel(), current_interceptor_index_ + 1));
192   }
193 
FailHijackedRecvMessage()194   void FailHijackedRecvMessage() override {
195     GPR_ASSERT(hooks_[static_cast<size_t>(
196         experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]);
197     *hijacked_recv_message_failed_ = true;
198   }
199 
200   // Clears all state
ClearState()201   void ClearState() {
202     reverse_ = false;
203     ran_hijacking_interceptor_ = false;
204     ClearHookPoints();
205   }
206 
207   // Prepares for Post_recv operations
SetReverse()208   void SetReverse() {
209     reverse_ = true;
210     ran_hijacking_interceptor_ = false;
211     ClearHookPoints();
212   }
213 
214   // This needs to be set before interceptors are run
SetCall(Call * call)215   void SetCall(Call* call) { call_ = call; }
216 
217   // This needs to be set before interceptors are run using RunInterceptors().
218   // Alternatively, RunInterceptors(std::function<void(void)> f) can be used.
SetCallOpSetInterface(CallOpSetInterface * ops)219   void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; }
220 
221   // SetCall should have been called before this.
222   // Returns true if the interceptors list is empty
InterceptorsListEmpty()223   bool InterceptorsListEmpty() {
224     auto* client_rpc_info = call_->client_rpc_info();
225     if (client_rpc_info != nullptr) {
226       return client_rpc_info->interceptors_.empty();
227     }
228 
229     auto* server_rpc_info = call_->server_rpc_info();
230     return server_rpc_info == nullptr || server_rpc_info->interceptors_.empty();
231   }
232 
233   // This should be used only by subclasses of CallOpSetInterface. SetCall and
234   // SetCallOpSetInterface should have been called before this. After all the
235   // interceptors are done running, either ContinueFillOpsAfterInterception or
236   // ContinueFinalizeOpsAfterInterception will be called. Note that neither of
237   // them is invoked if there were no interceptors registered.
RunInterceptors()238   bool RunInterceptors() {
239     GPR_ASSERT(ops_);
240     auto* client_rpc_info = call_->client_rpc_info();
241     if (client_rpc_info != nullptr) {
242       if (client_rpc_info->interceptors_.empty()) {
243         return true;
244       } else {
245         RunClientInterceptors();
246         return false;
247       }
248     }
249 
250     auto* server_rpc_info = call_->server_rpc_info();
251     if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
252       return true;
253     }
254     RunServerInterceptors();
255     return false;
256   }
257 
258   // Returns true if no interceptors are run. Returns false otherwise if there
259   // are interceptors registered. After the interceptors are done running \a f
260   // will be invoked. This is to be used only by BaseAsyncRequest and
261   // SyncRequest.
RunInterceptors(std::function<void (void)> f)262   bool RunInterceptors(std::function<void(void)> f) {
263     // This is used only by the server for initial call request
264     GPR_ASSERT(reverse_ == true);
265     GPR_ASSERT(call_->client_rpc_info() == nullptr);
266     auto* server_rpc_info = call_->server_rpc_info();
267     if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
268       return true;
269     }
270     callback_ = std::move(f);
271     RunServerInterceptors();
272     return false;
273   }
274 
275  private:
RunClientInterceptors()276   void RunClientInterceptors() {
277     auto* rpc_info = call_->client_rpc_info();
278     if (!reverse_) {
279       current_interceptor_index_ = 0;
280     } else {
281       if (rpc_info->hijacked_) {
282         current_interceptor_index_ = rpc_info->hijacked_interceptor_;
283       } else {
284         current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
285       }
286     }
287     rpc_info->RunInterceptor(this, current_interceptor_index_);
288   }
289 
RunServerInterceptors()290   void RunServerInterceptors() {
291     auto* rpc_info = call_->server_rpc_info();
292     if (!reverse_) {
293       current_interceptor_index_ = 0;
294     } else {
295       current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
296     }
297     rpc_info->RunInterceptor(this, current_interceptor_index_);
298   }
299 
ProceedClient()300   void ProceedClient() {
301     auto* rpc_info = call_->client_rpc_info();
302     if (rpc_info->hijacked_ && !reverse_ &&
303         current_interceptor_index_ == rpc_info->hijacked_interceptor_ &&
304         !ran_hijacking_interceptor_) {
305       // We now need to provide hijacked recv ops to this interceptor
306       ClearHookPoints();
307       ops_->SetHijackingState();
308       ran_hijacking_interceptor_ = true;
309       rpc_info->RunInterceptor(this, current_interceptor_index_);
310       return;
311     }
312     if (!reverse_) {
313       current_interceptor_index_++;
314       // We are going down the stack of interceptors
315       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
316         if (rpc_info->hijacked_ &&
317             current_interceptor_index_ > rpc_info->hijacked_interceptor_) {
318           // This is a hijacked RPC and we are done with hijacking
319           ops_->ContinueFillOpsAfterInterception();
320         } else {
321           rpc_info->RunInterceptor(this, current_interceptor_index_);
322         }
323       } else {
324         // we are done running all the interceptors without any hijacking
325         ops_->ContinueFillOpsAfterInterception();
326       }
327     } else {
328       // We are going up the stack of interceptors
329       if (current_interceptor_index_ > 0) {
330         // Continue running interceptors
331         current_interceptor_index_--;
332         rpc_info->RunInterceptor(this, current_interceptor_index_);
333       } else {
334         // we are done running all the interceptors without any hijacking
335         ops_->ContinueFinalizeResultAfterInterception();
336       }
337     }
338   }
339 
ProceedServer()340   void ProceedServer() {
341     auto* rpc_info = call_->server_rpc_info();
342     if (!reverse_) {
343       current_interceptor_index_++;
344       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
345         return rpc_info->RunInterceptor(this, current_interceptor_index_);
346       } else if (ops_) {
347         return ops_->ContinueFillOpsAfterInterception();
348       }
349     } else {
350       // We are going up the stack of interceptors
351       if (current_interceptor_index_ > 0) {
352         // Continue running interceptors
353         current_interceptor_index_--;
354         return rpc_info->RunInterceptor(this, current_interceptor_index_);
355       } else if (ops_) {
356         return ops_->ContinueFinalizeResultAfterInterception();
357       }
358     }
359     GPR_ASSERT(callback_);
360     callback_();
361   }
362 
ClearHookPoints()363   void ClearHookPoints() {
364     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
365          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
366          i = static_cast<experimental::InterceptionHookPoints>(
367              static_cast<size_t>(i) + 1)) {
368       hooks_[static_cast<size_t>(i)] = false;
369     }
370   }
371 
372   std::array<bool,
373              static_cast<size_t>(
374                  experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)>
375       hooks_;
376 
377   size_t current_interceptor_index_ = 0;  // Current iterator
378   bool reverse_ = false;
379   bool ran_hijacking_interceptor_ = false;
380   Call* call_ = nullptr;  // The Call object is present along with CallOpSet
381                           // object/callback
382   CallOpSetInterface* ops_ = nullptr;
383   std::function<void(void)> callback_;
384 
385   ByteBuffer* send_message_ = nullptr;
386   bool* fail_send_message_ = nullptr;
387   const void** orig_send_message_ = nullptr;
388   std::function<Status(const void*)> serializer_;
389 
390   std::multimap<std::string, std::string>* send_initial_metadata_;
391 
392   grpc_status_code* code_ = nullptr;
393   std::string* error_details_ = nullptr;
394   std::string* error_message_ = nullptr;
395 
396   std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr;
397 
398   void* recv_message_ = nullptr;
399   bool* hijacked_recv_message_failed_ = nullptr;
400 
401   MetadataMap* recv_initial_metadata_ = nullptr;
402 
403   Status* recv_status_ = nullptr;
404 
405   MetadataMap* recv_trailing_metadata_ = nullptr;
406 };
407 
408 // A special implementation of InterceptorBatchMethods to send a Cancel
409 // notification down the interceptor stack
410 class CancelInterceptorBatchMethods
411     : public experimental::InterceptorBatchMethods {
412  public:
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)413   bool QueryInterceptionHookPoint(
414       experimental::InterceptionHookPoints type) override {
415     return type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL;
416   }
417 
Proceed()418   void Proceed() override {
419     // This is a no-op. For actual continuation of the RPC simply needs to
420     // return from the Intercept method
421   }
422 
Hijack()423   void Hijack() override {
424     // Only the client can hijack when sending down initial metadata
425     GPR_ASSERT(false &&
426                "It is illegal to call Hijack on a method which has a "
427                "Cancel notification");
428   }
429 
GetSerializedSendMessage()430   ByteBuffer* GetSerializedSendMessage() override {
431     GPR_ASSERT(false &&
432                "It is illegal to call GetSendMessage on a method which "
433                "has a Cancel notification");
434     return nullptr;
435   }
436 
GetSendMessageStatus()437   bool GetSendMessageStatus() override {
438     GPR_ASSERT(false &&
439                "It is illegal to call GetSendMessageStatus on a method which "
440                "has a Cancel notification");
441     return false;
442   }
443 
GetSendMessage()444   const void* GetSendMessage() override {
445     GPR_ASSERT(false &&
446                "It is illegal to call GetOriginalSendMessage on a method which "
447                "has a Cancel notification");
448     return nullptr;
449   }
450 
ModifySendMessage(const void *)451   void ModifySendMessage(const void* /*message*/) override {
452     GPR_ASSERT(false &&
453                "It is illegal to call ModifySendMessage on a method which "
454                "has a Cancel notification");
455   }
456 
GetSendInitialMetadata()457   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
458     GPR_ASSERT(false &&
459                "It is illegal to call GetSendInitialMetadata on a "
460                "method which has a Cancel notification");
461     return nullptr;
462   }
463 
GetSendStatus()464   Status GetSendStatus() override {
465     GPR_ASSERT(false &&
466                "It is illegal to call GetSendStatus on a method which "
467                "has a Cancel notification");
468     return Status();
469   }
470 
ModifySendStatus(const Status &)471   void ModifySendStatus(const Status& /*status*/) override {
472     GPR_ASSERT(false &&
473                "It is illegal to call ModifySendStatus on a method "
474                "which has a Cancel notification");
475   }
476 
GetSendTrailingMetadata()477   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
478     GPR_ASSERT(false &&
479                "It is illegal to call GetSendTrailingMetadata on a "
480                "method which has a Cancel notification");
481     return nullptr;
482   }
483 
GetRecvMessage()484   void* GetRecvMessage() override {
485     GPR_ASSERT(false &&
486                "It is illegal to call GetRecvMessage on a method which "
487                "has a Cancel notification");
488     return nullptr;
489   }
490 
GetRecvInitialMetadata()491   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
492       override {
493     GPR_ASSERT(false &&
494                "It is illegal to call GetRecvInitialMetadata on a "
495                "method which has a Cancel notification");
496     return nullptr;
497   }
498 
GetRecvStatus()499   Status* GetRecvStatus() override {
500     GPR_ASSERT(false &&
501                "It is illegal to call GetRecvStatus on a method which "
502                "has a Cancel notification");
503     return nullptr;
504   }
505 
GetRecvTrailingMetadata()506   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
507       override {
508     GPR_ASSERT(false &&
509                "It is illegal to call GetRecvTrailingMetadata on a "
510                "method which has a Cancel notification");
511     return nullptr;
512   }
513 
GetInterceptedChannel()514   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
515     GPR_ASSERT(false &&
516                "It is illegal to call GetInterceptedChannel on a "
517                "method which has a Cancel notification");
518     return std::unique_ptr<ChannelInterface>(nullptr);
519   }
520 
FailHijackedRecvMessage()521   void FailHijackedRecvMessage() override {
522     GPR_ASSERT(false &&
523                "It is illegal to call FailHijackedRecvMessage on a "
524                "method which has a Cancel notification");
525   }
526 
FailHijackedSendMessage()527   void FailHijackedSendMessage() override {
528     GPR_ASSERT(false &&
529                "It is illegal to call FailHijackedSendMessage on a "
530                "method which has a Cancel notification");
531   }
532 };
533 }  // namespace internal
534 }  // namespace grpc
535 
536 #endif  // GRPCPP_IMPL_INTERCEPTOR_COMMON_H
537