1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sql
6
7import (
8	"database/sql/driver"
9	"fmt"
10	"reflect"
11	"runtime"
12	"strings"
13	"sync"
14	"testing"
15	"time"
16)
17
18var someTime = time.Unix(123, 0)
19var answer int64 = 42
20
21type (
22	userDefined       float64
23	userDefinedSlice  []int
24	userDefinedString string
25)
26
27type conversionTest struct {
28	s, d any // source and destination
29
30	// following are used if they're non-zero
31	wantint    int64
32	wantuint   uint64
33	wantstr    string
34	wantbytes  []byte
35	wantraw    RawBytes
36	wantf32    float32
37	wantf64    float64
38	wanttime   time.Time
39	wantbool   bool // used if d is of type *bool
40	wanterr    string
41	wantiface  any
42	wantptr    *int64 // if non-nil, *d's pointed value must be equal to *wantptr
43	wantnil    bool   // if true, *d must be *int64(nil)
44	wantusrdef userDefined
45	wantusrstr userDefinedString
46}
47
48// Target variables for scanning into.
49var (
50	scanstr    string
51	scanbytes  []byte
52	scanraw    RawBytes
53	scanint    int
54	scanuint8  uint8
55	scanuint16 uint16
56	scanbool   bool
57	scanf32    float32
58	scanf64    float64
59	scantime   time.Time
60	scanptr    *int64
61	scaniface  any
62)
63
64func conversionTests() []conversionTest {
65	// Return a fresh instance to test so "go test -count 2" works correctly.
66	return []conversionTest{
67		// Exact conversions (destination pointer type matches source type)
68		{s: "foo", d: &scanstr, wantstr: "foo"},
69		{s: 123, d: &scanint, wantint: 123},
70		{s: someTime, d: &scantime, wanttime: someTime},
71
72		// To strings
73		{s: "string", d: &scanstr, wantstr: "string"},
74		{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
75		{s: 123, d: &scanstr, wantstr: "123"},
76		{s: int8(123), d: &scanstr, wantstr: "123"},
77		{s: int64(123), d: &scanstr, wantstr: "123"},
78		{s: uint8(123), d: &scanstr, wantstr: "123"},
79		{s: uint16(123), d: &scanstr, wantstr: "123"},
80		{s: uint32(123), d: &scanstr, wantstr: "123"},
81		{s: uint64(123), d: &scanstr, wantstr: "123"},
82		{s: 1.5, d: &scanstr, wantstr: "1.5"},
83
84		// From time.Time:
85		{s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"},
86		{s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"},
87		{s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"},
88		{s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"},
89		{s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")},
90		{s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()},
91
92		// To []byte
93		{s: nil, d: &scanbytes, wantbytes: nil},
94		{s: "string", d: &scanbytes, wantbytes: []byte("string")},
95		{s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")},
96		{s: 123, d: &scanbytes, wantbytes: []byte("123")},
97		{s: int8(123), d: &scanbytes, wantbytes: []byte("123")},
98		{s: int64(123), d: &scanbytes, wantbytes: []byte("123")},
99		{s: uint8(123), d: &scanbytes, wantbytes: []byte("123")},
100		{s: uint16(123), d: &scanbytes, wantbytes: []byte("123")},
101		{s: uint32(123), d: &scanbytes, wantbytes: []byte("123")},
102		{s: uint64(123), d: &scanbytes, wantbytes: []byte("123")},
103		{s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")},
104
105		// To RawBytes
106		{s: nil, d: &scanraw, wantraw: nil},
107		{s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
108		{s: "string", d: &scanraw, wantraw: RawBytes("string")},
109		{s: 123, d: &scanraw, wantraw: RawBytes("123")},
110		{s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
111		{s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
112		{s: uint8(123), d: &scanraw, wantraw: RawBytes("123")},
113		{s: uint16(123), d: &scanraw, wantraw: RawBytes("123")},
114		{s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
115		{s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
116		{s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
117		// time.Time has been placed here to check that the RawBytes slice gets
118		// correctly reset when calling time.Time.AppendFormat.
119		{s: time.Unix(2, 5).UTC(), d: &scanraw, wantraw: RawBytes("1970-01-01T00:00:02.000000005Z")},
120
121		// Strings to integers
122		{s: "255", d: &scanuint8, wantuint: 255},
123		{s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"},
124		{s: "256", d: &scanuint16, wantuint: 256},
125		{s: "-1", d: &scanint, wantint: -1},
126		{s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"},
127
128		// int64 to smaller integers
129		{s: int64(5), d: &scanuint8, wantuint: 5},
130		{s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"},
131		{s: int64(256), d: &scanuint16, wantuint: 256},
132		{s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"},
133
134		// True bools
135		{s: true, d: &scanbool, wantbool: true},
136		{s: "True", d: &scanbool, wantbool: true},
137		{s: "TRUE", d: &scanbool, wantbool: true},
138		{s: "1", d: &scanbool, wantbool: true},
139		{s: 1, d: &scanbool, wantbool: true},
140		{s: int64(1), d: &scanbool, wantbool: true},
141		{s: uint16(1), d: &scanbool, wantbool: true},
142
143		// False bools
144		{s: false, d: &scanbool, wantbool: false},
145		{s: "false", d: &scanbool, wantbool: false},
146		{s: "FALSE", d: &scanbool, wantbool: false},
147		{s: "0", d: &scanbool, wantbool: false},
148		{s: 0, d: &scanbool, wantbool: false},
149		{s: int64(0), d: &scanbool, wantbool: false},
150		{s: uint16(0), d: &scanbool, wantbool: false},
151
152		// Not bools
153		{s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`},
154		{s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`},
155
156		// Floats
157		{s: float64(1.5), d: &scanf64, wantf64: float64(1.5)},
158		{s: int64(1), d: &scanf64, wantf64: float64(1)},
159		{s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
160		{s: "1.5", d: &scanf32, wantf32: float32(1.5)},
161		{s: "1.5", d: &scanf64, wantf64: float64(1.5)},
162
163		// Pointers
164		{s: any(nil), d: &scanptr, wantnil: true},
165		{s: int64(42), d: &scanptr, wantptr: &answer},
166
167		// To interface{}
168		{s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
169		{s: int64(1), d: &scaniface, wantiface: int64(1)},
170		{s: "str", d: &scaniface, wantiface: "str"},
171		{s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
172		{s: true, d: &scaniface, wantiface: true},
173		{s: nil, d: &scaniface},
174		{s: []byte(nil), d: &scaniface, wantiface: []byte(nil)},
175
176		// To a user-defined type
177		{s: 1.5, d: new(userDefined), wantusrdef: 1.5},
178		{s: int64(123), d: new(userDefined), wantusrdef: 123},
179		{s: "1.5", d: new(userDefined), wantusrdef: 1.5},
180		{s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`},
181		{s: "str", d: new(userDefinedString), wantusrstr: "str"},
182
183		// Other errors
184		{s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`},
185	}
186}
187
188func intPtrValue(intptr any) any {
189	return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int()
190}
191
192func intValue(intptr any) int64 {
193	return reflect.Indirect(reflect.ValueOf(intptr)).Int()
194}
195
196func uintValue(intptr any) uint64 {
197	return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
198}
199
200func float64Value(ptr any) float64 {
201	return *(ptr.(*float64))
202}
203
204func float32Value(ptr any) float32 {
205	return *(ptr.(*float32))
206}
207
208func timeValue(ptr any) time.Time {
209	return *(ptr.(*time.Time))
210}
211
212func TestConversions(t *testing.T) {
213	for n, ct := range conversionTests() {
214		err := convertAssign(ct.d, ct.s)
215		errstr := ""
216		if err != nil {
217			errstr = err.Error()
218		}
219		errf := func(format string, args ...any) {
220			base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d)
221			t.Errorf(base+format, args...)
222		}
223		if errstr != ct.wanterr {
224			errf("got error %q, want error %q", errstr, ct.wanterr)
225		}
226		if ct.wantstr != "" && ct.wantstr != scanstr {
227			errf("want string %q, got %q", ct.wantstr, scanstr)
228		}
229		if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) {
230			errf("want byte %q, got %q", ct.wantbytes, scanbytes)
231		}
232		if ct.wantraw != nil && string(ct.wantraw) != string(scanraw) {
233			errf("want RawBytes %q, got %q", ct.wantraw, scanraw)
234		}
235		if ct.wantint != 0 && ct.wantint != intValue(ct.d) {
236			errf("want int %d, got %d", ct.wantint, intValue(ct.d))
237		}
238		if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
239			errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
240		}
241		if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) {
242			errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d))
243		}
244		if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) {
245			errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d))
246		}
247		if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
248			errf("want bool %v, got %v", ct.wantbool, *bp)
249		}
250		if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
251			errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
252		}
253		if ct.wantnil && *ct.d.(**int64) != nil {
254			errf("want nil, got %v", intPtrValue(ct.d))
255		}
256		if ct.wantptr != nil {
257			if *ct.d.(**int64) == nil {
258				errf("want pointer to %v, got nil", *ct.wantptr)
259			} else if *ct.wantptr != intPtrValue(ct.d) {
260				errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d))
261			}
262		}
263		if ifptr, ok := ct.d.(*any); ok {
264			if !reflect.DeepEqual(ct.wantiface, scaniface) {
265				errf("want interface %#v, got %#v", ct.wantiface, scaniface)
266				continue
267			}
268			if srcBytes, ok := ct.s.([]byte); ok {
269				dstBytes := (*ifptr).([]byte)
270				if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] {
271					errf("copy into interface{} didn't copy []byte data")
272				}
273			}
274		}
275		if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) {
276			errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined))
277		}
278		if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) {
279			errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString))
280		}
281	}
282}
283
284func TestNullString(t *testing.T) {
285	var ns NullString
286	convertAssign(&ns, []byte("foo"))
287	if !ns.Valid {
288		t.Errorf("expecting not null")
289	}
290	if ns.String != "foo" {
291		t.Errorf("expecting foo; got %q", ns.String)
292	}
293	convertAssign(&ns, nil)
294	if ns.Valid {
295		t.Errorf("expecting null on nil")
296	}
297	if ns.String != "" {
298		t.Errorf("expecting blank on nil; got %q", ns.String)
299	}
300}
301
302type valueConverterTest struct {
303	c       driver.ValueConverter
304	in, out any
305	err     string
306}
307
308var valueConverterTests = []valueConverterTest{
309	{driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""},
310	{driver.DefaultParameterConverter, NullString{"", false}, nil, ""},
311}
312
313func TestValueConverters(t *testing.T) {
314	for i, tt := range valueConverterTests {
315		out, err := tt.c.ConvertValue(tt.in)
316		goterr := ""
317		if err != nil {
318			goterr = err.Error()
319		}
320		if goterr != tt.err {
321			t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q",
322				i, tt.c, tt.in, tt.in, goterr, tt.err)
323		}
324		if tt.err != "" {
325			continue
326		}
327		if !reflect.DeepEqual(out, tt.out) {
328			t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)",
329				i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
330		}
331	}
332}
333
334// Tests that assigning to RawBytes doesn't allocate (and also works).
335func TestRawBytesAllocs(t *testing.T) {
336	var tests = []struct {
337		name string
338		in   any
339		want string
340	}{
341		{"uint64", uint64(12345678), "12345678"},
342		{"uint32", uint32(1234), "1234"},
343		{"uint16", uint16(12), "12"},
344		{"uint8", uint8(1), "1"},
345		{"uint", uint(123), "123"},
346		{"int", int(123), "123"},
347		{"int8", int8(1), "1"},
348		{"int16", int16(12), "12"},
349		{"int32", int32(1234), "1234"},
350		{"int64", int64(12345678), "12345678"},
351		{"float32", float32(1.5), "1.5"},
352		{"float64", float64(64), "64"},
353		{"bool", false, "false"},
354		{"time", time.Unix(2, 5).UTC(), "1970-01-01T00:00:02.000000005Z"},
355	}
356
357	var buf RawBytes
358	rows := &Rows{}
359	test := func(name string, in any, want string) {
360		if err := convertAssignRows(&buf, in, rows); err != nil {
361			t.Fatalf("%s: convertAssign = %v", name, err)
362		}
363		match := len(buf) == len(want)
364		if match {
365			for i, b := range buf {
366				if want[i] != b {
367					match = false
368					break
369				}
370			}
371		}
372		if !match {
373			t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want))
374		}
375	}
376
377	n := testing.AllocsPerRun(100, func() {
378		for _, tt := range tests {
379			rows.raw = rows.raw[:0]
380			test(tt.name, tt.in, tt.want)
381		}
382	})
383
384	// The numbers below are only valid for 64-bit interface word sizes,
385	// and gc. With 32-bit words there are more convT2E allocs, and
386	// with gccgo, only pointers currently go in interface data.
387	// So only care on amd64 gc for now.
388	measureAllocs := false
389	switch runtime.GOARCH {
390	case "amd64", "arm64":
391		measureAllocs = runtime.Compiler == "gc"
392	}
393
394	if n > 0.5 && measureAllocs {
395		t.Fatalf("allocs = %v; want 0", n)
396	}
397
398	// This one involves a convT2E allocation, string -> interface{}
399	n = testing.AllocsPerRun(100, func() {
400		test("string", "foo", "foo")
401	})
402	if n > 1.5 && measureAllocs {
403		t.Fatalf("allocs = %v; want max 1", n)
404	}
405}
406
407// https://golang.org/issues/13905
408func TestUserDefinedBytes(t *testing.T) {
409	type userDefinedBytes []byte
410	var u userDefinedBytes
411	v := []byte("foo")
412
413	convertAssign(&u, v)
414	if &u[0] == &v[0] {
415		t.Fatal("userDefinedBytes got potentially dirty driver memory")
416	}
417}
418
419type Valuer_V string
420
421func (v Valuer_V) Value() (driver.Value, error) {
422	return strings.ToUpper(string(v)), nil
423}
424
425type Valuer_P string
426
427func (p *Valuer_P) Value() (driver.Value, error) {
428	if p == nil {
429		return "nil-to-str", nil
430	}
431	return strings.ToUpper(string(*p)), nil
432}
433
434func TestDriverArgs(t *testing.T) {
435	var nilValuerVPtr *Valuer_V
436	var nilValuerPPtr *Valuer_P
437	var nilStrPtr *string
438	tests := []struct {
439		args []any
440		want []driver.NamedValue
441	}{
442		0: {
443			args: []any{Valuer_V("foo")},
444			want: []driver.NamedValue{
445				{
446					Ordinal: 1,
447					Value:   "FOO",
448				},
449			},
450		},
451		1: {
452			args: []any{nilValuerVPtr},
453			want: []driver.NamedValue{
454				{
455					Ordinal: 1,
456					Value:   nil,
457				},
458			},
459		},
460		2: {
461			args: []any{nilValuerPPtr},
462			want: []driver.NamedValue{
463				{
464					Ordinal: 1,
465					Value:   "nil-to-str",
466				},
467			},
468		},
469		3: {
470			args: []any{"plain-str"},
471			want: []driver.NamedValue{
472				{
473					Ordinal: 1,
474					Value:   "plain-str",
475				},
476			},
477		},
478		4: {
479			args: []any{nilStrPtr},
480			want: []driver.NamedValue{
481				{
482					Ordinal: 1,
483					Value:   nil,
484				},
485			},
486		},
487	}
488	for i, tt := range tests {
489		ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
490		got, err := driverArgsConnLocked(nil, ds, tt.args)
491		if err != nil {
492			t.Errorf("test[%d]: %v", i, err)
493			continue
494		}
495		if !reflect.DeepEqual(got, tt.want) {
496			t.Errorf("test[%d]: got %v, want %v", i, got, tt.want)
497		}
498	}
499}
500
501type dec struct {
502	form        byte
503	neg         bool
504	coefficient [16]byte
505	exponent    int32
506}
507
508func (d dec) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) {
509	coef := make([]byte, 16)
510	copy(coef, d.coefficient[:])
511	return d.form, d.neg, coef, d.exponent
512}
513
514func (d *dec) Compose(form byte, negative bool, coefficient []byte, exponent int32) error {
515	switch form {
516	default:
517		return fmt.Errorf("unknown form %d", form)
518	case 1, 2:
519		d.form = form
520		d.neg = negative
521		return nil
522	case 0:
523	}
524	d.form = form
525	d.neg = negative
526	d.exponent = exponent
527
528	// This isn't strictly correct, as the extra bytes could be all zero,
529	// ignore this for this test.
530	if len(coefficient) > 16 {
531		return fmt.Errorf("coefficient too large")
532	}
533	copy(d.coefficient[:], coefficient)
534
535	return nil
536}
537
538type decFinite struct {
539	neg         bool
540	coefficient [16]byte
541	exponent    int32
542}
543
544func (d decFinite) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) {
545	coef := make([]byte, 16)
546	copy(coef, d.coefficient[:])
547	return 0, d.neg, coef, d.exponent
548}
549
550func (d *decFinite) Compose(form byte, negative bool, coefficient []byte, exponent int32) error {
551	switch form {
552	default:
553		return fmt.Errorf("unknown form %d", form)
554	case 1, 2:
555		return fmt.Errorf("unsupported form %d", form)
556	case 0:
557	}
558	d.neg = negative
559	d.exponent = exponent
560
561	// This isn't strictly correct, as the extra bytes could be all zero,
562	// ignore this for this test.
563	if len(coefficient) > 16 {
564		return fmt.Errorf("coefficient too large")
565	}
566	copy(d.coefficient[:], coefficient)
567
568	return nil
569}
570
571func TestDecimal(t *testing.T) {
572	list := []struct {
573		name string
574		in   decimalDecompose
575		out  dec
576		err  bool
577	}{
578		{name: "same", in: dec{exponent: -6}, out: dec{exponent: -6}},
579
580		// Ensure reflection is not used to assign the value by using different types.
581		{name: "diff", in: decFinite{exponent: -6}, out: dec{exponent: -6}},
582
583		{name: "bad-form", in: dec{form: 200}, err: true},
584	}
585	for _, item := range list {
586		t.Run(item.name, func(t *testing.T) {
587			out := dec{}
588			err := convertAssign(&out, item.in)
589			if item.err {
590				if err == nil {
591					t.Fatalf("unexpected nil error")
592				}
593				return
594			}
595			if err != nil {
596				t.Fatalf("unexpected error: %v", err)
597			}
598			if !reflect.DeepEqual(out, item.out) {
599				t.Fatalf("got %#v want %#v", out, item.out)
600			}
601		})
602	}
603}
604