xref: /aosp_15_r20/external/grpc-grpc/test/cpp/end2end/interceptors_util.h (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #ifndef GRPC_TEST_CPP_END2END_INTERCEPTORS_UTIL_H
20 #define GRPC_TEST_CPP_END2END_INTERCEPTORS_UTIL_H
21 
22 #include <condition_variable>
23 
24 #include <gtest/gtest.h>
25 
26 #include "absl/strings/str_format.h"
27 
28 #include <grpcpp/channel.h>
29 
30 #include "src/core/lib/gprpp/crash.h"
31 #include "src/proto/grpc/testing/echo.grpc.pb.h"
32 #include "test/cpp/util/string_ref_helper.h"
33 
34 namespace grpc {
35 namespace testing {
36 // This interceptor does nothing. Just keeps a global count on the number of
37 // times it was invoked.
38 class PhonyInterceptor : public experimental::Interceptor {
39  public:
PhonyInterceptor()40   PhonyInterceptor() {}
41 
Intercept(experimental::InterceptorBatchMethods * methods)42   void Intercept(experimental::InterceptorBatchMethods* methods) override {
43     if (methods->QueryInterceptionHookPoint(
44             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
45       num_times_run_++;
46     } else if (methods->QueryInterceptionHookPoint(
47                    experimental::InterceptionHookPoints::
48                        POST_RECV_INITIAL_METADATA)) {
49       num_times_run_reverse_++;
50     } else if (methods->QueryInterceptionHookPoint(
51                    experimental::InterceptionHookPoints::PRE_SEND_CANCEL)) {
52       num_times_cancel_++;
53     }
54     methods->Proceed();
55   }
56 
Reset()57   static void Reset() {
58     num_times_run_.store(0);
59     num_times_run_reverse_.store(0);
60     num_times_cancel_.store(0);
61   }
62 
GetNumTimesRun()63   static int GetNumTimesRun() {
64     EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load());
65     return num_times_run_.load();
66   }
67 
GetNumTimesCancel()68   static int GetNumTimesCancel() { return num_times_cancel_.load(); }
69 
70  private:
71   static std::atomic<int> num_times_run_;
72   static std::atomic<int> num_times_run_reverse_;
73   static std::atomic<int> num_times_cancel_;
74 };
75 
76 class PhonyInterceptorFactory
77     : public experimental::ClientInterceptorFactoryInterface,
78       public experimental::ServerInterceptorFactoryInterface {
79  public:
CreateClientInterceptor(experimental::ClientRpcInfo *)80   experimental::Interceptor* CreateClientInterceptor(
81       experimental::ClientRpcInfo* /*info*/) override {
82     return new PhonyInterceptor();
83   }
84 
CreateServerInterceptor(experimental::ServerRpcInfo *)85   experimental::Interceptor* CreateServerInterceptor(
86       experimental::ServerRpcInfo* /*info*/) override {
87     return new PhonyInterceptor();
88   }
89 };
90 
91 // This interceptor can be used to test the interception mechanism.
92 class TestInterceptor : public experimental::Interceptor {
93  public:
TestInterceptor(const std::string & method,const char * suffix_for_stats,experimental::ClientRpcInfo * info)94   TestInterceptor(const std::string& method, const char* suffix_for_stats,
95                   experimental::ClientRpcInfo* info) {
96     EXPECT_EQ(info->method(), method);
97 
98     if (suffix_for_stats == nullptr || info->suffix_for_stats() == nullptr) {
99       EXPECT_EQ(info->suffix_for_stats(), suffix_for_stats);
100     } else {
101       EXPECT_EQ(strcmp(info->suffix_for_stats(), suffix_for_stats), 0);
102     }
103   }
104 
Intercept(experimental::InterceptorBatchMethods * methods)105   void Intercept(experimental::InterceptorBatchMethods* methods) override {
106     methods->Proceed();
107   }
108 };
109 
110 class TestInterceptorFactory
111     : public experimental::ClientInterceptorFactoryInterface {
112  public:
TestInterceptorFactory(const std::string & method,const char * suffix_for_stats)113   TestInterceptorFactory(const std::string& method,
114                          const char* suffix_for_stats)
115       : method_(method), suffix_for_stats_(suffix_for_stats) {}
116 
CreateClientInterceptor(experimental::ClientRpcInfo * info)117   experimental::Interceptor* CreateClientInterceptor(
118       experimental::ClientRpcInfo* info) override {
119     return new TestInterceptor(method_, suffix_for_stats_, info);
120   }
121 
122  private:
123   std::string method_;
124   const char* suffix_for_stats_;
125 };
126 
127 // This interceptor factory returns nullptr on interceptor creation
128 class NullInterceptorFactory
129     : public experimental::ClientInterceptorFactoryInterface,
130       public experimental::ServerInterceptorFactoryInterface {
131  public:
CreateClientInterceptor(experimental::ClientRpcInfo *)132   experimental::Interceptor* CreateClientInterceptor(
133       experimental::ClientRpcInfo* /*info*/) override {
134     return nullptr;
135   }
136 
CreateServerInterceptor(experimental::ServerRpcInfo *)137   experimental::Interceptor* CreateServerInterceptor(
138       experimental::ServerRpcInfo* /*info*/) override {
139     return nullptr;
140   }
141 };
142 
143 class EchoTestServiceStreamingImpl : public EchoTestService::Service {
144  public:
~EchoTestServiceStreamingImpl()145   ~EchoTestServiceStreamingImpl() override {}
146 
Echo(ServerContext * context,const EchoRequest * request,EchoResponse * response)147   Status Echo(ServerContext* context, const EchoRequest* request,
148               EchoResponse* response) override {
149     auto client_metadata = context->client_metadata();
150     for (const auto& pair : client_metadata) {
151       context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
152     }
153     response->set_message(request->message());
154     return Status::OK;
155   }
156 
BidiStream(ServerContext * context,grpc::ServerReaderWriter<EchoResponse,EchoRequest> * stream)157   Status BidiStream(
158       ServerContext* context,
159       grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
160     EchoRequest req;
161     EchoResponse resp;
162     auto client_metadata = context->client_metadata();
163     for (const auto& pair : client_metadata) {
164       context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
165     }
166 
167     while (stream->Read(&req)) {
168       resp.set_message(req.message());
169       EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions()));
170     }
171     return Status::OK;
172   }
173 
RequestStream(ServerContext * context,ServerReader<EchoRequest> * reader,EchoResponse * resp)174   Status RequestStream(ServerContext* context,
175                        ServerReader<EchoRequest>* reader,
176                        EchoResponse* resp) override {
177     auto client_metadata = context->client_metadata();
178     for (const auto& pair : client_metadata) {
179       context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
180     }
181 
182     EchoRequest req;
183     string response_str;
184     while (reader->Read(&req)) {
185       response_str += req.message();
186     }
187     resp->set_message(response_str);
188     return Status::OK;
189   }
190 
ResponseStream(ServerContext * context,const EchoRequest * req,ServerWriter<EchoResponse> * writer)191   Status ResponseStream(ServerContext* context, const EchoRequest* req,
192                         ServerWriter<EchoResponse>* writer) override {
193     auto client_metadata = context->client_metadata();
194     for (const auto& pair : client_metadata) {
195       context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
196     }
197 
198     EchoResponse resp;
199     resp.set_message(req->message());
200     for (int i = 0; i < 10; i++) {
201       EXPECT_TRUE(writer->Write(resp));
202     }
203     return Status::OK;
204   }
205 };
206 
207 constexpr int kNumStreamingMessages = 10;
208 
209 void MakeCall(const std::shared_ptr<Channel>& channel,
210               const StubOptions& options = StubOptions());
211 
212 void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel);
213 
214 void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel);
215 
216 void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel);
217 
218 void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel);
219 
220 void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel);
221 
222 void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel);
223 
224 void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel);
225 
226 void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
227 
228 bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
229                    const string& key, const string& value);
230 
231 bool CheckMetadata(const std::multimap<std::string, std::string>& map,
232                    const string& key, const string& value);
233 
234 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
235 CreatePhonyClientInterceptors();
236 
tag(int i)237 inline void* tag(int i) { return reinterpret_cast<void*>(i); }
detag(void * p)238 inline int detag(void* p) {
239   return static_cast<int>(reinterpret_cast<intptr_t>(p));
240 }
241 
242 class Verifier {
243  public:
Verifier()244   Verifier() : lambda_run_(false) {}
245   // Expect sets the expected ok value for a specific tag
Expect(int i,bool expect_ok)246   Verifier& Expect(int i, bool expect_ok) {
247     return ExpectUnless(i, expect_ok, false);
248   }
249   // ExpectUnless sets the expected ok value for a specific tag
250   // unless the tag was already marked seen (as a result of ExpectMaybe)
ExpectUnless(int i,bool expect_ok,bool seen)251   Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
252     if (!seen) {
253       expectations_[tag(i)] = expect_ok;
254     }
255     return *this;
256   }
257   // ExpectMaybe sets the expected ok value for a specific tag, but does not
258   // require it to appear
259   // If it does, sets *seen to true
ExpectMaybe(int i,bool expect_ok,bool * seen)260   Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
261     if (!*seen) {
262       maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
263     }
264     return *this;
265   }
266 
267   // Next waits for 1 async tag to complete, checks its
268   // expectations, and returns the tag
Next(CompletionQueue * cq,bool ignore_ok)269   int Next(CompletionQueue* cq, bool ignore_ok) {
270     bool ok;
271     void* got_tag;
272     EXPECT_TRUE(cq->Next(&got_tag, &ok));
273     GotTag(got_tag, ok, ignore_ok);
274     return detag(got_tag);
275   }
276 
277   template <typename T>
DoOnceThenAsyncNext(CompletionQueue * cq,void ** got_tag,bool * ok,T deadline,std::function<void (void)> lambda)278   CompletionQueue::NextStatus DoOnceThenAsyncNext(
279       CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
280       std::function<void(void)> lambda) {
281     if (lambda_run_) {
282       return cq->AsyncNext(got_tag, ok, deadline);
283     } else {
284       lambda_run_ = true;
285       return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
286     }
287   }
288 
289   // Verify keeps calling Next until all currently set
290   // expected tags are complete
Verify(CompletionQueue * cq)291   void Verify(CompletionQueue* cq) { Verify(cq, false); }
292 
293   // This version of Verify allows optionally ignoring the
294   // outcome of the expectation
Verify(CompletionQueue * cq,bool ignore_ok)295   void Verify(CompletionQueue* cq, bool ignore_ok) {
296     GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty());
297     while (!expectations_.empty()) {
298       Next(cq, ignore_ok);
299     }
300   }
301 
302   // This version of Verify stops after a certain deadline, and uses the
303   // DoThenAsyncNext API
304   // to call the lambda
Verify(CompletionQueue * cq,std::chrono::system_clock::time_point deadline,const std::function<void (void)> & lambda)305   void Verify(CompletionQueue* cq,
306               std::chrono::system_clock::time_point deadline,
307               const std::function<void(void)>& lambda) {
308     if (expectations_.empty()) {
309       bool ok;
310       void* got_tag;
311       EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
312                 CompletionQueue::TIMEOUT);
313     } else {
314       while (!expectations_.empty()) {
315         bool ok;
316         void* got_tag;
317         EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
318                   CompletionQueue::GOT_EVENT);
319         GotTag(got_tag, ok, false);
320       }
321     }
322   }
323 
324  private:
GotTag(void * got_tag,bool ok,bool ignore_ok)325   void GotTag(void* got_tag, bool ok, bool ignore_ok) {
326     auto it = expectations_.find(got_tag);
327     if (it != expectations_.end()) {
328       if (!ignore_ok) {
329         EXPECT_EQ(it->second, ok);
330       }
331       expectations_.erase(it);
332     } else {
333       auto it2 = maybe_expectations_.find(got_tag);
334       if (it2 != maybe_expectations_.end()) {
335         if (it2->second.seen != nullptr) {
336           EXPECT_FALSE(*it2->second.seen);
337           *it2->second.seen = true;
338         }
339         if (!ignore_ok) {
340           EXPECT_EQ(it2->second.ok, ok);
341         }
342       } else {
343         grpc_core::Crash(absl::StrFormat("Unexpected tag: %p", got_tag));
344       }
345     }
346   }
347 
348   struct MaybeExpect {
349     bool ok;
350     bool* seen;
351   };
352 
353   std::map<void*, bool> expectations_;
354   std::map<void*, MaybeExpect> maybe_expectations_;
355   bool lambda_run_;
356 };
357 
358 }  // namespace testing
359 }  // namespace grpc
360 
361 #endif  // GRPC_TEST_CPP_END2END_INTERCEPTORS_UTIL_H
362