1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sql
6
7import (
8	"context"
9	"database/sql/driver"
10	"errors"
11	"fmt"
12	"io"
13	"reflect"
14	"slices"
15	"strconv"
16	"strings"
17	"sync"
18	"sync/atomic"
19	"testing"
20	"time"
21)
22
23// fakeDriver is a fake database that implements Go's driver.Driver
24// interface, just for testing.
25//
26// It speaks a query language that's semantically similar to but
27// syntactically different and simpler than SQL.  The syntax is as
28// follows:
29//
30//	WIPE
31//	CREATE|<tablename>|<col>=<type>,<col>=<type>,...
32//	  where types are: "string", [u]int{8,16,32,64}, "bool"
33//	INSERT|<tablename>|col=val,col2=val2,col3=?
34//	SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
35//	SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
36//
37// Any of these can be preceded by PANIC|<method>|, to cause the
38// named method on fakeStmt to panic.
39//
40// Any of these can be proceeded by WAIT|<duration>|, to cause the
41// named method on fakeStmt to sleep for the specified duration.
42//
43// Multiple of these can be combined when separated with a semicolon.
44//
45// When opening a fakeDriver's database, it starts empty with no
46// tables. All tables and data are stored in memory only.
47type fakeDriver struct {
48	mu         sync.Mutex // guards 3 following fields
49	openCount  int        // conn opens
50	closeCount int        // conn closes
51	waitCh     chan struct{}
52	waitingCh  chan struct{}
53	dbs        map[string]*fakeDB
54}
55
56type fakeConnector struct {
57	name string
58
59	waiter func(context.Context)
60	closed bool
61}
62
63func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
64	conn, err := fdriver.Open(c.name)
65	conn.(*fakeConn).waiter = c.waiter
66	return conn, err
67}
68
69func (c *fakeConnector) Driver() driver.Driver {
70	return fdriver
71}
72
73func (c *fakeConnector) Close() error {
74	if c.closed {
75		return errors.New("fakedb: connector is closed")
76	}
77	c.closed = true
78	return nil
79}
80
81type fakeDriverCtx struct {
82	fakeDriver
83}
84
85var _ driver.DriverContext = &fakeDriverCtx{}
86
87func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
88	return &fakeConnector{name: name}, nil
89}
90
91type fakeDB struct {
92	name string
93
94	useRawBytes atomic.Bool
95
96	mu       sync.Mutex
97	tables   map[string]*table
98	badConn  bool
99	allowAny bool
100}
101
102type fakeError struct {
103	Message string
104	Wrapped error
105}
106
107func (err fakeError) Error() string {
108	return err.Message
109}
110
111func (err fakeError) Unwrap() error {
112	return err.Wrapped
113}
114
115type table struct {
116	mu      sync.Mutex
117	colname []string
118	coltype []string
119	rows    []*row
120}
121
122func (t *table) columnIndex(name string) int {
123	return slices.Index(t.colname, name)
124}
125
126type row struct {
127	cols []any // must be same size as its table colname + coltype
128}
129
130type memToucher interface {
131	// touchMem reads & writes some memory, to help find data races.
132	touchMem()
133}
134
135type fakeConn struct {
136	db *fakeDB // where to return ourselves to
137
138	currTx *fakeTx
139
140	// Every operation writes to line to enable the race detector
141	// check for data races.
142	line int64
143
144	// Stats for tests:
145	mu          sync.Mutex
146	stmtsMade   int
147	stmtsClosed int
148	numPrepare  int
149
150	// bad connection tests; see isBad()
151	bad       bool
152	stickyBad bool
153
154	skipDirtySession bool // tests that use Conn should set this to true.
155
156	// dirtySession tests ResetSession, true if a query has executed
157	// until ResetSession is called.
158	dirtySession bool
159
160	// The waiter is called before each query. May be used in place of the "WAIT"
161	// directive.
162	waiter func(context.Context)
163}
164
165func (c *fakeConn) touchMem() {
166	c.line++
167}
168
169func (c *fakeConn) incrStat(v *int) {
170	c.mu.Lock()
171	*v++
172	c.mu.Unlock()
173}
174
175type fakeTx struct {
176	c *fakeConn
177}
178
179type boundCol struct {
180	Column      string
181	Placeholder string
182	Ordinal     int
183}
184
185type fakeStmt struct {
186	memToucher
187	c *fakeConn
188	q string // just for debugging
189
190	cmd   string
191	table string
192	panic string
193	wait  time.Duration
194
195	next *fakeStmt // used for returning multiple results.
196
197	closed bool
198
199	colName      []string // used by CREATE, INSERT, SELECT (selected columns)
200	colType      []string // used by CREATE
201	colValue     []any    // used by INSERT (mix of strings and "?" for bound params)
202	placeholders int      // used by INSERT/SELECT: number of ? params
203
204	whereCol []boundCol // used by SELECT (all placeholders)
205
206	placeholderConverter []driver.ValueConverter // used by INSERT
207}
208
209var fdriver driver.Driver = &fakeDriver{}
210
211func init() {
212	Register("test", fdriver)
213}
214
215type Dummy struct {
216	driver.Driver
217}
218
219func TestDrivers(t *testing.T) {
220	unregisterAllDrivers()
221	Register("test", fdriver)
222	Register("invalid", Dummy{})
223	all := Drivers()
224	if len(all) < 2 || !slices.IsSorted(all) || !slices.Contains(all, "test") || !slices.Contains(all, "invalid") {
225		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
226	}
227}
228
229// hook to simulate connection failures
230var hookOpenErr struct {
231	sync.Mutex
232	fn func() error
233}
234
235func setHookOpenErr(fn func() error) {
236	hookOpenErr.Lock()
237	defer hookOpenErr.Unlock()
238	hookOpenErr.fn = fn
239}
240
241// Supports dsn forms:
242//
243//	<dbname>
244//	<dbname>;<opts>  (only currently supported option is `badConn`,
245//	                  which causes driver.ErrBadConn to be returned on
246//	                  every other conn.Begin())
247func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
248	hookOpenErr.Lock()
249	fn := hookOpenErr.fn
250	hookOpenErr.Unlock()
251	if fn != nil {
252		if err := fn(); err != nil {
253			return nil, err
254		}
255	}
256	parts := strings.Split(dsn, ";")
257	if len(parts) < 1 {
258		return nil, errors.New("fakedb: no database name")
259	}
260	name := parts[0]
261
262	db := d.getDB(name)
263
264	d.mu.Lock()
265	d.openCount++
266	d.mu.Unlock()
267	conn := &fakeConn{db: db}
268
269	if len(parts) >= 2 && parts[1] == "badConn" {
270		conn.bad = true
271	}
272	if d.waitCh != nil {
273		d.waitingCh <- struct{}{}
274		<-d.waitCh
275		d.waitCh = nil
276		d.waitingCh = nil
277	}
278	return conn, nil
279}
280
281func (d *fakeDriver) getDB(name string) *fakeDB {
282	d.mu.Lock()
283	defer d.mu.Unlock()
284	if d.dbs == nil {
285		d.dbs = make(map[string]*fakeDB)
286	}
287	db, ok := d.dbs[name]
288	if !ok {
289		db = &fakeDB{name: name}
290		d.dbs[name] = db
291	}
292	return db
293}
294
295func (db *fakeDB) wipe() {
296	db.mu.Lock()
297	defer db.mu.Unlock()
298	db.tables = nil
299}
300
301func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
302	db.mu.Lock()
303	defer db.mu.Unlock()
304	if db.tables == nil {
305		db.tables = make(map[string]*table)
306	}
307	if _, exist := db.tables[name]; exist {
308		return fmt.Errorf("fakedb: table %q already exists", name)
309	}
310	if len(columnNames) != len(columnTypes) {
311		return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
312			name, len(columnNames), len(columnTypes))
313	}
314	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
315	return nil
316}
317
318// must be called with db.mu lock held
319func (db *fakeDB) table(table string) (*table, bool) {
320	if db.tables == nil {
321		return nil, false
322	}
323	t, ok := db.tables[table]
324	return t, ok
325}
326
327func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
328	db.mu.Lock()
329	defer db.mu.Unlock()
330	t, ok := db.table(table)
331	if !ok {
332		return
333	}
334	if i := slices.Index(t.colname, column); i != -1 {
335		return t.coltype[i], true
336	}
337	return "", false
338}
339
340func (c *fakeConn) isBad() bool {
341	if c.stickyBad {
342		return true
343	} else if c.bad {
344		if c.db == nil {
345			return false
346		}
347		// alternate between bad conn and not bad conn
348		c.db.badConn = !c.db.badConn
349		return c.db.badConn
350	} else {
351		return false
352	}
353}
354
355func (c *fakeConn) isDirtyAndMark() bool {
356	if c.skipDirtySession {
357		return false
358	}
359	if c.currTx != nil {
360		c.dirtySession = true
361		return false
362	}
363	if c.dirtySession {
364		return true
365	}
366	c.dirtySession = true
367	return false
368}
369
370func (c *fakeConn) Begin() (driver.Tx, error) {
371	if c.isBad() {
372		return nil, fakeError{Wrapped: driver.ErrBadConn}
373	}
374	if c.currTx != nil {
375		return nil, errors.New("fakedb: already in a transaction")
376	}
377	c.touchMem()
378	c.currTx = &fakeTx{c: c}
379	return c.currTx, nil
380}
381
382var hookPostCloseConn struct {
383	sync.Mutex
384	fn func(*fakeConn, error)
385}
386
387func setHookpostCloseConn(fn func(*fakeConn, error)) {
388	hookPostCloseConn.Lock()
389	defer hookPostCloseConn.Unlock()
390	hookPostCloseConn.fn = fn
391}
392
393var testStrictClose *testing.T
394
395// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
396// fails to close. If nil, the check is disabled.
397func setStrictFakeConnClose(t *testing.T) {
398	testStrictClose = t
399}
400
401func (c *fakeConn) ResetSession(ctx context.Context) error {
402	c.dirtySession = false
403	c.currTx = nil
404	if c.isBad() {
405		return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
406	}
407	return nil
408}
409
410var _ driver.Validator = (*fakeConn)(nil)
411
412func (c *fakeConn) IsValid() bool {
413	return !c.isBad()
414}
415
416func (c *fakeConn) Close() (err error) {
417	drv := fdriver.(*fakeDriver)
418	defer func() {
419		if err != nil && testStrictClose != nil {
420			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
421		}
422		hookPostCloseConn.Lock()
423		fn := hookPostCloseConn.fn
424		hookPostCloseConn.Unlock()
425		if fn != nil {
426			fn(c, err)
427		}
428		if err == nil {
429			drv.mu.Lock()
430			drv.closeCount++
431			drv.mu.Unlock()
432		}
433	}()
434	c.touchMem()
435	if c.currTx != nil {
436		return errors.New("fakedb: can't close fakeConn; in a Transaction")
437	}
438	if c.db == nil {
439		return errors.New("fakedb: can't close fakeConn; already closed")
440	}
441	if c.stmtsMade > c.stmtsClosed {
442		return errors.New("fakedb: can't close; dangling statement(s)")
443	}
444	c.db = nil
445	return nil
446}
447
448func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
449	for _, arg := range args {
450		switch arg.Value.(type) {
451		case int64, float64, bool, nil, []byte, string, time.Time:
452		default:
453			if !allowAny {
454				return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
455			}
456		}
457	}
458	return nil
459}
460
461func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
462	// Ensure that ExecContext is called if available.
463	panic("ExecContext was not called.")
464}
465
466func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
467	// This is an optional interface, but it's implemented here
468	// just to check that all the args are of the proper types.
469	// ErrSkip is returned so the caller acts as if we didn't
470	// implement this at all.
471	err := checkSubsetTypes(c.db.allowAny, args)
472	if err != nil {
473		return nil, err
474	}
475	return nil, driver.ErrSkip
476}
477
478func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
479	// Ensure that ExecContext is called if available.
480	panic("QueryContext was not called.")
481}
482
483func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
484	// This is an optional interface, but it's implemented here
485	// just to check that all the args are of the proper types.
486	// ErrSkip is returned so the caller acts as if we didn't
487	// implement this at all.
488	err := checkSubsetTypes(c.db.allowAny, args)
489	if err != nil {
490		return nil, err
491	}
492	return nil, driver.ErrSkip
493}
494
495func errf(msg string, args ...any) error {
496	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
497}
498
499// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
500// (note that where columns must always contain ? marks,
501// just a limitation for fakedb)
502func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
503	if len(parts) != 3 {
504		stmt.Close()
505		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
506	}
507	stmt.table = parts[0]
508
509	stmt.colName = strings.Split(parts[1], ",")
510	for n, colspec := range strings.Split(parts[2], ",") {
511		if colspec == "" {
512			continue
513		}
514		nameVal := strings.Split(colspec, "=")
515		if len(nameVal) != 2 {
516			stmt.Close()
517			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
518		}
519		column, value := nameVal[0], nameVal[1]
520		_, ok := c.db.columnType(stmt.table, column)
521		if !ok {
522			stmt.Close()
523			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
524		}
525		if !strings.HasPrefix(value, "?") {
526			stmt.Close()
527			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
528				stmt.table, column)
529		}
530		stmt.placeholders++
531		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
532	}
533	return stmt, nil
534}
535
536// parts are table|col=type,col2=type2
537func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
538	if len(parts) != 2 {
539		stmt.Close()
540		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
541	}
542	stmt.table = parts[0]
543	for n, colspec := range strings.Split(parts[1], ",") {
544		nameType := strings.Split(colspec, "=")
545		if len(nameType) != 2 {
546			stmt.Close()
547			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
548		}
549		stmt.colName = append(stmt.colName, nameType[0])
550		stmt.colType = append(stmt.colType, nameType[1])
551	}
552	return stmt, nil
553}
554
555// parts are table|col=?,col2=val
556func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
557	if len(parts) != 2 {
558		stmt.Close()
559		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
560	}
561	stmt.table = parts[0]
562	for n, colspec := range strings.Split(parts[1], ",") {
563		nameVal := strings.Split(colspec, "=")
564		if len(nameVal) != 2 {
565			stmt.Close()
566			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
567		}
568		column, value := nameVal[0], nameVal[1]
569		ctype, ok := c.db.columnType(stmt.table, column)
570		if !ok {
571			stmt.Close()
572			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
573		}
574		stmt.colName = append(stmt.colName, column)
575
576		if !strings.HasPrefix(value, "?") {
577			var subsetVal any
578			// Convert to driver subset type
579			switch ctype {
580			case "string":
581				subsetVal = []byte(value)
582			case "blob":
583				subsetVal = []byte(value)
584			case "int32":
585				i, err := strconv.Atoi(value)
586				if err != nil {
587					stmt.Close()
588					return nil, errf("invalid conversion to int32 from %q", value)
589				}
590				subsetVal = int64(i) // int64 is a subset type, but not int32
591			case "table": // For testing cursor reads.
592				c.skipDirtySession = true
593				vparts := strings.Split(value, "!")
594
595				substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
596				if err != nil {
597					return nil, err
598				}
599				cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
600				substmt.Close()
601				if err != nil {
602					return nil, err
603				}
604				subsetVal = cursor
605			default:
606				stmt.Close()
607				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
608			}
609			stmt.colValue = append(stmt.colValue, subsetVal)
610		} else {
611			stmt.placeholders++
612			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
613			stmt.colValue = append(stmt.colValue, value)
614		}
615	}
616	return stmt, nil
617}
618
619// hook to simulate broken connections
620var hookPrepareBadConn func() bool
621
622func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
623	panic("use PrepareContext")
624}
625
626func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
627	c.numPrepare++
628	if c.db == nil {
629		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
630	}
631
632	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
633		return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn}
634	}
635
636	c.touchMem()
637	var firstStmt, prev *fakeStmt
638	for _, query := range strings.Split(query, ";") {
639		parts := strings.Split(query, "|")
640		if len(parts) < 1 {
641			return nil, errf("empty query")
642		}
643		stmt := &fakeStmt{q: query, c: c, memToucher: c}
644		if firstStmt == nil {
645			firstStmt = stmt
646		}
647		if len(parts) >= 3 {
648			switch parts[0] {
649			case "PANIC":
650				stmt.panic = parts[1]
651				parts = parts[2:]
652			case "WAIT":
653				wait, err := time.ParseDuration(parts[1])
654				if err != nil {
655					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
656				}
657				parts = parts[2:]
658				stmt.wait = wait
659			}
660		}
661		cmd := parts[0]
662		stmt.cmd = cmd
663		parts = parts[1:]
664
665		if c.waiter != nil {
666			c.waiter(ctx)
667			if err := ctx.Err(); err != nil {
668				return nil, err
669			}
670		}
671
672		if stmt.wait > 0 {
673			wait := time.NewTimer(stmt.wait)
674			select {
675			case <-wait.C:
676			case <-ctx.Done():
677				wait.Stop()
678				return nil, ctx.Err()
679			}
680		}
681
682		c.incrStat(&c.stmtsMade)
683		var err error
684		switch cmd {
685		case "WIPE":
686			// Nothing
687		case "USE_RAWBYTES":
688			c.db.useRawBytes.Store(true)
689		case "SELECT":
690			stmt, err = c.prepareSelect(stmt, parts)
691		case "CREATE":
692			stmt, err = c.prepareCreate(stmt, parts)
693		case "INSERT":
694			stmt, err = c.prepareInsert(ctx, stmt, parts)
695		case "NOSERT":
696			// Do all the prep-work like for an INSERT but don't actually insert the row.
697			// Used for some of the concurrent tests.
698			stmt, err = c.prepareInsert(ctx, stmt, parts)
699		default:
700			stmt.Close()
701			return nil, errf("unsupported command type %q", cmd)
702		}
703		if err != nil {
704			return nil, err
705		}
706		if prev != nil {
707			prev.next = stmt
708		}
709		prev = stmt
710	}
711	return firstStmt, nil
712}
713
714func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
715	if s.panic == "ColumnConverter" {
716		panic(s.panic)
717	}
718	if len(s.placeholderConverter) == 0 {
719		return driver.DefaultParameterConverter
720	}
721	return s.placeholderConverter[idx]
722}
723
724func (s *fakeStmt) Close() error {
725	if s.panic == "Close" {
726		panic(s.panic)
727	}
728	if s.c == nil {
729		panic("nil conn in fakeStmt.Close")
730	}
731	if s.c.db == nil {
732		panic("in fakeStmt.Close, conn's db is nil (already closed)")
733	}
734	s.touchMem()
735	if !s.closed {
736		s.c.incrStat(&s.c.stmtsClosed)
737		s.closed = true
738	}
739	if s.next != nil {
740		s.next.Close()
741	}
742	return nil
743}
744
745var errClosed = errors.New("fakedb: statement has been closed")
746
747// hook to simulate broken connections
748var hookExecBadConn func() bool
749
750func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
751	panic("Using ExecContext")
752}
753
754var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
755
756func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
757	if s.panic == "Exec" {
758		panic(s.panic)
759	}
760	if s.closed {
761		return nil, errClosed
762	}
763
764	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
765		return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
766	}
767	if s.c.isDirtyAndMark() {
768		return nil, errFakeConnSessionDirty
769	}
770
771	err := checkSubsetTypes(s.c.db.allowAny, args)
772	if err != nil {
773		return nil, err
774	}
775	s.touchMem()
776
777	if s.wait > 0 {
778		time.Sleep(s.wait)
779	}
780
781	select {
782	default:
783	case <-ctx.Done():
784		return nil, ctx.Err()
785	}
786
787	db := s.c.db
788	switch s.cmd {
789	case "WIPE":
790		db.wipe()
791		return driver.ResultNoRows, nil
792	case "USE_RAWBYTES":
793		s.c.db.useRawBytes.Store(true)
794		return driver.ResultNoRows, nil
795	case "CREATE":
796		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
797			return nil, err
798		}
799		return driver.ResultNoRows, nil
800	case "INSERT":
801		return s.execInsert(args, true)
802	case "NOSERT":
803		// Do all the prep-work like for an INSERT but don't actually insert the row.
804		// Used for some of the concurrent tests.
805		return s.execInsert(args, false)
806	}
807	return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
808}
809
810func valueFromPlaceholderName(args []driver.NamedValue, name string) driver.Value {
811	for i := range args {
812		if args[i].Name == name {
813			return args[i].Value
814		}
815	}
816	return nil
817}
818
819// When doInsert is true, add the row to the table.
820// When doInsert is false do prep-work and error checking, but don't
821// actually add the row to the table.
822func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
823	db := s.c.db
824	if len(args) != s.placeholders {
825		panic("error in pkg db; should only get here if size is correct")
826	}
827	db.mu.Lock()
828	t, ok := db.table(s.table)
829	db.mu.Unlock()
830	if !ok {
831		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
832	}
833
834	t.mu.Lock()
835	defer t.mu.Unlock()
836
837	var cols []any
838	if doInsert {
839		cols = make([]any, len(t.colname))
840	}
841	argPos := 0
842	for n, colname := range s.colName {
843		colidx := t.columnIndex(colname)
844		if colidx == -1 {
845			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
846		}
847		var val any
848		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
849			if strvalue == "?" {
850				val = args[argPos].Value
851			} else {
852				// Assign value from argument placeholder name.
853				if v := valueFromPlaceholderName(args, strvalue[1:]); v != nil {
854					val = v
855				}
856			}
857			argPos++
858		} else {
859			val = s.colValue[n]
860		}
861		if doInsert {
862			cols[colidx] = val
863		}
864	}
865
866	if doInsert {
867		t.rows = append(t.rows, &row{cols: cols})
868	}
869	return driver.RowsAffected(1), nil
870}
871
872// hook to simulate broken connections
873var hookQueryBadConn func() bool
874
875func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
876	panic("Use QueryContext")
877}
878
879func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
880	if s.panic == "Query" {
881		panic(s.panic)
882	}
883	if s.closed {
884		return nil, errClosed
885	}
886
887	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
888		return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
889	}
890	if s.c.isDirtyAndMark() {
891		return nil, errFakeConnSessionDirty
892	}
893
894	err := checkSubsetTypes(s.c.db.allowAny, args)
895	if err != nil {
896		return nil, err
897	}
898
899	s.touchMem()
900	db := s.c.db
901	if len(args) != s.placeholders {
902		panic("error in pkg db; should only get here if size is correct")
903	}
904
905	setMRows := make([][]*row, 0, 1)
906	setColumns := make([][]string, 0, 1)
907	setColType := make([][]string, 0, 1)
908
909	for {
910		db.mu.Lock()
911		t, ok := db.table(s.table)
912		db.mu.Unlock()
913		if !ok {
914			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
915		}
916
917		if s.table == "magicquery" {
918			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
919				if args[0].Value == "sleep" {
920					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
921				}
922			}
923		}
924		if s.table == "tx_status" && s.colName[0] == "tx_status" {
925			txStatus := "autocommit"
926			if s.c.currTx != nil {
927				txStatus = "transaction"
928			}
929			cursor := &rowsCursor{
930				db:        s.c.db,
931				parentMem: s.c,
932				posRow:    -1,
933				rows: [][]*row{
934					{
935						{
936							cols: []any{
937								txStatus,
938							},
939						},
940					},
941				},
942				cols: [][]string{
943					{
944						"tx_status",
945					},
946				},
947				colType: [][]string{
948					{
949						"string",
950					},
951				},
952				errPos: -1,
953			}
954			return cursor, nil
955		}
956
957		t.mu.Lock()
958
959		colIdx := make(map[string]int) // select column name -> column index in table
960		for _, name := range s.colName {
961			idx := t.columnIndex(name)
962			if idx == -1 {
963				t.mu.Unlock()
964				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
965			}
966			colIdx[name] = idx
967		}
968
969		mrows := []*row{}
970	rows:
971		for _, trow := range t.rows {
972			// Process the where clause, skipping non-match rows. This is lazy
973			// and just uses fmt.Sprintf("%v") to test equality. Good enough
974			// for test code.
975			for _, wcol := range s.whereCol {
976				idx := t.columnIndex(wcol.Column)
977				if idx == -1 {
978					t.mu.Unlock()
979					return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
980				}
981				tcol := trow.cols[idx]
982				if bs, ok := tcol.([]byte); ok {
983					// lazy hack to avoid sprintf %v on a []byte
984					tcol = string(bs)
985				}
986				var argValue any
987				if wcol.Placeholder == "?" {
988					argValue = args[wcol.Ordinal-1].Value
989				} else {
990					if v := valueFromPlaceholderName(args, wcol.Placeholder[1:]); v != nil {
991						argValue = v
992					}
993				}
994				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
995					continue rows
996				}
997			}
998			mrow := &row{cols: make([]any, len(s.colName))}
999			for seli, name := range s.colName {
1000				mrow.cols[seli] = trow.cols[colIdx[name]]
1001			}
1002			mrows = append(mrows, mrow)
1003		}
1004
1005		var colType []string
1006		for _, column := range s.colName {
1007			colType = append(colType, t.coltype[t.columnIndex(column)])
1008		}
1009
1010		t.mu.Unlock()
1011
1012		setMRows = append(setMRows, mrows)
1013		setColumns = append(setColumns, s.colName)
1014		setColType = append(setColType, colType)
1015
1016		if s.next == nil {
1017			break
1018		}
1019		s = s.next
1020	}
1021
1022	cursor := &rowsCursor{
1023		db:        s.c.db,
1024		parentMem: s.c,
1025		posRow:    -1,
1026		rows:      setMRows,
1027		cols:      setColumns,
1028		colType:   setColType,
1029		errPos:    -1,
1030	}
1031	return cursor, nil
1032}
1033
1034func (s *fakeStmt) NumInput() int {
1035	if s.panic == "NumInput" {
1036		panic(s.panic)
1037	}
1038	return s.placeholders
1039}
1040
1041// hook to simulate broken connections
1042var hookCommitBadConn func() bool
1043
1044func (tx *fakeTx) Commit() error {
1045	tx.c.currTx = nil
1046	if hookCommitBadConn != nil && hookCommitBadConn() {
1047		return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1048	}
1049	tx.c.touchMem()
1050	return nil
1051}
1052
1053// hook to simulate broken connections
1054var hookRollbackBadConn func() bool
1055
1056func (tx *fakeTx) Rollback() error {
1057	tx.c.currTx = nil
1058	if hookRollbackBadConn != nil && hookRollbackBadConn() {
1059		return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1060	}
1061	tx.c.touchMem()
1062	return nil
1063}
1064
1065type rowsCursor struct {
1066	db        *fakeDB
1067	parentMem memToucher
1068	cols      [][]string
1069	colType   [][]string
1070	posSet    int
1071	posRow    int
1072	rows      [][]*row
1073	closed    bool
1074
1075	// errPos and err are for making Next return early with error.
1076	errPos int
1077	err    error
1078
1079	// a clone of slices to give out to clients, indexed by the
1080	// original slice's first byte address.  we clone them
1081	// just so we're able to corrupt them on close.
1082	bytesClone map[*byte][]byte
1083
1084	// Every operation writes to line to enable the race detector
1085	// check for data races.
1086	// This is separate from the fakeConn.line to allow for drivers that
1087	// can start multiple queries on the same transaction at the same time.
1088	line int64
1089
1090	// closeErr is returned when rowsCursor.Close
1091	closeErr error
1092}
1093
1094func (rc *rowsCursor) touchMem() {
1095	rc.parentMem.touchMem()
1096	rc.line++
1097}
1098
1099func (rc *rowsCursor) Close() error {
1100	rc.touchMem()
1101	rc.parentMem.touchMem()
1102	rc.closed = true
1103	return rc.closeErr
1104}
1105
1106func (rc *rowsCursor) Columns() []string {
1107	return rc.cols[rc.posSet]
1108}
1109
1110func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1111	return colTypeToReflectType(rc.colType[rc.posSet][index])
1112}
1113
1114var rowsCursorNextHook func(dest []driver.Value) error
1115
1116func (rc *rowsCursor) Next(dest []driver.Value) error {
1117	if rowsCursorNextHook != nil {
1118		return rowsCursorNextHook(dest)
1119	}
1120
1121	if rc.closed {
1122		return errors.New("fakedb: cursor is closed")
1123	}
1124	rc.touchMem()
1125	rc.posRow++
1126	if rc.posRow == rc.errPos {
1127		return rc.err
1128	}
1129	if rc.posRow >= len(rc.rows[rc.posSet]) {
1130		return io.EOF // per interface spec
1131	}
1132	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1133		// TODO(bradfitz): convert to subset types? naah, I
1134		// think the subset types should only be input to
1135		// driver, but the sql package should be able to handle
1136		// a wider range of types coming out of drivers. all
1137		// for ease of drivers, and to prevent drivers from
1138		// messing up conversions or doing them differently.
1139		dest[i] = v
1140
1141		if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
1142			if rc.bytesClone == nil {
1143				rc.bytesClone = make(map[*byte][]byte)
1144			}
1145			clone, ok := rc.bytesClone[&bs[0]]
1146			if !ok {
1147				clone = make([]byte, len(bs))
1148				copy(clone, bs)
1149				rc.bytesClone[&bs[0]] = clone
1150			}
1151			dest[i] = clone
1152		}
1153	}
1154	return nil
1155}
1156
1157func (rc *rowsCursor) HasNextResultSet() bool {
1158	rc.touchMem()
1159	return rc.posSet < len(rc.rows)-1
1160}
1161
1162func (rc *rowsCursor) NextResultSet() error {
1163	rc.touchMem()
1164	if rc.HasNextResultSet() {
1165		rc.posSet++
1166		rc.posRow = -1
1167		return nil
1168	}
1169	return io.EOF // Per interface spec.
1170}
1171
1172// fakeDriverString is like driver.String, but indirects pointers like
1173// DefaultValueConverter.
1174//
1175// This could be surprising behavior to retroactively apply to
1176// driver.String now that Go1 is out, but this is convenient for
1177// our TestPointerParamsAndScans.
1178type fakeDriverString struct{}
1179
1180func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
1181	switch c := v.(type) {
1182	case string, []byte:
1183		return v, nil
1184	case *string:
1185		if c == nil {
1186			return nil, nil
1187		}
1188		return *c, nil
1189	}
1190	return fmt.Sprintf("%v", v), nil
1191}
1192
1193type anyTypeConverter struct{}
1194
1195func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
1196	return v, nil
1197}
1198
1199func converterForType(typ string) driver.ValueConverter {
1200	switch typ {
1201	case "bool":
1202		return driver.Bool
1203	case "nullbool":
1204		return driver.Null{Converter: driver.Bool}
1205	case "byte", "int16":
1206		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1207	case "int32":
1208		return driver.Int32
1209	case "nullbyte", "nullint32", "nullint16":
1210		return driver.Null{Converter: driver.DefaultParameterConverter}
1211	case "string":
1212		return driver.NotNull{Converter: fakeDriverString{}}
1213	case "nullstring":
1214		return driver.Null{Converter: fakeDriverString{}}
1215	case "int64":
1216		// TODO(coopernurse): add type-specific converter
1217		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1218	case "nullint64":
1219		// TODO(coopernurse): add type-specific converter
1220		return driver.Null{Converter: driver.DefaultParameterConverter}
1221	case "float64":
1222		// TODO(coopernurse): add type-specific converter
1223		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1224	case "nullfloat64":
1225		// TODO(coopernurse): add type-specific converter
1226		return driver.Null{Converter: driver.DefaultParameterConverter}
1227	case "datetime":
1228		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1229	case "nulldatetime":
1230		return driver.Null{Converter: driver.DefaultParameterConverter}
1231	case "any":
1232		return anyTypeConverter{}
1233	}
1234	panic("invalid fakedb column type of " + typ)
1235}
1236
1237func colTypeToReflectType(typ string) reflect.Type {
1238	switch typ {
1239	case "bool":
1240		return reflect.TypeFor[bool]()
1241	case "nullbool":
1242		return reflect.TypeFor[NullBool]()
1243	case "int16":
1244		return reflect.TypeFor[int16]()
1245	case "nullint16":
1246		return reflect.TypeFor[NullInt16]()
1247	case "int32":
1248		return reflect.TypeFor[int32]()
1249	case "nullint32":
1250		return reflect.TypeFor[NullInt32]()
1251	case "string":
1252		return reflect.TypeFor[string]()
1253	case "nullstring":
1254		return reflect.TypeFor[NullString]()
1255	case "int64":
1256		return reflect.TypeFor[int64]()
1257	case "nullint64":
1258		return reflect.TypeFor[NullInt64]()
1259	case "float64":
1260		return reflect.TypeFor[float64]()
1261	case "nullfloat64":
1262		return reflect.TypeFor[NullFloat64]()
1263	case "datetime":
1264		return reflect.TypeFor[time.Time]()
1265	case "any":
1266		return reflect.TypeFor[any]()
1267	}
1268	panic("invalid fakedb column type of " + typ)
1269}
1270