1// Copyright 2019 Google Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// ////////////////////////////////////////////////////////////////////////////// 16package hcvault_test 17 18import ( 19 "bytes" 20 "crypto/tls" 21 "encoding/base64" 22 "encoding/json" 23 "errors" 24 "fmt" 25 "net" 26 "net/http" 27 "os" 28 "path/filepath" 29 "strings" 30 "testing" 31 32 "github.com/google/tink/go/integration/hcvault" 33) 34 35const ( 36 keyURITmpl = "hcvault://localhost:%d/transit/keys/key-1" 37 token = "mytoken" 38) 39 40var ( 41 vaultKey = filepath.Join(os.Getenv("TEST_WORKSPACE"), "/integration/hcvault/testdata/server.key") 42 vaultCert = filepath.Join(os.Getenv("TEST_WORKSPACE"), "/integration/hcvault/testdata/server.crt") 43) 44 45func TestVaultAEAD_Encrypt(t *testing.T) { 46 port, stopFunc := newServer(t) 47 defer stopFunc() 48 49 client, err := hcvault.NewClient( 50 fmt.Sprintf("hcvault://localhost:%d/", port), 51 // Using InsecureSkipVerify is fine here, since this is just a test running locally. 52 &tls.Config{InsecureSkipVerify: true}, // NOLINT 53 token, 54 ) 55 if err != nil { 56 t.Fatal("Cannot initialize a client:", err) 57 } 58 59 keyURI := fmt.Sprintf(keyURITmpl, port) 60 aead, err := client.GetAEAD(keyURI) 61 if err != nil { 62 t.Fatal("Cannot obtain Vault AEAD:", err) 63 } 64 pt := []byte("Hello World") 65 context := []byte("extracontext") 66 ct, err := aead.Encrypt(pt, context) 67 if err != nil { 68 t.Fatal("Error encrypting data:", err) 69 } 70 wantCT := encrypt(pt, context) 71 if !bytes.Equal(wantCT, ct) { 72 t.Fatalf("Incorrect cipher text, want=%s;got=%s", wantCT, ct) 73 } 74} 75 76func TestVaultAEAD_Decrypt(t *testing.T) { 77 port, stopFunc := newServer(t) 78 defer stopFunc() 79 80 client, err := hcvault.NewClient( 81 fmt.Sprintf("hcvault://localhost:%d/", port), 82 // Using InsecureSkipVerify is fine here, since this is just a test running locally. 83 &tls.Config{InsecureSkipVerify: true}, // NOLINT 84 token, 85 ) 86 if err != nil { 87 t.Fatal("Cannot initialize a client:", err) 88 } 89 90 keyURI := fmt.Sprintf(keyURITmpl, port) 91 aead, err := client.GetAEAD(keyURI) 92 if err != nil { 93 t.Fatal("Cannot obtain Vault AEAD:", err) 94 } 95 wantPT := []byte("Hello World") 96 context := []byte("extracontext") 97 ct := encrypt(wantPT, context) 98 pt, err := aead.Decrypt(ct, context) 99 if err != nil { 100 t.Fatal("Error decrypting data:", err) 101 } 102 if !bytes.Equal(wantPT, pt) { 103 t.Fatalf("Incorrect plain text, want=%s;got=%s", string(wantPT), string(pt)) 104 } 105} 106 107func TestGetAEADFailWithBadKeyURI(t *testing.T) { 108 port, stopFunc := newServer(t) 109 defer stopFunc() 110 111 client, err := hcvault.NewClient( 112 fmt.Sprintf("hcvault://localhost:%d/", port), 113 // Using InsecureSkipVerify is fine here, since this is just a test running locally. 114 &tls.Config{InsecureSkipVerify: true}, // NOLINT 115 token, 116 ) 117 if err != nil { 118 t.Fatalf("hcvault.NewClient() err = %v, want nil", err) 119 } 120 121 for _, test := range []struct { 122 name string 123 keyURI string 124 }{ 125 { 126 name: "empty", 127 keyURI: fmt.Sprintf("hcvault://localhost:%d/", port), 128 }, 129 { 130 name: "without slash", 131 keyURI: fmt.Sprintf("hcvault://localhost:%d/badKeyUri", port), 132 }, 133 { 134 name: "with one slash", 135 keyURI: fmt.Sprintf("hcvault://localhost:%d/bad/KeyUri", port), 136 }, 137 { 138 name: "with three slash", 139 keyURI: fmt.Sprintf("hcvault://localhost:%d/one/two/three/four", port), 140 }, 141 } { 142 t.Run(test.name, func(t *testing.T) { 143 if _, err := client.GetAEAD(test.keyURI); err == nil { 144 t.Errorf("client.GetAEAD(%q) err = nil, want error", test.keyURI) 145 } 146 }) 147 } 148} 149 150type closeFunc func() error 151 152func newServer(t *testing.T) (int, closeFunc) { 153 handler := func(w http.ResponseWriter, r *http.Request) { 154 switch r.RequestURI { 155 156 // Encrypt 157 case "/v1/transit/encrypt/key-1": 158 decoder := json.NewDecoder(r.Body) 159 var encReq = make(map[string]string) 160 if err := decoder.Decode(&encReq); err != nil { 161 t.Fatal("Cannot decode encryption request:", err) 162 } 163 pt64 := encReq["plaintext"] 164 pt, err := base64.StdEncoding.DecodeString(pt64) 165 if err != nil { 166 t.Fatal("plaintext must be base64 encoded") 167 } 168 context64 := encReq["context"] 169 context, err := base64.StdEncoding.DecodeString(context64) 170 if err != nil { 171 t.Fatal("context must be base64 encoded") 172 } 173 resp := map[string]interface{}{ 174 "data": map[string]string{ 175 "ciphertext": string(encrypt(pt, context)), 176 }, 177 } 178 respBytes, err := json.Marshal(resp) 179 if err != nil { 180 t.Fatal("Cannot encode encrypted data:", err) 181 } 182 if _, err := w.Write(respBytes); err != nil { 183 t.Fatal("Cannot send encrypted data response:", err) 184 } 185 186 // Decrypt 187 case "/v1/transit/decrypt/key-1": 188 decoder := json.NewDecoder(r.Body) 189 var encReq = make(map[string]string) 190 if err := decoder.Decode(&encReq); err != nil { 191 t.Fatal("Cannot decode encryption request:", err) 192 } 193 ct := encReq["ciphertext"] 194 context64 := encReq["context"] 195 context, err := base64.StdEncoding.DecodeString(context64) 196 if err != nil { 197 t.Fatal("context must be base64 encoded") 198 } 199 pt, err := decrypt([]byte(ct), context) 200 if err != nil { 201 t.Fatal("Cannot decrypt ciphertext:", err) 202 } 203 resp := map[string]interface{}{ 204 "data": map[string]string{ 205 "plaintext": base64.StdEncoding.EncodeToString(pt), 206 }, 207 } 208 respBytes, err := json.Marshal(resp) 209 if err != nil { 210 t.Fatal("Cannot encode encrypted data:", err) 211 } 212 if _, err := w.Write(respBytes); err != nil { 213 t.Fatal("Cannot send encrypted data response:", err) 214 } 215 216 default: 217 http.NotFound(w, r) 218 } 219 } 220 221 srcDir, ok := os.LookupEnv("TEST_SRCDIR") 222 if !ok { 223 t.Skip("TEST_SRCDIR not set") 224 } 225 226 vaultCertPath := filepath.Join(srcDir, vaultCert) 227 if _, err := os.Stat(vaultCertPath); err != nil { 228 t.Fatal("Cannot load Vault certificate file:", err) 229 } 230 vaultKeyPath := filepath.Join(srcDir, vaultKey) 231 if _, err := os.Stat(vaultKeyPath); err != nil { 232 t.Fatal("Cannot load Vault key file:", err) 233 } 234 235 l, err := net.Listen("tcp", ":0") 236 if err != nil { 237 t.Fatal("Cannot start Vault mock server:", err) 238 } 239 go http.ServeTLS(l, http.HandlerFunc(handler), vaultCertPath, vaultKeyPath) 240 241 port := l.Addr().(*net.TCPAddr).Port 242 return port, l.Close 243} 244 245func encrypt(pt, context []byte) []byte { 246 s := fmt.Sprintf( 247 "enc:%s:%s", 248 base64.StdEncoding.EncodeToString(context), 249 base64.StdEncoding.EncodeToString(pt), 250 ) 251 return []byte(s) 252} 253 254func decrypt(ctb, context []byte) ([]byte, error) { 255 ct := string(ctb) 256 parts := strings.Split(ct, ":") 257 if len(parts) != 3 || parts[0] != "enc" { 258 return nil, errors.New("malformed ciphertext") 259 } 260 context2, err := base64.StdEncoding.DecodeString(parts[1]) 261 if err != nil { 262 return nil, err 263 } 264 if !bytes.Equal(context, context2) { 265 return nil, errors.New("context doesn't match") 266 } 267 pt, err := base64.StdEncoding.DecodeString(parts[2]) 268 if err != nil { 269 return nil, err 270 } 271 return pt, nil 272} 273