xref: /aosp_15_r20/external/tink/go/jwt/jwt_rsa_ssa_pss_signer_key_manager.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package jwt
18
19import (
20	"crypto/rand"
21	"crypto/rsa"
22	"errors"
23	"fmt"
24	"math/big"
25
26	"google.golang.org/protobuf/proto"
27	"github.com/google/tink/go/core/registry"
28	"github.com/google/tink/go/internal/signature"
29	"github.com/google/tink/go/keyset"
30	jrsppb "github.com/google/tink/go/proto/jwt_rsa_ssa_pss_go_proto"
31	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
32)
33
34const (
35	jwtPSSignerKeyVersion = 0
36	jwtPSSignerTypeURL    = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPrivateKey"
37)
38
39var (
40	errPSInvalidPrivateKey = errors.New("invalid JwtRsaSsaPssPrivateKey")
41	errPSInvalidKeyFormat  = errors.New("invalid RSA SSA PSS key format")
42)
43
44// jwtPSSignerKeyManager implements the KeyManager interface
45// for JWT Signing using the 'PS256', 'PS384', and 'PS512' JWA algorithm.
46type jwtPSSignerKeyManager struct{}
47
48var _ registry.PrivateKeyManager = (*jwtPSSignerKeyManager)(nil)
49
50func (km *jwtPSSignerKeyManager) Primitive(serializedKey []byte) (interface{}, error) {
51	if serializedKey == nil {
52		return nil, fmt.Errorf("invalid JwtRsaSsaPSSPrivateKey")
53	}
54	privKey := &jrsppb.JwtRsaSsaPssPrivateKey{}
55	if err := proto.Unmarshal(serializedKey, privKey); err != nil {
56		return nil, fmt.Errorf("failed to unmarshal JwtRsaSsaPssPrivateKey: %v", err)
57	}
58	if err := validatePSPrivateKey(privKey); err != nil {
59		return nil, err
60	}
61	rsaPrivKey := &rsa.PrivateKey{
62		PublicKey: rsa.PublicKey{
63			N: bytesToBigInt(privKey.GetPublicKey().GetN()),
64			E: int(bytesToBigInt(privKey.GetPublicKey().GetE()).Int64()),
65		},
66		D: bytesToBigInt(privKey.GetD()),
67		Primes: []*big.Int{
68			bytesToBigInt(privKey.GetP()),
69			bytesToBigInt(privKey.GetQ()),
70		},
71		Precomputed: rsa.PrecomputedValues{
72			Dp: bytesToBigInt(privKey.GetDp()),
73			Dq: bytesToBigInt(privKey.GetDq()),
74			// in crypto/rsa `Qinv` is the "Chinese Remainder Theorem
75			// coefficient q^(-1) mod p. Which is `GetCrt` in the tink proto and not
76			// the `CRTValues`.
77			Qinv: bytesToBigInt(privKey.GetCrt()),
78		},
79	}
80	algorithm := privKey.GetPublicKey().GetAlgorithm()
81	if err := signature.Validate_RSA_SSA_PSS(validPSAlgToHash[algorithm], psAlgToSaltLen[algorithm], rsaPrivKey); err != nil {
82		return nil, err
83	}
84	signer, err := signature.New_RSA_SSA_PSS_Signer(validPSAlgToHash[algorithm], psAlgToSaltLen[algorithm], rsaPrivKey)
85	if err != nil {
86		return nil, err
87	}
88	return newSignerWithKID(signer, algorithm.String(), psCustomKID(privKey.GetPublicKey()))
89}
90
91func validatePSPrivateKey(privKey *jrsppb.JwtRsaSsaPssPrivateKey) error {
92	if err := keyset.ValidateKeyVersion(privKey.Version, jwtPSSignerKeyVersion); err != nil {
93		return err
94	}
95	if privKey.GetD() == nil ||
96		len(privKey.GetPublicKey().GetN()) == 0 ||
97		len(privKey.GetPublicKey().GetE()) == 0 ||
98		privKey.GetP() == nil ||
99		privKey.GetQ() == nil ||
100		privKey.GetDp() == nil ||
101		privKey.GetDq() == nil ||
102		privKey.GetCrt() == nil {
103		return fmt.Errorf("invalid private key")
104	}
105	if err := validatePSPublicKey(privKey.GetPublicKey()); err != nil {
106		return err
107	}
108	return nil
109}
110
111func (km *jwtPSSignerKeyManager) NewKey(serializedKeyFormat []byte) (proto.Message, error) {
112	if len(serializedKeyFormat) == 0 {
113		return nil, errPSInvalidKeyFormat
114	}
115	keyFormat := &jrsppb.JwtRsaSsaPssKeyFormat{}
116	if err := proto.Unmarshal(serializedKeyFormat, keyFormat); err != nil {
117		return nil, fmt.Errorf("failed to unmarshal JwtRsaSsaPssKeyFormat: %v", err)
118	}
119	if err := keyset.ValidateKeyVersion(keyFormat.GetVersion(), jwtPSSignerKeyVersion); err != nil {
120		return nil, err
121	}
122	rsaKey, err := rsa.GenerateKey(rand.Reader, int(keyFormat.GetModulusSizeInBits()))
123	if err != nil {
124		return nil, err
125	}
126	privKey := &jrsppb.JwtRsaSsaPssPrivateKey{
127		Version: jwtPSSignerKeyVersion,
128		PublicKey: &jrsppb.JwtRsaSsaPssPublicKey{
129			Version:   jwtPSSignerKeyVersion,
130			Algorithm: keyFormat.GetAlgorithm(),
131			N:         rsaKey.PublicKey.N.Bytes(),
132			E:         keyFormat.GetPublicExponent(),
133		},
134		D:  rsaKey.D.Bytes(),
135		P:  rsaKey.Primes[0].Bytes(),
136		Q:  rsaKey.Primes[1].Bytes(),
137		Dp: rsaKey.Precomputed.Dp.Bytes(),
138		Dq: rsaKey.Precomputed.Dq.Bytes(),
139		// in crypto/rsa `Qinv` is the "Chinese Remainder Theorem
140		// coefficient q^(-1) mod p. Which is `Crt` in the tink proto and not
141		// the `CRTValues`.
142		Crt: rsaKey.Precomputed.Qinv.Bytes(),
143	}
144	if err := validatePSPrivateKey(privKey); err != nil {
145		return nil, err
146	}
147	return privKey, nil
148}
149
150func (km *jwtPSSignerKeyManager) NewKeyData(serializedKeyFormat []byte) (*tinkpb.KeyData, error) {
151	key, err := km.NewKey(serializedKeyFormat)
152	if err != nil {
153		return nil, err
154	}
155	serializedKey, err := proto.Marshal(key)
156	if err != nil {
157		return nil, err
158	}
159	return &tinkpb.KeyData{
160		TypeUrl:         jwtPSSignerTypeURL,
161		Value:           serializedKey,
162		KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PRIVATE,
163	}, nil
164}
165
166func (km *jwtPSSignerKeyManager) PublicKeyData(serializedPrivKey []byte) (*tinkpb.KeyData, error) {
167	if serializedPrivKey == nil {
168		return nil, errPSInvalidKeyFormat
169	}
170	privKey := &jrsppb.JwtRsaSsaPssPrivateKey{}
171	if err := proto.Unmarshal(serializedPrivKey, privKey); err != nil {
172		return nil, fmt.Errorf("failed to unmarshal JwtRsaSsaPssPrivateKey: %v", err)
173	}
174	if err := validatePSPrivateKey(privKey); err != nil {
175		return nil, err
176	}
177	serializedPubKey, err := proto.Marshal(privKey.GetPublicKey())
178	if err != nil {
179		return nil, err
180	}
181	return &tinkpb.KeyData{
182		TypeUrl:         jwtPSVerifierTypeURL,
183		Value:           serializedPubKey,
184		KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC,
185	}, nil
186}
187
188func (km *jwtPSSignerKeyManager) DoesSupport(typeURL string) bool {
189	return jwtPSSignerTypeURL == typeURL
190}
191
192func (km *jwtPSSignerKeyManager) TypeURL() string {
193	return jwtPSSignerTypeURL
194}
195