xref: /aosp_15_r20/external/tink/cc/subtle/streaming_aead_decrypting_stream.cc (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/subtle/streaming_aead_decrypting_stream.h"
18 
19 #include <algorithm>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/memory/memory.h"
26 #include "absl/status/status.h"
27 #include "tink/input_stream.h"
28 #include "tink/subtle/stream_segment_decrypter.h"
29 #include "tink/util/status.h"
30 #include "tink/util/statusor.h"
31 
32 using crypto::tink::InputStream;
33 using crypto::tink::util::Status;
34 using crypto::tink::util::StatusOr;
35 
36 namespace crypto {
37 namespace tink {
38 namespace subtle {
39 
40 namespace {
41 
42 // Reads at most 'count' bytes from the specified 'input_stream',
43 // and puts them into 'output', where both 'input_stream' and 'output'
44 // must be non-null.
45 // Will try to read exactly 'count' bytes, unless the end of stream
46 // is reached (then returns status OUT_OF_RANGE) or an error occurs
47 // (an other non-OK status).
48 // Before returning, resizes 'output' accordingly, to reflect
49 // the actual number of bytes read.
50 
ReadFromStream(InputStream * input_stream,int count,std::vector<uint8_t> * output)51 util::Status ReadFromStream(InputStream* input_stream, int count,
52                             std::vector<uint8_t>* output) {
53   if (count <= 0 || input_stream == nullptr || output == nullptr) {
54     return Status(absl::StatusCode::kInternal, "Illegal read from a stream");
55   }
56   const void* buffer;
57   int bytes_to_be_read = count;
58   int read_bytes;    // bytes read in one Next()-call
59   int needed_bytes;  // bytes actually needed
60   output->resize(count);
61   while (bytes_to_be_read > 0) {
62     auto next_result = input_stream->Next(&buffer);
63     if (next_result.status().code() == absl::StatusCode::kOutOfRange) {
64       // End of stream.
65       output->resize(count - bytes_to_be_read);
66       return next_result.status();
67     }
68     if (!next_result.ok()) return next_result.status();
69     read_bytes = next_result.value();
70     needed_bytes = std::min(read_bytes, bytes_to_be_read);
71     memcpy(output->data() + (count - bytes_to_be_read), buffer, needed_bytes);
72     bytes_to_be_read -= needed_bytes;
73   }
74   if (read_bytes > needed_bytes) {
75     input_stream->BackUp(read_bytes - needed_bytes);
76   }
77   return util::OkStatus();
78 }
79 
80 }  // anonymous namespace
81 
82 // static
New(std::unique_ptr<StreamSegmentDecrypter> segment_decrypter,std::unique_ptr<InputStream> ciphertext_source)83 StatusOr<std::unique_ptr<InputStream>> StreamingAeadDecryptingStream::New(
84     std::unique_ptr<StreamSegmentDecrypter> segment_decrypter,
85     std::unique_ptr<InputStream> ciphertext_source) {
86   if (segment_decrypter == nullptr) {
87     return Status(absl::StatusCode::kInvalidArgument,
88                   "segment_decrypter must be non-null");
89   }
90   if (ciphertext_source == nullptr) {
91     return Status(absl::StatusCode::kInvalidArgument,
92                   "cipertext_source must be non-null");
93   }
94   std::unique_ptr<StreamingAeadDecryptingStream> dec_stream(
95       new StreamingAeadDecryptingStream());
96   dec_stream->segment_decrypter_ = std::move(segment_decrypter);
97   dec_stream->ct_source_ = std::move(ciphertext_source);
98   int first_segment_size =
99       dec_stream->segment_decrypter_->get_ciphertext_segment_size() -
100       dec_stream->segment_decrypter_->get_ciphertext_offset() -
101       dec_stream->segment_decrypter_->get_header_size();
102   if (first_segment_size <= 0) {
103     return Status(absl::StatusCode::kInternal,
104                   "Size of the first segment must be greater than 0.");
105   }
106   dec_stream->ct_buffer_.resize(first_segment_size);
107   dec_stream->position_ = 0;
108   dec_stream->segment_number_ = 0;
109   dec_stream->is_initialized_ = false;
110   dec_stream->read_last_segment_ = false;
111   dec_stream->count_backedup_ = first_segment_size;
112   dec_stream->pt_buffer_offset_ = 0;
113   dec_stream->status_ = util::OkStatus();
114   return {std::move(dec_stream)};
115 }
116 
Next(const void ** data)117 StatusOr<int> StreamingAeadDecryptingStream::Next(const void** data) {
118   if (!status_.ok()) return status_;
119 
120   // The first call to Next().
121   if (!is_initialized_) {
122     std::vector<uint8_t> header;
123     status_ = ReadFromStream(ct_source_.get(),
124                              segment_decrypter_->get_header_size(), &header);
125     if (status_.code() == absl::StatusCode::kOutOfRange) {
126       status_ = Status(absl::StatusCode::kInvalidArgument,
127                        "Could not read stream header.");
128     }
129     if (!status_.ok()) return status_;
130     status_ = segment_decrypter_->Init(header);
131     if (!status_.ok()) return status_;
132     is_initialized_ = true;
133     count_backedup_ = 0;
134     status_ = ReadFromStream(ct_source_.get(), ct_buffer_.size(), &ct_buffer_);
135     if (!status_.ok() && (status_.code() != absl::StatusCode::kOutOfRange)) {
136       return status_;
137     }
138     read_last_segment_ = (status_.code() == absl::StatusCode::kOutOfRange);
139     status_ = segment_decrypter_->DecryptSegment(
140         ct_buffer_,
141         /* segment_number = */ segment_number_,
142         /* is_last_segment = */ read_last_segment_,
143         &pt_buffer_);
144     if (!status_.ok() && !read_last_segment_) {
145       // Try decrypting as the last segment, if haven't tried yet.
146       read_last_segment_ = true;
147       status_ = segment_decrypter_->DecryptSegment(
148           ct_buffer_,
149           /* segment_number = */ segment_number_,
150           /* is_last_segment = */ read_last_segment_,
151           &pt_buffer_);
152     }
153     if (!status_.ok()) return status_;
154     *data = pt_buffer_.data();
155     position_ = pt_buffer_.size();
156     return pt_buffer_.size();
157   }
158 
159   // If some bytes were backed up, return them first.
160   if (count_backedup_ > 0) {
161     position_ += count_backedup_;
162     pt_buffer_offset_ = pt_buffer_.size() - count_backedup_;
163     int backedup = count_backedup_;
164     count_backedup_ = 0;
165     *data = pt_buffer_.data() + pt_buffer_offset_;
166     return backedup;
167   }
168 
169   // We're past the first segment, and no space was backed up, so we
170   // try to get and decrypt the next ciphertext segment, if any.
171   if (read_last_segment_) {
172     status_ = Status(absl::StatusCode::kOutOfRange, "Reached end of stream.");
173     return status_;
174   }
175   segment_number_++;
176   ct_buffer_.resize(segment_decrypter_->get_ciphertext_segment_size());
177   status_ = ReadFromStream(ct_source_.get(), ct_buffer_.size(), &ct_buffer_);
178   if (!status_.ok() && (status_.code() != absl::StatusCode::kOutOfRange)) {
179     return status_;
180   }
181   read_last_segment_ = (status_.code() == absl::StatusCode::kOutOfRange);
182   status_ = segment_decrypter_->DecryptSegment(
183       ct_buffer_,
184       /* segment_number = */ segment_number_,
185       /* is_last_segment = */ read_last_segment_,
186       &pt_buffer_);
187   if (!status_.ok() && !read_last_segment_) {
188     // Try decrypting as the last segment, if haven't tried yet.
189     read_last_segment_ = true;
190     status_ = segment_decrypter_->DecryptSegment(
191         ct_buffer_,
192         /* segment_number = */ segment_number_,
193         /* is_last_segment = */ read_last_segment_,
194         &pt_buffer_);
195   }
196   if (!status_.ok()) return status_;
197   *data = pt_buffer_.data();
198   pt_buffer_offset_ = 0;
199   position_ += pt_buffer_.size();
200   return pt_buffer_.size();
201 }
202 
BackUp(int count)203 void StreamingAeadDecryptingStream::BackUp(int count) {
204   if (!is_initialized_ || !status_.ok() || count < 1) return;
205   int curr_buffer_size = pt_buffer_.size() - pt_buffer_offset_;
206   int actual_count = std::min(count, curr_buffer_size - count_backedup_);
207   count_backedup_ += actual_count;
208   position_ -= actual_count;
209 }
210 
Position() const211 int64_t StreamingAeadDecryptingStream::Position() const {
212   return position_;
213 }
214 
215 }  // namespace subtle
216 }  // namespace tink
217 }  // namespace crypto
218