1 // Copyright 2019 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 /////////////////////////////////////////////////////////////////////////////// 16 17 #ifndef TINK_SUBTLE_TEST_UTIL_H_ 18 #define TINK_SUBTLE_TEST_UTIL_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "absl/memory/memory.h" 25 #include "absl/status/status.h" 26 #include "absl/strings/str_cat.h" 27 #include "absl/strings/string_view.h" 28 #include "tink/input_stream.h" 29 #include "tink/output_stream.h" 30 #include "tink/subtle/nonce_based_streaming_aead.h" 31 #include "tink/subtle/stream_segment_decrypter.h" 32 #include "tink/subtle/stream_segment_encrypter.h" 33 #include "tink/util/status.h" 34 #include "tink/util/statusor.h" 35 36 namespace crypto { 37 namespace tink { 38 namespace subtle { 39 namespace test { 40 41 // Various utilities for testing. 42 /////////////////////////////////////////////////////////////////////////////// 43 44 // Writes 'contents' the specified 'output_stream', and if 'close_stream' 45 // is true, then closes the stream. 46 // Returns the status of output_stream->Close()-operation, or a non-OK status 47 // of a prior output_stream->Next()-operation, if any. 48 util::Status WriteToStream(OutputStream* output_stream, 49 absl::string_view contents, 50 bool close_stream = true); 51 52 // Reads all bytes from the specified 'input_stream', and puts 53 // them into 'output', where both 'input_stream' and 'output must be non-null. 54 // Returns a non-OK status only if reading fails for some reason. 55 // If the end of stream is reached ('input_stream' returns OUT_OF_RANGE), 56 // then this function returns OK. 57 util::Status ReadFromStream(InputStream* input_stream, std::string* output); 58 59 // A dummy encrypter that "encrypts" by just appending to the plaintext 60 // the current segment number and a marker byte indicating whether 61 // the segment is last one. 62 class DummyStreamSegmentEncrypter : public StreamSegmentEncrypter { 63 public: 64 // Size of the per-segment tag added upon encryption. 65 static constexpr int kSegmentTagSize = sizeof(int64_t) + 1; 66 67 // Bytes for marking whether a given segment is the last one. 68 static constexpr char kLastSegment = 'l'; 69 static constexpr char kNotLastSegment = 'n'; 70 DummyStreamSegmentEncrypter(int pt_segment_size,int header_size,int ct_offset)71 DummyStreamSegmentEncrypter(int pt_segment_size, 72 int header_size, 73 int ct_offset) : 74 pt_segment_size_(pt_segment_size), 75 ct_offset_(ct_offset), 76 segment_number_(0) { 77 // Fill the header with 'header_size' copies of letter 'h' 78 header_.resize(0); 79 header_.resize(header_size, static_cast<uint8_t>('h')); 80 generated_output_size_ = header_size; 81 } 82 83 // Generates an expected ciphertext for the given 'plaintext'. GenerateCiphertext(absl::string_view plaintext)84 std::string GenerateCiphertext(absl::string_view plaintext) { 85 std::string ct(header_.begin(), header_.end()); 86 int64_t seg_no = 0; 87 int pos = 0; 88 do { 89 int seg_len = pt_segment_size_; 90 if (pos == 0) { // The first segment. 91 seg_len -= (ct_offset_ + header_.size()); 92 } 93 if (seg_len > plaintext.size() - pos) { // The last segment. 94 seg_len = plaintext.size() - pos; 95 } 96 ct.append(plaintext.substr(pos, seg_len).data(), seg_len); 97 pos += seg_len; 98 ct.append(reinterpret_cast<const char*>(&seg_no), sizeof(seg_no)); 99 ct.append(1, pos < plaintext.size() ? kNotLastSegment : kLastSegment); 100 seg_no++; 101 } while (pos < plaintext.size()); 102 return ct; 103 } 104 EncryptSegment(const std::vector<uint8_t> & plaintext,bool is_last_segment,std::vector<uint8_t> * ciphertext_buffer)105 util::Status EncryptSegment( 106 const std::vector<uint8_t>& plaintext, 107 bool is_last_segment, 108 std::vector<uint8_t>* ciphertext_buffer) override { 109 ciphertext_buffer->resize(plaintext.size() + kSegmentTagSize); 110 memcpy(ciphertext_buffer->data(), plaintext.data(), plaintext.size()); 111 memcpy(ciphertext_buffer->data() + plaintext.size(), 112 &segment_number_, sizeof(segment_number_)); 113 // The last byte of the a ciphertext segment. 114 ciphertext_buffer->back() = 115 is_last_segment ? kLastSegment : kNotLastSegment; 116 generated_output_size_ += ciphertext_buffer->size(); 117 IncSegmentNumber(); 118 return util::OkStatus(); 119 } 120 get_header()121 const std::vector<uint8_t>& get_header() const override { 122 return header_; 123 } 124 get_segment_number()125 int64_t get_segment_number() const override { 126 return segment_number_; 127 } 128 get_plaintext_segment_size()129 int get_plaintext_segment_size() const override { 130 return pt_segment_size_; 131 } 132 get_ciphertext_segment_size()133 int get_ciphertext_segment_size() const override { 134 return pt_segment_size_ + kSegmentTagSize; 135 } 136 get_ciphertext_offset()137 int get_ciphertext_offset() const override { 138 return ct_offset_; 139 } 140 141 ~DummyStreamSegmentEncrypter() override = default; 142 get_generated_output_size()143 int get_generated_output_size() { 144 return generated_output_size_; 145 } 146 147 protected: IncSegmentNumber()148 void IncSegmentNumber() override { 149 segment_number_++; 150 } 151 152 private: 153 std::vector<uint8_t> header_; 154 int pt_segment_size_; 155 int ct_offset_; 156 int64_t segment_number_; 157 int64_t generated_output_size_; 158 }; // class DummyStreamSegmentEncrypter 159 160 // A dummy decrypter that "decrypts" segments encrypted by 161 // DummyStreamSegmentEncrypter. 162 class DummyStreamSegmentDecrypter : public StreamSegmentDecrypter { 163 public: DummyStreamSegmentDecrypter(int pt_segment_size,int header_size,int ct_offset)164 DummyStreamSegmentDecrypter(int pt_segment_size, 165 int header_size, 166 int ct_offset) : 167 pt_segment_size_(pt_segment_size), 168 ct_offset_(ct_offset) { 169 // Fill the header with 'header_size' copies of letter 'h' 170 header_.resize(0); 171 header_.resize(header_size, static_cast<uint8_t>('h')); 172 generated_output_size_ = 0; 173 } 174 Init(const std::vector<uint8_t> & header)175 util::Status Init(const std::vector<uint8_t>& header) override { 176 if (header_.size() != header.size() || 177 memcmp(header_.data(), header.data(), header_.size()) != 0) { 178 return util::Status(absl::StatusCode::kInvalidArgument, 179 "Invalid stream header"); 180 } 181 return util::OkStatus(); 182 } 183 get_header_size()184 int get_header_size() const override { 185 return header_.size(); 186 } 187 DecryptSegment(const std::vector<uint8_t> & ciphertext,int64_t segment_number,bool is_last_segment,std::vector<uint8_t> * plaintext_buffer)188 util::Status DecryptSegment( 189 const std::vector<uint8_t>& ciphertext, 190 int64_t segment_number, 191 bool is_last_segment, 192 std::vector<uint8_t>* plaintext_buffer) override { 193 if (ciphertext.size() < DummyStreamSegmentEncrypter::kSegmentTagSize) { 194 return util::Status(absl::StatusCode::kInvalidArgument, 195 "Ciphertext segment too short"); 196 } 197 if (ciphertext.back() != 198 (is_last_segment ? DummyStreamSegmentEncrypter::kLastSegment : 199 DummyStreamSegmentEncrypter::kNotLastSegment)) { 200 return util::Status(absl::StatusCode::kInvalidArgument, 201 "unexpected last-segment marker"); 202 } 203 int pt_size = 204 ciphertext.size() - DummyStreamSegmentEncrypter::kSegmentTagSize; 205 if (memcmp(ciphertext.data() + pt_size, 206 reinterpret_cast<const char*>(&segment_number), 207 sizeof(segment_number)) != 0) { 208 return util::Status(absl::StatusCode::kInvalidArgument, 209 "wrong segment number"); 210 } 211 plaintext_buffer->resize(pt_size); 212 memcpy(plaintext_buffer->data(), ciphertext.data(), pt_size); 213 generated_output_size_ += pt_size; 214 return util::OkStatus(); 215 } 216 217 get_plaintext_segment_size()218 int get_plaintext_segment_size() const override { 219 return pt_segment_size_; 220 } 221 get_ciphertext_segment_size()222 int get_ciphertext_segment_size() const override { 223 return pt_segment_size_ + DummyStreamSegmentEncrypter::kSegmentTagSize; 224 } 225 get_ciphertext_offset()226 int get_ciphertext_offset() const override { 227 return ct_offset_; 228 } 229 230 ~DummyStreamSegmentDecrypter() override = default; 231 get_generated_output_size()232 int get_generated_output_size() { 233 return generated_output_size_; 234 } 235 236 private: 237 std::vector<uint8_t> header_; 238 int pt_segment_size_; 239 int ct_offset_; 240 int64_t generated_output_size_; 241 }; // class DummyStreamSegmentDecrypter 242 243 class DummyStreamingAead : public NonceBasedStreamingAead { 244 public: DummyStreamingAead(int pt_segment_size,int header_size,int ct_offset)245 DummyStreamingAead(int pt_segment_size, int header_size, int ct_offset) 246 : pt_segment_size_(pt_segment_size), 247 header_size_(header_size), 248 ct_offset_(ct_offset) {} 249 250 protected: NewSegmentEncrypter(absl::string_view associated_data)251 util::StatusOr<std::unique_ptr<StreamSegmentEncrypter>> NewSegmentEncrypter( 252 absl::string_view associated_data) const override { 253 return {absl::make_unique<DummyStreamSegmentEncrypter>( 254 pt_segment_size_, header_size_, ct_offset_)}; 255 } 256 NewSegmentDecrypter(absl::string_view associated_data)257 util::StatusOr<std::unique_ptr<StreamSegmentDecrypter>> NewSegmentDecrypter( 258 absl::string_view associated_data) const override { 259 return {absl::make_unique<DummyStreamSegmentDecrypter>( 260 pt_segment_size_, header_size_, ct_offset_)}; 261 } 262 263 private: 264 int pt_segment_size_; 265 int header_size_; 266 int ct_offset_; 267 }; 268 269 } // namespace test 270 } // namespace subtle 271 } // namespace tink 272 } // namespace crypto 273 274 #endif // TINK_SUBTLE_TEST_UTIL_H_ 275