1// Copyright 2021 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package bigmod 6 7import ( 8 "fmt" 9 "math/big" 10 "math/bits" 11 "math/rand" 12 "reflect" 13 "strings" 14 "testing" 15 "testing/quick" 16) 17 18func (n *Nat) String() string { 19 var limbs []string 20 for i := range n.limbs { 21 limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i])) 22 } 23 return "{" + strings.Join(limbs, " ") + "}" 24} 25 26// Generate generates an even nat. It's used by testing/quick to produce random 27// *nat values for quick.Check invocations. 28func (*Nat) Generate(r *rand.Rand, size int) reflect.Value { 29 limbs := make([]uint, size) 30 for i := 0; i < size; i++ { 31 limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2) 32 } 33 return reflect.ValueOf(&Nat{limbs}) 34} 35 36func testModAddCommutative(a *Nat, b *Nat) bool { 37 m := maxModulus(uint(len(a.limbs))) 38 aPlusB := new(Nat).set(a) 39 aPlusB.Add(b, m) 40 bPlusA := new(Nat).set(b) 41 bPlusA.Add(a, m) 42 return aPlusB.Equal(bPlusA) == 1 43} 44 45func TestModAddCommutative(t *testing.T) { 46 err := quick.Check(testModAddCommutative, &quick.Config{}) 47 if err != nil { 48 t.Error(err) 49 } 50} 51 52func testModSubThenAddIdentity(a *Nat, b *Nat) bool { 53 m := maxModulus(uint(len(a.limbs))) 54 original := new(Nat).set(a) 55 a.Sub(b, m) 56 a.Add(b, m) 57 return a.Equal(original) == 1 58} 59 60func TestModSubThenAddIdentity(t *testing.T) { 61 err := quick.Check(testModSubThenAddIdentity, &quick.Config{}) 62 if err != nil { 63 t.Error(err) 64 } 65} 66 67func TestMontgomeryRoundtrip(t *testing.T) { 68 err := quick.Check(func(a *Nat) bool { 69 one := &Nat{make([]uint, len(a.limbs))} 70 one.limbs[0] = 1 71 aPlusOne := new(big.Int).SetBytes(natBytes(a)) 72 aPlusOne.Add(aPlusOne, big.NewInt(1)) 73 m, _ := NewModulusFromBig(aPlusOne) 74 monty := new(Nat).set(a) 75 monty.montgomeryRepresentation(m) 76 aAgain := new(Nat).set(monty) 77 aAgain.montgomeryMul(monty, one, m) 78 if a.Equal(aAgain) != 1 { 79 t.Errorf("%v != %v", a, aAgain) 80 return false 81 } 82 return true 83 }, &quick.Config{}) 84 if err != nil { 85 t.Error(err) 86 } 87} 88 89func TestShiftIn(t *testing.T) { 90 if bits.UintSize != 64 { 91 t.Skip("examples are only valid in 64 bit") 92 } 93 examples := []struct { 94 m, x, expected []byte 95 y uint64 96 }{{ 97 m: []byte{13}, 98 x: []byte{0}, 99 y: 0xFFFF_FFFF_FFFF_FFFF, 100 expected: []byte{2}, 101 }, { 102 m: []byte{13}, 103 x: []byte{7}, 104 y: 0xFFFF_FFFF_FFFF_FFFF, 105 expected: []byte{10}, 106 }, { 107 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, 108 x: make([]byte, 9), 109 y: 0xFFFF_FFFF_FFFF_FFFF, 110 expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 111 }, { 112 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, 113 x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 114 y: 0, 115 expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06}, 116 }} 117 118 for i, tt := range examples { 119 m := modulusFromBytes(tt.m) 120 got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m) 121 if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 { 122 t.Errorf("%d: got %v, expected %v", i, got, exp) 123 } 124 } 125} 126 127func TestModulusAndNatSizes(t *testing.T) { 128 // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as 129 // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two 130 // limbs, if they are not, they fit in three. This can be a problem because 131 // modulus strips leading zeroes and nat does not. 132 m := modulusFromBytes([]byte{ 133 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 134 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) 135 xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 136 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe} 137 natFromBytes(xb).ExpandFor(m) // must not panic for shrinking 138 NewNat().SetBytes(xb, m) 139} 140 141func TestSetBytes(t *testing.T) { 142 tests := []struct { 143 m, b []byte 144 fail bool 145 }{{ 146 m: []byte{0xff, 0xff}, 147 b: []byte{0x00, 0x01}, 148 }, { 149 m: []byte{0xff, 0xff}, 150 b: []byte{0xff, 0xff}, 151 fail: true, 152 }, { 153 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 154 b: []byte{0x00, 0x01}, 155 }, { 156 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 157 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 158 }, { 159 m: []byte{0xff, 0xff}, 160 b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 161 fail: true, 162 }, { 163 m: []byte{0xff, 0xff}, 164 b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 165 fail: true, 166 }, { 167 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 168 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 169 }, { 170 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 171 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 172 fail: true, 173 }, { 174 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 175 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 176 fail: true, 177 }, { 178 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 179 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 180 fail: true, 181 }, { 182 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd}, 183 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 184 fail: true, 185 }} 186 187 for i, tt := range tests { 188 m := modulusFromBytes(tt.m) 189 got, err := NewNat().SetBytes(tt.b, m) 190 if err != nil { 191 if !tt.fail { 192 t.Errorf("%d: unexpected error: %v", i, err) 193 } 194 continue 195 } 196 if tt.fail { 197 t.Errorf("%d: unexpected success", i) 198 continue 199 } 200 if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes { 201 t.Errorf("%d: got %v, expected %v", i, got, expected) 202 } 203 } 204 205 f := func(xBytes []byte) bool { 206 m := maxModulus(uint(len(xBytes)*8/_W + 1)) 207 got, err := NewNat().SetBytes(xBytes, m) 208 if err != nil { 209 return false 210 } 211 return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes 212 } 213 214 err := quick.Check(f, &quick.Config{}) 215 if err != nil { 216 t.Error(err) 217 } 218} 219 220func TestExpand(t *testing.T) { 221 sliced := []uint{1, 2, 3, 4} 222 examples := []struct { 223 in []uint 224 n int 225 out []uint 226 }{{ 227 []uint{1, 2}, 228 4, 229 []uint{1, 2, 0, 0}, 230 }, { 231 sliced[:2], 232 4, 233 []uint{1, 2, 0, 0}, 234 }, { 235 []uint{1, 2}, 236 2, 237 []uint{1, 2}, 238 }} 239 240 for i, tt := range examples { 241 got := (&Nat{tt.in}).expand(tt.n) 242 if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 { 243 t.Errorf("%d: got %v, expected %v", i, got, tt.out) 244 } 245 } 246} 247 248func TestMod(t *testing.T) { 249 m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}) 250 x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) 251 out := new(Nat) 252 out.Mod(x, m) 253 expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09}) 254 if out.Equal(expected) != 1 { 255 t.Errorf("%+v != %+v", out, expected) 256 } 257} 258 259func TestModSub(t *testing.T) { 260 m := modulusFromBytes([]byte{13}) 261 x := &Nat{[]uint{6}} 262 y := &Nat{[]uint{7}} 263 x.Sub(y, m) 264 expected := &Nat{[]uint{12}} 265 if x.Equal(expected) != 1 { 266 t.Errorf("%+v != %+v", x, expected) 267 } 268 x.Sub(y, m) 269 expected = &Nat{[]uint{5}} 270 if x.Equal(expected) != 1 { 271 t.Errorf("%+v != %+v", x, expected) 272 } 273} 274 275func TestModAdd(t *testing.T) { 276 m := modulusFromBytes([]byte{13}) 277 x := &Nat{[]uint{6}} 278 y := &Nat{[]uint{7}} 279 x.Add(y, m) 280 expected := &Nat{[]uint{0}} 281 if x.Equal(expected) != 1 { 282 t.Errorf("%+v != %+v", x, expected) 283 } 284 x.Add(y, m) 285 expected = &Nat{[]uint{7}} 286 if x.Equal(expected) != 1 { 287 t.Errorf("%+v != %+v", x, expected) 288 } 289} 290 291func TestExp(t *testing.T) { 292 m := modulusFromBytes([]byte{13}) 293 x := &Nat{[]uint{3}} 294 out := &Nat{[]uint{0}} 295 out.Exp(x, []byte{12}, m) 296 expected := &Nat{[]uint{1}} 297 if out.Equal(expected) != 1 { 298 t.Errorf("%+v != %+v", out, expected) 299 } 300} 301 302func TestExpShort(t *testing.T) { 303 m := modulusFromBytes([]byte{13}) 304 x := &Nat{[]uint{3}} 305 out := &Nat{[]uint{0}} 306 out.ExpShortVarTime(x, 12, m) 307 expected := &Nat{[]uint{1}} 308 if out.Equal(expected) != 1 { 309 t.Errorf("%+v != %+v", out, expected) 310 } 311} 312 313// TestMulReductions tests that Mul reduces results equal or slightly greater 314// than the modulus. Some Montgomery algorithms don't and need extra care to 315// return correct results. See https://go.dev/issue/13907. 316func TestMulReductions(t *testing.T) { 317 // Two short but multi-limb primes. 318 a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10) 319 b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10) 320 n := new(big.Int).Mul(a, b) 321 322 N, _ := NewModulusFromBig(n) 323 A := NewNat().setBig(a).ExpandFor(N) 324 B := NewNat().setBig(b).ExpandFor(N) 325 326 if A.Mul(B, N).IsZero() != 1 { 327 t.Error("a * b mod (a * b) != 0") 328 } 329 330 i := new(big.Int).ModInverse(a, b) 331 N, _ = NewModulusFromBig(b) 332 A = NewNat().setBig(a).ExpandFor(N) 333 I := NewNat().setBig(i).ExpandFor(N) 334 one := NewNat().setBig(big.NewInt(1)).ExpandFor(N) 335 336 if A.Mul(I, N).Equal(one) != 1 { 337 t.Error("a * inv(a) mod b != 1") 338 } 339} 340 341func natBytes(n *Nat) []byte { 342 return n.Bytes(maxModulus(uint(len(n.limbs)))) 343} 344 345func natFromBytes(b []byte) *Nat { 346 // Must not use Nat.SetBytes as it's used in TestSetBytes. 347 bb := new(big.Int).SetBytes(b) 348 return NewNat().setBig(bb) 349} 350 351func modulusFromBytes(b []byte) *Modulus { 352 bb := new(big.Int).SetBytes(b) 353 m, _ := NewModulusFromBig(bb) 354 return m 355} 356 357// maxModulus returns the biggest modulus that can fit in n limbs. 358func maxModulus(n uint) *Modulus { 359 b := big.NewInt(1) 360 b.Lsh(b, n*_W) 361 b.Sub(b, big.NewInt(1)) 362 m, _ := NewModulusFromBig(b) 363 return m 364} 365 366func makeBenchmarkModulus() *Modulus { 367 return maxModulus(32) 368} 369 370func makeBenchmarkValue() *Nat { 371 x := make([]uint, 32) 372 for i := 0; i < 32; i++ { 373 x[i]-- 374 } 375 return &Nat{limbs: x} 376} 377 378func makeBenchmarkExponent() []byte { 379 e := make([]byte, 256) 380 for i := 0; i < 32; i++ { 381 e[i] = 0xFF 382 } 383 return e 384} 385 386func BenchmarkModAdd(b *testing.B) { 387 x := makeBenchmarkValue() 388 y := makeBenchmarkValue() 389 m := makeBenchmarkModulus() 390 391 b.ResetTimer() 392 for i := 0; i < b.N; i++ { 393 x.Add(y, m) 394 } 395} 396 397func BenchmarkModSub(b *testing.B) { 398 x := makeBenchmarkValue() 399 y := makeBenchmarkValue() 400 m := makeBenchmarkModulus() 401 402 b.ResetTimer() 403 for i := 0; i < b.N; i++ { 404 x.Sub(y, m) 405 } 406} 407 408func BenchmarkMontgomeryRepr(b *testing.B) { 409 x := makeBenchmarkValue() 410 m := makeBenchmarkModulus() 411 412 b.ResetTimer() 413 for i := 0; i < b.N; i++ { 414 x.montgomeryRepresentation(m) 415 } 416} 417 418func BenchmarkMontgomeryMul(b *testing.B) { 419 x := makeBenchmarkValue() 420 y := makeBenchmarkValue() 421 out := makeBenchmarkValue() 422 m := makeBenchmarkModulus() 423 424 b.ResetTimer() 425 for i := 0; i < b.N; i++ { 426 out.montgomeryMul(x, y, m) 427 } 428} 429 430func BenchmarkModMul(b *testing.B) { 431 x := makeBenchmarkValue() 432 y := makeBenchmarkValue() 433 m := makeBenchmarkModulus() 434 435 b.ResetTimer() 436 for i := 0; i < b.N; i++ { 437 x.Mul(y, m) 438 } 439} 440 441func BenchmarkExpBig(b *testing.B) { 442 out := new(big.Int) 443 exponentBytes := makeBenchmarkExponent() 444 x := new(big.Int).SetBytes(exponentBytes) 445 e := new(big.Int).SetBytes(exponentBytes) 446 n := new(big.Int).SetBytes(exponentBytes) 447 one := new(big.Int).SetUint64(1) 448 n.Add(n, one) 449 450 b.ResetTimer() 451 for i := 0; i < b.N; i++ { 452 out.Exp(x, e, n) 453 } 454} 455 456func BenchmarkExp(b *testing.B) { 457 x := makeBenchmarkValue() 458 e := makeBenchmarkExponent() 459 out := makeBenchmarkValue() 460 m := makeBenchmarkModulus() 461 462 b.ResetTimer() 463 for i := 0; i < b.N; i++ { 464 out.Exp(x, e, m) 465 } 466} 467 468func TestNewModFromBigZero(t *testing.T) { 469 expected := "modulus must be >= 0" 470 _, err := NewModulusFromBig(big.NewInt(0)) 471 if err == nil || err.Error() != expected { 472 t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected) 473 } 474 475 expected = "modulus must be odd" 476 _, err = NewModulusFromBig(big.NewInt(2)) 477 if err == nil || err.Error() != expected { 478 t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected) 479 } 480} 481