1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package syscall_test
6
7import (
8	"fmt"
9	"slices"
10	"syscall"
11	"testing"
12	"unicode/utf16"
13	"unicode/utf8"
14	"unsafe"
15)
16
17var wtf8tests = []struct {
18	str  string
19	wstr []uint16
20}{
21	{
22		str:  "\x00",
23		wstr: []uint16{0x00},
24	},
25	{
26		str:  "\x5C",
27		wstr: []uint16{0x5C},
28	},
29	{
30		str:  "\x7F",
31		wstr: []uint16{0x7F},
32	},
33
34	// 2-byte
35	{
36		str:  "\xC2\x80",
37		wstr: []uint16{0x80},
38	},
39	{
40		str:  "\xD7\x8A",
41		wstr: []uint16{0x05CA},
42	},
43	{
44		str:  "\xDF\xBF",
45		wstr: []uint16{0x07FF},
46	},
47
48	// 3-byte
49	{
50		str:  "\xE0\xA0\x80",
51		wstr: []uint16{0x0800},
52	},
53	{
54		str:  "\xE2\xB0\xBC",
55		wstr: []uint16{0x2C3C},
56	},
57	{
58		str:  "\xEF\xBF\xBF",
59		wstr: []uint16{0xFFFF},
60	},
61	// unmatched surrogate halves
62	// high surrogates: 0xD800 to 0xDBFF
63	{
64		str:  "\xED\xA0\x80",
65		wstr: []uint16{0xD800},
66	},
67	{
68		// "High surrogate followed by another high surrogate"
69		str:  "\xED\xA0\x80\xED\xA0\x80",
70		wstr: []uint16{0xD800, 0xD800},
71	},
72	{
73		// "High surrogate followed by a symbol that is not a surrogate"
74		str:  string([]byte{0xED, 0xA0, 0x80, 0xA}),
75		wstr: []uint16{0xD800, 0xA},
76	},
77	{
78		// "Unmatched high surrogate, followed by a surrogate pair, followed by an unmatched high surrogate"
79		str:  string([]byte{0xED, 0xA0, 0x80, 0xF0, 0x9D, 0x8C, 0x86, 0xED, 0xA0, 0x80}),
80		wstr: []uint16{0xD800, 0xD834, 0xDF06, 0xD800},
81	},
82	{
83		str:  "\xED\xA6\xAF",
84		wstr: []uint16{0xD9AF},
85	},
86	{
87		str:  "\xED\xAF\xBF",
88		wstr: []uint16{0xDBFF},
89	},
90	// low surrogates: 0xDC00 to 0xDFFF
91	{
92		str:  "\xED\xB0\x80",
93		wstr: []uint16{0xDC00},
94	},
95	{
96		// "Low surrogate followed by another low surrogate"
97		str:  "\xED\xB0\x80\xED\xB0\x80",
98		wstr: []uint16{0xDC00, 0xDC00},
99	},
100	{
101		// "Low surrogate followed by a symbol that is not a surrogate"
102		str:  string([]byte{0xED, 0xB0, 0x80, 0xA}),
103		wstr: []uint16{0xDC00, 0xA},
104	},
105	{
106		// "Unmatched low surrogate, followed by a surrogate pair, followed by an unmatched low surrogate"
107		str:  string([]byte{0xED, 0xB0, 0x80, 0xF0, 0x9D, 0x8C, 0x86, 0xED, 0xB0, 0x80}),
108		wstr: []uint16{0xDC00, 0xD834, 0xDF06, 0xDC00},
109	},
110	{
111		str:  "\xED\xBB\xAE",
112		wstr: []uint16{0xDEEE},
113	},
114	{
115		str:  "\xED\xBF\xBF",
116		wstr: []uint16{0xDFFF},
117	},
118
119	// 4-byte
120	{
121		str:  "\xF0\x90\x80\x80",
122		wstr: []uint16{0xD800, 0xDC00},
123	},
124	{
125		str:  "\xF0\x9D\x8C\x86",
126		wstr: []uint16{0xD834, 0xDF06},
127	},
128	{
129		str:  "\xF4\x8F\xBF\xBF",
130		wstr: []uint16{0xDBFF, 0xDFFF},
131	},
132}
133
134func TestWTF16Rountrip(t *testing.T) {
135	for _, tt := range wtf8tests {
136		t.Run(fmt.Sprintf("%X", tt.str), func(t *testing.T) {
137			got := syscall.EncodeWTF16(tt.str, nil)
138			got2 := string(syscall.DecodeWTF16(got, nil))
139			if got2 != tt.str {
140				t.Errorf("got:\n%s\nwant:\n%s", got2, tt.str)
141			}
142		})
143	}
144}
145
146func TestWTF16Golden(t *testing.T) {
147	for _, tt := range wtf8tests {
148		t.Run(fmt.Sprintf("%X", tt.str), func(t *testing.T) {
149			got := syscall.EncodeWTF16(tt.str, nil)
150			if !slices.Equal(got, tt.wstr) {
151				t.Errorf("got:\n%v\nwant:\n%v", got, tt.wstr)
152			}
153		})
154	}
155}
156
157func FuzzEncodeWTF16(f *testing.F) {
158	for _, tt := range wtf8tests {
159		f.Add(tt.str)
160	}
161	f.Fuzz(func(t *testing.T, b string) {
162		// test that there are no panics
163		got := syscall.EncodeWTF16(b, nil)
164		syscall.DecodeWTF16(got, nil)
165		if utf8.ValidString(b) {
166			// if the input is a valid UTF-8 string, then
167			// test that syscall.EncodeWTF16 behaves as
168			// utf16.Encode
169			want := utf16.Encode([]rune(b))
170			if !slices.Equal(got, want) {
171				t.Errorf("got:\n%v\nwant:\n%v", got, want)
172			}
173		}
174	})
175}
176
177func FuzzDecodeWTF16(f *testing.F) {
178	for _, tt := range wtf8tests {
179		b := unsafe.Slice((*uint8)(unsafe.Pointer(unsafe.SliceData(tt.wstr))), len(tt.wstr)*2)
180		f.Add(b)
181	}
182	f.Fuzz(func(t *testing.T, b []byte) {
183		u16 := unsafe.Slice((*uint16)(unsafe.Pointer(unsafe.SliceData(b))), len(b)/2)
184		got := syscall.DecodeWTF16(u16, nil)
185		if utf8.Valid(got) {
186			// if the input is a valid UTF-8 string, then
187			// test that syscall.DecodeWTF16 behaves as
188			// utf16.Decode
189			want := utf16.Decode(u16)
190			if string(got) != string(want) {
191				t.Errorf("got:\n%s\nwant:\n%s", string(got), string(want))
192			}
193		}
194		// WTF-8 should always roundtrip
195		got2 := syscall.EncodeWTF16(string(got), nil)
196		if !slices.Equal(got2, u16) {
197			t.Errorf("got:\n%v\nwant:\n%v", got2, u16)
198		}
199	})
200}
201