1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package subtle
6
7import (
8	"testing"
9	"testing/quick"
10)
11
12type TestConstantTimeCompareStruct struct {
13	a, b []byte
14	out  int
15}
16
17var testConstantTimeCompareData = []TestConstantTimeCompareStruct{
18	{[]byte{}, []byte{}, 1},
19	{[]byte{0x11}, []byte{0x11}, 1},
20	{[]byte{0x12}, []byte{0x11}, 0},
21	{[]byte{0x11}, []byte{0x11, 0x12}, 0},
22	{[]byte{0x11, 0x12}, []byte{0x11}, 0},
23}
24
25func TestConstantTimeCompare(t *testing.T) {
26	for i, test := range testConstantTimeCompareData {
27		if r := ConstantTimeCompare(test.a, test.b); r != test.out {
28			t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out)
29		}
30	}
31}
32
33type TestConstantTimeByteEqStruct struct {
34	a, b uint8
35	out  int
36}
37
38var testConstandTimeByteEqData = []TestConstantTimeByteEqStruct{
39	{0, 0, 1},
40	{0, 1, 0},
41	{1, 0, 0},
42	{0xff, 0xff, 1},
43	{0xff, 0xfe, 0},
44}
45
46func byteEq(a, b uint8) int {
47	if a == b {
48		return 1
49	}
50	return 0
51}
52
53func TestConstantTimeByteEq(t *testing.T) {
54	for i, test := range testConstandTimeByteEqData {
55		if r := ConstantTimeByteEq(test.a, test.b); r != test.out {
56			t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out)
57		}
58	}
59	err := quick.CheckEqual(ConstantTimeByteEq, byteEq, nil)
60	if err != nil {
61		t.Error(err)
62	}
63}
64
65func eq(a, b int32) int {
66	if a == b {
67		return 1
68	}
69	return 0
70}
71
72func TestConstantTimeEq(t *testing.T) {
73	err := quick.CheckEqual(ConstantTimeEq, eq, nil)
74	if err != nil {
75		t.Error(err)
76	}
77}
78
79func makeCopy(v int, x, y []byte) []byte {
80	if len(x) > len(y) {
81		x = x[0:len(y)]
82	} else {
83		y = y[0:len(x)]
84	}
85	if v == 1 {
86		copy(x, y)
87	}
88	return x
89}
90
91func constantTimeCopyWrapper(v int, x, y []byte) []byte {
92	if len(x) > len(y) {
93		x = x[0:len(y)]
94	} else {
95		y = y[0:len(x)]
96	}
97	v &= 1
98	ConstantTimeCopy(v, x, y)
99	return x
100}
101
102func TestConstantTimeCopy(t *testing.T) {
103	err := quick.CheckEqual(constantTimeCopyWrapper, makeCopy, nil)
104	if err != nil {
105		t.Error(err)
106	}
107}
108
109var lessOrEqTests = []struct {
110	x, y, result int
111}{
112	{0, 0, 1},
113	{1, 0, 0},
114	{0, 1, 1},
115	{10, 20, 1},
116	{20, 10, 0},
117	{10, 10, 1},
118}
119
120func TestConstantTimeLessOrEq(t *testing.T) {
121	for i, test := range lessOrEqTests {
122		result := ConstantTimeLessOrEq(test.x, test.y)
123		if result != test.result {
124			t.Errorf("#%d: %d <= %d gave %d, expected %d", i, test.x, test.y, result, test.result)
125		}
126	}
127}
128
129var benchmarkGlobal uint8
130
131func BenchmarkConstantTimeByteEq(b *testing.B) {
132	var x, y uint8
133
134	for i := 0; i < b.N; i++ {
135		x, y = uint8(ConstantTimeByteEq(x, y)), x
136	}
137
138	benchmarkGlobal = x
139}
140
141func BenchmarkConstantTimeEq(b *testing.B) {
142	var x, y int
143
144	for i := 0; i < b.N; i++ {
145		x, y = ConstantTimeEq(int32(x), int32(y)), x
146	}
147
148	benchmarkGlobal = uint8(x)
149}
150
151func BenchmarkConstantTimeLessOrEq(b *testing.B) {
152	var x, y int
153
154	for i := 0; i < b.N; i++ {
155		x, y = ConstantTimeLessOrEq(x, y), x
156	}
157
158	benchmarkGlobal = uint8(x)
159}
160