xref: /aosp_15_r20/external/tink/go/internal/signature/rsassapss_signer_verifier_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package signature_test
18
19import (
20	"crypto/rand"
21	"crypto/rsa"
22	"fmt"
23	"math/big"
24	"testing"
25
26	"github.com/google/tink/go/internal/signature"
27	"github.com/google/tink/go/subtle/random"
28	"github.com/google/tink/go/subtle"
29	"github.com/google/tink/go/testutil"
30)
31
32func TestRSASSAPSSSignVerify(t *testing.T) {
33	data := random.GetRandomBytes(20)
34	sigHash := "SHA256"
35	saltLength := 10
36	privKey, err := rsa.GenerateKey(rand.Reader, 3072)
37	if err != nil {
38		t.Fatalf("rsa.GenerateKey(rand.Reader, 3072) err = %v, want nil", err)
39	}
40	signer, err := signature.New_RSA_SSA_PSS_Signer(sigHash, saltLength, privKey)
41	if err != nil {
42		t.Fatalf("New_RSA_SSA_PSS_Signer() error = %v, want nil", err)
43	}
44	verifier, err := signature.New_RSA_SSA_PSS_Verifier(sigHash, saltLength, &privKey.PublicKey)
45	if err != nil {
46		t.Fatalf("New_RSA_SSA_PSS_Verifier() error = %v, want nil", err)
47	}
48	s, err := signer.Sign(data)
49	if err != nil {
50		t.Fatalf("Sign() err = %v, want nil", err)
51	}
52	if err = verifier.Verify(s, data); err != nil {
53		t.Fatalf("Verify() err = %v, want nil", err)
54	}
55}
56
57func TestRSASSAPSSSignVerifyInvalidFails(t *testing.T) {
58	data := random.GetRandomBytes(20)
59	sigHash := "SHA256"
60	saltLength := 10
61	privKey, err := rsa.GenerateKey(rand.Reader, 3072)
62	if err != nil {
63		t.Fatalf("rsa.GenerateKey(rand.Reader, 3072) err = %v, want nil", err)
64	}
65	signer, err := signature.New_RSA_SSA_PSS_Signer(sigHash, saltLength, privKey)
66	if err != nil {
67		t.Fatalf("New_RSA_SSA_PSS_Signer() error = %v, want nil", err)
68	}
69	verifier, err := signature.New_RSA_SSA_PSS_Verifier(sigHash, saltLength, &privKey.PublicKey)
70	if err != nil {
71		t.Fatalf("New_RSA_SSA_PSS_Verifier() error = %v, want nil", err)
72	}
73	s, err := signer.Sign(data)
74	if err != nil {
75		t.Fatalf("Sign() err = %v, want nil", err)
76	}
77	if err = verifier.Verify(s, data); err != nil {
78		t.Fatalf("Verify() err = %v, want nil", err)
79	}
80
81	modifiedSig := s[:]
82	// modify first byte in signature
83	modifiedSig[0] = byte(uint8(modifiedSig[0]) + 1)
84	if err := verifier.Verify(modifiedSig, data); err == nil {
85		t.Errorf("Verify(modifiedSig, data) err = nil, want error")
86	}
87	if err := verifier.Verify(s, []byte("invalid_data")); err == nil {
88		t.Errorf("Verify(s, invalid_data) err = nil, want error")
89	}
90	if err := verifier.Verify([]byte("invalid_signature"), data); err == nil {
91		t.Errorf("Verify(invalid_signature, data) err = nil, want error")
92	}
93
94	diffPrivKey, err := rsa.GenerateKey(rand.Reader, 3072)
95	if err != nil {
96		t.Fatalf("rsa.GenerateKey(rand.Reader, 3072) err = %v, want nil", err)
97	}
98	diffVerifier, err := signature.New_RSA_SSA_PSS_Verifier(sigHash, saltLength, &diffPrivKey.PublicKey)
99	if err != nil {
100		t.Fatalf("New_RSA_SSA_PSS_Verifier() error = %v, want nil", err)
101	}
102	if err := diffVerifier.Verify(s, data); err == nil {
103		t.Errorf("Verify() err = nil, want error")
104	}
105}
106
107func TestNewRSASSAPSSSignerVerifierFailWithInvalidInputs(t *testing.T) {
108	type testCase struct {
109		name    string
110		hash    string
111		salt    int
112		privKey *rsa.PrivateKey
113	}
114	validPrivKey, err := rsa.GenerateKey(rand.Reader, 3072)
115	if err != nil {
116		t.Fatalf("rsa.GenerateKey(rand.Reader, 3072) err = %v, want nil", err)
117	}
118	for _, tc := range []testCase{
119		{
120			name:    "invalid hash function",
121			hash:    "SHA1",
122			privKey: validPrivKey,
123			salt:    0,
124		},
125		{
126			name: "invalid exponent",
127			hash: "SHA256",
128			salt: 0,
129			privKey: &rsa.PrivateKey{
130				D: validPrivKey.D,
131				PublicKey: rsa.PublicKey{
132					N: validPrivKey.N,
133					E: 8,
134				},
135				Primes:      validPrivKey.Primes,
136				Precomputed: validPrivKey.Precomputed,
137			},
138		},
139		{
140			name: "invalid modulus",
141			hash: "SHA256",
142			salt: 0,
143			privKey: &rsa.PrivateKey{
144				D: validPrivKey.D,
145				PublicKey: rsa.PublicKey{
146					N: big.NewInt(5),
147					E: validPrivKey.E,
148				},
149				Primes:      validPrivKey.Primes,
150				Precomputed: validPrivKey.Precomputed,
151			},
152		},
153		{
154			name:    "invalid salt",
155			hash:    "SHA256",
156			salt:    -1,
157			privKey: validPrivKey,
158		},
159	} {
160		t.Run(tc.name, func(t *testing.T) {
161			if _, err := signature.New_RSA_SSA_PSS_Signer(tc.hash, tc.salt, tc.privKey); err == nil {
162				t.Errorf("New_RSA_SSA_PSS_Signer() err = nil, want error")
163			}
164			if _, err := signature.New_RSA_SSA_PSS_Verifier(tc.hash, tc.salt, &tc.privKey.PublicKey); err == nil {
165				t.Errorf("New_RSA_SSA_PSS_Verifier() err = nil, want error")
166			}
167		})
168	}
169}
170
171type rsaSSAPSSSuite struct {
172	testutil.WycheproofSuite
173	TestGroups []*rsaSSAPSSGroup `json:"testGroups"`
174}
175
176type rsaSSAPSSGroup struct {
177	testutil.WycheproofGroup
178	SHA        string            `json:"sha"`
179	MGFSHA     string            `json:"mgfSha"`
180	SaltLength int               `json:"sLen"`
181	E          testutil.HexBytes `json:"e"`
182	N          testutil.HexBytes `json:"N"`
183	Tests      []*rsaSSAPSSCase  `json:"tests"`
184}
185
186type rsaSSAPSSCase struct {
187	testutil.WycheproofCase
188	Message   testutil.HexBytes `json:"msg"`
189	Signature testutil.HexBytes `json:"sig"`
190}
191
192func TestRSASSAPSSWycheproofCases(t *testing.T) {
193	testutil.SkipTestIfTestSrcDirIsNotSet(t)
194	ranTestCount := 0
195	vectorsFiles := []string{
196		"rsa_pss_2048_sha512_256_mgf1_28_test.json",
197		"rsa_pss_2048_sha512_256_mgf1_32_test.json",
198		"rsa_pss_2048_sha256_mgf1_0_test.json",
199		"rsa_pss_2048_sha256_mgf1_32_test.json",
200		"rsa_pss_3072_sha256_mgf1_32_test.json",
201		"rsa_pss_4096_sha256_mgf1_32_test.json",
202		"rsa_pss_4096_sha512_mgf1_32_test.json",
203	}
204	for _, v := range vectorsFiles {
205		suite := &rsaSSAPSSSuite{}
206		if err := testutil.PopulateSuite(suite, v); err != nil {
207			t.Fatalf("failed populating suite: %s", err)
208		}
209		for _, group := range suite.TestGroups {
210			sigHash := subtle.ConvertHashName(group.SHA)
211			if sigHash == "" {
212				continue
213			}
214			pubKey := &rsa.PublicKey{
215				E: int(new(big.Int).SetBytes(group.E).Uint64()),
216				N: new(big.Int).SetBytes(group.N),
217			}
218			verifier, err := signature.New_RSA_SSA_PSS_Verifier(sigHash, group.SaltLength, pubKey)
219			if err != nil {
220				t.Fatalf("New_RSA_SSA_PSS_Verifier() err = %v, want nil", err)
221			}
222			for _, test := range group.Tests {
223				if (test.CaseID == 67 || test.CaseID == 68) && v == "rsa_pss_2048_sha256_mgf1_0_test.json" {
224					// crypto/rsa will interpret zero length salt and parse the salt length from signature.
225					// Since this test cases use a zero salt length as a parameter, even if a different parameter
226					// is provided, Golang will interpret it and parse the salt directly from the signature.
227					continue
228				}
229				ranTestCount++
230				caseName := fmt.Sprintf("%s: %s-%s-%s-%d:Case-%d", v, group.Type, group.SHA, group.MGFSHA, group.SaltLength, test.CaseID)
231				t.Run(caseName, func(t *testing.T) {
232					err := verifier.Verify(test.Signature, test.Message)
233					switch test.Result {
234					case "valid":
235						if err != nil {
236							t.Errorf("Verify() err = %, want nil", err)
237						}
238					case "invalid":
239						if err == nil {
240							t.Errorf("Verify() err = nil, want error")
241						}
242					case "acceptable":
243						// TODO(b/230489047): Inspect flags to appropriately handle acceptable test cases.
244					default:
245						t.Errorf("unsupported test result: %q", test.Result)
246					}
247				})
248			}
249		}
250	}
251	if ranTestCount < 578 {
252		t.Errorf("ranTestCount > %d, want > %d", ranTestCount, 578)
253	}
254}
255