xref: /aosp_15_r20/external/tink/cc/subtle/decrypting_random_access_stream.cc (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/subtle/decrypting_random_access_stream.h"
18 
19 #include <algorithm>
20 #include <cstring>
21 #include <limits>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/base/thread_annotations.h"
27 #include "absl/memory/memory.h"
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/synchronization/mutex.h"
31 #include "tink/random_access_stream.h"
32 #include "tink/subtle/stream_segment_decrypter.h"
33 #include "tink/util/buffer.h"
34 #include "tink/util/errors.h"
35 #include "tink/util/status.h"
36 #include "tink/util/statusor.h"
37 
38 namespace crypto {
39 namespace tink {
40 namespace subtle {
41 
42 using crypto::tink::RandomAccessStream;
43 using crypto::tink::ToStatusF;
44 using crypto::tink::util::Buffer;
45 using crypto::tink::util::Status;
46 using crypto::tink::util::StatusOr;
47 
48 // static
New(std::unique_ptr<StreamSegmentDecrypter> segment_decrypter,std::unique_ptr<RandomAccessStream> ciphertext_source)49 StatusOr<std::unique_ptr<RandomAccessStream>> DecryptingRandomAccessStream::New(
50     std::unique_ptr<StreamSegmentDecrypter> segment_decrypter,
51     std::unique_ptr<RandomAccessStream> ciphertext_source) {
52   if (segment_decrypter == nullptr) {
53     return Status(absl::StatusCode::kInvalidArgument,
54                   "segment_decrypter must be non-null");
55   }
56   if (ciphertext_source == nullptr) {
57     return Status(absl::StatusCode::kInvalidArgument,
58                   "cipertext_source must be non-null");
59   }
60   std::unique_ptr<DecryptingRandomAccessStream> dec_stream(
61       new DecryptingRandomAccessStream());
62   absl::MutexLock lock(&(dec_stream->status_mutex_));
63   dec_stream->segment_decrypter_ = std::move(segment_decrypter);
64   dec_stream->ct_source_ = std::move(ciphertext_source);
65 
66   if (dec_stream->segment_decrypter_->get_ciphertext_offset() < 0) {
67     return util::Status(absl::StatusCode::kInvalidArgument,
68                         "The ciphertext offset must be non-negative");
69   }
70   int first_segment_size =
71       dec_stream->segment_decrypter_->get_plaintext_segment_size() -
72       dec_stream->segment_decrypter_->get_ciphertext_offset() -
73       dec_stream->segment_decrypter_->get_header_size();
74   if (first_segment_size <= 0) {
75     return Status(absl::StatusCode::kInvalidArgument,
76                   "Size of the first segment must be greater than 0.");
77   }
78   dec_stream->status_ =
79       Status(absl::StatusCode::kUnavailable,
80              "The header hasn't been read yet.");
81   return {std::move(dec_stream)};
82 }
83 
PRead(int64_t position,int count,Buffer * dest_buffer)84 util::Status DecryptingRandomAccessStream::PRead(int64_t position, int count,
85                                                  Buffer* dest_buffer) {
86   if (dest_buffer == nullptr) {
87     return Status(absl::StatusCode::kInvalidArgument,
88                   "dest_buffer must be non-null");
89   }
90   auto status = dest_buffer->set_size(0);
91   if (!status.ok()) return status;
92   if (count < 0) {
93     return Status(absl::StatusCode::kInvalidArgument,
94                   "count cannot be negative");
95   }
96   if (count > dest_buffer->allocated_size()) {
97     return Status(absl::StatusCode::kInvalidArgument, "buffer too small");
98   }
99   if (position < 0) {
100     return Status(absl::StatusCode::kInvalidArgument,
101                   "position cannot be negative");
102   }
103 
104   {  // Initialize, if not initialized yet.
105     absl::MutexLock lock(&status_mutex_);
106     InitializeIfNeeded();
107     if (!status_.ok()) return status_;
108   }
109 
110   if (position > pt_size_) {
111     return Status(absl::StatusCode::kInvalidArgument, "position too large");
112   }
113   return PReadAndDecrypt(position, count, dest_buffer);
114 }
115 
116 // NOTE: As the initialization below requires availability of size() of the
117 // underlying ciphertext stream, the current implementation does not support
118 // dynamic encrypted streams, whose size is not known or can change over time
119 // (e.g. when one process produces an encrypted file/stream, while concurrently
120 // another process consumes the resulting encrypted stream).
121 //
122 // This is consistent with Java implementation of SeekableDecryptingChannel,
123 // and detects ciphertext truncation attacks.  However, a support for dynamic
124 // streams can be added in the future if needed.
InitializeIfNeeded()125 void DecryptingRandomAccessStream::InitializeIfNeeded()
126     ABSL_EXCLUSIVE_LOCKS_REQUIRED(status_mutex_) {
127   if (status_.code() != absl::StatusCode::kUnavailable) {
128     // Already initialized or stream failed permanently.
129     return;
130   }
131 
132   // Initialize segment decrypter from data in the stream header.
133   header_size_ = segment_decrypter_->get_header_size();
134   ct_offset_ = segment_decrypter_->get_ciphertext_offset();
135   auto buf_result = Buffer::New(header_size_);
136   if (!buf_result.ok()) {
137     status_ = buf_result.status();
138     return;
139   }
140   auto buf = std::move(buf_result.value());
141   status_ = ct_source_->PRead(ct_offset_, header_size_, buf.get());
142   if (!status_.ok()) {
143     if (status_.code() == absl::StatusCode::kOutOfRange) {
144       status_ =
145           Status(absl::StatusCode::kInvalidArgument, "could not read header");
146     }
147     return;
148   }
149   status_ = segment_decrypter_->Init(std::vector<uint8_t>(
150       buf->get_mem_block(), buf->get_mem_block() + header_size_));
151   if (!status_.ok()) return;
152   ct_segment_size_ = segment_decrypter_->get_ciphertext_segment_size();
153   pt_segment_size_ = segment_decrypter_->get_plaintext_segment_size();
154   ct_segment_overhead_ = ct_segment_size_ - pt_segment_size_;
155 
156   // Calculate the number of segments and the plaintext size.
157   StatusOr<int64_t> ct_size_result = ct_source_->size();
158   if (!ct_size_result.ok()) {
159     status_ = ct_size_result.status();
160     return;
161   }
162   int64_t ct_size = ct_size_result.value();
163   // ct_segment_size_ is always larger than 1, thus full_segment_count is always
164   // smaller than std::numeric_limits<int64_t>::max().
165   int64_t full_segment_count = ct_size / ct_segment_size_;
166   int64_t remainder_size = ct_size % ct_segment_size_;
167   if (remainder_size > 0) {
168     // This does not overflow because full_segment_count <
169     // std::numeric_limits<int64_t>::max().
170     segment_count_ = full_segment_count + 1;
171   } else {
172     segment_count_ = full_segment_count;
173   }
174   // Tink supports up to 2^32 segments.
175   if (segment_count_ - 1 > std::numeric_limits<uint32_t>::max()) {
176     status_ = Status(absl::StatusCode::kInvalidArgument,
177                      absl::StrCat("too many segments: ", segment_count_));
178     return;
179   }
180 
181   // This should not overflow because:
182   // * segment_count is int64 and smaller than 2^32, and
183   // * ct_segment_overhead_, ct_offset_ and header_size_ are small int numbers.
184   auto overhead =
185       ct_segment_overhead_ * segment_count_ + ct_offset_ + header_size_;
186   if (overhead > ct_size) {
187     status_ = Status(absl::StatusCode::kInvalidArgument,
188                      "ciphertext stream is too short");
189     return;
190   }
191   pt_size_ = ct_size - overhead;
192 }
193 
GetPlaintextOffset(int64_t pt_position)194 int DecryptingRandomAccessStream::GetPlaintextOffset(int64_t pt_position) {
195   if (GetSegmentNr(pt_position) == 0) return pt_position;
196   // Computed according to the formula:
197   // (pt_position - (pt_segment_size_ - ct_offset_ - header_size_))
198   //     % pt_segment_size_;
199   // pt_position + ct_offset_ + header_size_ is always smaller than size of
200   // the ciphertext, thus it should never overflow.
201   return (pt_position + ct_offset_ + header_size_) % pt_segment_size_;
202 }
203 
GetSegmentNr(int64_t pt_position)204 int64_t DecryptingRandomAccessStream::GetSegmentNr(int64_t pt_position) {
205   return (pt_position + ct_offset_ + header_size_) / pt_segment_size_;
206 }
207 
ReadAndDecryptSegment(int64_t segment_nr,Buffer * ct_buffer,std::vector<uint8_t> * pt_segment)208 util::Status DecryptingRandomAccessStream::ReadAndDecryptSegment(
209     int64_t segment_nr, Buffer* ct_buffer, std::vector<uint8_t>* pt_segment) {
210   int64_t ct_position = segment_nr * ct_segment_size_;
211   if (ct_position / ct_segment_size_ != segment_nr /* overflow occured! */) {
212     return Status(absl::StatusCode::kOutOfRange,
213                   absl::StrCat("segment_nr * ct_segment_size too large: ",
214                                segment_nr, ct_segment_size_));
215   }
216   int segment_size = ct_segment_size_;
217   if (segment_nr == 0) {
218     // The sum of ct_offset_ and header_size is always smaller than
219     // ct_segment_size_, which is an int, therefore the next two statements
220     // should never overflow.
221     ct_position = ct_offset_ + header_size_;
222     segment_size = ct_segment_size_ - ct_position;
223   }
224   bool is_last_segment = (segment_nr == segment_count_ - 1);
225   auto pread_status = ct_source_->PRead(ct_position, segment_size, ct_buffer);
226   if (pread_status.ok() ||
227       (is_last_segment && ct_buffer->size() > 0 &&
228        pread_status.code() == absl::StatusCode::kOutOfRange)) {
229     // some bytes were read
230     auto dec_status = segment_decrypter_->DecryptSegment(
231         std::vector<uint8_t>(ct_buffer->get_mem_block(),
232                              ct_buffer->get_mem_block() + ct_buffer->size()),
233         segment_nr, is_last_segment, pt_segment);
234     if (dec_status.ok()) {
235       return is_last_segment ?
236           Status(absl::StatusCode::kOutOfRange, "EOF") : util::OkStatus();
237     }
238     return dec_status;
239   }
240   return pread_status;
241 }
242 
PReadAndDecrypt(int64_t position,int count,Buffer * dest_buffer)243 util::Status DecryptingRandomAccessStream::PReadAndDecrypt(
244     int64_t position, int count, Buffer* dest_buffer) {
245   if (position < 0 || count < 0 || dest_buffer == nullptr
246       || count > dest_buffer->allocated_size() || dest_buffer->size() != 0) {
247     return Status(absl::StatusCode::kInternal,
248                   "Invalid parameters to PReadAndDecrypt");
249   }
250 
251   if (position > std::numeric_limits<int64_t>::max() - count) {
252     return Status(
253         absl::StatusCode::kOutOfRange,
254         absl::StrCat(
255             "Invalid parameters to PReadAndDecrypt; position too large: ",
256             position));
257   }
258 
259   auto pt_size_result = size();
260   if (pt_size_result.ok()) {
261     auto pt_size = pt_size_result.value();
262     if (position > pt_size) {
263       return Status(absl::StatusCode::kOutOfRange,
264                     "position is larger than stream size");
265     }
266   }
267   auto ct_buffer_result = Buffer::New(ct_segment_size_);
268   if (!ct_buffer_result.ok()) {
269     return ToStatusF(absl::StatusCode::kInvalidArgument,
270                      "Invalid ciphertext segment size %d.", ct_segment_size_);
271   }
272   auto ct_buffer = std::move(ct_buffer_result.value());
273   std::vector<uint8_t> pt_segment;
274   int remaining = count;
275   int read_count = 0;
276   int pt_offset = GetPlaintextOffset(position);
277   while (remaining > 0) {
278     auto segment_nr = GetSegmentNr(position + read_count);
279     auto status =
280         ReadAndDecryptSegment(segment_nr, ct_buffer.get(), &pt_segment);
281     if (status.ok() || status.code() == absl::StatusCode::kOutOfRange) {
282       int pt_count = pt_segment.size() - pt_offset;
283       int to_copy_count = std::min(pt_count, remaining);
284       auto s = dest_buffer->set_size(read_count + to_copy_count);
285       if (!s.ok()) return s;
286       std::memcpy(dest_buffer->get_mem_block() + read_count,
287                   pt_segment.data() + pt_offset, to_copy_count);
288       pt_offset = 0;
289       if (status.code() == absl::StatusCode::kOutOfRange &&
290           to_copy_count == pt_count)
291         return status;
292       read_count += to_copy_count;
293       remaining = count - dest_buffer->size();
294     } else {  // some other error happened
295       return status;
296     }
297   }
298   return util::OkStatus();
299 }
300 
size()301 StatusOr<int64_t> DecryptingRandomAccessStream::size() {
302   {  // Initialize, if not initialized yet.
303     absl::MutexLock lock(&status_mutex_);
304     InitializeIfNeeded();
305     if (!status_.ok()) return status_;
306   }
307   return pt_size_;
308 }
309 
310 }  // namespace subtle
311 }  // namespace tink
312 }  // namespace crypto
313