xref: /aosp_15_r20/external/tink/go/integration/hcvault/hcvault_aead_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2019 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//	http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// //////////////////////////////////////////////////////////////////////////////
16package hcvault_test
17
18import (
19	"bytes"
20	"crypto/tls"
21	"encoding/base64"
22	"encoding/json"
23	"errors"
24	"fmt"
25	"net"
26	"net/http"
27	"os"
28	"path/filepath"
29	"strings"
30	"testing"
31
32	"github.com/google/tink/go/integration/hcvault"
33)
34
35const (
36	keyURITmpl = "hcvault://localhost:%d/transit/keys/key-1"
37	token      = "mytoken"
38)
39
40var (
41	vaultKey  = filepath.Join(os.Getenv("TEST_WORKSPACE"), "/integration/hcvault/testdata/server.key")
42	vaultCert = filepath.Join(os.Getenv("TEST_WORKSPACE"), "/integration/hcvault/testdata/server.crt")
43)
44
45func TestVaultAEAD_Encrypt(t *testing.T) {
46	port, stopFunc := newServer(t)
47	defer stopFunc()
48
49	client, err := hcvault.NewClient(
50		fmt.Sprintf("hcvault://localhost:%d/", port),
51		// Using InsecureSkipVerify is fine here, since this is just a test running locally.
52		&tls.Config{InsecureSkipVerify: true}, // NOLINT
53		token,
54	)
55	if err != nil {
56		t.Fatal("Cannot initialize a client:", err)
57	}
58
59	keyURI := fmt.Sprintf(keyURITmpl, port)
60	aead, err := client.GetAEAD(keyURI)
61	if err != nil {
62		t.Fatal("Cannot obtain Vault AEAD:", err)
63	}
64	pt := []byte("Hello World")
65	context := []byte("extracontext")
66	ct, err := aead.Encrypt(pt, context)
67	if err != nil {
68		t.Fatal("Error encrypting data:", err)
69	}
70	wantCT := encrypt(pt, context)
71	if !bytes.Equal(wantCT, ct) {
72		t.Fatalf("Incorrect cipher text, want=%s;got=%s", wantCT, ct)
73	}
74}
75
76func TestVaultAEAD_Decrypt(t *testing.T) {
77	port, stopFunc := newServer(t)
78	defer stopFunc()
79
80	client, err := hcvault.NewClient(
81		fmt.Sprintf("hcvault://localhost:%d/", port),
82		// Using InsecureSkipVerify is fine here, since this is just a test running locally.
83		&tls.Config{InsecureSkipVerify: true}, // NOLINT
84		token,
85	)
86	if err != nil {
87		t.Fatal("Cannot initialize a client:", err)
88	}
89
90	keyURI := fmt.Sprintf(keyURITmpl, port)
91	aead, err := client.GetAEAD(keyURI)
92	if err != nil {
93		t.Fatal("Cannot obtain Vault AEAD:", err)
94	}
95	wantPT := []byte("Hello World")
96	context := []byte("extracontext")
97	ct := encrypt(wantPT, context)
98	pt, err := aead.Decrypt(ct, context)
99	if err != nil {
100		t.Fatal("Error decrypting data:", err)
101	}
102	if !bytes.Equal(wantPT, pt) {
103		t.Fatalf("Incorrect plain text, want=%s;got=%s", string(wantPT), string(pt))
104	}
105}
106
107func TestGetAEADFailWithBadKeyURI(t *testing.T) {
108	port, stopFunc := newServer(t)
109	defer stopFunc()
110
111	client, err := hcvault.NewClient(
112		fmt.Sprintf("hcvault://localhost:%d/", port),
113		// Using InsecureSkipVerify is fine here, since this is just a test running locally.
114		&tls.Config{InsecureSkipVerify: true}, // NOLINT
115		token,
116	)
117	if err != nil {
118		t.Fatalf("hcvault.NewClient() err = %v, want nil", err)
119	}
120
121	for _, test := range []struct {
122		name   string
123		keyURI string
124	}{
125		{
126			name:   "empty",
127			keyURI: fmt.Sprintf("hcvault://localhost:%d/", port),
128		},
129		{
130			name:   "without slash",
131			keyURI: fmt.Sprintf("hcvault://localhost:%d/badKeyUri", port),
132		},
133		{
134			name:   "with one slash",
135			keyURI: fmt.Sprintf("hcvault://localhost:%d/bad/KeyUri", port),
136		},
137		{
138			name:   "with three slash",
139			keyURI: fmt.Sprintf("hcvault://localhost:%d/one/two/three/four", port),
140		},
141	} {
142		t.Run(test.name, func(t *testing.T) {
143			if _, err := client.GetAEAD(test.keyURI); err == nil {
144				t.Errorf("client.GetAEAD(%q) err = nil, want error", test.keyURI)
145			}
146		})
147	}
148}
149
150type closeFunc func() error
151
152func newServer(t *testing.T) (int, closeFunc) {
153	handler := func(w http.ResponseWriter, r *http.Request) {
154		switch r.RequestURI {
155
156		// Encrypt
157		case "/v1/transit/encrypt/key-1":
158			decoder := json.NewDecoder(r.Body)
159			var encReq = make(map[string]string)
160			if err := decoder.Decode(&encReq); err != nil {
161				t.Fatal("Cannot decode encryption request:", err)
162			}
163			pt64 := encReq["plaintext"]
164			pt, err := base64.StdEncoding.DecodeString(pt64)
165			if err != nil {
166				t.Fatal("plaintext must be base64 encoded")
167			}
168			context64 := encReq["context"]
169			context, err := base64.StdEncoding.DecodeString(context64)
170			if err != nil {
171				t.Fatal("context must be base64 encoded")
172			}
173			resp := map[string]interface{}{
174				"data": map[string]string{
175					"ciphertext": string(encrypt(pt, context)),
176				},
177			}
178			respBytes, err := json.Marshal(resp)
179			if err != nil {
180				t.Fatal("Cannot encode encrypted data:", err)
181			}
182			if _, err := w.Write(respBytes); err != nil {
183				t.Fatal("Cannot send encrypted data response:", err)
184			}
185
186		// Decrypt
187		case "/v1/transit/decrypt/key-1":
188			decoder := json.NewDecoder(r.Body)
189			var encReq = make(map[string]string)
190			if err := decoder.Decode(&encReq); err != nil {
191				t.Fatal("Cannot decode encryption request:", err)
192			}
193			ct := encReq["ciphertext"]
194			context64 := encReq["context"]
195			context, err := base64.StdEncoding.DecodeString(context64)
196			if err != nil {
197				t.Fatal("context must be base64 encoded")
198			}
199			pt, err := decrypt([]byte(ct), context)
200			if err != nil {
201				t.Fatal("Cannot decrypt ciphertext:", err)
202			}
203			resp := map[string]interface{}{
204				"data": map[string]string{
205					"plaintext": base64.StdEncoding.EncodeToString(pt),
206				},
207			}
208			respBytes, err := json.Marshal(resp)
209			if err != nil {
210				t.Fatal("Cannot encode encrypted data:", err)
211			}
212			if _, err := w.Write(respBytes); err != nil {
213				t.Fatal("Cannot send encrypted data response:", err)
214			}
215
216		default:
217			http.NotFound(w, r)
218		}
219	}
220
221	srcDir, ok := os.LookupEnv("TEST_SRCDIR")
222	if !ok {
223		t.Skip("TEST_SRCDIR not set")
224	}
225
226	vaultCertPath := filepath.Join(srcDir, vaultCert)
227	if _, err := os.Stat(vaultCertPath); err != nil {
228		t.Fatal("Cannot load Vault certificate file:", err)
229	}
230	vaultKeyPath := filepath.Join(srcDir, vaultKey)
231	if _, err := os.Stat(vaultKeyPath); err != nil {
232		t.Fatal("Cannot load Vault key file:", err)
233	}
234
235	l, err := net.Listen("tcp", ":0")
236	if err != nil {
237		t.Fatal("Cannot start Vault mock server:", err)
238	}
239	go http.ServeTLS(l, http.HandlerFunc(handler), vaultCertPath, vaultKeyPath)
240
241	port := l.Addr().(*net.TCPAddr).Port
242	return port, l.Close
243}
244
245func encrypt(pt, context []byte) []byte {
246	s := fmt.Sprintf(
247		"enc:%s:%s",
248		base64.StdEncoding.EncodeToString(context),
249		base64.StdEncoding.EncodeToString(pt),
250	)
251	return []byte(s)
252}
253
254func decrypt(ctb, context []byte) ([]byte, error) {
255	ct := string(ctb)
256	parts := strings.Split(ct, ":")
257	if len(parts) != 3 || parts[0] != "enc" {
258		return nil, errors.New("malformed ciphertext")
259	}
260	context2, err := base64.StdEncoding.DecodeString(parts[1])
261	if err != nil {
262		return nil, err
263	}
264	if !bytes.Equal(context, context2) {
265		return nil, errors.New("context doesn't match")
266	}
267	pt, err := base64.StdEncoding.DecodeString(parts[2])
268	if err != nil {
269		return nil, err
270	}
271	return pt, nil
272}
273