1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build boringcrypto && linux && (amd64 || arm64) && !android && !msan
6
7package boring
8
9// #include "goboringcrypto.h"
10import "C"
11import (
12	"errors"
13	"runtime"
14	"unsafe"
15)
16
17type PublicKeyECDH struct {
18	curve string
19	key   *C.GO_EC_POINT
20	group *C.GO_EC_GROUP
21	bytes []byte
22}
23
24func (k *PublicKeyECDH) finalize() {
25	C._goboringcrypto_EC_POINT_free(k.key)
26}
27
28type PrivateKeyECDH struct {
29	curve string
30	key   *C.GO_EC_KEY
31}
32
33func (k *PrivateKeyECDH) finalize() {
34	C._goboringcrypto_EC_KEY_free(k.key)
35}
36
37func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
38	if len(bytes) < 1 {
39		return nil, errors.New("NewPublicKeyECDH: missing key")
40	}
41
42	nid, err := curveNID(curve)
43	if err != nil {
44		return nil, err
45	}
46
47	group := C._goboringcrypto_EC_GROUP_new_by_curve_name(nid)
48	if group == nil {
49		return nil, fail("EC_GROUP_new_by_curve_name")
50	}
51	defer C._goboringcrypto_EC_GROUP_free(group)
52	key := C._goboringcrypto_EC_POINT_new(group)
53	if key == nil {
54		return nil, fail("EC_POINT_new")
55	}
56	ok := C._goboringcrypto_EC_POINT_oct2point(group, key, (*C.uint8_t)(unsafe.Pointer(&bytes[0])), C.size_t(len(bytes)), nil) != 0
57	if !ok {
58		C._goboringcrypto_EC_POINT_free(key)
59		return nil, errors.New("point not on curve")
60	}
61
62	k := &PublicKeyECDH{curve, key, group, append([]byte(nil), bytes...)}
63	// Note: Because of the finalizer, any time k.key is passed to cgo,
64	// that call must be followed by a call to runtime.KeepAlive(k),
65	// to make sure k is not collected (and finalized) before the cgo
66	// call returns.
67	runtime.SetFinalizer(k, (*PublicKeyECDH).finalize)
68	return k, nil
69}
70
71func (k *PublicKeyECDH) Bytes() []byte { return k.bytes }
72
73func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) {
74	nid, err := curveNID(curve)
75	if err != nil {
76		return nil, err
77	}
78	key := C._goboringcrypto_EC_KEY_new_by_curve_name(nid)
79	if key == nil {
80		return nil, fail("EC_KEY_new_by_curve_name")
81	}
82	b := bytesToBN(bytes)
83	ok := b != nil && C._goboringcrypto_EC_KEY_set_private_key(key, b) != 0
84	if b != nil {
85		C._goboringcrypto_BN_free(b)
86	}
87	if !ok {
88		C._goboringcrypto_EC_KEY_free(key)
89		return nil, fail("EC_KEY_set_private_key")
90	}
91	k := &PrivateKeyECDH{curve, key}
92	// Note: Same as in NewPublicKeyECDH regarding finalizer and KeepAlive.
93	runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
94	return k, nil
95}
96
97func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
98	defer runtime.KeepAlive(k)
99
100	group := C._goboringcrypto_EC_KEY_get0_group(k.key)
101	if group == nil {
102		return nil, fail("EC_KEY_get0_group")
103	}
104	kbig := C._goboringcrypto_EC_KEY_get0_private_key(k.key)
105	if kbig == nil {
106		return nil, fail("EC_KEY_get0_private_key")
107	}
108	pt := C._goboringcrypto_EC_POINT_new(group)
109	if pt == nil {
110		return nil, fail("EC_POINT_new")
111	}
112	if C._goboringcrypto_EC_POINT_mul(group, pt, kbig, nil, nil, nil) == 0 {
113		C._goboringcrypto_EC_POINT_free(pt)
114		return nil, fail("EC_POINT_mul")
115	}
116	bytes, err := pointBytesECDH(k.curve, group, pt)
117	if err != nil {
118		C._goboringcrypto_EC_POINT_free(pt)
119		return nil, err
120	}
121	pub := &PublicKeyECDH{k.curve, pt, group, bytes}
122	// Note: Same as in NewPublicKeyECDH regarding finalizer and KeepAlive.
123	runtime.SetFinalizer(pub, (*PublicKeyECDH).finalize)
124	return pub, nil
125}
126
127func pointBytesECDH(curve string, group *C.GO_EC_GROUP, pt *C.GO_EC_POINT) ([]byte, error) {
128	out := make([]byte, 1+2*curveSize(curve))
129	n := C._goboringcrypto_EC_POINT_point2oct(group, pt, C.GO_POINT_CONVERSION_UNCOMPRESSED, (*C.uint8_t)(unsafe.Pointer(&out[0])), C.size_t(len(out)), nil)
130	if int(n) != len(out) {
131		return nil, fail("EC_POINT_point2oct")
132	}
133	return out, nil
134}
135
136func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
137	group := C._goboringcrypto_EC_KEY_get0_group(priv.key)
138	if group == nil {
139		return nil, fail("EC_KEY_get0_group")
140	}
141	privBig := C._goboringcrypto_EC_KEY_get0_private_key(priv.key)
142	if privBig == nil {
143		return nil, fail("EC_KEY_get0_private_key")
144	}
145	pt := C._goboringcrypto_EC_POINT_new(group)
146	if pt == nil {
147		return nil, fail("EC_POINT_new")
148	}
149	defer C._goboringcrypto_EC_POINT_free(pt)
150	if C._goboringcrypto_EC_POINT_mul(group, pt, nil, pub.key, privBig, nil) == 0 {
151		return nil, fail("EC_POINT_mul")
152	}
153	out, err := xCoordBytesECDH(priv.curve, group, pt)
154	if err != nil {
155		return nil, err
156	}
157	return out, nil
158}
159
160func xCoordBytesECDH(curve string, group *C.GO_EC_GROUP, pt *C.GO_EC_POINT) ([]byte, error) {
161	big := C._goboringcrypto_BN_new()
162	defer C._goboringcrypto_BN_free(big)
163	if C._goboringcrypto_EC_POINT_get_affine_coordinates_GFp(group, pt, big, nil, nil) == 0 {
164		return nil, fail("EC_POINT_get_affine_coordinates_GFp")
165	}
166	return bigBytesECDH(curve, big)
167}
168
169func bigBytesECDH(curve string, big *C.GO_BIGNUM) ([]byte, error) {
170	out := make([]byte, curveSize(curve))
171	if C._goboringcrypto_BN_bn2bin_padded((*C.uint8_t)(&out[0]), C.size_t(len(out)), big) == 0 {
172		return nil, fail("BN_bn2bin_padded")
173	}
174	return out, nil
175}
176
177func curveSize(curve string) int {
178	switch curve {
179	default:
180		panic("crypto/internal/boring: unknown curve " + curve)
181	case "P-256":
182		return 256 / 8
183	case "P-384":
184		return 384 / 8
185	case "P-521":
186		return (521 + 7) / 8
187	}
188}
189
190func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
191	nid, err := curveNID(curve)
192	if err != nil {
193		return nil, nil, err
194	}
195	key := C._goboringcrypto_EC_KEY_new_by_curve_name(nid)
196	if key == nil {
197		return nil, nil, fail("EC_KEY_new_by_curve_name")
198	}
199	if C._goboringcrypto_EC_KEY_generate_key_fips(key) == 0 {
200		C._goboringcrypto_EC_KEY_free(key)
201		return nil, nil, fail("EC_KEY_generate_key_fips")
202	}
203
204	group := C._goboringcrypto_EC_KEY_get0_group(key)
205	if group == nil {
206		C._goboringcrypto_EC_KEY_free(key)
207		return nil, nil, fail("EC_KEY_get0_group")
208	}
209	b := C._goboringcrypto_EC_KEY_get0_private_key(key)
210	if b == nil {
211		C._goboringcrypto_EC_KEY_free(key)
212		return nil, nil, fail("EC_KEY_get0_private_key")
213	}
214	bytes, err := bigBytesECDH(curve, b)
215	if err != nil {
216		C._goboringcrypto_EC_KEY_free(key)
217		return nil, nil, err
218	}
219
220	k := &PrivateKeyECDH{curve, key}
221	// Note: Same as in NewPublicKeyECDH regarding finalizer and KeepAlive.
222	runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
223	return k, bytes, nil
224}
225