1// Copyright 2023 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package cmp_test 6 7import ( 8 "cmp" 9 "fmt" 10 "math" 11 "slices" 12 "sort" 13 "strings" 14 "testing" 15 "unsafe" 16) 17 18var negzero = math.Copysign(0, -1) 19var nonnilptr uintptr = uintptr(unsafe.Pointer(&negzero)) 20var nilptr uintptr = uintptr(unsafe.Pointer(nil)) 21 22var tests = []struct { 23 x, y any 24 compare int 25}{ 26 {1, 2, -1}, 27 {1, 1, 0}, 28 {2, 1, +1}, 29 {"a", "aa", -1}, 30 {"a", "a", 0}, 31 {"aa", "a", +1}, 32 {1.0, 1.1, -1}, 33 {1.1, 1.1, 0}, 34 {1.1, 1.0, +1}, 35 {math.Inf(1), math.Inf(1), 0}, 36 {math.Inf(-1), math.Inf(-1), 0}, 37 {math.Inf(-1), 1.0, -1}, 38 {1.0, math.Inf(-1), +1}, 39 {math.Inf(1), 1.0, +1}, 40 {1.0, math.Inf(1), -1}, 41 {math.NaN(), math.NaN(), 0}, 42 {0.0, math.NaN(), +1}, 43 {math.NaN(), 0.0, -1}, 44 {math.NaN(), math.Inf(-1), -1}, 45 {math.Inf(-1), math.NaN(), +1}, 46 {0.0, 0.0, 0}, 47 {negzero, negzero, 0}, 48 {negzero, 0.0, 0}, 49 {0.0, negzero, 0}, 50 {negzero, 1.0, -1}, 51 {negzero, -1.0, +1}, 52 {nilptr, nonnilptr, -1}, 53 {nonnilptr, nilptr, 1}, 54 {nonnilptr, nonnilptr, 0}, 55} 56 57func TestLess(t *testing.T) { 58 for _, test := range tests { 59 var b bool 60 switch test.x.(type) { 61 case int: 62 b = cmp.Less(test.x.(int), test.y.(int)) 63 case string: 64 b = cmp.Less(test.x.(string), test.y.(string)) 65 case float64: 66 b = cmp.Less(test.x.(float64), test.y.(float64)) 67 case uintptr: 68 b = cmp.Less(test.x.(uintptr), test.y.(uintptr)) 69 } 70 if b != (test.compare < 0) { 71 t.Errorf("Less(%v, %v) == %t, want %t", test.x, test.y, b, test.compare < 0) 72 } 73 } 74} 75 76func TestCompare(t *testing.T) { 77 for _, test := range tests { 78 var c int 79 switch test.x.(type) { 80 case int: 81 c = cmp.Compare(test.x.(int), test.y.(int)) 82 case string: 83 c = cmp.Compare(test.x.(string), test.y.(string)) 84 case float64: 85 c = cmp.Compare(test.x.(float64), test.y.(float64)) 86 case uintptr: 87 c = cmp.Compare(test.x.(uintptr), test.y.(uintptr)) 88 } 89 if c != test.compare { 90 t.Errorf("Compare(%v, %v) == %d, want %d", test.x, test.y, c, test.compare) 91 } 92 } 93} 94 95func TestSort(t *testing.T) { 96 // Test that our comparison function is consistent with 97 // sort.Float64s. 98 input := []float64{1.0, 0.0, negzero, math.Inf(1), math.Inf(-1), math.NaN()} 99 sort.Float64s(input) 100 for i := 0; i < len(input)-1; i++ { 101 if cmp.Less(input[i+1], input[i]) { 102 t.Errorf("Less sort mismatch at %d in %v", i, input) 103 } 104 if cmp.Compare(input[i], input[i+1]) > 0 { 105 t.Errorf("Compare sort mismatch at %d in %v", i, input) 106 } 107 } 108} 109 110func TestOr(t *testing.T) { 111 cases := []struct { 112 in []int 113 want int 114 }{ 115 {nil, 0}, 116 {[]int{0}, 0}, 117 {[]int{1}, 1}, 118 {[]int{0, 2}, 2}, 119 {[]int{3, 0}, 3}, 120 {[]int{4, 5}, 4}, 121 {[]int{0, 6, 7}, 6}, 122 } 123 for _, tc := range cases { 124 if got := cmp.Or(tc.in...); got != tc.want { 125 t.Errorf("cmp.Or(%v) = %v; want %v", tc.in, got, tc.want) 126 } 127 } 128} 129 130func ExampleOr() { 131 // Suppose we have some user input 132 // that may or may not be an empty string 133 userInput1 := "" 134 userInput2 := "some text" 135 136 fmt.Println(cmp.Or(userInput1, "default")) 137 fmt.Println(cmp.Or(userInput2, "default")) 138 fmt.Println(cmp.Or(userInput1, userInput2, "default")) 139 // Output: 140 // default 141 // some text 142 // some text 143} 144 145func ExampleOr_sort() { 146 type Order struct { 147 Product string 148 Customer string 149 Price float64 150 } 151 orders := []Order{ 152 {"foo", "alice", 1.00}, 153 {"bar", "bob", 3.00}, 154 {"baz", "carol", 4.00}, 155 {"foo", "alice", 2.00}, 156 {"bar", "carol", 1.00}, 157 {"foo", "bob", 4.00}, 158 } 159 // Sort by customer first, product second, and last by higher price 160 slices.SortFunc(orders, func(a, b Order) int { 161 return cmp.Or( 162 strings.Compare(a.Customer, b.Customer), 163 strings.Compare(a.Product, b.Product), 164 cmp.Compare(b.Price, a.Price), 165 ) 166 }) 167 for _, order := range orders { 168 fmt.Printf("%s %s %.2f\n", order.Product, order.Customer, order.Price) 169 } 170 171 // Output: 172 // foo alice 2.00 173 // foo alice 1.00 174 // bar bob 3.00 175 // foo bob 4.00 176 // bar carol 1.00 177 // baz carol 4.00 178} 179