1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cmp_test
6
7import (
8	"cmp"
9	"fmt"
10	"math"
11	"slices"
12	"sort"
13	"strings"
14	"testing"
15	"unsafe"
16)
17
18var negzero = math.Copysign(0, -1)
19var nonnilptr uintptr = uintptr(unsafe.Pointer(&negzero))
20var nilptr uintptr = uintptr(unsafe.Pointer(nil))
21
22var tests = []struct {
23	x, y    any
24	compare int
25}{
26	{1, 2, -1},
27	{1, 1, 0},
28	{2, 1, +1},
29	{"a", "aa", -1},
30	{"a", "a", 0},
31	{"aa", "a", +1},
32	{1.0, 1.1, -1},
33	{1.1, 1.1, 0},
34	{1.1, 1.0, +1},
35	{math.Inf(1), math.Inf(1), 0},
36	{math.Inf(-1), math.Inf(-1), 0},
37	{math.Inf(-1), 1.0, -1},
38	{1.0, math.Inf(-1), +1},
39	{math.Inf(1), 1.0, +1},
40	{1.0, math.Inf(1), -1},
41	{math.NaN(), math.NaN(), 0},
42	{0.0, math.NaN(), +1},
43	{math.NaN(), 0.0, -1},
44	{math.NaN(), math.Inf(-1), -1},
45	{math.Inf(-1), math.NaN(), +1},
46	{0.0, 0.0, 0},
47	{negzero, negzero, 0},
48	{negzero, 0.0, 0},
49	{0.0, negzero, 0},
50	{negzero, 1.0, -1},
51	{negzero, -1.0, +1},
52	{nilptr, nonnilptr, -1},
53	{nonnilptr, nilptr, 1},
54	{nonnilptr, nonnilptr, 0},
55}
56
57func TestLess(t *testing.T) {
58	for _, test := range tests {
59		var b bool
60		switch test.x.(type) {
61		case int:
62			b = cmp.Less(test.x.(int), test.y.(int))
63		case string:
64			b = cmp.Less(test.x.(string), test.y.(string))
65		case float64:
66			b = cmp.Less(test.x.(float64), test.y.(float64))
67		case uintptr:
68			b = cmp.Less(test.x.(uintptr), test.y.(uintptr))
69		}
70		if b != (test.compare < 0) {
71			t.Errorf("Less(%v, %v) == %t, want %t", test.x, test.y, b, test.compare < 0)
72		}
73	}
74}
75
76func TestCompare(t *testing.T) {
77	for _, test := range tests {
78		var c int
79		switch test.x.(type) {
80		case int:
81			c = cmp.Compare(test.x.(int), test.y.(int))
82		case string:
83			c = cmp.Compare(test.x.(string), test.y.(string))
84		case float64:
85			c = cmp.Compare(test.x.(float64), test.y.(float64))
86		case uintptr:
87			c = cmp.Compare(test.x.(uintptr), test.y.(uintptr))
88		}
89		if c != test.compare {
90			t.Errorf("Compare(%v, %v) == %d, want %d", test.x, test.y, c, test.compare)
91		}
92	}
93}
94
95func TestSort(t *testing.T) {
96	// Test that our comparison function is consistent with
97	// sort.Float64s.
98	input := []float64{1.0, 0.0, negzero, math.Inf(1), math.Inf(-1), math.NaN()}
99	sort.Float64s(input)
100	for i := 0; i < len(input)-1; i++ {
101		if cmp.Less(input[i+1], input[i]) {
102			t.Errorf("Less sort mismatch at %d in %v", i, input)
103		}
104		if cmp.Compare(input[i], input[i+1]) > 0 {
105			t.Errorf("Compare sort mismatch at %d in %v", i, input)
106		}
107	}
108}
109
110func TestOr(t *testing.T) {
111	cases := []struct {
112		in   []int
113		want int
114	}{
115		{nil, 0},
116		{[]int{0}, 0},
117		{[]int{1}, 1},
118		{[]int{0, 2}, 2},
119		{[]int{3, 0}, 3},
120		{[]int{4, 5}, 4},
121		{[]int{0, 6, 7}, 6},
122	}
123	for _, tc := range cases {
124		if got := cmp.Or(tc.in...); got != tc.want {
125			t.Errorf("cmp.Or(%v) = %v; want %v", tc.in, got, tc.want)
126		}
127	}
128}
129
130func ExampleOr() {
131	// Suppose we have some user input
132	// that may or may not be an empty string
133	userInput1 := ""
134	userInput2 := "some text"
135
136	fmt.Println(cmp.Or(userInput1, "default"))
137	fmt.Println(cmp.Or(userInput2, "default"))
138	fmt.Println(cmp.Or(userInput1, userInput2, "default"))
139	// Output:
140	// default
141	// some text
142	// some text
143}
144
145func ExampleOr_sort() {
146	type Order struct {
147		Product  string
148		Customer string
149		Price    float64
150	}
151	orders := []Order{
152		{"foo", "alice", 1.00},
153		{"bar", "bob", 3.00},
154		{"baz", "carol", 4.00},
155		{"foo", "alice", 2.00},
156		{"bar", "carol", 1.00},
157		{"foo", "bob", 4.00},
158	}
159	// Sort by customer first, product second, and last by higher price
160	slices.SortFunc(orders, func(a, b Order) int {
161		return cmp.Or(
162			strings.Compare(a.Customer, b.Customer),
163			strings.Compare(a.Product, b.Product),
164			cmp.Compare(b.Price, a.Price),
165		)
166	})
167	for _, order := range orders {
168		fmt.Printf("%s %s %.2f\n", order.Product, order.Customer, order.Price)
169	}
170
171	// Output:
172	// foo alice 2.00
173	// foo alice 1.00
174	// bar bob 3.00
175	// foo bob 4.00
176	// bar carol 1.00
177	// baz carol 4.00
178}
179