1// Copyright 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package bigmod
6
7import (
8	"errors"
9	"internal/byteorder"
10	"math/big"
11	"math/bits"
12)
13
14const (
15	// _W is the size in bits of our limbs.
16	_W = bits.UintSize
17	// _S is the size in bytes of our limbs.
18	_S = _W / 8
19)
20
21// choice represents a constant-time boolean. The value of choice is always
22// either 1 or 0. We use an int instead of bool in order to make decisions in
23// constant time by turning it into a mask.
24type choice uint
25
26func not(c choice) choice { return 1 ^ c }
27
28const yes = choice(1)
29const no = choice(0)
30
31// ctMask is all 1s if on is yes, and all 0s otherwise.
32func ctMask(on choice) uint { return -uint(on) }
33
34// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
35// function does not depend on its inputs.
36func ctEq(x, y uint) choice {
37	// If x != y, then either x - y or y - x will generate a carry.
38	_, c1 := bits.Sub(x, y, 0)
39	_, c2 := bits.Sub(y, x, 0)
40	return not(choice(c1 | c2))
41}
42
43// Nat represents an arbitrary natural number
44//
45// Each Nat has an announced length, which is the number of limbs it has stored.
46// Operations on this number are allowed to leak this length, but will not leak
47// any information about the values contained in those limbs.
48type Nat struct {
49	// limbs is little-endian in base 2^W with W = bits.UintSize.
50	limbs []uint
51}
52
53// preallocTarget is the size in bits of the numbers used to implement the most
54// common and most performant RSA key size. It's also enough to cover some of
55// the operations of key sizes up to 4096.
56const preallocTarget = 2048
57const preallocLimbs = (preallocTarget + _W - 1) / _W
58
59// NewNat returns a new nat with a size of zero, just like new(Nat), but with
60// the preallocated capacity to hold a number of up to preallocTarget bits.
61// NewNat inlines, so the allocation can live on the stack.
62func NewNat() *Nat {
63	limbs := make([]uint, 0, preallocLimbs)
64	return &Nat{limbs}
65}
66
67// expand expands x to n limbs, leaving its value unchanged.
68func (x *Nat) expand(n int) *Nat {
69	if len(x.limbs) > n {
70		panic("bigmod: internal error: shrinking nat")
71	}
72	if cap(x.limbs) < n {
73		newLimbs := make([]uint, n)
74		copy(newLimbs, x.limbs)
75		x.limbs = newLimbs
76		return x
77	}
78	extraLimbs := x.limbs[len(x.limbs):n]
79	clear(extraLimbs)
80	x.limbs = x.limbs[:n]
81	return x
82}
83
84// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
85func (x *Nat) reset(n int) *Nat {
86	if cap(x.limbs) < n {
87		x.limbs = make([]uint, n)
88		return x
89	}
90	clear(x.limbs)
91	x.limbs = x.limbs[:n]
92	return x
93}
94
95// set assigns x = y, optionally resizing x to the appropriate size.
96func (x *Nat) set(y *Nat) *Nat {
97	x.reset(len(y.limbs))
98	copy(x.limbs, y.limbs)
99	return x
100}
101
102// setBig assigns x = n, optionally resizing n to the appropriate size.
103//
104// The announced length of x is set based on the actual bit size of the input,
105// ignoring leading zeroes.
106func (x *Nat) setBig(n *big.Int) *Nat {
107	limbs := n.Bits()
108	x.reset(len(limbs))
109	for i := range limbs {
110		x.limbs[i] = uint(limbs[i])
111	}
112	return x
113}
114
115// Bytes returns x as a zero-extended big-endian byte slice. The size of the
116// slice will match the size of m.
117//
118// x must have the same size as m and it must be reduced modulo m.
119func (x *Nat) Bytes(m *Modulus) []byte {
120	i := m.Size()
121	bytes := make([]byte, i)
122	for _, limb := range x.limbs {
123		for j := 0; j < _S; j++ {
124			i--
125			if i < 0 {
126				if limb == 0 {
127					break
128				}
129				panic("bigmod: modulus is smaller than nat")
130			}
131			bytes[i] = byte(limb)
132			limb >>= 8
133		}
134	}
135	return bytes
136}
137
138// SetBytes assigns x = b, where b is a slice of big-endian bytes.
139// SetBytes returns an error if b >= m.
140//
141// The output will be resized to the size of m and overwritten.
142func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
143	if err := x.setBytes(b, m); err != nil {
144		return nil, err
145	}
146	if x.cmpGeq(m.nat) == yes {
147		return nil, errors.New("input overflows the modulus")
148	}
149	return x, nil
150}
151
152// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes.
153// SetOverflowingBytes returns an error if b has a longer bit length than m, but
154// reduces overflowing values up to 2^⌈log2(m)⌉ - 1.
155//
156// The output will be resized to the size of m and overwritten.
157func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
158	if err := x.setBytes(b, m); err != nil {
159		return nil, err
160	}
161	leading := _W - bitLen(x.limbs[len(x.limbs)-1])
162	if leading < m.leading {
163		return nil, errors.New("input overflows the modulus size")
164	}
165	x.maybeSubtractModulus(no, m)
166	return x, nil
167}
168
169// bigEndianUint returns the contents of buf interpreted as a
170// big-endian encoded uint value.
171func bigEndianUint(buf []byte) uint {
172	if _W == 64 {
173		return uint(byteorder.BeUint64(buf))
174	}
175	return uint(byteorder.BeUint32(buf))
176}
177
178func (x *Nat) setBytes(b []byte, m *Modulus) error {
179	x.resetFor(m)
180	i, k := len(b), 0
181	for k < len(x.limbs) && i >= _S {
182		x.limbs[k] = bigEndianUint(b[i-_S : i])
183		i -= _S
184		k++
185	}
186	for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 {
187		x.limbs[k] |= uint(b[i-1]) << s
188		i--
189	}
190	if i > 0 {
191		return errors.New("input overflows the modulus size")
192	}
193	return nil
194}
195
196// Equal returns 1 if x == y, and 0 otherwise.
197//
198// Both operands must have the same announced length.
199func (x *Nat) Equal(y *Nat) choice {
200	// Eliminate bounds checks in the loop.
201	size := len(x.limbs)
202	xLimbs := x.limbs[:size]
203	yLimbs := y.limbs[:size]
204
205	equal := yes
206	for i := 0; i < size; i++ {
207		equal &= ctEq(xLimbs[i], yLimbs[i])
208	}
209	return equal
210}
211
212// IsZero returns 1 if x == 0, and 0 otherwise.
213func (x *Nat) IsZero() choice {
214	// Eliminate bounds checks in the loop.
215	size := len(x.limbs)
216	xLimbs := x.limbs[:size]
217
218	zero := yes
219	for i := 0; i < size; i++ {
220		zero &= ctEq(xLimbs[i], 0)
221	}
222	return zero
223}
224
225// cmpGeq returns 1 if x >= y, and 0 otherwise.
226//
227// Both operands must have the same announced length.
228func (x *Nat) cmpGeq(y *Nat) choice {
229	// Eliminate bounds checks in the loop.
230	size := len(x.limbs)
231	xLimbs := x.limbs[:size]
232	yLimbs := y.limbs[:size]
233
234	var c uint
235	for i := 0; i < size; i++ {
236		_, c = bits.Sub(xLimbs[i], yLimbs[i], c)
237	}
238	// If there was a carry, then subtracting y underflowed, so
239	// x is not greater than or equal to y.
240	return not(choice(c))
241}
242
243// assign sets x <- y if on == 1, and does nothing otherwise.
244//
245// Both operands must have the same announced length.
246func (x *Nat) assign(on choice, y *Nat) *Nat {
247	// Eliminate bounds checks in the loop.
248	size := len(x.limbs)
249	xLimbs := x.limbs[:size]
250	yLimbs := y.limbs[:size]
251
252	mask := ctMask(on)
253	for i := 0; i < size; i++ {
254		xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i])
255	}
256	return x
257}
258
259// add computes x += y and returns the carry.
260//
261// Both operands must have the same announced length.
262func (x *Nat) add(y *Nat) (c uint) {
263	// Eliminate bounds checks in the loop.
264	size := len(x.limbs)
265	xLimbs := x.limbs[:size]
266	yLimbs := y.limbs[:size]
267
268	for i := 0; i < size; i++ {
269		xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c)
270	}
271	return
272}
273
274// sub computes x -= y. It returns the borrow of the subtraction.
275//
276// Both operands must have the same announced length.
277func (x *Nat) sub(y *Nat) (c uint) {
278	// Eliminate bounds checks in the loop.
279	size := len(x.limbs)
280	xLimbs := x.limbs[:size]
281	yLimbs := y.limbs[:size]
282
283	for i := 0; i < size; i++ {
284		xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c)
285	}
286	return
287}
288
289// Modulus is used for modular arithmetic, precomputing relevant constants.
290//
291// Moduli are assumed to be odd numbers. Moduli can also leak the exact
292// number of bits needed to store their value, and are stored without padding.
293//
294// Their actual value is still kept secret.
295type Modulus struct {
296	// The underlying natural number for this modulus.
297	//
298	// This will be stored without any padding, and shouldn't alias with any
299	// other natural number being used.
300	nat     *Nat
301	leading int  // number of leading zeros in the modulus
302	m0inv   uint // -nat.limbs[0]⁻¹ mod _W
303	rr      *Nat // R*R for montgomeryRepresentation
304}
305
306// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
307func rr(m *Modulus) *Nat {
308	rr := NewNat().ExpandFor(m)
309	n := uint(len(rr.limbs))
310	mLen := uint(m.BitLen())
311	logR := _W * n
312
313	// We start by computing R = 2^(_W * n) mod m. We can get pretty close, to
314	// 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce.
315	rr.limbs[n-1] = 1 << ((mLen - 1) % _W)
316	// Then we double until we reach 2^(_W * n).
317	for i := mLen - 1; i < logR; i++ {
318		rr.Add(rr, m)
319	}
320
321	// Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in
322	// the Montgomery domain, meaning we can use Montgomery multiplication now).
323	// We could do that by doubling _W * n times, or with a square-and-double
324	// chain log2(_W * n) long. Turns out the fastest thing is to start out with
325	// doublings, and switch to square-and-double once the exponent is large
326	// enough to justify the cost of the multiplications.
327
328	// The threshold is selected experimentally as a linear function of n.
329	threshold := n / 4
330
331	// We calculate how many of the most-significant bits of the exponent we can
332	// compute before crossing the threshold, and we do it with doublings.
333	i := bits.UintSize
334	for logR>>i <= threshold {
335		i--
336	}
337	for k := uint(0); k < logR>>i; k++ {
338		rr.Add(rr, m)
339	}
340
341	// Then we process the remaining bits of the exponent with a
342	// square-and-double chain.
343	for i > 0 {
344		rr.montgomeryMul(rr, rr, m)
345		i--
346		if logR>>i&1 != 0 {
347			rr.Add(rr, m)
348		}
349	}
350
351	return rr
352}
353
354// minusInverseModW computes -x⁻¹ mod _W with x odd.
355//
356// This operation is used to precompute a constant involved in Montgomery
357// multiplication.
358func minusInverseModW(x uint) uint {
359	// Every iteration of this loop doubles the least-significant bits of
360	// correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
361	// 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
362	// for 64 bits (and wastes only one iteration for 32 bits).
363	//
364	// See https://crypto.stackexchange.com/a/47496.
365	y := x
366	for i := 0; i < 5; i++ {
367		y = y * (2 - x*y)
368	}
369	return -y
370}
371
372// NewModulusFromBig creates a new Modulus from a [big.Int].
373//
374// The Int must be odd. The number of significant bits (and nothing else) is
375// leaked through timing side-channels.
376func NewModulusFromBig(n *big.Int) (*Modulus, error) {
377	if b := n.Bits(); len(b) == 0 {
378		return nil, errors.New("modulus must be >= 0")
379	} else if b[0]&1 != 1 {
380		return nil, errors.New("modulus must be odd")
381	}
382	m := &Modulus{}
383	m.nat = NewNat().setBig(n)
384	m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
385	m.m0inv = minusInverseModW(m.nat.limbs[0])
386	m.rr = rr(m)
387	return m, nil
388}
389
390// bitLen is a version of bits.Len that only leaks the bit length of n, but not
391// its value. bits.Len and bits.LeadingZeros use a lookup table for the
392// low-order bits on some architectures.
393func bitLen(n uint) int {
394	var len int
395	// We assume, here and elsewhere, that comparison to zero is constant time
396	// with respect to different non-zero values.
397	for n != 0 {
398		len++
399		n >>= 1
400	}
401	return len
402}
403
404// Size returns the size of m in bytes.
405func (m *Modulus) Size() int {
406	return (m.BitLen() + 7) / 8
407}
408
409// BitLen returns the size of m in bits.
410func (m *Modulus) BitLen() int {
411	return len(m.nat.limbs)*_W - int(m.leading)
412}
413
414// Nat returns m as a Nat. The return value must not be written to.
415func (m *Modulus) Nat() *Nat {
416	return m.nat
417}
418
419// shiftIn calculates x = x << _W + y mod m.
420//
421// This assumes that x is already reduced mod m.
422func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
423	d := NewNat().resetFor(m)
424
425	// Eliminate bounds checks in the loop.
426	size := len(m.nat.limbs)
427	xLimbs := x.limbs[:size]
428	dLimbs := d.limbs[:size]
429	mLimbs := m.nat.limbs[:size]
430
431	// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
432	// from y. Effectively, it left-shifts x and adds y one bit at a time,
433	// reducing it every time.
434	//
435	// To do the reduction, each iteration computes both 2x + b and 2x + b - m.
436	// The next iteration (and finally the return line) will use either result
437	// based on whether 2x + b overflows m.
438	needSubtraction := no
439	for i := _W - 1; i >= 0; i-- {
440		carry := (y >> i) & 1
441		var borrow uint
442		mask := ctMask(needSubtraction)
443		for i := 0; i < size; i++ {
444			l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i]))
445			xLimbs[i], carry = bits.Add(l, l, carry)
446			dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow)
447		}
448		// Like in maybeSubtractModulus, we need the subtraction if either it
449		// didn't underflow (meaning 2x + b > m) or if computing 2x + b
450		// overflowed (meaning 2x + b > 2^_W*n > m).
451		needSubtraction = not(choice(borrow)) | choice(carry)
452	}
453	return x.assign(needSubtraction, d)
454}
455
456// Mod calculates out = x mod m.
457//
458// This works regardless how large the value of x is.
459//
460// The output will be resized to the size of m and overwritten.
461func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
462	out.resetFor(m)
463	// Working our way from the most significant to the least significant limb,
464	// we can insert each limb at the least significant position, shifting all
465	// previous limbs left by _W. This way each limb will get shifted by the
466	// correct number of bits. We can insert at least N - 1 limbs without
467	// overflowing m. After that, we need to reduce every time we shift.
468	i := len(x.limbs) - 1
469	// For the first N - 1 limbs we can skip the actual shifting and position
470	// them at the shifted position, which starts at min(N - 2, i).
471	start := len(m.nat.limbs) - 2
472	if i < start {
473		start = i
474	}
475	for j := start; j >= 0; j-- {
476		out.limbs[j] = x.limbs[i]
477		i--
478	}
479	// We shift in the remaining limbs, reducing modulo m each time.
480	for i >= 0 {
481		out.shiftIn(x.limbs[i], m)
482		i--
483	}
484	return out
485}
486
487// ExpandFor ensures x has the right size to work with operations modulo m.
488//
489// The announced size of x must be smaller than or equal to that of m.
490func (x *Nat) ExpandFor(m *Modulus) *Nat {
491	return x.expand(len(m.nat.limbs))
492}
493
494// resetFor ensures out has the right size to work with operations modulo m.
495//
496// out is zeroed and may start at any size.
497func (out *Nat) resetFor(m *Modulus) *Nat {
498	return out.reset(len(m.nat.limbs))
499}
500
501// maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes.
502//
503// It can be used to reduce modulo m a value up to 2m - 1, which is a common
504// range for results computed by higher level operations.
505//
506// always is usually a carry that indicates that the operation that produced x
507// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
508//
509// x and m operands must have the same announced length.
510func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
511	t := NewNat().set(x)
512	underflow := t.sub(m.nat)
513	// We keep the result if x - m didn't underflow (meaning x >= m)
514	// or if always was set.
515	keep := not(choice(underflow)) | choice(always)
516	x.assign(keep, t)
517}
518
519// Sub computes x = x - y mod m.
520//
521// The length of both operands must be the same as the modulus. Both operands
522// must already be reduced modulo m.
523func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
524	underflow := x.sub(y)
525	// If the subtraction underflowed, add m.
526	t := NewNat().set(x)
527	t.add(m.nat)
528	x.assign(choice(underflow), t)
529	return x
530}
531
532// Add computes x = x + y mod m.
533//
534// The length of both operands must be the same as the modulus. Both operands
535// must already be reduced modulo m.
536func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
537	overflow := x.add(y)
538	x.maybeSubtractModulus(choice(overflow), m)
539	return x
540}
541
542// montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and
543// n = len(m.nat.limbs).
544//
545// Faster Montgomery multiplication replaces standard modular multiplication for
546// numbers in this representation.
547//
548// This assumes that x is already reduced mod m.
549func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
550	// A Montgomery multiplication (which computes a * b / R) by R * R works out
551	// to a multiplication by R, which takes the value out of the Montgomery domain.
552	return x.montgomeryMul(x, m.rr, m)
553}
554
555// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
556// n = len(m.nat.limbs).
557//
558// This assumes that x is already reduced mod m.
559func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
560	// By Montgomery multiplying with 1 not in Montgomery representation, we
561	// convert out back from Montgomery representation, because it works out to
562	// dividing by R.
563	one := NewNat().ExpandFor(m)
564	one.limbs[0] = 1
565	return x.montgomeryMul(x, one, m)
566}
567
568// montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and
569// n = len(m.nat.limbs), also known as a Montgomery multiplication.
570//
571// All inputs should be the same length and already reduced modulo m.
572// x will be resized to the size of m and overwritten.
573func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
574	n := len(m.nat.limbs)
575	mLimbs := m.nat.limbs[:n]
576	aLimbs := a.limbs[:n]
577	bLimbs := b.limbs[:n]
578
579	switch n {
580	default:
581		// Attempt to use a stack-allocated backing array.
582		T := make([]uint, 0, preallocLimbs*2)
583		if cap(T) < n*2 {
584			T = make([]uint, 0, n*2)
585		}
586		T = T[:n*2]
587
588		// This loop implements Word-by-Word Montgomery Multiplication, as
589		// described in Algorithm 4 (Fig. 3) of "Efficient Software
590		// Implementations of Modular Exponentiation" by Shay Gueron
591		// [https://eprint.iacr.org/2011/239.pdf].
592		var c uint
593		for i := 0; i < n; i++ {
594			_ = T[n+i] // bounds check elimination hint
595
596			// Step 1 (T = a × b) is computed as a large pen-and-paper column
597			// multiplication of two numbers with n base-2^_W digits. If we just
598			// wanted to produce 2n-wide T, we would do
599			//
600			//   for i := 0; i < n; i++ {
601			//       d := bLimbs[i]
602			//       T[n+i] = addMulVVW(T[i:n+i], aLimbs, d)
603			//   }
604			//
605			// where d is a digit of the multiplier, T[i:n+i] is the shifted
606			// position of the product of that digit, and T[n+i] is the final carry.
607			// Note that T[i] isn't modified after processing the i-th digit.
608			//
609			// Instead of running two loops, one for Step 1 and one for Steps 2–6,
610			// the result of Step 1 is computed during the next loop. This is
611			// possible because each iteration only uses T[i] in Step 2 and then
612			// discards it in Step 6.
613			d := bLimbs[i]
614			c1 := addMulVVW(T[i:n+i], aLimbs, d)
615
616			// Step 6 is replaced by shifting the virtual window we operate
617			// over: T of the algorithm is T[i:] for us. That means that T1 in
618			// Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv.
619			Y := T[i] * m.m0inv
620
621			// Step 4 and 5 add Y × m to T, which as mentioned above is stored
622			// at T[i:]. The two carries (from a × d and Y × m) are added up in
623			// the next word T[n+i], and the carry bit from that addition is
624			// brought forward to the next iteration.
625			c2 := addMulVVW(T[i:n+i], mLimbs, Y)
626			T[n+i], c = bits.Add(c1, c2, c)
627		}
628
629		// Finally for Step 7 we copy the final T window into x, and subtract m
630		// if necessary (which as explained in maybeSubtractModulus can be the
631		// case both if x >= m, or if x overflowed).
632		//
633		// The paper suggests in Section 4 that we can do an "Almost Montgomery
634		// Multiplication" by subtracting only in the overflow case, but the
635		// cost is very similar since the constant time subtraction tells us if
636		// x >= m as a side effect, and taking care of the broken invariant is
637		// highly undesirable (see https://go.dev/issue/13907).
638		copy(x.reset(n).limbs, T[n:])
639		x.maybeSubtractModulus(choice(c), m)
640
641	// The following specialized cases follow the exact same algorithm, but
642	// optimized for the sizes most used in RSA. addMulVVW is implemented in
643	// assembly with loop unrolling depending on the architecture and bounds
644	// checks are removed by the compiler thanks to the constant size.
645	case 1024 / _W:
646		const n = 1024 / _W // compiler hint
647		T := make([]uint, n*2)
648		var c uint
649		for i := 0; i < n; i++ {
650			d := bLimbs[i]
651			c1 := addMulVVW1024(&T[i], &aLimbs[0], d)
652			Y := T[i] * m.m0inv
653			c2 := addMulVVW1024(&T[i], &mLimbs[0], Y)
654			T[n+i], c = bits.Add(c1, c2, c)
655		}
656		copy(x.reset(n).limbs, T[n:])
657		x.maybeSubtractModulus(choice(c), m)
658
659	case 1536 / _W:
660		const n = 1536 / _W // compiler hint
661		T := make([]uint, n*2)
662		var c uint
663		for i := 0; i < n; i++ {
664			d := bLimbs[i]
665			c1 := addMulVVW1536(&T[i], &aLimbs[0], d)
666			Y := T[i] * m.m0inv
667			c2 := addMulVVW1536(&T[i], &mLimbs[0], Y)
668			T[n+i], c = bits.Add(c1, c2, c)
669		}
670		copy(x.reset(n).limbs, T[n:])
671		x.maybeSubtractModulus(choice(c), m)
672
673	case 2048 / _W:
674		const n = 2048 / _W // compiler hint
675		T := make([]uint, n*2)
676		var c uint
677		for i := 0; i < n; i++ {
678			d := bLimbs[i]
679			c1 := addMulVVW2048(&T[i], &aLimbs[0], d)
680			Y := T[i] * m.m0inv
681			c2 := addMulVVW2048(&T[i], &mLimbs[0], Y)
682			T[n+i], c = bits.Add(c1, c2, c)
683		}
684		copy(x.reset(n).limbs, T[n:])
685		x.maybeSubtractModulus(choice(c), m)
686	}
687
688	return x
689}
690
691// addMulVVW multiplies the multi-word value x by the single-word value y,
692// adding the result to the multi-word value z and returning the final carry.
693// It can be thought of as one row of a pen-and-paper column multiplication.
694func addMulVVW(z, x []uint, y uint) (carry uint) {
695	_ = x[len(z)-1] // bounds check elimination hint
696	for i := range z {
697		hi, lo := bits.Mul(x[i], y)
698		lo, c := bits.Add(lo, z[i], 0)
699		// We use bits.Add with zero to get an add-with-carry instruction that
700		// absorbs the carry from the previous bits.Add.
701		hi, _ = bits.Add(hi, 0, c)
702		lo, c = bits.Add(lo, carry, 0)
703		hi, _ = bits.Add(hi, 0, c)
704		carry = hi
705		z[i] = lo
706	}
707	return carry
708}
709
710// Mul calculates x = x * y mod m.
711//
712// The length of both operands must be the same as the modulus. Both operands
713// must already be reduced modulo m.
714func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
715	// A Montgomery multiplication by a value out of the Montgomery domain
716	// takes the result out of Montgomery representation.
717	xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
718	return x.montgomeryMul(xR, y, m)                  // x = xR * y / R mod m
719}
720
721// Exp calculates out = x^e mod m.
722//
723// The exponent e is represented in big-endian order. The output will be resized
724// to the size of m and overwritten. x must already be reduced modulo m.
725func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
726	// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
727	// than 2 bit windows, but use an extra 12 nats worth of scratch space.
728	// Using bit sizes that don't divide 8 are more complex to implement, but
729	// are likely to be more efficient if necessary.
730
731	table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
732		// newNat calls are unrolled so they are allocated on the stack.
733		NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
734		NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
735		NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
736	}
737	table[0].set(x).montgomeryRepresentation(m)
738	for i := 1; i < len(table); i++ {
739		table[i].montgomeryMul(table[i-1], table[0], m)
740	}
741
742	out.resetFor(m)
743	out.limbs[0] = 1
744	out.montgomeryRepresentation(m)
745	tmp := NewNat().ExpandFor(m)
746	for _, b := range e {
747		for _, j := range []int{4, 0} {
748			// Square four times. Optimization note: this can be implemented
749			// more efficiently than with generic Montgomery multiplication.
750			out.montgomeryMul(out, out, m)
751			out.montgomeryMul(out, out, m)
752			out.montgomeryMul(out, out, m)
753			out.montgomeryMul(out, out, m)
754
755			// Select x^k in constant time from the table.
756			k := uint((b >> j) & 0b1111)
757			for i := range table {
758				tmp.assign(ctEq(k, uint(i+1)), table[i])
759			}
760
761			// Multiply by x^k, discarding the result if k = 0.
762			tmp.montgomeryMul(out, tmp, m)
763			out.assign(not(ctEq(k, 0)), tmp)
764		}
765	}
766
767	return out.montgomeryReduction(m)
768}
769
770// ExpShortVarTime calculates out = x^e mod m.
771//
772// The output will be resized to the size of m and overwritten. x must already
773// be reduced modulo m. This leaks the exponent through timing side-channels.
774func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
775	// For short exponents, precomputing a table and using a window like in Exp
776	// doesn't pay off. Instead, we do a simple conditional square-and-multiply
777	// chain, skipping the initial run of zeroes.
778	xR := NewNat().set(x).montgomeryRepresentation(m)
779	out.set(xR)
780	for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ {
781		out.montgomeryMul(out, out, m)
782		if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
783			out.montgomeryMul(out, xR, m)
784		}
785	}
786	return out.montgomeryReduction(m)
787}
788