1// Copyright 2011 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package sql 6 7import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "reflect" 14 "slices" 15 "strconv" 16 "strings" 17 "sync" 18 "sync/atomic" 19 "testing" 20 "time" 21) 22 23// fakeDriver is a fake database that implements Go's driver.Driver 24// interface, just for testing. 25// 26// It speaks a query language that's semantically similar to but 27// syntactically different and simpler than SQL. The syntax is as 28// follows: 29// 30// WIPE 31// CREATE|<tablename>|<col>=<type>,<col>=<type>,... 32// where types are: "string", [u]int{8,16,32,64}, "bool" 33// INSERT|<tablename>|col=val,col2=val2,col3=? 34// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 35// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 36// 37// Any of these can be preceded by PANIC|<method>|, to cause the 38// named method on fakeStmt to panic. 39// 40// Any of these can be proceeded by WAIT|<duration>|, to cause the 41// named method on fakeStmt to sleep for the specified duration. 42// 43// Multiple of these can be combined when separated with a semicolon. 44// 45// When opening a fakeDriver's database, it starts empty with no 46// tables. All tables and data are stored in memory only. 47type fakeDriver struct { 48 mu sync.Mutex // guards 3 following fields 49 openCount int // conn opens 50 closeCount int // conn closes 51 waitCh chan struct{} 52 waitingCh chan struct{} 53 dbs map[string]*fakeDB 54} 55 56type fakeConnector struct { 57 name string 58 59 waiter func(context.Context) 60 closed bool 61} 62 63func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 64 conn, err := fdriver.Open(c.name) 65 conn.(*fakeConn).waiter = c.waiter 66 return conn, err 67} 68 69func (c *fakeConnector) Driver() driver.Driver { 70 return fdriver 71} 72 73func (c *fakeConnector) Close() error { 74 if c.closed { 75 return errors.New("fakedb: connector is closed") 76 } 77 c.closed = true 78 return nil 79} 80 81type fakeDriverCtx struct { 82 fakeDriver 83} 84 85var _ driver.DriverContext = &fakeDriverCtx{} 86 87func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { 88 return &fakeConnector{name: name}, nil 89} 90 91type fakeDB struct { 92 name string 93 94 useRawBytes atomic.Bool 95 96 mu sync.Mutex 97 tables map[string]*table 98 badConn bool 99 allowAny bool 100} 101 102type fakeError struct { 103 Message string 104 Wrapped error 105} 106 107func (err fakeError) Error() string { 108 return err.Message 109} 110 111func (err fakeError) Unwrap() error { 112 return err.Wrapped 113} 114 115type table struct { 116 mu sync.Mutex 117 colname []string 118 coltype []string 119 rows []*row 120} 121 122func (t *table) columnIndex(name string) int { 123 return slices.Index(t.colname, name) 124} 125 126type row struct { 127 cols []any // must be same size as its table colname + coltype 128} 129 130type memToucher interface { 131 // touchMem reads & writes some memory, to help find data races. 132 touchMem() 133} 134 135type fakeConn struct { 136 db *fakeDB // where to return ourselves to 137 138 currTx *fakeTx 139 140 // Every operation writes to line to enable the race detector 141 // check for data races. 142 line int64 143 144 // Stats for tests: 145 mu sync.Mutex 146 stmtsMade int 147 stmtsClosed int 148 numPrepare int 149 150 // bad connection tests; see isBad() 151 bad bool 152 stickyBad bool 153 154 skipDirtySession bool // tests that use Conn should set this to true. 155 156 // dirtySession tests ResetSession, true if a query has executed 157 // until ResetSession is called. 158 dirtySession bool 159 160 // The waiter is called before each query. May be used in place of the "WAIT" 161 // directive. 162 waiter func(context.Context) 163} 164 165func (c *fakeConn) touchMem() { 166 c.line++ 167} 168 169func (c *fakeConn) incrStat(v *int) { 170 c.mu.Lock() 171 *v++ 172 c.mu.Unlock() 173} 174 175type fakeTx struct { 176 c *fakeConn 177} 178 179type boundCol struct { 180 Column string 181 Placeholder string 182 Ordinal int 183} 184 185type fakeStmt struct { 186 memToucher 187 c *fakeConn 188 q string // just for debugging 189 190 cmd string 191 table string 192 panic string 193 wait time.Duration 194 195 next *fakeStmt // used for returning multiple results. 196 197 closed bool 198 199 colName []string // used by CREATE, INSERT, SELECT (selected columns) 200 colType []string // used by CREATE 201 colValue []any // used by INSERT (mix of strings and "?" for bound params) 202 placeholders int // used by INSERT/SELECT: number of ? params 203 204 whereCol []boundCol // used by SELECT (all placeholders) 205 206 placeholderConverter []driver.ValueConverter // used by INSERT 207} 208 209var fdriver driver.Driver = &fakeDriver{} 210 211func init() { 212 Register("test", fdriver) 213} 214 215type Dummy struct { 216 driver.Driver 217} 218 219func TestDrivers(t *testing.T) { 220 unregisterAllDrivers() 221 Register("test", fdriver) 222 Register("invalid", Dummy{}) 223 all := Drivers() 224 if len(all) < 2 || !slices.IsSorted(all) || !slices.Contains(all, "test") || !slices.Contains(all, "invalid") { 225 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 226 } 227} 228 229// hook to simulate connection failures 230var hookOpenErr struct { 231 sync.Mutex 232 fn func() error 233} 234 235func setHookOpenErr(fn func() error) { 236 hookOpenErr.Lock() 237 defer hookOpenErr.Unlock() 238 hookOpenErr.fn = fn 239} 240 241// Supports dsn forms: 242// 243// <dbname> 244// <dbname>;<opts> (only currently supported option is `badConn`, 245// which causes driver.ErrBadConn to be returned on 246// every other conn.Begin()) 247func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 248 hookOpenErr.Lock() 249 fn := hookOpenErr.fn 250 hookOpenErr.Unlock() 251 if fn != nil { 252 if err := fn(); err != nil { 253 return nil, err 254 } 255 } 256 parts := strings.Split(dsn, ";") 257 if len(parts) < 1 { 258 return nil, errors.New("fakedb: no database name") 259 } 260 name := parts[0] 261 262 db := d.getDB(name) 263 264 d.mu.Lock() 265 d.openCount++ 266 d.mu.Unlock() 267 conn := &fakeConn{db: db} 268 269 if len(parts) >= 2 && parts[1] == "badConn" { 270 conn.bad = true 271 } 272 if d.waitCh != nil { 273 d.waitingCh <- struct{}{} 274 <-d.waitCh 275 d.waitCh = nil 276 d.waitingCh = nil 277 } 278 return conn, nil 279} 280 281func (d *fakeDriver) getDB(name string) *fakeDB { 282 d.mu.Lock() 283 defer d.mu.Unlock() 284 if d.dbs == nil { 285 d.dbs = make(map[string]*fakeDB) 286 } 287 db, ok := d.dbs[name] 288 if !ok { 289 db = &fakeDB{name: name} 290 d.dbs[name] = db 291 } 292 return db 293} 294 295func (db *fakeDB) wipe() { 296 db.mu.Lock() 297 defer db.mu.Unlock() 298 db.tables = nil 299} 300 301func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 302 db.mu.Lock() 303 defer db.mu.Unlock() 304 if db.tables == nil { 305 db.tables = make(map[string]*table) 306 } 307 if _, exist := db.tables[name]; exist { 308 return fmt.Errorf("fakedb: table %q already exists", name) 309 } 310 if len(columnNames) != len(columnTypes) { 311 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", 312 name, len(columnNames), len(columnTypes)) 313 } 314 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 315 return nil 316} 317 318// must be called with db.mu lock held 319func (db *fakeDB) table(table string) (*table, bool) { 320 if db.tables == nil { 321 return nil, false 322 } 323 t, ok := db.tables[table] 324 return t, ok 325} 326 327func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 328 db.mu.Lock() 329 defer db.mu.Unlock() 330 t, ok := db.table(table) 331 if !ok { 332 return 333 } 334 if i := slices.Index(t.colname, column); i != -1 { 335 return t.coltype[i], true 336 } 337 return "", false 338} 339 340func (c *fakeConn) isBad() bool { 341 if c.stickyBad { 342 return true 343 } else if c.bad { 344 if c.db == nil { 345 return false 346 } 347 // alternate between bad conn and not bad conn 348 c.db.badConn = !c.db.badConn 349 return c.db.badConn 350 } else { 351 return false 352 } 353} 354 355func (c *fakeConn) isDirtyAndMark() bool { 356 if c.skipDirtySession { 357 return false 358 } 359 if c.currTx != nil { 360 c.dirtySession = true 361 return false 362 } 363 if c.dirtySession { 364 return true 365 } 366 c.dirtySession = true 367 return false 368} 369 370func (c *fakeConn) Begin() (driver.Tx, error) { 371 if c.isBad() { 372 return nil, fakeError{Wrapped: driver.ErrBadConn} 373 } 374 if c.currTx != nil { 375 return nil, errors.New("fakedb: already in a transaction") 376 } 377 c.touchMem() 378 c.currTx = &fakeTx{c: c} 379 return c.currTx, nil 380} 381 382var hookPostCloseConn struct { 383 sync.Mutex 384 fn func(*fakeConn, error) 385} 386 387func setHookpostCloseConn(fn func(*fakeConn, error)) { 388 hookPostCloseConn.Lock() 389 defer hookPostCloseConn.Unlock() 390 hookPostCloseConn.fn = fn 391} 392 393var testStrictClose *testing.T 394 395// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 396// fails to close. If nil, the check is disabled. 397func setStrictFakeConnClose(t *testing.T) { 398 testStrictClose = t 399} 400 401func (c *fakeConn) ResetSession(ctx context.Context) error { 402 c.dirtySession = false 403 c.currTx = nil 404 if c.isBad() { 405 return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn} 406 } 407 return nil 408} 409 410var _ driver.Validator = (*fakeConn)(nil) 411 412func (c *fakeConn) IsValid() bool { 413 return !c.isBad() 414} 415 416func (c *fakeConn) Close() (err error) { 417 drv := fdriver.(*fakeDriver) 418 defer func() { 419 if err != nil && testStrictClose != nil { 420 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 421 } 422 hookPostCloseConn.Lock() 423 fn := hookPostCloseConn.fn 424 hookPostCloseConn.Unlock() 425 if fn != nil { 426 fn(c, err) 427 } 428 if err == nil { 429 drv.mu.Lock() 430 drv.closeCount++ 431 drv.mu.Unlock() 432 } 433 }() 434 c.touchMem() 435 if c.currTx != nil { 436 return errors.New("fakedb: can't close fakeConn; in a Transaction") 437 } 438 if c.db == nil { 439 return errors.New("fakedb: can't close fakeConn; already closed") 440 } 441 if c.stmtsMade > c.stmtsClosed { 442 return errors.New("fakedb: can't close; dangling statement(s)") 443 } 444 c.db = nil 445 return nil 446} 447 448func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 449 for _, arg := range args { 450 switch arg.Value.(type) { 451 case int64, float64, bool, nil, []byte, string, time.Time: 452 default: 453 if !allowAny { 454 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 455 } 456 } 457 } 458 return nil 459} 460 461func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 462 // Ensure that ExecContext is called if available. 463 panic("ExecContext was not called.") 464} 465 466func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 467 // This is an optional interface, but it's implemented here 468 // just to check that all the args are of the proper types. 469 // ErrSkip is returned so the caller acts as if we didn't 470 // implement this at all. 471 err := checkSubsetTypes(c.db.allowAny, args) 472 if err != nil { 473 return nil, err 474 } 475 return nil, driver.ErrSkip 476} 477 478func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 479 // Ensure that ExecContext is called if available. 480 panic("QueryContext was not called.") 481} 482 483func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 484 // This is an optional interface, but it's implemented here 485 // just to check that all the args are of the proper types. 486 // ErrSkip is returned so the caller acts as if we didn't 487 // implement this at all. 488 err := checkSubsetTypes(c.db.allowAny, args) 489 if err != nil { 490 return nil, err 491 } 492 return nil, driver.ErrSkip 493} 494 495func errf(msg string, args ...any) error { 496 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 497} 498 499// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 500// (note that where columns must always contain ? marks, 501// just a limitation for fakedb) 502func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 503 if len(parts) != 3 { 504 stmt.Close() 505 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 506 } 507 stmt.table = parts[0] 508 509 stmt.colName = strings.Split(parts[1], ",") 510 for n, colspec := range strings.Split(parts[2], ",") { 511 if colspec == "" { 512 continue 513 } 514 nameVal := strings.Split(colspec, "=") 515 if len(nameVal) != 2 { 516 stmt.Close() 517 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 518 } 519 column, value := nameVal[0], nameVal[1] 520 _, ok := c.db.columnType(stmt.table, column) 521 if !ok { 522 stmt.Close() 523 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 524 } 525 if !strings.HasPrefix(value, "?") { 526 stmt.Close() 527 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 528 stmt.table, column) 529 } 530 stmt.placeholders++ 531 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 532 } 533 return stmt, nil 534} 535 536// parts are table|col=type,col2=type2 537func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 538 if len(parts) != 2 { 539 stmt.Close() 540 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 541 } 542 stmt.table = parts[0] 543 for n, colspec := range strings.Split(parts[1], ",") { 544 nameType := strings.Split(colspec, "=") 545 if len(nameType) != 2 { 546 stmt.Close() 547 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 548 } 549 stmt.colName = append(stmt.colName, nameType[0]) 550 stmt.colType = append(stmt.colType, nameType[1]) 551 } 552 return stmt, nil 553} 554 555// parts are table|col=?,col2=val 556func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { 557 if len(parts) != 2 { 558 stmt.Close() 559 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 560 } 561 stmt.table = parts[0] 562 for n, colspec := range strings.Split(parts[1], ",") { 563 nameVal := strings.Split(colspec, "=") 564 if len(nameVal) != 2 { 565 stmt.Close() 566 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 567 } 568 column, value := nameVal[0], nameVal[1] 569 ctype, ok := c.db.columnType(stmt.table, column) 570 if !ok { 571 stmt.Close() 572 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 573 } 574 stmt.colName = append(stmt.colName, column) 575 576 if !strings.HasPrefix(value, "?") { 577 var subsetVal any 578 // Convert to driver subset type 579 switch ctype { 580 case "string": 581 subsetVal = []byte(value) 582 case "blob": 583 subsetVal = []byte(value) 584 case "int32": 585 i, err := strconv.Atoi(value) 586 if err != nil { 587 stmt.Close() 588 return nil, errf("invalid conversion to int32 from %q", value) 589 } 590 subsetVal = int64(i) // int64 is a subset type, but not int32 591 case "table": // For testing cursor reads. 592 c.skipDirtySession = true 593 vparts := strings.Split(value, "!") 594 595 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) 596 if err != nil { 597 return nil, err 598 } 599 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) 600 substmt.Close() 601 if err != nil { 602 return nil, err 603 } 604 subsetVal = cursor 605 default: 606 stmt.Close() 607 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 608 } 609 stmt.colValue = append(stmt.colValue, subsetVal) 610 } else { 611 stmt.placeholders++ 612 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 613 stmt.colValue = append(stmt.colValue, value) 614 } 615 } 616 return stmt, nil 617} 618 619// hook to simulate broken connections 620var hookPrepareBadConn func() bool 621 622func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 623 panic("use PrepareContext") 624} 625 626func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 627 c.numPrepare++ 628 if c.db == nil { 629 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 630 } 631 632 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 633 return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn} 634 } 635 636 c.touchMem() 637 var firstStmt, prev *fakeStmt 638 for _, query := range strings.Split(query, ";") { 639 parts := strings.Split(query, "|") 640 if len(parts) < 1 { 641 return nil, errf("empty query") 642 } 643 stmt := &fakeStmt{q: query, c: c, memToucher: c} 644 if firstStmt == nil { 645 firstStmt = stmt 646 } 647 if len(parts) >= 3 { 648 switch parts[0] { 649 case "PANIC": 650 stmt.panic = parts[1] 651 parts = parts[2:] 652 case "WAIT": 653 wait, err := time.ParseDuration(parts[1]) 654 if err != nil { 655 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 656 } 657 parts = parts[2:] 658 stmt.wait = wait 659 } 660 } 661 cmd := parts[0] 662 stmt.cmd = cmd 663 parts = parts[1:] 664 665 if c.waiter != nil { 666 c.waiter(ctx) 667 if err := ctx.Err(); err != nil { 668 return nil, err 669 } 670 } 671 672 if stmt.wait > 0 { 673 wait := time.NewTimer(stmt.wait) 674 select { 675 case <-wait.C: 676 case <-ctx.Done(): 677 wait.Stop() 678 return nil, ctx.Err() 679 } 680 } 681 682 c.incrStat(&c.stmtsMade) 683 var err error 684 switch cmd { 685 case "WIPE": 686 // Nothing 687 case "USE_RAWBYTES": 688 c.db.useRawBytes.Store(true) 689 case "SELECT": 690 stmt, err = c.prepareSelect(stmt, parts) 691 case "CREATE": 692 stmt, err = c.prepareCreate(stmt, parts) 693 case "INSERT": 694 stmt, err = c.prepareInsert(ctx, stmt, parts) 695 case "NOSERT": 696 // Do all the prep-work like for an INSERT but don't actually insert the row. 697 // Used for some of the concurrent tests. 698 stmt, err = c.prepareInsert(ctx, stmt, parts) 699 default: 700 stmt.Close() 701 return nil, errf("unsupported command type %q", cmd) 702 } 703 if err != nil { 704 return nil, err 705 } 706 if prev != nil { 707 prev.next = stmt 708 } 709 prev = stmt 710 } 711 return firstStmt, nil 712} 713 714func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 715 if s.panic == "ColumnConverter" { 716 panic(s.panic) 717 } 718 if len(s.placeholderConverter) == 0 { 719 return driver.DefaultParameterConverter 720 } 721 return s.placeholderConverter[idx] 722} 723 724func (s *fakeStmt) Close() error { 725 if s.panic == "Close" { 726 panic(s.panic) 727 } 728 if s.c == nil { 729 panic("nil conn in fakeStmt.Close") 730 } 731 if s.c.db == nil { 732 panic("in fakeStmt.Close, conn's db is nil (already closed)") 733 } 734 s.touchMem() 735 if !s.closed { 736 s.c.incrStat(&s.c.stmtsClosed) 737 s.closed = true 738 } 739 if s.next != nil { 740 s.next.Close() 741 } 742 return nil 743} 744 745var errClosed = errors.New("fakedb: statement has been closed") 746 747// hook to simulate broken connections 748var hookExecBadConn func() bool 749 750func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 751 panic("Using ExecContext") 752} 753 754var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") 755 756func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 757 if s.panic == "Exec" { 758 panic(s.panic) 759 } 760 if s.closed { 761 return nil, errClosed 762 } 763 764 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 765 return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn} 766 } 767 if s.c.isDirtyAndMark() { 768 return nil, errFakeConnSessionDirty 769 } 770 771 err := checkSubsetTypes(s.c.db.allowAny, args) 772 if err != nil { 773 return nil, err 774 } 775 s.touchMem() 776 777 if s.wait > 0 { 778 time.Sleep(s.wait) 779 } 780 781 select { 782 default: 783 case <-ctx.Done(): 784 return nil, ctx.Err() 785 } 786 787 db := s.c.db 788 switch s.cmd { 789 case "WIPE": 790 db.wipe() 791 return driver.ResultNoRows, nil 792 case "USE_RAWBYTES": 793 s.c.db.useRawBytes.Store(true) 794 return driver.ResultNoRows, nil 795 case "CREATE": 796 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 797 return nil, err 798 } 799 return driver.ResultNoRows, nil 800 case "INSERT": 801 return s.execInsert(args, true) 802 case "NOSERT": 803 // Do all the prep-work like for an INSERT but don't actually insert the row. 804 // Used for some of the concurrent tests. 805 return s.execInsert(args, false) 806 } 807 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) 808} 809 810func valueFromPlaceholderName(args []driver.NamedValue, name string) driver.Value { 811 for i := range args { 812 if args[i].Name == name { 813 return args[i].Value 814 } 815 } 816 return nil 817} 818 819// When doInsert is true, add the row to the table. 820// When doInsert is false do prep-work and error checking, but don't 821// actually add the row to the table. 822func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 823 db := s.c.db 824 if len(args) != s.placeholders { 825 panic("error in pkg db; should only get here if size is correct") 826 } 827 db.mu.Lock() 828 t, ok := db.table(s.table) 829 db.mu.Unlock() 830 if !ok { 831 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 832 } 833 834 t.mu.Lock() 835 defer t.mu.Unlock() 836 837 var cols []any 838 if doInsert { 839 cols = make([]any, len(t.colname)) 840 } 841 argPos := 0 842 for n, colname := range s.colName { 843 colidx := t.columnIndex(colname) 844 if colidx == -1 { 845 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 846 } 847 var val any 848 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 849 if strvalue == "?" { 850 val = args[argPos].Value 851 } else { 852 // Assign value from argument placeholder name. 853 if v := valueFromPlaceholderName(args, strvalue[1:]); v != nil { 854 val = v 855 } 856 } 857 argPos++ 858 } else { 859 val = s.colValue[n] 860 } 861 if doInsert { 862 cols[colidx] = val 863 } 864 } 865 866 if doInsert { 867 t.rows = append(t.rows, &row{cols: cols}) 868 } 869 return driver.RowsAffected(1), nil 870} 871 872// hook to simulate broken connections 873var hookQueryBadConn func() bool 874 875func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 876 panic("Use QueryContext") 877} 878 879func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 880 if s.panic == "Query" { 881 panic(s.panic) 882 } 883 if s.closed { 884 return nil, errClosed 885 } 886 887 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 888 return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn} 889 } 890 if s.c.isDirtyAndMark() { 891 return nil, errFakeConnSessionDirty 892 } 893 894 err := checkSubsetTypes(s.c.db.allowAny, args) 895 if err != nil { 896 return nil, err 897 } 898 899 s.touchMem() 900 db := s.c.db 901 if len(args) != s.placeholders { 902 panic("error in pkg db; should only get here if size is correct") 903 } 904 905 setMRows := make([][]*row, 0, 1) 906 setColumns := make([][]string, 0, 1) 907 setColType := make([][]string, 0, 1) 908 909 for { 910 db.mu.Lock() 911 t, ok := db.table(s.table) 912 db.mu.Unlock() 913 if !ok { 914 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 915 } 916 917 if s.table == "magicquery" { 918 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 919 if args[0].Value == "sleep" { 920 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 921 } 922 } 923 } 924 if s.table == "tx_status" && s.colName[0] == "tx_status" { 925 txStatus := "autocommit" 926 if s.c.currTx != nil { 927 txStatus = "transaction" 928 } 929 cursor := &rowsCursor{ 930 db: s.c.db, 931 parentMem: s.c, 932 posRow: -1, 933 rows: [][]*row{ 934 { 935 { 936 cols: []any{ 937 txStatus, 938 }, 939 }, 940 }, 941 }, 942 cols: [][]string{ 943 { 944 "tx_status", 945 }, 946 }, 947 colType: [][]string{ 948 { 949 "string", 950 }, 951 }, 952 errPos: -1, 953 } 954 return cursor, nil 955 } 956 957 t.mu.Lock() 958 959 colIdx := make(map[string]int) // select column name -> column index in table 960 for _, name := range s.colName { 961 idx := t.columnIndex(name) 962 if idx == -1 { 963 t.mu.Unlock() 964 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 965 } 966 colIdx[name] = idx 967 } 968 969 mrows := []*row{} 970 rows: 971 for _, trow := range t.rows { 972 // Process the where clause, skipping non-match rows. This is lazy 973 // and just uses fmt.Sprintf("%v") to test equality. Good enough 974 // for test code. 975 for _, wcol := range s.whereCol { 976 idx := t.columnIndex(wcol.Column) 977 if idx == -1 { 978 t.mu.Unlock() 979 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) 980 } 981 tcol := trow.cols[idx] 982 if bs, ok := tcol.([]byte); ok { 983 // lazy hack to avoid sprintf %v on a []byte 984 tcol = string(bs) 985 } 986 var argValue any 987 if wcol.Placeholder == "?" { 988 argValue = args[wcol.Ordinal-1].Value 989 } else { 990 if v := valueFromPlaceholderName(args, wcol.Placeholder[1:]); v != nil { 991 argValue = v 992 } 993 } 994 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 995 continue rows 996 } 997 } 998 mrow := &row{cols: make([]any, len(s.colName))} 999 for seli, name := range s.colName { 1000 mrow.cols[seli] = trow.cols[colIdx[name]] 1001 } 1002 mrows = append(mrows, mrow) 1003 } 1004 1005 var colType []string 1006 for _, column := range s.colName { 1007 colType = append(colType, t.coltype[t.columnIndex(column)]) 1008 } 1009 1010 t.mu.Unlock() 1011 1012 setMRows = append(setMRows, mrows) 1013 setColumns = append(setColumns, s.colName) 1014 setColType = append(setColType, colType) 1015 1016 if s.next == nil { 1017 break 1018 } 1019 s = s.next 1020 } 1021 1022 cursor := &rowsCursor{ 1023 db: s.c.db, 1024 parentMem: s.c, 1025 posRow: -1, 1026 rows: setMRows, 1027 cols: setColumns, 1028 colType: setColType, 1029 errPos: -1, 1030 } 1031 return cursor, nil 1032} 1033 1034func (s *fakeStmt) NumInput() int { 1035 if s.panic == "NumInput" { 1036 panic(s.panic) 1037 } 1038 return s.placeholders 1039} 1040 1041// hook to simulate broken connections 1042var hookCommitBadConn func() bool 1043 1044func (tx *fakeTx) Commit() error { 1045 tx.c.currTx = nil 1046 if hookCommitBadConn != nil && hookCommitBadConn() { 1047 return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1048 } 1049 tx.c.touchMem() 1050 return nil 1051} 1052 1053// hook to simulate broken connections 1054var hookRollbackBadConn func() bool 1055 1056func (tx *fakeTx) Rollback() error { 1057 tx.c.currTx = nil 1058 if hookRollbackBadConn != nil && hookRollbackBadConn() { 1059 return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1060 } 1061 tx.c.touchMem() 1062 return nil 1063} 1064 1065type rowsCursor struct { 1066 db *fakeDB 1067 parentMem memToucher 1068 cols [][]string 1069 colType [][]string 1070 posSet int 1071 posRow int 1072 rows [][]*row 1073 closed bool 1074 1075 // errPos and err are for making Next return early with error. 1076 errPos int 1077 err error 1078 1079 // a clone of slices to give out to clients, indexed by the 1080 // original slice's first byte address. we clone them 1081 // just so we're able to corrupt them on close. 1082 bytesClone map[*byte][]byte 1083 1084 // Every operation writes to line to enable the race detector 1085 // check for data races. 1086 // This is separate from the fakeConn.line to allow for drivers that 1087 // can start multiple queries on the same transaction at the same time. 1088 line int64 1089 1090 // closeErr is returned when rowsCursor.Close 1091 closeErr error 1092} 1093 1094func (rc *rowsCursor) touchMem() { 1095 rc.parentMem.touchMem() 1096 rc.line++ 1097} 1098 1099func (rc *rowsCursor) Close() error { 1100 rc.touchMem() 1101 rc.parentMem.touchMem() 1102 rc.closed = true 1103 return rc.closeErr 1104} 1105 1106func (rc *rowsCursor) Columns() []string { 1107 return rc.cols[rc.posSet] 1108} 1109 1110func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1111 return colTypeToReflectType(rc.colType[rc.posSet][index]) 1112} 1113 1114var rowsCursorNextHook func(dest []driver.Value) error 1115 1116func (rc *rowsCursor) Next(dest []driver.Value) error { 1117 if rowsCursorNextHook != nil { 1118 return rowsCursorNextHook(dest) 1119 } 1120 1121 if rc.closed { 1122 return errors.New("fakedb: cursor is closed") 1123 } 1124 rc.touchMem() 1125 rc.posRow++ 1126 if rc.posRow == rc.errPos { 1127 return rc.err 1128 } 1129 if rc.posRow >= len(rc.rows[rc.posSet]) { 1130 return io.EOF // per interface spec 1131 } 1132 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1133 // TODO(bradfitz): convert to subset types? naah, I 1134 // think the subset types should only be input to 1135 // driver, but the sql package should be able to handle 1136 // a wider range of types coming out of drivers. all 1137 // for ease of drivers, and to prevent drivers from 1138 // messing up conversions or doing them differently. 1139 dest[i] = v 1140 1141 if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() { 1142 if rc.bytesClone == nil { 1143 rc.bytesClone = make(map[*byte][]byte) 1144 } 1145 clone, ok := rc.bytesClone[&bs[0]] 1146 if !ok { 1147 clone = make([]byte, len(bs)) 1148 copy(clone, bs) 1149 rc.bytesClone[&bs[0]] = clone 1150 } 1151 dest[i] = clone 1152 } 1153 } 1154 return nil 1155} 1156 1157func (rc *rowsCursor) HasNextResultSet() bool { 1158 rc.touchMem() 1159 return rc.posSet < len(rc.rows)-1 1160} 1161 1162func (rc *rowsCursor) NextResultSet() error { 1163 rc.touchMem() 1164 if rc.HasNextResultSet() { 1165 rc.posSet++ 1166 rc.posRow = -1 1167 return nil 1168 } 1169 return io.EOF // Per interface spec. 1170} 1171 1172// fakeDriverString is like driver.String, but indirects pointers like 1173// DefaultValueConverter. 1174// 1175// This could be surprising behavior to retroactively apply to 1176// driver.String now that Go1 is out, but this is convenient for 1177// our TestPointerParamsAndScans. 1178type fakeDriverString struct{} 1179 1180func (fakeDriverString) ConvertValue(v any) (driver.Value, error) { 1181 switch c := v.(type) { 1182 case string, []byte: 1183 return v, nil 1184 case *string: 1185 if c == nil { 1186 return nil, nil 1187 } 1188 return *c, nil 1189 } 1190 return fmt.Sprintf("%v", v), nil 1191} 1192 1193type anyTypeConverter struct{} 1194 1195func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) { 1196 return v, nil 1197} 1198 1199func converterForType(typ string) driver.ValueConverter { 1200 switch typ { 1201 case "bool": 1202 return driver.Bool 1203 case "nullbool": 1204 return driver.Null{Converter: driver.Bool} 1205 case "byte", "int16": 1206 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1207 case "int32": 1208 return driver.Int32 1209 case "nullbyte", "nullint32", "nullint16": 1210 return driver.Null{Converter: driver.DefaultParameterConverter} 1211 case "string": 1212 return driver.NotNull{Converter: fakeDriverString{}} 1213 case "nullstring": 1214 return driver.Null{Converter: fakeDriverString{}} 1215 case "int64": 1216 // TODO(coopernurse): add type-specific converter 1217 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1218 case "nullint64": 1219 // TODO(coopernurse): add type-specific converter 1220 return driver.Null{Converter: driver.DefaultParameterConverter} 1221 case "float64": 1222 // TODO(coopernurse): add type-specific converter 1223 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1224 case "nullfloat64": 1225 // TODO(coopernurse): add type-specific converter 1226 return driver.Null{Converter: driver.DefaultParameterConverter} 1227 case "datetime": 1228 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1229 case "nulldatetime": 1230 return driver.Null{Converter: driver.DefaultParameterConverter} 1231 case "any": 1232 return anyTypeConverter{} 1233 } 1234 panic("invalid fakedb column type of " + typ) 1235} 1236 1237func colTypeToReflectType(typ string) reflect.Type { 1238 switch typ { 1239 case "bool": 1240 return reflect.TypeFor[bool]() 1241 case "nullbool": 1242 return reflect.TypeFor[NullBool]() 1243 case "int16": 1244 return reflect.TypeFor[int16]() 1245 case "nullint16": 1246 return reflect.TypeFor[NullInt16]() 1247 case "int32": 1248 return reflect.TypeFor[int32]() 1249 case "nullint32": 1250 return reflect.TypeFor[NullInt32]() 1251 case "string": 1252 return reflect.TypeFor[string]() 1253 case "nullstring": 1254 return reflect.TypeFor[NullString]() 1255 case "int64": 1256 return reflect.TypeFor[int64]() 1257 case "nullint64": 1258 return reflect.TypeFor[NullInt64]() 1259 case "float64": 1260 return reflect.TypeFor[float64]() 1261 case "nullfloat64": 1262 return reflect.TypeFor[NullFloat64]() 1263 case "datetime": 1264 return reflect.TypeFor[time.Time]() 1265 case "any": 1266 return reflect.TypeFor[any]() 1267 } 1268 panic("invalid fakedb column type of " + typ) 1269} 1270