xref: /aosp_15_r20/external/tink/go/aead/subtle/aes_gcm_siv.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package subtle
18
19import (
20	"crypto/aes"
21	"crypto/subtle"
22	"encoding/binary"
23	"fmt"
24	"math"
25
26	// Placeholder for internal crypto/cipher allowlist, please ignore.
27	// Placeholder for internal crypto/subtle allowlist, please ignore. // to allow import of "crypto/subte"
28	"github.com/google/tink/go/subtle/random"
29)
30
31const (
32	// AESGCMSIVNonceSize is the acceptable IV size defined by RFC 8452.
33	AESGCMSIVNonceSize = 12
34
35	// aesgcmsivBlockSize is the block size that AES-GCM-SIV uses. This is the
36	// size for the tag, the KDF etc.
37	// Note: this value is the same as AES block size.
38	aesgcmsivBlockSize = 16
39
40	// aesgcmsivTagSize is the byte-length of the authentication tag produced by
41	// AES-GCM-SIV.
42	aesgcmsivTagSize = aesgcmsivBlockSize
43
44	// aesgcmsivPolyvalSize is the byte-length of result produced by the
45	// POLYVAL function.
46	aesgcmsivPolyvalSize = aesgcmsivBlockSize
47)
48
49// AESGCMSIV is an implementation of AEAD interface.
50type AESGCMSIV struct {
51	Key []byte
52}
53
54// NewAESGCMSIV returns an AESGCMSIV instance.
55// The key argument should be the AES key, either 16 or 32 bytes to select
56// AES-128 or AES-256.
57func NewAESGCMSIV(key []byte) (*AESGCMSIV, error) {
58	keySize := uint32(len(key))
59	if err := ValidateAESKeySize(keySize); err != nil {
60		return nil, fmt.Errorf("aes_gcm_siv: %s", err)
61	}
62	return &AESGCMSIV{Key: key}, nil
63}
64
65// Encrypt encrypts plaintext with associatedData.
66//
67// The resulting ciphertext consists of three parts:
68// (1) the Nonce used for encryption
69// (2) the actual ciphertext
70// (3) the authentication tag.
71func (a *AESGCMSIV) Encrypt(plaintext, associatedData []byte) ([]byte, error) {
72	if len(plaintext) > math.MaxInt32-AESGCMSIVNonceSize-aesgcmsivTagSize {
73		return nil, fmt.Errorf("aes_gcm_siv: plaintext too long")
74	}
75	if len(associatedData) > math.MaxInt32 {
76		return nil, fmt.Errorf("aes_gcm_siv: associatedData too long")
77	}
78
79	nonce := random.GetRandomBytes(uint32(AESGCMSIVNonceSize))
80	authKey, encKey, err := a.deriveKeys(nonce)
81	if err != nil {
82		return nil, err
83	}
84
85	polyval, err := a.computePolyval(authKey, plaintext, associatedData)
86	if err != nil {
87		return nil, err
88	}
89	tag, err := a.computeTag(polyval, nonce, encKey)
90	if err != nil {
91		return nil, err
92	}
93
94	ct, err := a.aesCTR(encKey, tag, plaintext)
95	if err != nil {
96		return nil, err
97	}
98
99	ret := make([]byte, 0, AESGCMSIVNonceSize+aesgcmsivTagSize+len(plaintext))
100	ret = append(ret, nonce...)
101	ret = append(ret, ct...)
102	ret = append(ret, tag...)
103
104	return ret, nil
105}
106
107// Decrypt decrypts ciphertext with associatedData.
108func (a *AESGCMSIV) Decrypt(ciphertext, associatedData []byte) ([]byte, error) {
109	if len(ciphertext) < AESGCMSIVNonceSize+aesgcmsivTagSize {
110		return nil, fmt.Errorf("aes_gcm_siv: ciphertext too short")
111	}
112	if len(ciphertext) > math.MaxInt32 {
113		return nil, fmt.Errorf("aes_gcm_siv: ciphertext too long")
114	}
115	if len(associatedData) > math.MaxInt32 {
116		return nil, fmt.Errorf("aes_gcm_siv: associatedData too long")
117	}
118
119	nonce := ciphertext[:AESGCMSIVNonceSize]
120	tag := ciphertext[len(ciphertext)-aesgcmsivTagSize:]
121	ciphertext = ciphertext[AESGCMSIVNonceSize : len(ciphertext)-aesgcmsivTagSize]
122
123	authKey, encKey, err := a.deriveKeys(nonce)
124	if err != nil {
125		return nil, err
126	}
127
128	pt, err := a.aesCTR(encKey, tag, ciphertext)
129	if err != nil {
130		return nil, err
131	}
132
133	polyval, err := a.computePolyval(authKey, pt, associatedData)
134	if err != nil {
135		return nil, err
136	}
137
138	expectedTag, err := a.computeTag(polyval, nonce, encKey)
139	if err != nil {
140		return nil, err
141	}
142
143	if subtle.ConstantTimeCompare(expectedTag, tag) != 1 {
144		return nil, fmt.Errorf("aes_gcm_siv: message authentication failure")
145	}
146
147	return pt, nil
148}
149
150// The KDF as described by the RFC #8452. This uses the AES-GCM-SIV key and
151// nonce to generate the authentication key and the encryption key.
152func (a *AESGCMSIV) deriveKeys(nonce []byte) ([]byte, []byte, error) {
153	if len(nonce) != AESGCMSIVNonceSize {
154		return nil, nil, fmt.Errorf("aes_gcm_siv: invalid nonce size")
155	}
156	nonceBlock := make([]byte, aesgcmsivBlockSize)
157	copy(nonceBlock[aesgcmsivBlockSize-AESGCMSIVNonceSize:], nonce)
158	block, err := aes.NewCipher(a.Key)
159	if err != nil {
160		return nil, nil, fmt.Errorf("aes_gcm_siv: failed to create block cipher, error: %v", err)
161	}
162
163	encBlock := make([]byte, block.BlockSize())
164	kdfAes := func(counter uint32, dst []byte) {
165		binary.LittleEndian.PutUint32(nonceBlock[:4], counter)
166		block.Encrypt(encBlock, nonceBlock)
167		copy(dst, encBlock[0:8])
168	}
169
170	authKey := make([]byte, aesgcmsivBlockSize)
171	kdfAes(0, authKey[0:8])
172	kdfAes(1, authKey[8:16])
173
174	encKey := make([]byte, len(a.Key))
175	kdfAes(2, encKey[0:8])
176	kdfAes(3, encKey[8:16])
177
178	if len(a.Key) == 32 {
179		kdfAes(4, encKey[16:24])
180		kdfAes(5, encKey[24:32])
181	}
182
183	return authKey, encKey, nil
184}
185
186func (a *AESGCMSIV) computePolyval(authKey, pt, ad []byte) ([]byte, error) {
187	lengthBlock := make([]byte, aesgcmsivBlockSize)
188	binary.LittleEndian.PutUint64(lengthBlock[:8], uint64(len(ad))*8)
189	binary.LittleEndian.PutUint64(lengthBlock[8:], uint64(len(pt))*8)
190
191	p, err := NewPolyval(authKey)
192	if err != nil {
193		return nil, fmt.Errorf("aes_gcm_siv: failed to create polyval, error: %v", err)
194	}
195
196	p.Update(ad)
197	p.Update(pt)
198	p.Update(lengthBlock)
199	polyval := p.Finish()
200
201	return polyval[:], nil
202}
203
204func (a *AESGCMSIV) computeTag(polyval, nonce, encKey []byte) ([]byte, error) {
205	if len(polyval) != aesgcmsivPolyvalSize {
206		return nil, fmt.Errorf("aes_gcm_siv: polyval returned invalid sized response")
207	}
208
209	for i, val := range nonce {
210		polyval[i] ^= val
211	}
212	polyval[aesgcmsivPolyvalSize-1] &= 0x7f
213
214	block, err := aes.NewCipher(encKey)
215	if err != nil {
216		return nil, fmt.Errorf("aes_gcm_siv: failed to create block cipher, error: %v", err)
217	}
218
219	tag := make([]byte, aesgcmsivTagSize)
220	block.Encrypt(tag, polyval)
221	return tag, nil
222}
223
224// aesCTR implements the AES-CTR operation in AES-GCM-SIV.
225// Note that RFC 8452 defines AES-CTR differently compared to standard AES
226// in CTR mode: the way they increment the counter block is completely different.
227func (a *AESGCMSIV) aesCTR(key, tag, in []byte) ([]byte, error) {
228	if len(tag) != aesgcmsivTagSize {
229		return nil, fmt.Errorf("aes_gcm_siv: incorrect IV size for stream cipher")
230	}
231
232	block, err := aes.NewCipher(key)
233	if err != nil {
234		return nil, fmt.Errorf(
235			"aes_gcm_siv: failed to create block cipher, error: %v", err)
236	}
237
238	counter := make([]byte, aesgcmsivBlockSize)
239	copy(counter, tag)
240	counter[aesgcmsivBlockSize-1] |= 0x80
241	counterInc := binary.LittleEndian.Uint32(counter[0:4])
242
243	output := make([]byte, len(in))
244	outputIdx := 0
245	keystreamBlock := make([]byte, block.BlockSize())
246	for len(in) > 0 {
247		block.Encrypt(keystreamBlock, counter)
248		counterInc++
249		binary.LittleEndian.PutUint32(counter[0:4], counterInc)
250
251		n := xorBytes(output[outputIdx:], in, keystreamBlock)
252		outputIdx += n
253		in = in[n:]
254	}
255
256	return output, nil
257}
258
259// It would have been better to call xorBytes function defined in
260// "crypto/cipher/xor_*.go" to make use of the architechture optimisations.
261func xorBytes(dst, a, b []byte) int {
262	n := len(a)
263	if len(b) < n {
264		n = len(b)
265	}
266	if n == 0 {
267		return 0
268	}
269	for i := 0; i < n; i++ {
270		dst[i] = a[i] ^ b[i]
271	}
272
273	return n
274}
275