xref: /aosp_15_r20/external/tink/cc/subtle/decrypting_random_access_stream_test.cc (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/subtle/decrypting_random_access_stream.h"
18 
19 #include <algorithm>
20 #include <cstring>
21 #include <limits>
22 #include <memory>
23 #include <sstream>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "gtest/gtest.h"
29 #include "absl/memory/memory.h"
30 #include "absl/status/status.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/string_view.h"
33 #include "tink/internal/test_random_access_stream.h"
34 #include "tink/output_stream.h"
35 #include "tink/random_access_stream.h"
36 #include "tink/streaming_aead.h"
37 #include "tink/subtle/random.h"
38 #include "tink/subtle/test_util.h"
39 #include "tink/util/ostream_output_stream.h"
40 #include "tink/util/status.h"
41 #include "tink/util/test_matchers.h"
42 
43 namespace crypto {
44 namespace tink {
45 namespace subtle {
46 namespace {
47 
48 using ::crypto::tink::internal::TestRandomAccessStream;
49 using crypto::tink::subtle::test::DummyStreamingAead;
50 using crypto::tink::subtle::test::DummyStreamSegmentDecrypter;
51 using crypto::tink::test::IsOk;
52 using crypto::tink::test::StatusIs;
53 using subtle::test::WriteToStream;
54 using testing::HasSubstr;
55 
56 // A dummy RandomAccessStream that fakes its size.
57 class DummyRandomAccessStream : public RandomAccessStream {
58  public:
DummyRandomAccessStream(int64_t size,int ct_offset)59   explicit DummyRandomAccessStream(int64_t size, int ct_offset)
60       : size_(size), ct_offset_(ct_offset) {}
61 
PRead(int64_t position,int count,crypto::tink::util::Buffer * dest_buffer)62   crypto::tink::util::Status PRead(
63       int64_t position, int count,
64       crypto::tink::util::Buffer* dest_buffer) override {
65     if (position == ct_offset_) {
66       // Someone attempts to read the header, return the same dummy value that
67       // DummyStreamSegmentDecrypter expects.
68       auto status = dest_buffer->set_size(count);
69       if (!status.ok()) return status;
70       std::memset(dest_buffer->get_mem_block(), 'h', count);
71     }
72     return util::OkStatus();
73   }
74 
size()75   crypto::tink::util::StatusOr<int64_t> size() override { return size_; }
76 
77  private:
78   int64_t size_;
79   int ct_offset_;
80 };
81 
82 // Returns a ciphertext resulting from encryption of 'pt' with 'aad' as
83 // associated data, using 'saead'.
GetCiphertext(StreamingAead * saead,absl::string_view pt,absl::string_view aad,int ct_offset)84 std::string GetCiphertext(StreamingAead* saead, absl::string_view pt,
85                           absl::string_view aad, int ct_offset) {
86   // Prepare ciphertext destination stream.
87   auto ct_stream = absl::make_unique<std::stringstream>();
88   // Write ct_offset 'o'-characters for the ciphertext offset.
89   *ct_stream << std::string(ct_offset, 'o');
90   // A reference to the ciphertext buffer.
91   auto ct_buf = ct_stream->rdbuf();
92   std::unique_ptr<OutputStream> ct_destination(
93       absl::make_unique<util::OstreamOutputStream>(std::move(ct_stream)));
94 
95   // Compute the ciphertext.
96   auto enc_stream_result =
97       saead->NewEncryptingStream(std::move(ct_destination), aad);
98   EXPECT_THAT(enc_stream_result, IsOk());
99   EXPECT_THAT(WriteToStream(enc_stream_result.value().get(), pt), IsOk());
100 
101   return ct_buf->str();
102 }
103 
104 // Creates an RandomAccessStream that contains ciphertext resulting
105 // from encryption of 'pt' with 'aad' as associated data, using 'saead'.
GetCiphertextSource(StreamingAead * saead,absl::string_view pt,absl::string_view aad,int ct_offset)106 std::unique_ptr<RandomAccessStream> GetCiphertextSource(StreamingAead* saead,
107                                                         absl::string_view pt,
108                                                         absl::string_view aad,
109                                                         int ct_offset) {
110   return std::make_unique<TestRandomAccessStream>(
111       GetCiphertext(saead, pt, aad, ct_offset));
112 }
113 
TEST(DecryptingRandomAccessStreamTest,NegativeCiphertextOffset)114 TEST(DecryptingRandomAccessStreamTest, NegativeCiphertextOffset) {
115   int pt_segment_size = 100;
116   int header_size = 20;
117   int ct_offset = -1;
118   auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
119       pt_segment_size, header_size, ct_offset);
120   int64_t ciphertext_size = 100;
121 
122   EXPECT_THAT(
123       DecryptingRandomAccessStream::New(
124           std::move(seg_decrypter), absl::make_unique<DummyRandomAccessStream>(
125                                         ciphertext_size, ct_offset))
126           .status(),
127       StatusIs(absl::StatusCode::kInvalidArgument,
128                HasSubstr("The ciphertext offset must be non-negative")));
129 }
130 
TEST(DecryptingRandomAccessStreamTest,SizeOfFirstSegmentIsSmallerOrEqualToZero)131 TEST(DecryptingRandomAccessStreamTest,
132      SizeOfFirstSegmentIsSmallerOrEqualToZero) {
133   int header_size = 20;
134   int ct_offset = 0;
135   // Make pt_segment_size equal to ct_offset + header_size. This means size of
136   // the first segment is zero.
137   int pt_segment_size = ct_offset + header_size;
138   auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
139       pt_segment_size, header_size, ct_offset);
140   int64_t ciphertext_size = 100;
141 
142   EXPECT_THAT(
143       DecryptingRandomAccessStream::New(
144           std::move(seg_decrypter), absl::make_unique<DummyRandomAccessStream>(
145                                         ciphertext_size, ct_offset))
146           .status(),
147       StatusIs(absl::StatusCode::kInvalidArgument,
148                HasSubstr("greater than 0")));
149 }
150 
TEST(DecryptingRandomAccessStreamTest,TooManySegments)151 TEST(DecryptingRandomAccessStreamTest, TooManySegments) {
152   int header_size = 1;
153   int ct_offset = 0;
154   // Use a valid pt_segment_size which is larger than ct_offset + header_size.
155   int pt_segment_size = ct_offset + header_size + 1;
156   auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
157       pt_segment_size, header_size, ct_offset);
158 
159   // Use an invalid segment_count larger than 2^32.
160   int64_t segment_count =
161       static_cast<int64_t>(std::numeric_limits<uint32_t>::max()) + 2;
162   // Based on this calculation:
163   // segment_count = ciphertext_size / ciphertext_segment_size
164   // -> ciphertext_size = segment_count * ciphertext_segment_size
165   int64_t ciphertext_size =
166       segment_count * seg_decrypter->get_ciphertext_segment_size();
167   auto dec_stream_result = DecryptingRandomAccessStream::New(
168       std::move(seg_decrypter),
169       absl::make_unique<DummyRandomAccessStream>(ciphertext_size, ct_offset));
170   EXPECT_THAT(dec_stream_result, IsOk());
171   auto dec_stream = std::move(dec_stream_result.value());
172 
173   auto result = dec_stream->size();
174   EXPECT_EQ(absl::StatusCode::kInvalidArgument, result.status().code());
175   EXPECT_THAT(std::string(result.status().message()),
176               HasSubstr("too many segments"));
177 }
178 
TEST(DecryptingRandomAccessStreamTest,BasicDecryption)179 TEST(DecryptingRandomAccessStreamTest, BasicDecryption) {
180   for (int pt_size : {1, 5, 20, 42, 100, 1000, 10000}) {
181     std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
182     for (int pt_segment_size : {50, 100, 123}) {
183       for (int header_size : {5, 10, 15}) {
184         for (int ct_offset : {0, 1, 5, 12}) {
185           SCOPED_TRACE(absl::StrCat(
186               "pt_size = ", pt_size, ", pt_segment_size = ", pt_segment_size,
187               ", header_size = ", header_size, ", ct_offset = ", ct_offset));
188           DummyStreamingAead saead(pt_segment_size, header_size, ct_offset);
189           // Pre-compute the ciphertext.
190           auto ciphertext =
191               GetCiphertextSource(&saead, plaintext, "some aad", ct_offset);
192           // Check the decryption of the pre-computed ciphertext.
193           auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
194               pt_segment_size, header_size, ct_offset);
195           auto dec_stream_result = DecryptingRandomAccessStream::New(
196               std::move(seg_decrypter), std::move(ciphertext));
197           EXPECT_THAT(dec_stream_result, IsOk());
198           auto dec_stream = std::move(dec_stream_result.value());
199           EXPECT_EQ(pt_size, dec_stream->size().value());
200           std::string decrypted;
201           auto status = internal::ReadAllFromRandomAccessStream(
202               dec_stream.get(), decrypted);
203           EXPECT_THAT(status, StatusIs(absl::StatusCode::kOutOfRange,
204                                        HasSubstr("EOF")));
205           EXPECT_EQ(plaintext, decrypted);
206         }
207       }
208     }
209   }
210 }
211 
TEST(DecryptingRandomAccessStreamTest,SelectiveDecryption)212 TEST(DecryptingRandomAccessStreamTest, SelectiveDecryption) {
213   for (int pt_size : {1, 20, 42, 100, 1000, 10000}) {
214     std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
215     for (int pt_segment_size : {50, 100, 200}) {
216       for (int header_size : {5, 10, 20}) {
217         for (int ct_offset : {0, 1, 10}) {
218           SCOPED_TRACE(absl::StrCat(
219               "pt_size = ", pt_size, ", pt_segment_size = ", pt_segment_size,
220               ", header_size = ", header_size, ", ct_offset = ", ct_offset));
221           DummyStreamingAead saead(pt_segment_size, header_size, ct_offset);
222           // Pre-compute the ciphertext.
223           auto ciphertext =
224               GetCiphertextSource(&saead, plaintext, "some aad", ct_offset);
225           // Check the decryption of the pre-computed ciphertext.
226           auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
227               pt_segment_size, header_size, ct_offset);
228           auto dec_stream_result = DecryptingRandomAccessStream::New(
229               std::move(seg_decrypter), std::move(ciphertext));
230           EXPECT_THAT(dec_stream_result, IsOk());
231           auto dec_stream = std::move(dec_stream_result.value());
232           for (int position : {0, 1, 2, pt_size / 2, pt_size - 1}) {
233             for (int chunk_size : {1, pt_size / 2, pt_size}) {
234               SCOPED_TRACE(absl::StrCat("position = ", position,
235                                         ", chunk_size = ", chunk_size));
236               auto buffer =
237                   std::move(util::Buffer::New(std::max(chunk_size, 1)).value());
238               auto status =
239                   dec_stream->PRead(position, chunk_size, buffer.get());
240               if (position <= pt_size) {
241                 EXPECT_TRUE(status.ok() ||
242                             status.code() == absl::StatusCode::kOutOfRange);
243               } else {
244                 EXPECT_THAT(status,
245                             StatusIs(absl::StatusCode::kInvalidArgument));
246               }
247               EXPECT_EQ(std::min(chunk_size, std::max(pt_size - position, 0)),
248                         buffer->size());
249               EXPECT_EQ(0,
250                         std::memcmp(plaintext.data() + position,
251                                     buffer->get_mem_block(), buffer->size()));
252             }
253           }
254         }
255       }
256     }
257   }
258 }
259 
TEST(DecryptingRandomAccessStreamTest,TruncatedCiphertextDecryption)260 TEST(DecryptingRandomAccessStreamTest, TruncatedCiphertextDecryption) {
261   for (int pt_size : {100, 200, 1000}) {
262     std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
263     for (int pt_segment_size : {50, 70}) {
264       for (int header_size : {5, 10, 20}) {
265         for (int ct_offset : {0, 1, 10}) {
266           SCOPED_TRACE(absl::StrCat(
267               "pt_size = ", pt_size, ", pt_segment_size = ", pt_segment_size,
268               ", header_size = ", header_size, ", ct_offset = ", ct_offset));
269           DummyStreamingAead saead(pt_segment_size, header_size, ct_offset);
270           // Pre-compute the ciphertext.
271           auto ct = GetCiphertext(&saead, plaintext, "some aad", ct_offset);
272           // Check the decryption of a truncated ciphertext.
273           auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
274               pt_segment_size, header_size, ct_offset);
275           for (int trunc_ct_size : {header_size + ct_offset,
276                   static_cast<int>(ct.size()) - 1,
277                   static_cast<int>(ct.size()) - pt_segment_size,
278                   static_cast<int>(ct.size())
279                       - seg_decrypter->get_ciphertext_segment_size()}) {
280             for (int chunk_size : {pt_size}) {
281               SCOPED_TRACE(absl::StrCat("ct_size = ", ct.size(),
282                                         ", trunc_ct_size = ", trunc_ct_size,
283                                         ", chunk_size = ", chunk_size));
284               auto trunc_ct = std::make_unique<TestRandomAccessStream>(
285                   ct.substr(0, trunc_ct_size));
286               int position = 0;
287               auto per_stream_seg_decrypter =
288                   absl::make_unique<DummyStreamSegmentDecrypter>(
289                       pt_segment_size, header_size, ct_offset);
290               auto dec_stream_result = DecryptingRandomAccessStream::New(
291                   std::move(per_stream_seg_decrypter), std::move(trunc_ct));
292               EXPECT_THAT(dec_stream_result, IsOk());
293               auto dec_stream = std::move(dec_stream_result.value());
294               auto buffer = std::move(util::Buffer::New(chunk_size).value());
295               auto status =
296                   dec_stream->PRead(position, chunk_size, buffer.get());
297               EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
298             }
299           }
300         }
301       }
302     }
303   }
304 }
305 
TEST(DecryptingRandomAccessStreamTest,OutOfRangeDecryption)306 TEST(DecryptingRandomAccessStreamTest, OutOfRangeDecryption) {
307   for (int pt_size : {0, 20, 42, 100, 1000, 10000}) {
308     std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
309     for (int pt_segment_size : {50, 100, 123}) {
310       for (int header_size : {5, 10, 20}) {
311         SCOPED_TRACE(absl::StrCat("pt_size = ", pt_size,
312                                   ", pt_segment_size = ", pt_segment_size,
313                                   ", header_size = ", header_size));
314         int ct_offset = 0;
315         DummyStreamingAead saead(pt_segment_size, header_size, ct_offset);
316         // Pre-compute the ciphertext.
317         auto ciphertext =
318             GetCiphertextSource(&saead, plaintext, "some aad", ct_offset);
319         // Check the decryption of the pre-computed ciphertext.
320         auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
321             pt_segment_size, header_size, ct_offset);
322         auto dec_stream_result = DecryptingRandomAccessStream::New(
323             std::move(seg_decrypter), std::move(ciphertext));
324         EXPECT_THAT(dec_stream_result, IsOk());
325         auto dec_stream = std::move(dec_stream_result.value());
326         int chunk_size = 1;
327         auto buffer = std::move(util::Buffer::New(chunk_size).value());
328         int position = pt_size;
329         // Negative chunk size.
330         auto status = dec_stream->PRead(position, -1, buffer.get());
331         EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
332 
333         // Negative position.
334         status = dec_stream->PRead(-1, chunk_size, buffer.get());
335         EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
336 
337         // Reading at EOF.
338         status = dec_stream->PRead(position, chunk_size, buffer.get());
339         EXPECT_THAT(status, StatusIs(absl::StatusCode::kOutOfRange));
340 
341         // Reading past EOF.
342         status = dec_stream->PRead(position + 1 , chunk_size, buffer.get());
343         EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
344       }
345     }
346   }
347 }
348 
TEST(DecryptingRandomAccessStreamTest,WrongCiphertext)349 TEST(DecryptingRandomAccessStreamTest, WrongCiphertext) {
350   int pt_segment_size = 42;
351   int header_size = 10;
352   int ct_offset = 0;
353   for (int ct_size : {0, 10, 100}) {
354     SCOPED_TRACE(absl::StrCat("ct_size = ", ct_size));
355     // Try decrypting a wrong ciphertext.
356     auto wrong_ct = std::make_unique<TestRandomAccessStream>(
357         subtle::Random::GetRandomBytes(ct_size));
358     auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
359         pt_segment_size, header_size, ct_offset);
360     auto dec_stream_result = DecryptingRandomAccessStream::New(
361         std::move(seg_decrypter), std::move(wrong_ct));
362     EXPECT_THAT(dec_stream_result, IsOk());
363     auto dec_stream = std::move(dec_stream_result.value());
364     std::string decrypted;
365     int chunk_size = 1;
366     int position = 0;
367     auto buffer = std::move(util::Buffer::New(chunk_size).value());
368     auto status = dec_stream->PRead(position, chunk_size, buffer.get());
369     EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
370   }
371 }
372 
TEST(DecryptingRandomAccessStreamTest,NullSegmentDecrypter)373 TEST(DecryptingRandomAccessStreamTest, NullSegmentDecrypter) {
374   auto ct_stream =
375       std::make_unique<TestRandomAccessStream>("some ciphertext contents");
376   auto dec_stream_result =
377       DecryptingRandomAccessStream::New(nullptr, std::move(ct_stream));
378   EXPECT_THAT(dec_stream_result.status(),
379               StatusIs(absl::StatusCode::kInvalidArgument,
380                        HasSubstr("segment_decrypter must be non-null")));
381 }
382 
TEST(DecryptingRandomAccessStreamTest,NullCiphertextSource)383 TEST(DecryptingRandomAccessStreamTest, NullCiphertextSource) {
384   int pt_segment_size = 42;
385   int header_size = 10;
386   int ct_offset = 0;
387   auto seg_decrypter = absl::make_unique<DummyStreamSegmentDecrypter>(
388       pt_segment_size, header_size, ct_offset);
389   auto dec_stream_result =
390       DecryptingRandomAccessStream::New(std::move(seg_decrypter), nullptr);
391   EXPECT_THAT(dec_stream_result.status(),
392               StatusIs(absl::StatusCode::kInvalidArgument,
393                        HasSubstr("cipertext_source must be non-null")));
394 }
395 
396 }  // namespace
397 }  // namespace subtle
398 }  // namespace tink
399 }  // namespace crypto
400