1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #ifndef GRPC_TEST_CPP_END2END_INTERCEPTORS_UTIL_H
20 #define GRPC_TEST_CPP_END2END_INTERCEPTORS_UTIL_H
21
22 #include <condition_variable>
23
24 #include <gtest/gtest.h>
25
26 #include "absl/strings/str_format.h"
27
28 #include <grpcpp/channel.h>
29
30 #include "src/core/lib/gprpp/crash.h"
31 #include "src/proto/grpc/testing/echo.grpc.pb.h"
32 #include "test/cpp/util/string_ref_helper.h"
33
34 namespace grpc {
35 namespace testing {
36 // This interceptor does nothing. Just keeps a global count on the number of
37 // times it was invoked.
38 class PhonyInterceptor : public experimental::Interceptor {
39 public:
PhonyInterceptor()40 PhonyInterceptor() {}
41
Intercept(experimental::InterceptorBatchMethods * methods)42 void Intercept(experimental::InterceptorBatchMethods* methods) override {
43 if (methods->QueryInterceptionHookPoint(
44 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
45 num_times_run_++;
46 } else if (methods->QueryInterceptionHookPoint(
47 experimental::InterceptionHookPoints::
48 POST_RECV_INITIAL_METADATA)) {
49 num_times_run_reverse_++;
50 } else if (methods->QueryInterceptionHookPoint(
51 experimental::InterceptionHookPoints::PRE_SEND_CANCEL)) {
52 num_times_cancel_++;
53 }
54 methods->Proceed();
55 }
56
Reset()57 static void Reset() {
58 num_times_run_.store(0);
59 num_times_run_reverse_.store(0);
60 num_times_cancel_.store(0);
61 }
62
GetNumTimesRun()63 static int GetNumTimesRun() {
64 EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load());
65 return num_times_run_.load();
66 }
67
GetNumTimesCancel()68 static int GetNumTimesCancel() { return num_times_cancel_.load(); }
69
70 private:
71 static std::atomic<int> num_times_run_;
72 static std::atomic<int> num_times_run_reverse_;
73 static std::atomic<int> num_times_cancel_;
74 };
75
76 class PhonyInterceptorFactory
77 : public experimental::ClientInterceptorFactoryInterface,
78 public experimental::ServerInterceptorFactoryInterface {
79 public:
CreateClientInterceptor(experimental::ClientRpcInfo *)80 experimental::Interceptor* CreateClientInterceptor(
81 experimental::ClientRpcInfo* /*info*/) override {
82 return new PhonyInterceptor();
83 }
84
CreateServerInterceptor(experimental::ServerRpcInfo *)85 experimental::Interceptor* CreateServerInterceptor(
86 experimental::ServerRpcInfo* /*info*/) override {
87 return new PhonyInterceptor();
88 }
89 };
90
91 // This interceptor can be used to test the interception mechanism.
92 class TestInterceptor : public experimental::Interceptor {
93 public:
TestInterceptor(const std::string & method,const char * suffix_for_stats,experimental::ClientRpcInfo * info)94 TestInterceptor(const std::string& method, const char* suffix_for_stats,
95 experimental::ClientRpcInfo* info) {
96 EXPECT_EQ(info->method(), method);
97
98 if (suffix_for_stats == nullptr || info->suffix_for_stats() == nullptr) {
99 EXPECT_EQ(info->suffix_for_stats(), suffix_for_stats);
100 } else {
101 EXPECT_EQ(strcmp(info->suffix_for_stats(), suffix_for_stats), 0);
102 }
103 }
104
Intercept(experimental::InterceptorBatchMethods * methods)105 void Intercept(experimental::InterceptorBatchMethods* methods) override {
106 methods->Proceed();
107 }
108 };
109
110 class TestInterceptorFactory
111 : public experimental::ClientInterceptorFactoryInterface {
112 public:
TestInterceptorFactory(const std::string & method,const char * suffix_for_stats)113 TestInterceptorFactory(const std::string& method,
114 const char* suffix_for_stats)
115 : method_(method), suffix_for_stats_(suffix_for_stats) {}
116
CreateClientInterceptor(experimental::ClientRpcInfo * info)117 experimental::Interceptor* CreateClientInterceptor(
118 experimental::ClientRpcInfo* info) override {
119 return new TestInterceptor(method_, suffix_for_stats_, info);
120 }
121
122 private:
123 std::string method_;
124 const char* suffix_for_stats_;
125 };
126
127 // This interceptor factory returns nullptr on interceptor creation
128 class NullInterceptorFactory
129 : public experimental::ClientInterceptorFactoryInterface,
130 public experimental::ServerInterceptorFactoryInterface {
131 public:
CreateClientInterceptor(experimental::ClientRpcInfo *)132 experimental::Interceptor* CreateClientInterceptor(
133 experimental::ClientRpcInfo* /*info*/) override {
134 return nullptr;
135 }
136
CreateServerInterceptor(experimental::ServerRpcInfo *)137 experimental::Interceptor* CreateServerInterceptor(
138 experimental::ServerRpcInfo* /*info*/) override {
139 return nullptr;
140 }
141 };
142
143 class EchoTestServiceStreamingImpl : public EchoTestService::Service {
144 public:
~EchoTestServiceStreamingImpl()145 ~EchoTestServiceStreamingImpl() override {}
146
Echo(ServerContext * context,const EchoRequest * request,EchoResponse * response)147 Status Echo(ServerContext* context, const EchoRequest* request,
148 EchoResponse* response) override {
149 auto client_metadata = context->client_metadata();
150 for (const auto& pair : client_metadata) {
151 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
152 }
153 response->set_message(request->message());
154 return Status::OK;
155 }
156
BidiStream(ServerContext * context,grpc::ServerReaderWriter<EchoResponse,EchoRequest> * stream)157 Status BidiStream(
158 ServerContext* context,
159 grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
160 EchoRequest req;
161 EchoResponse resp;
162 auto client_metadata = context->client_metadata();
163 for (const auto& pair : client_metadata) {
164 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
165 }
166
167 while (stream->Read(&req)) {
168 resp.set_message(req.message());
169 EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions()));
170 }
171 return Status::OK;
172 }
173
RequestStream(ServerContext * context,ServerReader<EchoRequest> * reader,EchoResponse * resp)174 Status RequestStream(ServerContext* context,
175 ServerReader<EchoRequest>* reader,
176 EchoResponse* resp) override {
177 auto client_metadata = context->client_metadata();
178 for (const auto& pair : client_metadata) {
179 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
180 }
181
182 EchoRequest req;
183 string response_str;
184 while (reader->Read(&req)) {
185 response_str += req.message();
186 }
187 resp->set_message(response_str);
188 return Status::OK;
189 }
190
ResponseStream(ServerContext * context,const EchoRequest * req,ServerWriter<EchoResponse> * writer)191 Status ResponseStream(ServerContext* context, const EchoRequest* req,
192 ServerWriter<EchoResponse>* writer) override {
193 auto client_metadata = context->client_metadata();
194 for (const auto& pair : client_metadata) {
195 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
196 }
197
198 EchoResponse resp;
199 resp.set_message(req->message());
200 for (int i = 0; i < 10; i++) {
201 EXPECT_TRUE(writer->Write(resp));
202 }
203 return Status::OK;
204 }
205 };
206
207 constexpr int kNumStreamingMessages = 10;
208
209 void MakeCall(const std::shared_ptr<Channel>& channel,
210 const StubOptions& options = StubOptions());
211
212 void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel);
213
214 void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel);
215
216 void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel);
217
218 void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel);
219
220 void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel);
221
222 void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel);
223
224 void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel);
225
226 void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
227
228 bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
229 const string& key, const string& value);
230
231 bool CheckMetadata(const std::multimap<std::string, std::string>& map,
232 const string& key, const string& value);
233
234 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
235 CreatePhonyClientInterceptors();
236
tag(int i)237 inline void* tag(int i) { return reinterpret_cast<void*>(i); }
detag(void * p)238 inline int detag(void* p) {
239 return static_cast<int>(reinterpret_cast<intptr_t>(p));
240 }
241
242 class Verifier {
243 public:
Verifier()244 Verifier() : lambda_run_(false) {}
245 // Expect sets the expected ok value for a specific tag
Expect(int i,bool expect_ok)246 Verifier& Expect(int i, bool expect_ok) {
247 return ExpectUnless(i, expect_ok, false);
248 }
249 // ExpectUnless sets the expected ok value for a specific tag
250 // unless the tag was already marked seen (as a result of ExpectMaybe)
ExpectUnless(int i,bool expect_ok,bool seen)251 Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
252 if (!seen) {
253 expectations_[tag(i)] = expect_ok;
254 }
255 return *this;
256 }
257 // ExpectMaybe sets the expected ok value for a specific tag, but does not
258 // require it to appear
259 // If it does, sets *seen to true
ExpectMaybe(int i,bool expect_ok,bool * seen)260 Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
261 if (!*seen) {
262 maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
263 }
264 return *this;
265 }
266
267 // Next waits for 1 async tag to complete, checks its
268 // expectations, and returns the tag
Next(CompletionQueue * cq,bool ignore_ok)269 int Next(CompletionQueue* cq, bool ignore_ok) {
270 bool ok;
271 void* got_tag;
272 EXPECT_TRUE(cq->Next(&got_tag, &ok));
273 GotTag(got_tag, ok, ignore_ok);
274 return detag(got_tag);
275 }
276
277 template <typename T>
DoOnceThenAsyncNext(CompletionQueue * cq,void ** got_tag,bool * ok,T deadline,std::function<void (void)> lambda)278 CompletionQueue::NextStatus DoOnceThenAsyncNext(
279 CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
280 std::function<void(void)> lambda) {
281 if (lambda_run_) {
282 return cq->AsyncNext(got_tag, ok, deadline);
283 } else {
284 lambda_run_ = true;
285 return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
286 }
287 }
288
289 // Verify keeps calling Next until all currently set
290 // expected tags are complete
Verify(CompletionQueue * cq)291 void Verify(CompletionQueue* cq) { Verify(cq, false); }
292
293 // This version of Verify allows optionally ignoring the
294 // outcome of the expectation
Verify(CompletionQueue * cq,bool ignore_ok)295 void Verify(CompletionQueue* cq, bool ignore_ok) {
296 GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty());
297 while (!expectations_.empty()) {
298 Next(cq, ignore_ok);
299 }
300 }
301
302 // This version of Verify stops after a certain deadline, and uses the
303 // DoThenAsyncNext API
304 // to call the lambda
Verify(CompletionQueue * cq,std::chrono::system_clock::time_point deadline,const std::function<void (void)> & lambda)305 void Verify(CompletionQueue* cq,
306 std::chrono::system_clock::time_point deadline,
307 const std::function<void(void)>& lambda) {
308 if (expectations_.empty()) {
309 bool ok;
310 void* got_tag;
311 EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
312 CompletionQueue::TIMEOUT);
313 } else {
314 while (!expectations_.empty()) {
315 bool ok;
316 void* got_tag;
317 EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
318 CompletionQueue::GOT_EVENT);
319 GotTag(got_tag, ok, false);
320 }
321 }
322 }
323
324 private:
GotTag(void * got_tag,bool ok,bool ignore_ok)325 void GotTag(void* got_tag, bool ok, bool ignore_ok) {
326 auto it = expectations_.find(got_tag);
327 if (it != expectations_.end()) {
328 if (!ignore_ok) {
329 EXPECT_EQ(it->second, ok);
330 }
331 expectations_.erase(it);
332 } else {
333 auto it2 = maybe_expectations_.find(got_tag);
334 if (it2 != maybe_expectations_.end()) {
335 if (it2->second.seen != nullptr) {
336 EXPECT_FALSE(*it2->second.seen);
337 *it2->second.seen = true;
338 }
339 if (!ignore_ok) {
340 EXPECT_EQ(it2->second.ok, ok);
341 }
342 } else {
343 grpc_core::Crash(absl::StrFormat("Unexpected tag: %p", got_tag));
344 }
345 }
346 }
347
348 struct MaybeExpect {
349 bool ok;
350 bool* seen;
351 };
352
353 std::map<void*, bool> expectations_;
354 std::map<void*, MaybeExpect> maybe_expectations_;
355 bool lambda_run_;
356 };
357
358 } // namespace testing
359 } // namespace grpc
360
361 #endif // GRPC_TEST_CPP_END2END_INTERCEPTORS_UTIL_H
362