1// Copyright 2009 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package subtle 6 7import ( 8 "testing" 9 "testing/quick" 10) 11 12type TestConstantTimeCompareStruct struct { 13 a, b []byte 14 out int 15} 16 17var testConstantTimeCompareData = []TestConstantTimeCompareStruct{ 18 {[]byte{}, []byte{}, 1}, 19 {[]byte{0x11}, []byte{0x11}, 1}, 20 {[]byte{0x12}, []byte{0x11}, 0}, 21 {[]byte{0x11}, []byte{0x11, 0x12}, 0}, 22 {[]byte{0x11, 0x12}, []byte{0x11}, 0}, 23} 24 25func TestConstantTimeCompare(t *testing.T) { 26 for i, test := range testConstantTimeCompareData { 27 if r := ConstantTimeCompare(test.a, test.b); r != test.out { 28 t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out) 29 } 30 } 31} 32 33type TestConstantTimeByteEqStruct struct { 34 a, b uint8 35 out int 36} 37 38var testConstandTimeByteEqData = []TestConstantTimeByteEqStruct{ 39 {0, 0, 1}, 40 {0, 1, 0}, 41 {1, 0, 0}, 42 {0xff, 0xff, 1}, 43 {0xff, 0xfe, 0}, 44} 45 46func byteEq(a, b uint8) int { 47 if a == b { 48 return 1 49 } 50 return 0 51} 52 53func TestConstantTimeByteEq(t *testing.T) { 54 for i, test := range testConstandTimeByteEqData { 55 if r := ConstantTimeByteEq(test.a, test.b); r != test.out { 56 t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out) 57 } 58 } 59 err := quick.CheckEqual(ConstantTimeByteEq, byteEq, nil) 60 if err != nil { 61 t.Error(err) 62 } 63} 64 65func eq(a, b int32) int { 66 if a == b { 67 return 1 68 } 69 return 0 70} 71 72func TestConstantTimeEq(t *testing.T) { 73 err := quick.CheckEqual(ConstantTimeEq, eq, nil) 74 if err != nil { 75 t.Error(err) 76 } 77} 78 79func makeCopy(v int, x, y []byte) []byte { 80 if len(x) > len(y) { 81 x = x[0:len(y)] 82 } else { 83 y = y[0:len(x)] 84 } 85 if v == 1 { 86 copy(x, y) 87 } 88 return x 89} 90 91func constantTimeCopyWrapper(v int, x, y []byte) []byte { 92 if len(x) > len(y) { 93 x = x[0:len(y)] 94 } else { 95 y = y[0:len(x)] 96 } 97 v &= 1 98 ConstantTimeCopy(v, x, y) 99 return x 100} 101 102func TestConstantTimeCopy(t *testing.T) { 103 err := quick.CheckEqual(constantTimeCopyWrapper, makeCopy, nil) 104 if err != nil { 105 t.Error(err) 106 } 107} 108 109var lessOrEqTests = []struct { 110 x, y, result int 111}{ 112 {0, 0, 1}, 113 {1, 0, 0}, 114 {0, 1, 1}, 115 {10, 20, 1}, 116 {20, 10, 0}, 117 {10, 10, 1}, 118} 119 120func TestConstantTimeLessOrEq(t *testing.T) { 121 for i, test := range lessOrEqTests { 122 result := ConstantTimeLessOrEq(test.x, test.y) 123 if result != test.result { 124 t.Errorf("#%d: %d <= %d gave %d, expected %d", i, test.x, test.y, result, test.result) 125 } 126 } 127} 128 129var benchmarkGlobal uint8 130 131func BenchmarkConstantTimeByteEq(b *testing.B) { 132 var x, y uint8 133 134 for i := 0; i < b.N; i++ { 135 x, y = uint8(ConstantTimeByteEq(x, y)), x 136 } 137 138 benchmarkGlobal = x 139} 140 141func BenchmarkConstantTimeEq(b *testing.B) { 142 var x, y int 143 144 for i := 0; i < b.N; i++ { 145 x, y = ConstantTimeEq(int32(x), int32(y)), x 146 } 147 148 benchmarkGlobal = uint8(x) 149} 150 151func BenchmarkConstantTimeLessOrEq(b *testing.B) { 152 var x, y int 153 154 for i := 0; i < b.N; i++ { 155 x, y = ConstantTimeLessOrEq(x, y), x 156 } 157 158 benchmarkGlobal = uint8(x) 159} 160