1 // Copyright 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ////////////////////////////////////////////////////////////////////////////////
16
17 #include "tink/subtle/wycheproof_util.h"
18
19 #include <fstream>
20 #include <iostream>
21 #include <memory>
22 #include <ostream>
23 #include <string>
24
25 #include "absl/status/status.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "include/rapidjson/document.h"
29 #include "include/rapidjson/istreamwrapper.h"
30 #include "tink/internal/test_file_util.h"
31 #include "tink/subtle/common_enums.h"
32 #include "tink/util/status.h"
33 #include "tink/util/statusor.h"
34
35 namespace crypto {
36 namespace tink {
37 namespace subtle {
38
39 namespace {
40
41 // TODO(tholenst): factor these helpers out to an "util"-class.
HexDecode(absl::string_view hex)42 util::StatusOr<std::string> HexDecode(absl::string_view hex) {
43 if (hex.size() % 2 != 0) {
44 return util::Status(absl::StatusCode::kInvalidArgument,
45 "Input has odd size.");
46 }
47 std::string decoded(hex.size() / 2, static_cast<char>(0));
48 for (size_t i = 0; i < hex.size(); ++i) {
49 char c = hex[i];
50 char val;
51 if ('0' <= c && c <= '9')
52 val = c - '0';
53 else if ('a' <= c && c <= 'f')
54 val = c - 'a' + 10;
55 else if ('A' <= c && c <= 'F')
56 val = c - 'A' + 10;
57 else
58 return util::Status(absl::StatusCode::kInvalidArgument,
59 "Not hexadecimal");
60 decoded[i / 2] = (decoded[i / 2] << 4) | val;
61 }
62 return decoded;
63 }
64
HexDecodeOrDie(absl::string_view hex)65 std::string HexDecodeOrDie(absl::string_view hex) {
66 return HexDecode(hex).value();
67 }
68
69 } // namespace
70
GetBytes(const rapidjson::Value & val)71 std::string WycheproofUtil::GetBytes(const rapidjson::Value &val) {
72 std::string s(val.GetString());
73 if (s.size() % 2 != 0) {
74 // ECDH private key may have odd length.
75 s = "0" + s;
76 }
77 return HexDecodeOrDie(s);
78 }
79
ReadTestVectors(const std::string & filename)80 std::unique_ptr<rapidjson::Document> WycheproofUtil::ReadTestVectors(
81 const std::string &filename) {
82 std::string test_vectors_path = crypto::tink::internal::RunfilesPath(
83 absl::StrCat("testvectors/", filename));
84 std::ifstream input_stream;
85 input_stream.open(test_vectors_path);
86 rapidjson::IStreamWrapper input(input_stream);
87 std::unique_ptr<rapidjson::Document> root(
88 new rapidjson::Document(rapidjson::kObjectType));
89 if (root->ParseStream(input).HasParseError()) {
90 std::cerr << "Failure parsing of test vectors from "
91 << test_vectors_path << std::endl;
92 exit(1);
93 }
94 return root;
95 }
96
GetHashType(const rapidjson::Value & val)97 HashType WycheproofUtil::GetHashType(const rapidjson::Value &val) {
98 std::string md(val.GetString());
99 if (md == "SHA-1") {
100 return HashType::SHA1;
101 } else if (md == "SHA-256") {
102 return HashType::SHA256;
103 } else if (md == "SHA-384") {
104 return HashType::UNKNOWN_HASH;
105 } else if (md == "SHA-512") {
106 return HashType::SHA512;
107 } else {
108 return HashType::UNKNOWN_HASH;
109 }
110 }
111
GetEllipticCurveType(const rapidjson::Value & val)112 EllipticCurveType WycheproofUtil::GetEllipticCurveType(
113 const rapidjson::Value &val) {
114 std::string curve(val.GetString());
115 if (curve == "secp256r1") {
116 return EllipticCurveType::NIST_P256;
117 } else if (curve == "secp384r1") {
118 return EllipticCurveType::NIST_P384;
119 } else if (curve == "secp521r1") {
120 return EllipticCurveType::NIST_P521;
121 } else {
122 return EllipticCurveType::UNKNOWN_CURVE;
123 }
124 }
125
GetInteger(const rapidjson::Value & val)126 std::string WycheproofUtil::GetInteger(const rapidjson::Value &val) {
127 std::string hex(val.GetString());
128 // Since val is a hexadecimal integer it can have an odd length.
129 if (hex.size() % 2 == 1) {
130 // Avoid a leading 0 byte.
131 if (hex[0] == '0') {
132 hex = std::string(hex, 1, hex.size() - 1);
133 } else {
134 hex = "0" + hex;
135 }
136 }
137 return HexDecode(hex).value();
138 }
139
140 } // namespace subtle
141 } // namespace tink
142 } // namespace crypto
143