1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16
17 #include "tink/subtle/streaming_aead_decrypting_stream.h"
18
19 #include <algorithm>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/memory/memory.h"
26 #include "absl/status/status.h"
27 #include "tink/input_stream.h"
28 #include "tink/subtle/stream_segment_decrypter.h"
29 #include "tink/util/status.h"
30 #include "tink/util/statusor.h"
31
32 using crypto::tink::InputStream;
33 using crypto::tink::util::Status;
34 using crypto::tink::util::StatusOr;
35
36 namespace crypto {
37 namespace tink {
38 namespace subtle {
39
40 namespace {
41
42 // Reads at most 'count' bytes from the specified 'input_stream',
43 // and puts them into 'output', where both 'input_stream' and 'output'
44 // must be non-null.
45 // Will try to read exactly 'count' bytes, unless the end of stream
46 // is reached (then returns status OUT_OF_RANGE) or an error occurs
47 // (an other non-OK status).
48 // Before returning, resizes 'output' accordingly, to reflect
49 // the actual number of bytes read.
50
ReadFromStream(InputStream * input_stream,int count,std::vector<uint8_t> * output)51 util::Status ReadFromStream(InputStream* input_stream, int count,
52 std::vector<uint8_t>* output) {
53 if (count <= 0 || input_stream == nullptr || output == nullptr) {
54 return Status(absl::StatusCode::kInternal, "Illegal read from a stream");
55 }
56 const void* buffer;
57 int bytes_to_be_read = count;
58 int read_bytes; // bytes read in one Next()-call
59 int needed_bytes; // bytes actually needed
60 output->resize(count);
61 while (bytes_to_be_read > 0) {
62 auto next_result = input_stream->Next(&buffer);
63 if (next_result.status().code() == absl::StatusCode::kOutOfRange) {
64 // End of stream.
65 output->resize(count - bytes_to_be_read);
66 return next_result.status();
67 }
68 if (!next_result.ok()) return next_result.status();
69 read_bytes = next_result.value();
70 needed_bytes = std::min(read_bytes, bytes_to_be_read);
71 memcpy(output->data() + (count - bytes_to_be_read), buffer, needed_bytes);
72 bytes_to_be_read -= needed_bytes;
73 }
74 if (read_bytes > needed_bytes) {
75 input_stream->BackUp(read_bytes - needed_bytes);
76 }
77 return util::OkStatus();
78 }
79
80 } // anonymous namespace
81
82 // static
New(std::unique_ptr<StreamSegmentDecrypter> segment_decrypter,std::unique_ptr<InputStream> ciphertext_source)83 StatusOr<std::unique_ptr<InputStream>> StreamingAeadDecryptingStream::New(
84 std::unique_ptr<StreamSegmentDecrypter> segment_decrypter,
85 std::unique_ptr<InputStream> ciphertext_source) {
86 if (segment_decrypter == nullptr) {
87 return Status(absl::StatusCode::kInvalidArgument,
88 "segment_decrypter must be non-null");
89 }
90 if (ciphertext_source == nullptr) {
91 return Status(absl::StatusCode::kInvalidArgument,
92 "cipertext_source must be non-null");
93 }
94 std::unique_ptr<StreamingAeadDecryptingStream> dec_stream(
95 new StreamingAeadDecryptingStream());
96 dec_stream->segment_decrypter_ = std::move(segment_decrypter);
97 dec_stream->ct_source_ = std::move(ciphertext_source);
98 int first_segment_size =
99 dec_stream->segment_decrypter_->get_ciphertext_segment_size() -
100 dec_stream->segment_decrypter_->get_ciphertext_offset() -
101 dec_stream->segment_decrypter_->get_header_size();
102 if (first_segment_size <= 0) {
103 return Status(absl::StatusCode::kInternal,
104 "Size of the first segment must be greater than 0.");
105 }
106 dec_stream->ct_buffer_.resize(first_segment_size);
107 dec_stream->position_ = 0;
108 dec_stream->segment_number_ = 0;
109 dec_stream->is_initialized_ = false;
110 dec_stream->read_last_segment_ = false;
111 dec_stream->count_backedup_ = first_segment_size;
112 dec_stream->pt_buffer_offset_ = 0;
113 dec_stream->status_ = util::OkStatus();
114 return {std::move(dec_stream)};
115 }
116
Next(const void ** data)117 StatusOr<int> StreamingAeadDecryptingStream::Next(const void** data) {
118 if (!status_.ok()) return status_;
119
120 // The first call to Next().
121 if (!is_initialized_) {
122 std::vector<uint8_t> header;
123 status_ = ReadFromStream(ct_source_.get(),
124 segment_decrypter_->get_header_size(), &header);
125 if (status_.code() == absl::StatusCode::kOutOfRange) {
126 status_ = Status(absl::StatusCode::kInvalidArgument,
127 "Could not read stream header.");
128 }
129 if (!status_.ok()) return status_;
130 status_ = segment_decrypter_->Init(header);
131 if (!status_.ok()) return status_;
132 is_initialized_ = true;
133 count_backedup_ = 0;
134 status_ = ReadFromStream(ct_source_.get(), ct_buffer_.size(), &ct_buffer_);
135 if (!status_.ok() && (status_.code() != absl::StatusCode::kOutOfRange)) {
136 return status_;
137 }
138 read_last_segment_ = (status_.code() == absl::StatusCode::kOutOfRange);
139 status_ = segment_decrypter_->DecryptSegment(
140 ct_buffer_,
141 /* segment_number = */ segment_number_,
142 /* is_last_segment = */ read_last_segment_,
143 &pt_buffer_);
144 if (!status_.ok() && !read_last_segment_) {
145 // Try decrypting as the last segment, if haven't tried yet.
146 read_last_segment_ = true;
147 status_ = segment_decrypter_->DecryptSegment(
148 ct_buffer_,
149 /* segment_number = */ segment_number_,
150 /* is_last_segment = */ read_last_segment_,
151 &pt_buffer_);
152 }
153 if (!status_.ok()) return status_;
154 *data = pt_buffer_.data();
155 position_ = pt_buffer_.size();
156 return pt_buffer_.size();
157 }
158
159 // If some bytes were backed up, return them first.
160 if (count_backedup_ > 0) {
161 position_ += count_backedup_;
162 pt_buffer_offset_ = pt_buffer_.size() - count_backedup_;
163 int backedup = count_backedup_;
164 count_backedup_ = 0;
165 *data = pt_buffer_.data() + pt_buffer_offset_;
166 return backedup;
167 }
168
169 // We're past the first segment, and no space was backed up, so we
170 // try to get and decrypt the next ciphertext segment, if any.
171 if (read_last_segment_) {
172 status_ = Status(absl::StatusCode::kOutOfRange, "Reached end of stream.");
173 return status_;
174 }
175 segment_number_++;
176 ct_buffer_.resize(segment_decrypter_->get_ciphertext_segment_size());
177 status_ = ReadFromStream(ct_source_.get(), ct_buffer_.size(), &ct_buffer_);
178 if (!status_.ok() && (status_.code() != absl::StatusCode::kOutOfRange)) {
179 return status_;
180 }
181 read_last_segment_ = (status_.code() == absl::StatusCode::kOutOfRange);
182 status_ = segment_decrypter_->DecryptSegment(
183 ct_buffer_,
184 /* segment_number = */ segment_number_,
185 /* is_last_segment = */ read_last_segment_,
186 &pt_buffer_);
187 if (!status_.ok() && !read_last_segment_) {
188 // Try decrypting as the last segment, if haven't tried yet.
189 read_last_segment_ = true;
190 status_ = segment_decrypter_->DecryptSegment(
191 ct_buffer_,
192 /* segment_number = */ segment_number_,
193 /* is_last_segment = */ read_last_segment_,
194 &pt_buffer_);
195 }
196 if (!status_.ok()) return status_;
197 *data = pt_buffer_.data();
198 pt_buffer_offset_ = 0;
199 position_ += pt_buffer_.size();
200 return pt_buffer_.size();
201 }
202
BackUp(int count)203 void StreamingAeadDecryptingStream::BackUp(int count) {
204 if (!is_initialized_ || !status_.ok() || count < 1) return;
205 int curr_buffer_size = pt_buffer_.size() - pt_buffer_offset_;
206 int actual_count = std::min(count, curr_buffer_size - count_backedup_);
207 count_backedup_ += actual_count;
208 position_ -= actual_count;
209 }
210
Position() const211 int64_t StreamingAeadDecryptingStream::Position() const {
212 return position_;
213 }
214
215 } // namespace subtle
216 } // namespace tink
217 } // namespace crypto
218