1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16
17 #include "tink/subtle/decrypting_random_access_stream.h"
18
19 #include <algorithm>
20 #include <cstring>
21 #include <limits>
22 #include <memory>
23 #include <sstream>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "gtest/gtest.h"
29 #include "absl/memory/memory.h"
30 #include "absl/status/status.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/string_view.h"
33 #include "tink/internal/test_random_access_stream.h"
34 #include "tink/output_stream.h"
35 #include "tink/random_access_stream.h"
36 #include "tink/streaming_aead.h"
37 #include "tink/subtle/random.h"
38 #include "tink/subtle/test_util.h"
39 #include "tink/util/ostream_output_stream.h"
40 #include "tink/util/status.h"
41 #include "tink/util/test_matchers.h"
42
43 namespace crypto {
44 namespace tink {
45 namespace subtle {
46 namespace {
47
48 using ::crypto::tink::internal::TestRandomAccessStream;
49 using crypto::tink::subtle::test::DummyStreamingAead;
50 using crypto::tink::subtle::test::DummyStreamSegmentDecrypter;
51 using crypto::tink::test::IsOk;
52 using crypto::tink::test::StatusIs;
53 using subtle::test::WriteToStream;
54 using testing::HasSubstr;
55
56 // A dummy RandomAccessStream that fakes its size.
57 class DummyRandomAccessStream : public RandomAccessStream {
58 public:
DummyRandomAccessStream(int64_t size,int ct_offset)59 explicit DummyRandomAccessStream(int64_t size, int ct_offset)
60 : size_(size), ct_offset_(ct_offset) {}
61
PRead(int64_t position,int count,crypto::tink::util::Buffer * dest_buffer)62 crypto::tink::util::Status PRead(
63 int64_t position, int count,
64 crypto::tink::util::Buffer* dest_buffer) override {
65 if (position == ct_offset_) {
66 // Someone attempts to read the header, return the same dummy value that
67 // DummyStreamSegmentDecrypter expects.
68 auto status = dest_buffer->set_size(count);
69 if (!status.ok()) return status;
70 std::memset(dest_buffer->get_mem_block(), 'h', count);
71 }
72 return util::OkStatus();
73 }
74
size()75 crypto::tink::util::StatusOr<int64_t> size() override { return size_; }
76
77 private:
78 int64_t size_;
79 int ct_offset_;
80 };
81
82 // Returns a ciphertext resulting from encryption of 'pt' with 'aad' as
83 // associated data, using 'saead'.
GetCiphertext(StreamingAead * saead,absl::string_view pt,absl::string_view aad,int ct_offset)84 std::string GetCiphertext(StreamingAead* saead, absl::string_view pt,
85 absl::string_view aad, int ct_offset) {
86 // Prepare ciphertext destination stream.
87 auto ct_stream = absl::make_unique<std::stringstream>();
88 // Write ct_offset 'o'-characters for the ciphertext offset.
89 *ct_stream << std::string(ct_offset, 'o');
90 // A reference to the ciphertext buffer.
91 auto ct_buf = ct_stream->rdbuf();
92 std::unique_ptr<OutputStream> ct_destination(
93 absl::make_unique<util::OstreamOutputStream>(std::move(ct_stream)));
94
95 // Compute the ciphertext.
96 auto enc_stream_result =
97 saead->NewEncryptingStream(std::move(ct_destination), aad);
98 EXPECT_THAT(enc_stream_result, IsOk());
99 EXPECT_THAT(WriteToStream(enc_stream_result.value().get(), pt), IsOk());
100
101 return ct_buf->str();
102 }
103
104 // Creates an RandomAccessStream that contains ciphertext resulting
105 // from encryption of 'pt' with 'aad' as associated data, using 'saead'.
GetCiphertextSource(StreamingAead * saead,absl::string_view pt,absl::string_view aad,int ct_offset)106 std::unique_ptr<RandomAccessStream> GetCiphertextSource(StreamingAead* saead,
107 absl::string_view pt,
108 absl::string_view aad,
109 int ct_offset) {
110 return std::make_unique<TestRandomAccessStream>(
111 GetCiphertext(saead, pt, aad, ct_offset));
112 }
113
TEST(DecryptingRandomAccessStreamTest,NegativeCiphertextOffset)114 TEST(DecryptingRandomAccessStreamTest, NegativeCiphertextOffset) {
115 int pt_segment_size = 100;
116 int header_size = 20;
117 int ct_offset = -1;
118 auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
119 pt_segment_size, header_size, ct_offset);
120 int64_t ciphertext_size = 100;
121
122 EXPECT_THAT(
123 DecryptingRandomAccessStream::New(
124 std::move(seg_decrypter), absl::make_unique<DummyRandomAccessStream>(
125 ciphertext_size, ct_offset))
126 .status(),
127 StatusIs(absl::StatusCode::kInvalidArgument,
128 HasSubstr("The ciphertext offset must be non-negative")));
129 }
130
TEST(DecryptingRandomAccessStreamTest,SizeOfFirstSegmentIsSmallerOrEqualToZero)131 TEST(DecryptingRandomAccessStreamTest,
132 SizeOfFirstSegmentIsSmallerOrEqualToZero) {
133 int header_size = 20;
134 int ct_offset = 0;
135 // Make pt_segment_size equal to ct_offset + header_size. This means size of
136 // the first segment is zero.
137 int pt_segment_size = ct_offset + header_size;
138 auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
139 pt_segment_size, header_size, ct_offset);
140 int64_t ciphertext_size = 100;
141
142 EXPECT_THAT(
143 DecryptingRandomAccessStream::New(
144 std::move(seg_decrypter), absl::make_unique<DummyRandomAccessStream>(
145 ciphertext_size, ct_offset))
146 .status(),
147 StatusIs(absl::StatusCode::kInvalidArgument,
148 HasSubstr("greater than 0")));
149 }
150
TEST(DecryptingRandomAccessStreamTest,TooManySegments)151 TEST(DecryptingRandomAccessStreamTest, TooManySegments) {
152 int header_size = 1;
153 int ct_offset = 0;
154 // Use a valid pt_segment_size which is larger than ct_offset + header_size.
155 int pt_segment_size = ct_offset + header_size + 1;
156 auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
157 pt_segment_size, header_size, ct_offset);
158
159 // Use an invalid segment_count larger than 2^32.
160 int64_t segment_count =
161 static_cast<int64_t>(std::numeric_limits<uint32_t>::max()) + 2;
162 // Based on this calculation:
163 // segment_count = ciphertext_size / ciphertext_segment_size
164 // -> ciphertext_size = segment_count * ciphertext_segment_size
165 int64_t ciphertext_size =
166 segment_count * seg_decrypter->get_ciphertext_segment_size();
167 auto dec_stream_result = DecryptingRandomAccessStream::New(
168 std::move(seg_decrypter),
169 absl::make_unique<DummyRandomAccessStream>(ciphertext_size, ct_offset));
170 EXPECT_THAT(dec_stream_result, IsOk());
171 auto dec_stream = std::move(dec_stream_result.value());
172
173 auto result = dec_stream->size();
174 EXPECT_EQ(absl::StatusCode::kInvalidArgument, result.status().code());
175 EXPECT_THAT(std::string(result.status().message()),
176 HasSubstr("too many segments"));
177 }
178
TEST(DecryptingRandomAccessStreamTest,BasicDecryption)179 TEST(DecryptingRandomAccessStreamTest, BasicDecryption) {
180 for (int pt_size : {1, 5, 20, 42, 100, 1000, 10000}) {
181 std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
182 for (int pt_segment_size : {50, 100, 123}) {
183 for (int header_size : {5, 10, 15}) {
184 for (int ct_offset : {0, 1, 5, 12}) {
185 SCOPED_TRACE(absl::StrCat(
186 "pt_size = ", pt_size, ", pt_segment_size = ", pt_segment_size,
187 ", header_size = ", header_size, ", ct_offset = ", ct_offset));
188 DummyStreamingAead saead(pt_segment_size, header_size, ct_offset);
189 // Pre-compute the ciphertext.
190 auto ciphertext =
191 GetCiphertextSource(&saead, plaintext, "some aad", ct_offset);
192 // Check the decryption of the pre-computed ciphertext.
193 auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
194 pt_segment_size, header_size, ct_offset);
195 auto dec_stream_result = DecryptingRandomAccessStream::New(
196 std::move(seg_decrypter), std::move(ciphertext));
197 EXPECT_THAT(dec_stream_result, IsOk());
198 auto dec_stream = std::move(dec_stream_result.value());
199 EXPECT_EQ(pt_size, dec_stream->size().value());
200 std::string decrypted;
201 auto status = internal::ReadAllFromRandomAccessStream(
202 dec_stream.get(), decrypted);
203 EXPECT_THAT(status, StatusIs(absl::StatusCode::kOutOfRange,
204 HasSubstr("EOF")));
205 EXPECT_EQ(plaintext, decrypted);
206 }
207 }
208 }
209 }
210 }
211
TEST(DecryptingRandomAccessStreamTest,SelectiveDecryption)212 TEST(DecryptingRandomAccessStreamTest, SelectiveDecryption) {
213 for (int pt_size : {1, 20, 42, 100, 1000, 10000}) {
214 std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
215 for (int pt_segment_size : {50, 100, 200}) {
216 for (int header_size : {5, 10, 20}) {
217 for (int ct_offset : {0, 1, 10}) {
218 SCOPED_TRACE(absl::StrCat(
219 "pt_size = ", pt_size, ", pt_segment_size = ", pt_segment_size,
220 ", header_size = ", header_size, ", ct_offset = ", ct_offset));
221 DummyStreamingAead saead(pt_segment_size, header_size, ct_offset);
222 // Pre-compute the ciphertext.
223 auto ciphertext =
224 GetCiphertextSource(&saead, plaintext, "some aad", ct_offset);
225 // Check the decryption of the pre-computed ciphertext.
226 auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
227 pt_segment_size, header_size, ct_offset);
228 auto dec_stream_result = DecryptingRandomAccessStream::New(
229 std::move(seg_decrypter), std::move(ciphertext));
230 EXPECT_THAT(dec_stream_result, IsOk());
231 auto dec_stream = std::move(dec_stream_result.value());
232 for (int position : {0, 1, 2, pt_size / 2, pt_size - 1}) {
233 for (int chunk_size : {1, pt_size / 2, pt_size}) {
234 SCOPED_TRACE(absl::StrCat("position = ", position,
235 ", chunk_size = ", chunk_size));
236 auto buffer =
237 std::move(util::Buffer::New(std::max(chunk_size, 1)).value());
238 auto status =
239 dec_stream->PRead(position, chunk_size, buffer.get());
240 if (position <= pt_size) {
241 EXPECT_TRUE(status.ok() ||
242 status.code() == absl::StatusCode::kOutOfRange);
243 } else {
244 EXPECT_THAT(status,
245 StatusIs(absl::StatusCode::kInvalidArgument));
246 }
247 EXPECT_EQ(std::min(chunk_size, std::max(pt_size - position, 0)),
248 buffer->size());
249 EXPECT_EQ(0,
250 std::memcmp(plaintext.data() + position,
251 buffer->get_mem_block(), buffer->size()));
252 }
253 }
254 }
255 }
256 }
257 }
258 }
259
TEST(DecryptingRandomAccessStreamTest,TruncatedCiphertextDecryption)260 TEST(DecryptingRandomAccessStreamTest, TruncatedCiphertextDecryption) {
261 for (int pt_size : {100, 200, 1000}) {
262 std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
263 for (int pt_segment_size : {50, 70}) {
264 for (int header_size : {5, 10, 20}) {
265 for (int ct_offset : {0, 1, 10}) {
266 SCOPED_TRACE(absl::StrCat(
267 "pt_size = ", pt_size, ", pt_segment_size = ", pt_segment_size,
268 ", header_size = ", header_size, ", ct_offset = ", ct_offset));
269 DummyStreamingAead saead(pt_segment_size, header_size, ct_offset);
270 // Pre-compute the ciphertext.
271 auto ct = GetCiphertext(&saead, plaintext, "some aad", ct_offset);
272 // Check the decryption of a truncated ciphertext.
273 auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
274 pt_segment_size, header_size, ct_offset);
275 for (int trunc_ct_size : {header_size + ct_offset,
276 static_cast<int>(ct.size()) - 1,
277 static_cast<int>(ct.size()) - pt_segment_size,
278 static_cast<int>(ct.size())
279 - seg_decrypter->get_ciphertext_segment_size()}) {
280 for (int chunk_size : {pt_size}) {
281 SCOPED_TRACE(absl::StrCat("ct_size = ", ct.size(),
282 ", trunc_ct_size = ", trunc_ct_size,
283 ", chunk_size = ", chunk_size));
284 auto trunc_ct = std::make_unique<TestRandomAccessStream>(
285 ct.substr(0, trunc_ct_size));
286 int position = 0;
287 auto per_stream_seg_decrypter =
288 absl::make_unique<DummyStreamSegmentDecrypter>(
289 pt_segment_size, header_size, ct_offset);
290 auto dec_stream_result = DecryptingRandomAccessStream::New(
291 std::move(per_stream_seg_decrypter), std::move(trunc_ct));
292 EXPECT_THAT(dec_stream_result, IsOk());
293 auto dec_stream = std::move(dec_stream_result.value());
294 auto buffer = std::move(util::Buffer::New(chunk_size).value());
295 auto status =
296 dec_stream->PRead(position, chunk_size, buffer.get());
297 EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
298 }
299 }
300 }
301 }
302 }
303 }
304 }
305
TEST(DecryptingRandomAccessStreamTest,OutOfRangeDecryption)306 TEST(DecryptingRandomAccessStreamTest, OutOfRangeDecryption) {
307 for (int pt_size : {0, 20, 42, 100, 1000, 10000}) {
308 std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
309 for (int pt_segment_size : {50, 100, 123}) {
310 for (int header_size : {5, 10, 20}) {
311 SCOPED_TRACE(absl::StrCat("pt_size = ", pt_size,
312 ", pt_segment_size = ", pt_segment_size,
313 ", header_size = ", header_size));
314 int ct_offset = 0;
315 DummyStreamingAead saead(pt_segment_size, header_size, ct_offset);
316 // Pre-compute the ciphertext.
317 auto ciphertext =
318 GetCiphertextSource(&saead, plaintext, "some aad", ct_offset);
319 // Check the decryption of the pre-computed ciphertext.
320 auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
321 pt_segment_size, header_size, ct_offset);
322 auto dec_stream_result = DecryptingRandomAccessStream::New(
323 std::move(seg_decrypter), std::move(ciphertext));
324 EXPECT_THAT(dec_stream_result, IsOk());
325 auto dec_stream = std::move(dec_stream_result.value());
326 int chunk_size = 1;
327 auto buffer = std::move(util::Buffer::New(chunk_size).value());
328 int position = pt_size;
329 // Negative chunk size.
330 auto status = dec_stream->PRead(position, -1, buffer.get());
331 EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
332
333 // Negative position.
334 status = dec_stream->PRead(-1, chunk_size, buffer.get());
335 EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
336
337 // Reading at EOF.
338 status = dec_stream->PRead(position, chunk_size, buffer.get());
339 EXPECT_THAT(status, StatusIs(absl::StatusCode::kOutOfRange));
340
341 // Reading past EOF.
342 status = dec_stream->PRead(position + 1 , chunk_size, buffer.get());
343 EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
344 }
345 }
346 }
347 }
348
TEST(DecryptingRandomAccessStreamTest,WrongCiphertext)349 TEST(DecryptingRandomAccessStreamTest, WrongCiphertext) {
350 int pt_segment_size = 42;
351 int header_size = 10;
352 int ct_offset = 0;
353 for (int ct_size : {0, 10, 100}) {
354 SCOPED_TRACE(absl::StrCat("ct_size = ", ct_size));
355 // Try decrypting a wrong ciphertext.
356 auto wrong_ct = std::make_unique<TestRandomAccessStream>(
357 subtle::Random::GetRandomBytes(ct_size));
358 auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
359 pt_segment_size, header_size, ct_offset);
360 auto dec_stream_result = DecryptingRandomAccessStream::New(
361 std::move(seg_decrypter), std::move(wrong_ct));
362 EXPECT_THAT(dec_stream_result, IsOk());
363 auto dec_stream = std::move(dec_stream_result.value());
364 std::string decrypted;
365 int chunk_size = 1;
366 int position = 0;
367 auto buffer = std::move(util::Buffer::New(chunk_size).value());
368 auto status = dec_stream->PRead(position, chunk_size, buffer.get());
369 EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
370 }
371 }
372
TEST(DecryptingRandomAccessStreamTest,NullSegmentDecrypter)373 TEST(DecryptingRandomAccessStreamTest, NullSegmentDecrypter) {
374 auto ct_stream =
375 std::make_unique<TestRandomAccessStream>("some ciphertext contents");
376 auto dec_stream_result =
377 DecryptingRandomAccessStream::New(nullptr, std::move(ct_stream));
378 EXPECT_THAT(dec_stream_result.status(),
379 StatusIs(absl::StatusCode::kInvalidArgument,
380 HasSubstr("segment_decrypter must be non-null")));
381 }
382
TEST(DecryptingRandomAccessStreamTest,NullCiphertextSource)383 TEST(DecryptingRandomAccessStreamTest, NullCiphertextSource) {
384 int pt_segment_size = 42;
385 int header_size = 10;
386 int ct_offset = 0;
387 auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
388 pt_segment_size, header_size, ct_offset);
389 auto dec_stream_result =
390 DecryptingRandomAccessStream::New(std::move(seg_decrypter), nullptr);
391 EXPECT_THAT(dec_stream_result.status(),
392 StatusIs(absl::StatusCode::kInvalidArgument,
393 HasSubstr("cipertext_source must be non-null")));
394 }
395
396 } // namespace
397 } // namespace subtle
398 } // namespace tink
399 } // namespace crypto
400