xref: /aosp_15_r20/external/grpc-grpc/test/core/transport/binder/transport_stream_receiver_test.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2021 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <cassert>
16 #include <string>
17 #include <utility>
18 #include <vector>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 
23 #include "absl/memory/memory.h"
24 
25 #include "src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h"
26 #include "test/core/util/test_config.h"
27 
28 namespace grpc_binder {
29 namespace {
30 
31 // TODO(waynetu): These are hacks to make callbacks aware of their stream IDs
32 // and sequence numbers. Remove/Refactor these hacks when possible.
33 template <typename T>
Decode(const T &)34 std::pair<StreamIdentifier, int> Decode(const T& /*data*/) {
35   assert(false && "This should not be called");
36   return {};
37 }
38 
39 template <>
Decode(const std::string & data)40 std::pair<StreamIdentifier, int> Decode<std::string>(const std::string& data) {
41   assert(data.size() == sizeof(StreamIdentifier) + sizeof(int));
42   StreamIdentifier id{};
43   int seq_num{};
44   std::memcpy(&id, data.data(), sizeof(StreamIdentifier));
45   std::memcpy(&seq_num, data.data() + sizeof(StreamIdentifier), sizeof(int));
46   return std::make_pair(id, seq_num);
47 }
48 
49 template <>
Decode(const Metadata & data)50 std::pair<StreamIdentifier, int> Decode<Metadata>(const Metadata& data) {
51   assert(data.size() == 1);
52   const std::string& encoding = data[0].first;
53   return Decode(encoding);
54 }
55 
56 template <typename T>
Encode(StreamIdentifier,int)57 T Encode(StreamIdentifier /*id*/, int /*seq_num*/) {
58   assert(false && "This should not be called");
59   return {};
60 }
61 
62 template <>
Encode(StreamIdentifier id,int seq_num)63 std::string Encode<std::string>(StreamIdentifier id, int seq_num) {
64   char result[sizeof(StreamIdentifier) + sizeof(int)];
65   std::memcpy(result, &id, sizeof(StreamIdentifier));
66   std::memcpy(result + sizeof(StreamIdentifier), &seq_num, sizeof(int));
67   return std::string(result, sizeof(StreamIdentifier) + sizeof(int));
68 }
69 
70 template <>
Encode(StreamIdentifier id,int seq_num)71 Metadata Encode<Metadata>(StreamIdentifier id, int seq_num) {
72   return {{Encode<std::string>(id, seq_num), ""}};
73 }
74 
75 MATCHER_P2(StreamIdAndSeqNumMatch, id, seq_num, "") {
76   auto p = Decode(arg.value());
77   return p.first == id && p.second == seq_num;
78 }
79 
80 // MockCallback is used to verify the every callback passed to transaction
81 // receiver will eventually be invoked with the artifact of its corresponding
82 // binder transaction.
83 template <typename FirstArg, typename... TrailingArgs>
84 class MockCallback {
85  public:
MockCallback(StreamIdentifier id,int seq_num)86   explicit MockCallback(StreamIdentifier id, int seq_num)
87       : id_(id), seq_num_(seq_num) {}
88 
89   MOCK_METHOD(void, ActualCallback, (FirstArg), ());
90 
GetHandle()91   std::function<void(FirstArg, TrailingArgs...)> GetHandle() {
92     return [this](FirstArg first_arg, TrailingArgs...) {
93       this->ActualCallback(first_arg);
94     };
95   }
96 
ExpectCallbackInvocation()97   void ExpectCallbackInvocation() {
98     EXPECT_CALL(*this, ActualCallback(StreamIdAndSeqNumMatch(id_, seq_num_)));
99   }
100 
101  private:
102   StreamIdentifier id_;
103   int seq_num_;
104 };
105 
106 using MockInitialMetadataCallback = MockCallback<absl::StatusOr<Metadata>>;
107 using MockMessageCallback = MockCallback<absl::StatusOr<std::string>>;
108 using MockTrailingMetadataCallback =
109     MockCallback<absl::StatusOr<Metadata>, int>;
110 
111 class MockOpBatch {
112  public:
MockOpBatch(StreamIdentifier id,int flag,int seq_num)113   MockOpBatch(StreamIdentifier id, int flag, int seq_num)
114       : id_(id), flag_(flag), seq_num_(seq_num) {
115     if (flag_ & kFlagPrefix) {
116       initial_metadata_callback_ =
117           std::make_unique<MockInitialMetadataCallback>(id_, seq_num_);
118     }
119     if (flag_ & kFlagMessageData) {
120       message_callback_ = std::make_unique<MockMessageCallback>(id_, seq_num_);
121     }
122     if (flag_ & kFlagSuffix) {
123       trailing_metadata_callback_ =
124           std::make_unique<MockTrailingMetadataCallback>(id_, seq_num_);
125     }
126   }
127 
Complete(TransportStreamReceiver & receiver)128   void Complete(TransportStreamReceiver& receiver) {
129     if (flag_ & kFlagPrefix) {
130       initial_metadata_callback_->ExpectCallbackInvocation();
131       receiver.NotifyRecvInitialMetadata(id_, Encode<Metadata>(id_, seq_num_));
132     }
133     if (flag_ & kFlagMessageData) {
134       message_callback_->ExpectCallbackInvocation();
135       receiver.NotifyRecvMessage(id_, Encode<std::string>(id_, seq_num_));
136     }
137     if (flag_ & kFlagSuffix) {
138       trailing_metadata_callback_->ExpectCallbackInvocation();
139       receiver.NotifyRecvTrailingMetadata(id_, Encode<Metadata>(id_, seq_num_),
140                                           0);
141     }
142   }
143 
RequestRecv(TransportStreamReceiver & receiver)144   void RequestRecv(TransportStreamReceiver& receiver) {
145     if (flag_ & kFlagPrefix) {
146       receiver.RegisterRecvInitialMetadata(
147           id_, initial_metadata_callback_->GetHandle());
148     }
149     if (flag_ & kFlagMessageData) {
150       receiver.RegisterRecvMessage(id_, message_callback_->GetHandle());
151     }
152     if (flag_ & kFlagSuffix) {
153       receiver.RegisterRecvTrailingMetadata(
154           id_, trailing_metadata_callback_->GetHandle());
155     }
156   }
157 
NextBatch(int flag) const158   MockOpBatch NextBatch(int flag) const {
159     return MockOpBatch(id_, flag, seq_num_ + 1);
160   }
161 
162  private:
163   std::unique_ptr<MockInitialMetadataCallback> initial_metadata_callback_;
164   std::unique_ptr<MockMessageCallback> message_callback_;
165   std::unique_ptr<MockTrailingMetadataCallback> trailing_metadata_callback_;
166   int id_, flag_, seq_num_;
167 };
168 
169 class TransportStreamReceiverTest : public ::testing::Test {
170  protected:
NewGrpcStream(int flag)171   MockOpBatch NewGrpcStream(int flag) {
172     return MockOpBatch(current_id_++, flag, 0);
173   }
174 
175   StreamIdentifier current_id_ = 0;
176 };
177 
178 const int kFlagAll = kFlagPrefix | kFlagMessageData | kFlagSuffix;
179 
180 }  // namespace
181 
TEST_F(TransportStreamReceiverTest,MultipleStreamRequestThenComplete)182 TEST_F(TransportStreamReceiverTest, MultipleStreamRequestThenComplete) {
183   TransportStreamReceiverImpl receiver(/*is_client=*/true);
184   MockOpBatch t0 = NewGrpcStream(kFlagAll);
185   t0.RequestRecv(receiver);
186   t0.Complete(receiver);
187 }
188 
TEST_F(TransportStreamReceiverTest,MultipleStreamCompleteThenRequest)189 TEST_F(TransportStreamReceiverTest, MultipleStreamCompleteThenRequest) {
190   TransportStreamReceiverImpl receiver(/*is_client=*/true);
191   MockOpBatch t0 = NewGrpcStream(kFlagAll);
192   t0.Complete(receiver);
193   t0.RequestRecv(receiver);
194 }
195 
TEST_F(TransportStreamReceiverTest,MultipleStreamInterleaved)196 TEST_F(TransportStreamReceiverTest, MultipleStreamInterleaved) {
197   TransportStreamReceiverImpl receiver(/*is_client=*/true);
198   MockOpBatch t0 = NewGrpcStream(kFlagAll);
199   MockOpBatch t1 = NewGrpcStream(kFlagAll);
200   t1.Complete(receiver);
201   t0.Complete(receiver);
202   t0.RequestRecv(receiver);
203   t1.RequestRecv(receiver);
204 }
205 
TEST_F(TransportStreamReceiverTest,MultipleStreamInterleavedReversed)206 TEST_F(TransportStreamReceiverTest, MultipleStreamInterleavedReversed) {
207   TransportStreamReceiverImpl receiver(/*is_client=*/true);
208   MockOpBatch t0 = NewGrpcStream(kFlagAll);
209   MockOpBatch t1 = NewGrpcStream(kFlagAll);
210   t0.RequestRecv(receiver);
211   t1.RequestRecv(receiver);
212   t1.Complete(receiver);
213   t0.Complete(receiver);
214 }
215 
TEST_F(TransportStreamReceiverTest,MultipleStreamMoreInterleaved)216 TEST_F(TransportStreamReceiverTest, MultipleStreamMoreInterleaved) {
217   TransportStreamReceiverImpl receiver(/*is_client=*/true);
218   MockOpBatch t0 = NewGrpcStream(kFlagAll);
219   MockOpBatch t1 = NewGrpcStream(kFlagAll);
220   t0.RequestRecv(receiver);
221   t1.Complete(receiver);
222   MockOpBatch t2 = NewGrpcStream(kFlagAll);
223   t2.RequestRecv(receiver);
224   t0.Complete(receiver);
225   t1.RequestRecv(receiver);
226   t2.Complete(receiver);
227 }
228 
TEST_F(TransportStreamReceiverTest,SingleStreamUnaryCall)229 TEST_F(TransportStreamReceiverTest, SingleStreamUnaryCall) {
230   TransportStreamReceiverImpl receiver(/*is_client=*/true);
231   MockOpBatch t0 = NewGrpcStream(kFlagPrefix);
232   MockOpBatch t1 = t0.NextBatch(kFlagMessageData);
233   MockOpBatch t2 = t1.NextBatch(kFlagSuffix);
234   t0.RequestRecv(receiver);
235   t1.RequestRecv(receiver);
236   t2.RequestRecv(receiver);
237   t0.Complete(receiver);
238   t1.Complete(receiver);
239   t2.Complete(receiver);
240 }
241 
TEST_F(TransportStreamReceiverTest,SingleStreamStreamingCall)242 TEST_F(TransportStreamReceiverTest, SingleStreamStreamingCall) {
243   TransportStreamReceiverImpl receiver(/*is_client=*/true);
244   MockOpBatch t0 = NewGrpcStream(kFlagPrefix);
245   t0.RequestRecv(receiver);
246   t0.Complete(receiver);
247   MockOpBatch t1 = t0.NextBatch(kFlagMessageData);
248   t1.Complete(receiver);
249   t1.RequestRecv(receiver);
250   MockOpBatch t2 = t1.NextBatch(kFlagMessageData);
251   t2.RequestRecv(receiver);
252   t2.Complete(receiver);
253   MockOpBatch t3 = t2.NextBatch(kFlagMessageData);
254   MockOpBatch t4 = t3.NextBatch(kFlagMessageData);
255   t3.Complete(receiver);
256   t4.Complete(receiver);
257   t3.RequestRecv(receiver);
258   t4.RequestRecv(receiver);
259 }
260 
TEST_F(TransportStreamReceiverTest,DISABLED_SingleStreamBufferedCallbacks)261 TEST_F(TransportStreamReceiverTest, DISABLED_SingleStreamBufferedCallbacks) {
262   TransportStreamReceiverImpl receiver(/*is_client=*/true);
263   MockOpBatch t0 = NewGrpcStream(kFlagPrefix);
264   MockOpBatch t1 = t0.NextBatch(kFlagMessageData);
265   MockOpBatch t2 = t1.NextBatch(kFlagMessageData);
266   MockOpBatch t3 = t2.NextBatch(kFlagSuffix);
267   t0.RequestRecv(receiver);
268   // TODO(waynetu): Can gRPC issues recv_message before it actually receives the
269   // previous one?
270   t1.RequestRecv(receiver);
271   t2.RequestRecv(receiver);
272   t3.RequestRecv(receiver);
273   t0.Complete(receiver);
274   t1.Complete(receiver);
275   t2.Complete(receiver);
276   t3.Complete(receiver);
277 }
278 
279 // TODO(waynetu): Should we have some concurrent stress tests to make sure that
280 // thread safety is well taken care of?
281 
282 }  // namespace grpc_binder
283 
main(int argc,char ** argv)284 int main(int argc, char** argv) {
285   ::testing::InitGoogleTest(&argc, argv);
286   grpc::testing::TestEnvironment env(&argc, argv);
287   return RUN_ALL_TESTS();
288 }
289