1 // Copyright 2021 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <cassert>
16 #include <string>
17 #include <utility>
18 #include <vector>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22
23 #include "absl/memory/memory.h"
24
25 #include "src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h"
26 #include "test/core/util/test_config.h"
27
28 namespace grpc_binder {
29 namespace {
30
31 // TODO(waynetu): These are hacks to make callbacks aware of their stream IDs
32 // and sequence numbers. Remove/Refactor these hacks when possible.
33 template <typename T>
Decode(const T &)34 std::pair<StreamIdentifier, int> Decode(const T& /*data*/) {
35 assert(false && "This should not be called");
36 return {};
37 }
38
39 template <>
Decode(const std::string & data)40 std::pair<StreamIdentifier, int> Decode<std::string>(const std::string& data) {
41 assert(data.size() == sizeof(StreamIdentifier) + sizeof(int));
42 StreamIdentifier id{};
43 int seq_num{};
44 std::memcpy(&id, data.data(), sizeof(StreamIdentifier));
45 std::memcpy(&seq_num, data.data() + sizeof(StreamIdentifier), sizeof(int));
46 return std::make_pair(id, seq_num);
47 }
48
49 template <>
Decode(const Metadata & data)50 std::pair<StreamIdentifier, int> Decode<Metadata>(const Metadata& data) {
51 assert(data.size() == 1);
52 const std::string& encoding = data[0].first;
53 return Decode(encoding);
54 }
55
56 template <typename T>
Encode(StreamIdentifier,int)57 T Encode(StreamIdentifier /*id*/, int /*seq_num*/) {
58 assert(false && "This should not be called");
59 return {};
60 }
61
62 template <>
Encode(StreamIdentifier id,int seq_num)63 std::string Encode<std::string>(StreamIdentifier id, int seq_num) {
64 char result[sizeof(StreamIdentifier) + sizeof(int)];
65 std::memcpy(result, &id, sizeof(StreamIdentifier));
66 std::memcpy(result + sizeof(StreamIdentifier), &seq_num, sizeof(int));
67 return std::string(result, sizeof(StreamIdentifier) + sizeof(int));
68 }
69
70 template <>
Encode(StreamIdentifier id,int seq_num)71 Metadata Encode<Metadata>(StreamIdentifier id, int seq_num) {
72 return {{Encode<std::string>(id, seq_num), ""}};
73 }
74
75 MATCHER_P2(StreamIdAndSeqNumMatch, id, seq_num, "") {
76 auto p = Decode(arg.value());
77 return p.first == id && p.second == seq_num;
78 }
79
80 // MockCallback is used to verify the every callback passed to transaction
81 // receiver will eventually be invoked with the artifact of its corresponding
82 // binder transaction.
83 template <typename FirstArg, typename... TrailingArgs>
84 class MockCallback {
85 public:
MockCallback(StreamIdentifier id,int seq_num)86 explicit MockCallback(StreamIdentifier id, int seq_num)
87 : id_(id), seq_num_(seq_num) {}
88
89 MOCK_METHOD(void, ActualCallback, (FirstArg), ());
90
GetHandle()91 std::function<void(FirstArg, TrailingArgs...)> GetHandle() {
92 return [this](FirstArg first_arg, TrailingArgs...) {
93 this->ActualCallback(first_arg);
94 };
95 }
96
ExpectCallbackInvocation()97 void ExpectCallbackInvocation() {
98 EXPECT_CALL(*this, ActualCallback(StreamIdAndSeqNumMatch(id_, seq_num_)));
99 }
100
101 private:
102 StreamIdentifier id_;
103 int seq_num_;
104 };
105
106 using MockInitialMetadataCallback = MockCallback<absl::StatusOr<Metadata>>;
107 using MockMessageCallback = MockCallback<absl::StatusOr<std::string>>;
108 using MockTrailingMetadataCallback =
109 MockCallback<absl::StatusOr<Metadata>, int>;
110
111 class MockOpBatch {
112 public:
MockOpBatch(StreamIdentifier id,int flag,int seq_num)113 MockOpBatch(StreamIdentifier id, int flag, int seq_num)
114 : id_(id), flag_(flag), seq_num_(seq_num) {
115 if (flag_ & kFlagPrefix) {
116 initial_metadata_callback_ =
117 std::make_unique<MockInitialMetadataCallback>(id_, seq_num_);
118 }
119 if (flag_ & kFlagMessageData) {
120 message_callback_ = std::make_unique<MockMessageCallback>(id_, seq_num_);
121 }
122 if (flag_ & kFlagSuffix) {
123 trailing_metadata_callback_ =
124 std::make_unique<MockTrailingMetadataCallback>(id_, seq_num_);
125 }
126 }
127
Complete(TransportStreamReceiver & receiver)128 void Complete(TransportStreamReceiver& receiver) {
129 if (flag_ & kFlagPrefix) {
130 initial_metadata_callback_->ExpectCallbackInvocation();
131 receiver.NotifyRecvInitialMetadata(id_, Encode<Metadata>(id_, seq_num_));
132 }
133 if (flag_ & kFlagMessageData) {
134 message_callback_->ExpectCallbackInvocation();
135 receiver.NotifyRecvMessage(id_, Encode<std::string>(id_, seq_num_));
136 }
137 if (flag_ & kFlagSuffix) {
138 trailing_metadata_callback_->ExpectCallbackInvocation();
139 receiver.NotifyRecvTrailingMetadata(id_, Encode<Metadata>(id_, seq_num_),
140 0);
141 }
142 }
143
RequestRecv(TransportStreamReceiver & receiver)144 void RequestRecv(TransportStreamReceiver& receiver) {
145 if (flag_ & kFlagPrefix) {
146 receiver.RegisterRecvInitialMetadata(
147 id_, initial_metadata_callback_->GetHandle());
148 }
149 if (flag_ & kFlagMessageData) {
150 receiver.RegisterRecvMessage(id_, message_callback_->GetHandle());
151 }
152 if (flag_ & kFlagSuffix) {
153 receiver.RegisterRecvTrailingMetadata(
154 id_, trailing_metadata_callback_->GetHandle());
155 }
156 }
157
NextBatch(int flag) const158 MockOpBatch NextBatch(int flag) const {
159 return MockOpBatch(id_, flag, seq_num_ + 1);
160 }
161
162 private:
163 std::unique_ptr<MockInitialMetadataCallback> initial_metadata_callback_;
164 std::unique_ptr<MockMessageCallback> message_callback_;
165 std::unique_ptr<MockTrailingMetadataCallback> trailing_metadata_callback_;
166 int id_, flag_, seq_num_;
167 };
168
169 class TransportStreamReceiverTest : public ::testing::Test {
170 protected:
NewGrpcStream(int flag)171 MockOpBatch NewGrpcStream(int flag) {
172 return MockOpBatch(current_id_++, flag, 0);
173 }
174
175 StreamIdentifier current_id_ = 0;
176 };
177
178 const int kFlagAll = kFlagPrefix | kFlagMessageData | kFlagSuffix;
179
180 } // namespace
181
TEST_F(TransportStreamReceiverTest,MultipleStreamRequestThenComplete)182 TEST_F(TransportStreamReceiverTest, MultipleStreamRequestThenComplete) {
183 TransportStreamReceiverImpl receiver(/*is_client=*/true);
184 MockOpBatch t0 = NewGrpcStream(kFlagAll);
185 t0.RequestRecv(receiver);
186 t0.Complete(receiver);
187 }
188
TEST_F(TransportStreamReceiverTest,MultipleStreamCompleteThenRequest)189 TEST_F(TransportStreamReceiverTest, MultipleStreamCompleteThenRequest) {
190 TransportStreamReceiverImpl receiver(/*is_client=*/true);
191 MockOpBatch t0 = NewGrpcStream(kFlagAll);
192 t0.Complete(receiver);
193 t0.RequestRecv(receiver);
194 }
195
TEST_F(TransportStreamReceiverTest,MultipleStreamInterleaved)196 TEST_F(TransportStreamReceiverTest, MultipleStreamInterleaved) {
197 TransportStreamReceiverImpl receiver(/*is_client=*/true);
198 MockOpBatch t0 = NewGrpcStream(kFlagAll);
199 MockOpBatch t1 = NewGrpcStream(kFlagAll);
200 t1.Complete(receiver);
201 t0.Complete(receiver);
202 t0.RequestRecv(receiver);
203 t1.RequestRecv(receiver);
204 }
205
TEST_F(TransportStreamReceiverTest,MultipleStreamInterleavedReversed)206 TEST_F(TransportStreamReceiverTest, MultipleStreamInterleavedReversed) {
207 TransportStreamReceiverImpl receiver(/*is_client=*/true);
208 MockOpBatch t0 = NewGrpcStream(kFlagAll);
209 MockOpBatch t1 = NewGrpcStream(kFlagAll);
210 t0.RequestRecv(receiver);
211 t1.RequestRecv(receiver);
212 t1.Complete(receiver);
213 t0.Complete(receiver);
214 }
215
TEST_F(TransportStreamReceiverTest,MultipleStreamMoreInterleaved)216 TEST_F(TransportStreamReceiverTest, MultipleStreamMoreInterleaved) {
217 TransportStreamReceiverImpl receiver(/*is_client=*/true);
218 MockOpBatch t0 = NewGrpcStream(kFlagAll);
219 MockOpBatch t1 = NewGrpcStream(kFlagAll);
220 t0.RequestRecv(receiver);
221 t1.Complete(receiver);
222 MockOpBatch t2 = NewGrpcStream(kFlagAll);
223 t2.RequestRecv(receiver);
224 t0.Complete(receiver);
225 t1.RequestRecv(receiver);
226 t2.Complete(receiver);
227 }
228
TEST_F(TransportStreamReceiverTest,SingleStreamUnaryCall)229 TEST_F(TransportStreamReceiverTest, SingleStreamUnaryCall) {
230 TransportStreamReceiverImpl receiver(/*is_client=*/true);
231 MockOpBatch t0 = NewGrpcStream(kFlagPrefix);
232 MockOpBatch t1 = t0.NextBatch(kFlagMessageData);
233 MockOpBatch t2 = t1.NextBatch(kFlagSuffix);
234 t0.RequestRecv(receiver);
235 t1.RequestRecv(receiver);
236 t2.RequestRecv(receiver);
237 t0.Complete(receiver);
238 t1.Complete(receiver);
239 t2.Complete(receiver);
240 }
241
TEST_F(TransportStreamReceiverTest,SingleStreamStreamingCall)242 TEST_F(TransportStreamReceiverTest, SingleStreamStreamingCall) {
243 TransportStreamReceiverImpl receiver(/*is_client=*/true);
244 MockOpBatch t0 = NewGrpcStream(kFlagPrefix);
245 t0.RequestRecv(receiver);
246 t0.Complete(receiver);
247 MockOpBatch t1 = t0.NextBatch(kFlagMessageData);
248 t1.Complete(receiver);
249 t1.RequestRecv(receiver);
250 MockOpBatch t2 = t1.NextBatch(kFlagMessageData);
251 t2.RequestRecv(receiver);
252 t2.Complete(receiver);
253 MockOpBatch t3 = t2.NextBatch(kFlagMessageData);
254 MockOpBatch t4 = t3.NextBatch(kFlagMessageData);
255 t3.Complete(receiver);
256 t4.Complete(receiver);
257 t3.RequestRecv(receiver);
258 t4.RequestRecv(receiver);
259 }
260
TEST_F(TransportStreamReceiverTest,DISABLED_SingleStreamBufferedCallbacks)261 TEST_F(TransportStreamReceiverTest, DISABLED_SingleStreamBufferedCallbacks) {
262 TransportStreamReceiverImpl receiver(/*is_client=*/true);
263 MockOpBatch t0 = NewGrpcStream(kFlagPrefix);
264 MockOpBatch t1 = t0.NextBatch(kFlagMessageData);
265 MockOpBatch t2 = t1.NextBatch(kFlagMessageData);
266 MockOpBatch t3 = t2.NextBatch(kFlagSuffix);
267 t0.RequestRecv(receiver);
268 // TODO(waynetu): Can gRPC issues recv_message before it actually receives the
269 // previous one?
270 t1.RequestRecv(receiver);
271 t2.RequestRecv(receiver);
272 t3.RequestRecv(receiver);
273 t0.Complete(receiver);
274 t1.Complete(receiver);
275 t2.Complete(receiver);
276 t3.Complete(receiver);
277 }
278
279 // TODO(waynetu): Should we have some concurrent stress tests to make sure that
280 // thread safety is well taken care of?
281
282 } // namespace grpc_binder
283
main(int argc,char ** argv)284 int main(int argc, char** argv) {
285 ::testing::InitGoogleTest(&argc, argv);
286 grpc::testing::TestEnvironment env(&argc, argv);
287 return RUN_ALL_TESTS();
288 }
289