1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16
17 #include "tink/streamingaead/shared_input_stream.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <sstream>
22 #include <string>
23 #include <utility>
24
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/memory/memory.h"
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/string_view.h"
31 #include "tink/input_stream.h"
32 #include "tink/streamingaead/buffered_input_stream.h"
33 #include "tink/subtle/random.h"
34 #include "tink/subtle/test_util.h"
35 #include "tink/util/istream_input_stream.h"
36 #include "tink/util/status.h"
37 #include "tink/util/test_matchers.h"
38
39 namespace crypto {
40 namespace tink {
41 namespace streamingaead {
42 namespace {
43
44 using crypto::tink::test::IsOk;
45 using crypto::tink::test::StatusIs;
46 using subtle::test::ReadFromStream;
47
48 static int kBufferSize = 4096;
49
50 // Creates an InputStream with the specified contents.
GetInputStream(absl::string_view contents)51 std::unique_ptr<InputStream> GetInputStream(absl::string_view contents) {
52 // Prepare ciphertext source stream.
53 auto string_stream =
54 absl::make_unique<std::stringstream>(std::string(contents));
55 std::unique_ptr<InputStream> input_stream(
56 absl::make_unique<util::IstreamInputStream>(
57 std::move(string_stream), kBufferSize));
58 return input_stream;
59 }
60
61 // Attempts to read 'count' bytes from 'input_stream', and writes the read
62 // bytes to 'output'.
ReadFromStream(InputStream * input_stream,int count,std::string * output)63 util::Status ReadFromStream(InputStream* input_stream, int count,
64 std::string* output) {
65 if (input_stream == nullptr || output == nullptr || count < 0) {
66 return util::Status(absl::StatusCode::kInternal,
67 "Illegal read from a stream");
68 }
69 const void* buffer;
70 output->clear();
71 int bytes_to_read = count;
72 while (bytes_to_read > 0) {
73 auto next_result = input_stream->Next(&buffer);
74 if (next_result.status().code() == absl::StatusCode::kOutOfRange) {
75 // End of stream.
76 return util::OkStatus();
77 }
78 if (!next_result.ok()) return next_result.status();
79 auto read_bytes = next_result.value();
80 auto used_bytes = std::min(read_bytes, bytes_to_read);
81 if (used_bytes > 0) {
82 output->append(
83 std::string(reinterpret_cast<const char*>(buffer), used_bytes));
84 bytes_to_read -= used_bytes;
85 if (bytes_to_read == 0) input_stream->BackUp(read_bytes - used_bytes);
86 }
87 }
88 return util::OkStatus();
89 }
90
TEST(SharedInputStreamTest,BasicOperations)91 TEST(SharedInputStreamTest, BasicOperations) {
92 for (auto input_size : {0, 1, 10, 100, 1000, 10000, 100000}) {
93 std::string contents = subtle::Random::GetRandomBytes(input_size);
94 auto input_stream = GetInputStream(contents);
95 auto buffered_stream =
96 std::make_shared<BufferedInputStream>(std::move(input_stream));
97 for (auto read_size : {0, 1, 10, 123, 300}) {
98 SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
99 ", read_size = ", read_size));
100 {
101 auto shared_stream =
102 absl::make_unique<SharedInputStream>(buffered_stream.get());
103
104 // Read a prefix of the stream.
105 std::string prefix;
106 auto status = ReadFromStream(shared_stream.get(), read_size, &prefix);
107 EXPECT_THAT(status, IsOk());
108 EXPECT_EQ(std::min(read_size, input_size), shared_stream->Position());
109 EXPECT_EQ(contents.substr(0, read_size), prefix);
110 EXPECT_EQ(buffered_stream->Position(), shared_stream->Position());
111
112 // Read the rest of the stream.
113 std::string rest;
114 status = ReadFromStream(shared_stream.get(), &rest);
115 EXPECT_THAT(status, IsOk());
116 EXPECT_EQ(input_size, shared_stream->Position());
117 EXPECT_EQ(contents, prefix + rest);
118 EXPECT_EQ(buffered_stream->Position(), shared_stream->Position());
119
120 // Try reading again, should get an empty string.
121 status = ReadFromStream(shared_stream.get(), &rest);
122 EXPECT_THAT(status, IsOk());
123 EXPECT_EQ("", rest);
124 EXPECT_EQ(buffered_stream->Position(), shared_stream->Position());
125 }
126
127 // Now that shared_stream is out of scope, we rewind the underlying
128 // buffered_stream, so that the next read iteration starts from
129 // the beginning.
130 auto status = buffered_stream->Rewind();
131 EXPECT_THAT(status, IsOk());
132 EXPECT_EQ(0, buffered_stream->Position());
133 }
134 }
135 }
136
137
TEST(SharedInputStreamTest,SingleBackup)138 TEST(SharedInputStreamTest, SingleBackup) {
139 for (auto input_size : {0, 1, 10, 100, 1000, 10000, 100000}) {
140 std::string contents = subtle::Random::GetRandomBytes(input_size);
141 auto input_stream = GetInputStream(contents);
142 auto buffered_stream =
143 std::make_shared<BufferedInputStream>(std::move(input_stream));
144 for (auto read_size : {0, 1, 10, 123, 300, 1024}) {
145 SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
146 ", read_size = ", read_size));
147 {
148 auto shared_stream = absl::make_unique<SharedInputStream>(
149 buffered_stream.get());
150
151 // Read a part of the stream.
152 std::string prefix;
153 auto status = ReadFromStream(shared_stream.get(), read_size, &prefix);
154 EXPECT_THAT(status, IsOk());
155 EXPECT_EQ(std::min(read_size, input_size), shared_stream->Position());
156 EXPECT_EQ(contents.substr(0, read_size), prefix);
157
158 // Read the next block of the stream, and then back it up.
159 const void* buf;
160 int pos = shared_stream->Position();
161 auto next_result = shared_stream->Next(&buf);
162 if (read_size < input_size) {
163 EXPECT_THAT(next_result, IsOk());
164 auto next_size = next_result.value();
165 EXPECT_EQ(pos + next_size, shared_stream->Position());
166 shared_stream->BackUp(next_size);
167 EXPECT_EQ(pos, shared_stream->Position());
168 shared_stream->BackUp(input_size);
169 EXPECT_EQ(pos, shared_stream->Position());
170 } else {
171 EXPECT_THAT(next_result.status(),
172 StatusIs(absl::StatusCode::kOutOfRange));
173 }
174
175 // Read the rest of the input.
176 std::string rest;
177 status = ReadFromStream(shared_stream.get(), &rest);
178 EXPECT_THAT(status, IsOk());
179 EXPECT_EQ(input_size, shared_stream->Position());
180 EXPECT_EQ(contents, prefix + rest);
181 }
182 // Now that shared_stream is out of scope, we rewind the underlying
183 // buffered_stream, so that the next read iteration starts from
184 // the beginning.
185 auto status = buffered_stream->Rewind();
186 EXPECT_THAT(status, IsOk());
187 EXPECT_EQ(0, buffered_stream->Position());
188 }
189 }
190 }
191
TEST(SharedInputStreamTest,MultipleBackups)192 TEST(SharedInputStreamTest, MultipleBackups) {
193 int input_size = 70000;
194 std::string contents = subtle::Random::GetRandomBytes(input_size);
195 auto input_stream = GetInputStream(contents);
196 auto buffered_stream =
197 std::make_shared<BufferedInputStream>(std::move(input_stream));
198
199 for (int i = 0; i < 2; i++) { // Two rounds, to test with Rewind.
200 auto status = buffered_stream->Rewind();
201 EXPECT_THAT(status, IsOk());
202 EXPECT_EQ(0, buffered_stream->Position());
203
204 auto shared_stream = absl::make_unique<SharedInputStream>(
205 buffered_stream.get());
206 EXPECT_EQ(0, shared_stream->Position());
207
208 const void* buffer;
209 auto next_result = shared_stream->Next(&buffer);
210 EXPECT_THAT(next_result, IsOk());
211 auto next_size = next_result.value();
212 EXPECT_EQ(contents.substr(0, next_size),
213 std::string(static_cast<const char*>(buffer), next_size));
214
215 // BackUp several times, but in total fewer bytes than returned by Next().
216 int total_backup_size = 0;
217 for (auto backup_size : {0, 1, 5, 0, 10, 100, -42, 400, 20, -100}) {
218 shared_stream->BackUp(backup_size);
219 total_backup_size += std::max(0, backup_size);
220 EXPECT_EQ(next_size - total_backup_size, shared_stream->Position());
221 EXPECT_EQ(buffered_stream->Position(), shared_stream->Position());
222 }
223 EXPECT_GT(next_size, total_backup_size);
224
225 // Call Next(), it should return exactly the backed up bytes.
226 next_result = shared_stream->Next(&buffer);
227 EXPECT_THAT(next_result, IsOk());
228 EXPECT_EQ(total_backup_size, next_result.value());
229 EXPECT_EQ(next_size, shared_stream->Position());
230 EXPECT_EQ(buffered_stream->Position(), shared_stream->Position());
231 EXPECT_EQ(contents.substr(next_size - total_backup_size, total_backup_size),
232 std::string(static_cast<const char*>(buffer), total_backup_size));
233 }
234 }
235
236
237 } // namespace
238 } // namespace streamingaead
239 } // namespace tink
240 } // namespace crypto
241