xref: /aosp_15_r20/external/tink/go/hybrid/subtle/elliptic_curves.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package subtle
18
19import (
20	"bytes"
21	"crypto/elliptic"
22	"crypto/rand"
23	"errors"
24	"fmt"
25	"math/big"
26)
27
28// ECPublicKey represents a elliptic curve public key.
29type ECPublicKey struct {
30	elliptic.Curve
31	Point ECPoint
32}
33
34// ECPrivateKey represents a elliptic curve private key.
35type ECPrivateKey struct {
36	PublicKey ECPublicKey
37	D         *big.Int
38}
39
40// GetECPrivateKey converts a stored private key to ECPrivateKey.
41func GetECPrivateKey(c elliptic.Curve, b []byte) *ECPrivateKey {
42	d := new(big.Int)
43	d.SetBytes(b)
44
45	x, y := c.Params().ScalarBaseMult(b)
46	pub := ECPublicKey{
47		Curve: c,
48		Point: ECPoint{
49			X: x,
50			Y: y,
51		},
52	}
53	return &ECPrivateKey{
54		PublicKey: pub,
55		D:         d,
56	}
57
58}
59
60// ECPoint represents a point on the elliptic curve.
61type ECPoint struct {
62	X, Y *big.Int
63}
64
65func (p *ECPrivateKey) getParams() *elliptic.CurveParams {
66	return p.PublicKey.Curve.Params()
67}
68
69func getModulus(c elliptic.Curve) *big.Int {
70	return c.Params().P
71}
72
73func fieldSizeInBits(c elliptic.Curve) int {
74	t := big.NewInt(1)
75	r := t.Sub(getModulus(c), t)
76	return r.BitLen()
77}
78
79func fieldSizeInBytes(c elliptic.Curve) int {
80	return (fieldSizeInBits(c) + 7) / 8
81}
82
83func encodingSizeInBytes(c elliptic.Curve, p string) (int, error) {
84	cSize := fieldSizeInBytes(c)
85	switch p {
86	case "UNCOMPRESSED":
87		return 2*cSize + 1, nil
88	case "DO_NOT_USE_CRUNCHY_UNCOMPRESSED":
89		return 2 * cSize, nil
90	case "COMPRESSED":
91		return cSize + 1, nil
92	}
93	return 0, fmt.Errorf("invalid point format :%s", p)
94
95}
96
97// PointEncode encodes a point into the format specified.
98func PointEncode(c elliptic.Curve, pFormat string, pt ECPoint) ([]byte, error) {
99	if !c.IsOnCurve(pt.X, pt.Y) {
100		return nil, errors.New("curve check failed")
101	}
102	cSize := fieldSizeInBytes(c)
103	y := pt.Y.Bytes()
104	x := pt.X.Bytes()
105	switch pFormat {
106	case "UNCOMPRESSED":
107		encoded := make([]byte, 2*cSize+1)
108		copy(encoded[1+2*cSize-len(y):], y)
109		copy(encoded[1+cSize-len(x):], x)
110		encoded[0] = 4
111		return encoded, nil
112	case "DO_NOT_USE_CRUNCHY_UNCOMPRESSED":
113		encoded := make([]byte, 2*cSize)
114		if len(x) > cSize {
115			x = bytes.Replace(x, []byte("\x00"), []byte{}, -1)
116		}
117		if len(y) > cSize {
118			y = bytes.Replace(y, []byte("\x00"), []byte{}, -1)
119		}
120		copy(encoded[2*cSize-len(y):], y)
121		copy(encoded[cSize-len(x):], x)
122		return encoded, nil
123	case "COMPRESSED":
124		encoded := make([]byte, cSize+1)
125		copy(encoded[1+cSize-len(x):], x)
126		encoded[0] = 2
127		if pt.Y.Bit(0) > 0 {
128			encoded[0] = 3
129		}
130		return encoded, nil
131	}
132	return nil, errors.New("invalid point format")
133
134}
135
136// PointDecode decodes a encoded point to return an ECPoint
137func PointDecode(c elliptic.Curve, pFormat string, e []byte) (*ECPoint, error) {
138	cSize := fieldSizeInBytes(c)
139	x, y := new(big.Int), new(big.Int)
140	switch pFormat {
141	case "UNCOMPRESSED":
142		if len(e) != (2*cSize + 1) {
143			return nil, errors.New("invalid point size")
144		}
145		if e[0] != 4 {
146			return nil, errors.New("invalid point format")
147		}
148		x.SetBytes(e[1 : cSize+1])
149		y.SetBytes(e[cSize+1:])
150		if !c.IsOnCurve(x, y) {
151			return nil, errors.New("invalid point")
152		}
153		return &ECPoint{
154			X: x,
155			Y: y,
156		}, nil
157	case "DO_NOT_USE_CRUNCHY_UNCOMPRESSED":
158		if len(e) != 2*cSize {
159			return nil, errors.New("invalid point size")
160		}
161		x.SetBytes(e[:cSize])
162		y.SetBytes(e[cSize:])
163		if !c.IsOnCurve(x, y) {
164			return nil, errors.New("invalid point")
165		}
166		return &ECPoint{
167			X: x,
168			Y: y,
169		}, nil
170	case "COMPRESSED":
171		if len(e) != cSize+1 {
172			return nil, errors.New("compressed point has wrong length")
173		}
174		lsb := false
175		if e[0] == 2 {
176			lsb = false
177		} else if e[0] == 3 {
178			lsb = true
179		} else {
180			return nil, errors.New("invalid format")
181		}
182		x := new(big.Int)
183		x.SetBytes(e[1:])
184		if (x.Sign() == -1) || (x.Cmp(c.Params().P) != -1) {
185			return nil, errors.New("x is out of range")
186		}
187		y := getY(x, lsb, c)
188		return &ECPoint{
189			X: x,
190			Y: y,
191		}, nil
192	}
193	return nil, fmt.Errorf("invalid format: %s", pFormat)
194}
195
196func getY(x *big.Int, lsb bool, c elliptic.Curve) *big.Int {
197	// y² = x³ - 3x + b
198	x3 := new(big.Int).Mul(x, x)
199	x3.Mul(x3, x)
200
201	threeX := new(big.Int).Lsh(x, 1)
202	threeX.Add(threeX, x)
203	b := c.Params().B
204	p := c.Params().P
205
206	x3.Sub(x3, threeX)
207	x3.Add(x3, b)
208	x3.ModSqrt(x3, p)
209	e := uint(1)
210	if lsb {
211		e = 0
212	}
213	if e == x3.Bit(0) {
214		x3 := x3.Sub(p, x3)
215		x3.Mod(x3, p)
216	}
217	return x3
218}
219
220func validatePublicPoint(pub *ECPoint, priv *ECPrivateKey) error {
221	if priv.PublicKey.Curve.IsOnCurve(pub.X, pub.Y) {
222		return nil
223	}
224	return errors.New("invalid public key")
225}
226
227// ComputeSharedSecret is used to compute a shared secret using given private key and peer public key.
228func ComputeSharedSecret(pub *ECPoint, priv *ECPrivateKey) ([]byte, error) {
229	if err := validatePublicPoint(pub, priv); err != nil {
230		return nil, err
231	}
232
233	x, y := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
234
235	if x == nil {
236		return nil, errors.New("shared key compute error")
237	}
238	// check if x,y are on the curve
239	if err := validatePublicPoint(&ECPoint{X: x, Y: y}, priv); err != nil {
240		return nil, errors.New("invalid shared key")
241	}
242
243	sharedSecret := make([]byte, maxSharedKeyLength(priv.PublicKey))
244	return x.FillBytes(sharedSecret), nil
245}
246
247func maxSharedKeyLength(pub ECPublicKey) int {
248	return (pub.Curve.Params().BitSize + 7) / 8
249}
250
251// GenerateECDHKeyPair will create a new private key for a given curve.
252func GenerateECDHKeyPair(c elliptic.Curve) (*ECPrivateKey, error) {
253	p, x, y, err := elliptic.GenerateKey(c, rand.Reader)
254	if err != nil {
255		return nil, err
256	}
257	return &ECPrivateKey{
258		PublicKey: ECPublicKey{
259			Curve: c,
260			Point: ECPoint{
261				X: x,
262				Y: y,
263			},
264		},
265		D: new(big.Int).SetBytes(p),
266	}, nil
267
268}
269
270// GetCurve returns the elliptic.Curve for a given standard curve name.
271func GetCurve(c string) (elliptic.Curve, error) {
272	switch c {
273	case "secp224r1", "NIST_P224", "P-224":
274		return elliptic.P224(), nil
275	case "secp256r1", "NIST_P256", "P-256", "EllipticCurveType_NIST_P256":
276		return elliptic.P256(), nil
277	case "secp384r1", "NIST_P384", "P-384", "EllipticCurveType_NIST_P384":
278		return elliptic.P384(), nil
279	case "secp521r1", "NIST_P521", "P-521", "EllipticCurveType_NIST_P521":
280		return elliptic.P521(), nil
281	default:
282		return nil, errors.New("unsupported curve")
283	}
284}
285