1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "cast/common/channel/message_framer.h"
6
7 #include <stddef.h>
8
9 #include <algorithm>
10 #include <string>
11
12 #include "cast/common/channel/proto/cast_channel.pb.h"
13 #include "gtest/gtest.h"
14 #include "util/big_endian.h"
15 #include "util/std_util.h"
16
17 namespace openscreen {
18 namespace cast {
19 namespace message_serialization {
20
21 using ::cast::channel::CastMessage;
22
23 namespace {
24
25 static constexpr size_t kHeaderSize = sizeof(uint32_t);
26
27 // Cast specifies a max message body size of 64 KiB.
28 static constexpr size_t kMaxBodySize = 65536;
29
30 } // namespace
31
32 class CastFramerTest : public testing::Test {
33 public:
CastFramerTest()34 CastFramerTest() : buffer_(kHeaderSize + kMaxBodySize) {}
35
SetUp()36 void SetUp() override {
37 cast_message_.set_protocol_version(CastMessage::CASTV2_1_0);
38 cast_message_.set_source_id("source");
39 cast_message_.set_destination_id("destination");
40 cast_message_.set_namespace_("namespace");
41 cast_message_.set_payload_type(CastMessage::STRING);
42 cast_message_.set_payload_utf8("payload");
43 ErrorOr<std::vector<uint8_t>> result = Serialize(cast_message_);
44 ASSERT_TRUE(result.is_value());
45 cast_message_serial_ = std::move(result.value());
46 }
47
WriteToBuffer(const std::vector<uint8_t> & data)48 void WriteToBuffer(const std::vector<uint8_t>& data) {
49 memcpy(&buffer_[0], data.data(), data.size());
50 }
51
GetSpan(size_t size)52 absl::Span<uint8_t> GetSpan(size_t size) {
53 return absl::Span<uint8_t>(&buffer_[0], size);
54 }
GetSpan()55 absl::Span<uint8_t> GetSpan() { return GetSpan(cast_message_serial_.size()); }
56
57 protected:
58 CastMessage cast_message_;
59 std::vector<uint8_t> cast_message_serial_;
60 std::vector<uint8_t> buffer_;
61 };
62
TEST_F(CastFramerTest,TestMessageFramerCompleteMessage)63 TEST_F(CastFramerTest, TestMessageFramerCompleteMessage) {
64 WriteToBuffer(cast_message_serial_);
65
66 // Receive 1 byte of the header, framer demands 3 more bytes.
67 ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(1));
68 EXPECT_FALSE(result);
69 EXPECT_EQ(Error::Code::kInsufficientBuffer, result.error().code());
70
71 // TryDeserialize remaining 3, expect that the framer has moved on to
72 // requesting the body contents.
73 result = TryDeserialize(GetSpan(3));
74 EXPECT_FALSE(result);
75 EXPECT_EQ(Error::Code::kInsufficientBuffer, result.error().code());
76
77 // Remainder of packet sent over the wire.
78 result = TryDeserialize(GetSpan());
79 ASSERT_TRUE(result);
80 EXPECT_EQ(result.value().length, cast_message_serial_.size());
81 const CastMessage& message = result.value().message;
82 EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString());
83 }
84
TEST_F(CastFramerTest,TestSerializeErrorMessageTooLarge)85 TEST_F(CastFramerTest, TestSerializeErrorMessageTooLarge) {
86 CastMessage big_message;
87 big_message.CopyFrom(cast_message_);
88 std::string payload;
89 payload.append(kMaxBodySize + 1, 'x');
90 big_message.set_payload_utf8(payload);
91 EXPECT_FALSE(Serialize(big_message));
92 }
93
TEST_F(CastFramerTest,TestCompleteMessageAtOnce)94 TEST_F(CastFramerTest, TestCompleteMessageAtOnce) {
95 WriteToBuffer(cast_message_serial_);
96
97 ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan());
98 ASSERT_TRUE(result);
99 EXPECT_EQ(result.value().length, cast_message_serial_.size());
100 const CastMessage& message = result.value().message;
101 EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString());
102 }
103
TEST_F(CastFramerTest,TestTryDeserializeIllegalLargeMessage)104 TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage) {
105 std::vector<uint8_t> mangled_cast_message = cast_message_serial_;
106 mangled_cast_message[0] = 88;
107 mangled_cast_message[1] = 88;
108 mangled_cast_message[2] = 88;
109 mangled_cast_message[3] = 88;
110 WriteToBuffer(mangled_cast_message);
111
112 ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(4));
113 ASSERT_FALSE(result);
114 EXPECT_EQ(Error::Code::kCastV2InvalidMessage, result.error().code());
115 }
116
TEST_F(CastFramerTest,TestTryDeserializeIllegalLargeMessage2)117 TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage2) {
118 std::vector<uint8_t> mangled_cast_message = cast_message_serial_;
119 // Header indicates body size is 0x00010001 = 65537
120 mangled_cast_message[0] = 0;
121 mangled_cast_message[1] = 0x1;
122 mangled_cast_message[2] = 0;
123 mangled_cast_message[3] = 0x1;
124 WriteToBuffer(mangled_cast_message);
125
126 ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(4));
127 ASSERT_FALSE(result);
128 EXPECT_EQ(Error::Code::kCastV2InvalidMessage, result.error().code());
129 }
130
TEST_F(CastFramerTest,TestUnparsableBodyProto)131 TEST_F(CastFramerTest, TestUnparsableBodyProto) {
132 // Message header is OK, but the body is replaced with "x"es.
133 std::vector<uint8_t> mangled_cast_message = cast_message_serial_;
134 for (size_t i = kHeaderSize; i < mangled_cast_message.size(); ++i) {
135 std::fill(mangled_cast_message.begin() + kHeaderSize,
136 mangled_cast_message.end(), 'x');
137 }
138 WriteToBuffer(mangled_cast_message);
139
140 // Send header.
141 ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(4));
142 EXPECT_FALSE(result);
143 EXPECT_EQ(Error::Code::kInsufficientBuffer, result.error().code());
144
145 // Send body, expect an error.
146 result = TryDeserialize(GetSpan());
147 ASSERT_FALSE(result);
148 EXPECT_EQ(Error::Code::kCastV2InvalidMessage, result.error().code());
149 }
150
151 } // namespace message_serialization
152 } // namespace cast
153 } // namespace openscreen
154