1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc/raw/internal/method_union.h"
16
17 #include <array>
18
19 #include "pw_bytes/array.h"
20 #include "pw_protobuf/decoder.h"
21 #include "pw_protobuf/encoder.h"
22 #include "pw_rpc/internal/test_utils.h"
23 #include "pw_rpc/service.h"
24 #include "pw_rpc_test_protos/test.pwpb.h"
25 #include "pw_unit_test/framework.h"
26
27 namespace pw::rpc::internal {
28 namespace {
29
30 namespace TestRequest = ::pw::rpc::test::pwpb::TestRequest;
31 namespace TestResponse = ::pw::rpc::test::pwpb::TestResponse;
32
33 template <typename Implementation>
34 class FakeGeneratedService : public Service {
35 public:
FakeGeneratedService(uint32_t id)36 constexpr FakeGeneratedService(uint32_t id) : Service(id, kMethods) {}
37
38 static constexpr std::array<RawMethodUnion, 3> kMethods = {
39 GetRawMethodFor<&Implementation::DoNothing, MethodType::kUnary>(10u),
40 GetRawMethodFor<&Implementation::AddFive, MethodType::kUnary>(11u),
41 GetRawMethodFor<&Implementation::StartStream,
42 MethodType::kServerStreaming>(12u),
43 };
44 };
45
46 class FakeGeneratedServiceImpl
47 : public FakeGeneratedService<FakeGeneratedServiceImpl> {
48 public:
FakeGeneratedServiceImpl(uint32_t id)49 FakeGeneratedServiceImpl(uint32_t id) : FakeGeneratedService(id) {}
50
DoNothing(ConstByteSpan,RawUnaryResponder &)51 void DoNothing(ConstByteSpan, RawUnaryResponder&) {}
52
AddFive(ConstByteSpan request,RawUnaryResponder & responder)53 void AddFive(ConstByteSpan request, RawUnaryResponder& responder) {
54 DecodeRawTestRequest(request);
55
56 std::byte response[32] = {};
57 TestResponse::MemoryEncoder test_response(response);
58 ASSERT_EQ(OkStatus(), test_response.WriteValue(last_request.integer + 5));
59
60 ASSERT_EQ(OkStatus(),
61 responder.Finish(span(response).first(test_response.size()),
62 Status::Unauthenticated()));
63 }
64
StartStream(ConstByteSpan request,RawServerWriter & writer)65 void StartStream(ConstByteSpan request, RawServerWriter& writer) {
66 DecodeRawTestRequest(request);
67 last_writer = std::move(writer);
68 }
69
70 struct {
71 int64_t integer;
72 uint32_t status_code;
73 } last_request;
74 RawServerWriter last_writer;
75
76 private:
DecodeRawTestRequest(ConstByteSpan request)77 void DecodeRawTestRequest(ConstByteSpan request) {
78 protobuf::Decoder decoder(request);
79
80 while (decoder.Next().ok()) {
81 TestRequest::Fields field =
82 static_cast<TestRequest::Fields>(decoder.FieldNumber());
83
84 switch (field) {
85 case TestRequest::Fields::kInteger:
86 ASSERT_EQ(OkStatus(), decoder.ReadInt64(&last_request.integer));
87 break;
88 case TestRequest::Fields::kStatusCode:
89 ASSERT_EQ(OkStatus(), decoder.ReadUint32(&last_request.status_code));
90 break;
91 }
92 }
93 }
94 };
95
TEST(RawMethodUnion,InvokesUnary)96 TEST(RawMethodUnion, InvokesUnary) {
97 std::byte buffer[16];
98
99 TestRequest::MemoryEncoder test_request(buffer);
100 ASSERT_EQ(OkStatus(), test_request.WriteInteger(456));
101 ASSERT_EQ(OkStatus(), test_request.WriteStatusCode(7));
102
103 const Method& method =
104 std::get<1>(FakeGeneratedServiceImpl::kMethods).method();
105 ServerContextForTest<FakeGeneratedServiceImpl> context(method);
106 rpc_lock().lock();
107 method.Invoke(context.get(), context.request(test_request));
108
109 EXPECT_EQ(context.service().last_request.integer, 456);
110 EXPECT_EQ(context.service().last_request.status_code, 7u);
111
112 const Packet& response = context.output().last_packet();
113 EXPECT_EQ(response.status(), Status::Unauthenticated());
114
115 protobuf::Decoder decoder(response.payload());
116 ASSERT_TRUE(decoder.Next().ok());
117 int64_t value;
118 EXPECT_EQ(decoder.ReadInt64(&value), OkStatus());
119 EXPECT_EQ(value, 461);
120 }
121
TEST(RawMethodUnion,InvokesServerStreaming)122 TEST(RawMethodUnion, InvokesServerStreaming) {
123 std::byte buffer[16];
124
125 TestRequest::MemoryEncoder test_request(buffer);
126 ASSERT_EQ(OkStatus(), test_request.WriteInteger(777));
127 ASSERT_EQ(OkStatus(), test_request.WriteStatusCode(2));
128
129 const Method& method =
130 std::get<2>(FakeGeneratedServiceImpl::kMethods).method();
131 ServerContextForTest<FakeGeneratedServiceImpl> context(method);
132
133 rpc_lock().lock();
134 method.Invoke(context.get(), context.request(test_request));
135
136 EXPECT_EQ(0u, context.output().total_packets());
137 EXPECT_EQ(777, context.service().last_request.integer);
138 EXPECT_EQ(2u, context.service().last_request.status_code);
139 EXPECT_TRUE(context.service().last_writer.active());
140 EXPECT_EQ(OkStatus(), context.service().last_writer.Finish());
141 }
142
143 } // namespace
144 } // namespace pw::rpc::internal
145