||
- package sqlite3_test
- import (
- "database/sql"
- "fmt"
- "math/rand"
- "regexp"
- "strconv"
- "sync"
- "testing"
- "time"
- )
- type Dialect int
- const (
- SQLITE Dialect = iota
- POSTGRESQL
- MYSQL
- )
- type DB struct {
- *testing.T
- *sql.DB
- dialect Dialect
- once sync.Once
- }
- var db *DB
- // the following tables will be created and dropped during the test
- var testTables = []string{"foo", "bar", "t", "bench"}
- var tests = []testing.InternalTest{
- {"TestBlobs", TestBlobs},
- {"TestManyQueryRow", TestManyQueryRow},
- {"TestTxQuery", TestTxQuery},
- {"TestPreparedStmt", TestPreparedStmt},
- }
- var benchmarks = []testing.InternalBenchmark{
- {"BenchmarkExec", BenchmarkExec},
- {"BenchmarkQuery", BenchmarkQuery},
- {"BenchmarkParams", BenchmarkParams},
- {"BenchmarkStmt", BenchmarkStmt},
- {"BenchmarkRows", BenchmarkRows},
- {"BenchmarkStmtRows", BenchmarkStmtRows},
- }
- // RunTests runs the SQL test suite
- func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
- db = &DB{t, d, dialect, sync.Once{}}
- testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
- if !testing.Short() {
- for _, b := range benchmarks {
- fmt.Printf("%-20s", b.Name)
- r := testing.Benchmark(b.F)
- fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
- }
- }
- db.tearDown()
- }
- func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
- res, err := db.Exec(sql, args...)
- if err != nil {
- db.Fatalf("Error running %q: %v", sql, err)
- }
- return res
- }
- func (db *DB) tearDown() {
- for _, tbl := range testTables {
- switch db.dialect {
- case SQLITE:
- db.mustExec("drop table if exists " + tbl)
- case MYSQL, POSTGRESQL:
- db.mustExec("drop table if exists " + tbl)
- default:
- db.Fatal("unkown dialect")
- }
- }
- }
- // q replaces ? parameters if needed
- func (db *DB) q(sql string) string {
- switch db.dialect {
- case POSTGRESQL: // repace with $1, $2, ..
- qrx := regexp.MustCompile(`\?`)
- n := 0
- return qrx.ReplaceAllStringFunc(sql, func(string) string {
- n++
- return "$" + strconv.Itoa(n)
- })
- }
- return sql
- }
- func (db *DB) blobType(size int) string {
- switch db.dialect {
- case SQLITE:
- return fmt.Sprintf("blob[%d]", size)
- case POSTGRESQL:
- return "bytea"
- case MYSQL:
- return fmt.Sprintf("VARBINARY(%d)", size)
- }
- panic("unkown dialect")
- }
- func (db *DB) serialPK() string {
- switch db.dialect {
- case SQLITE:
- return "integer primary key autoincrement"
- case POSTGRESQL:
- return "serial primary key"
- case MYSQL:
- return "integer primary key auto_increment"
- }
- panic("unkown dialect")
- }
- func (db *DB) now() string {
- switch db.dialect {
- case SQLITE:
- return "datetime('now')"
- case POSTGRESQL:
- return "now()"
- case MYSQL:
- return "now()"
- }
- panic("unkown dialect")
- }
- func makeBench() {
- if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
- panic(err)
- }
- st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
- if err != nil {
- panic(err)
- }
- defer st.Close()
- for i := 0; i < 100; i++ {
- if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
- panic(err)
- }
- }
- }
- func TestResult(t *testing.T) {
- db.tearDown()
- db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
- for i := 1; i < 3; i++ {
- r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
- n, err := r.RowsAffected()
- if err != nil {
- t.Fatal(err)
- }
- if n != 1 {
- t.Errorf("got %v, want %v", n, 1)
- }
- n, err = r.LastInsertId()
- if err != nil {
- t.Fatal(err)
- }
- if n != int64(i) {
- t.Errorf("got %v, want %v", n, i)
- }
- }
- if _, err := db.Exec("error!"); err == nil {
- t.Fatalf("expected error")
- }
- }
- func TestBlobs(t *testing.T) {
- db.tearDown()
- var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
- db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
- db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
- want := fmt.Sprintf("%x", blob)
- b := make([]byte, 16)
- err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
- got := fmt.Sprintf("%x", b)
- if err != nil {
- t.Errorf("[]byte scan: %v", err)
- } else if got != want {
- t.Errorf("for []byte, got %q; want %q", got, want)
- }
- err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
- want = string(blob)
- if err != nil {
- t.Errorf("string scan: %v", err)
- } else if got != want {
- t.Errorf("for string, got %q; want %q", got, want)
- }
- }
- func TestManyQueryRow(t *testing.T) {
- if testing.Short() {
- t.Log("skipping in short mode")
- return
- }
- db.tearDown()
- db.mustExec("create table foo (id integer primary key, name varchar(50))")
- db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
- var name string
- for i := 0; i < 10000; i++ {
- err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
- if err != nil || name != "bob" {
- t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
- }
- }
- }
- func TestTxQuery(t *testing.T) {
- db.tearDown()
- tx, err := db.Begin()
- if err != nil {
- t.Fatal(err)
- }
- defer tx.Rollback()
- _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
- if err != nil {
- t.Fatal(err)
- }
- _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
- if err != nil {
- t.Fatal(err)
- }
- r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
- if err != nil {
- t.Fatal(err)
- }
- defer r.Close()
- if !r.Next() {
- if r.Err() != nil {
- t.Fatal(err)
- }
- t.Fatal("expected one rows")
- }
- var name string
- err = r.Scan(&name)
- if err != nil {
- t.Fatal(err)
- }
- }
- func TestPreparedStmt(t *testing.T) {
- db.tearDown()
- db.mustExec("CREATE TABLE t (count INT)")
- sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
- if err != nil {
- t.Fatalf("prepare 1: %v", err)
- }
- ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
- if err != nil {
- t.Fatalf("prepare 2: %v", err)
- }
- for n := 1; n <= 3; n++ {
- if _, err := ins.Exec(n); err != nil {
- t.Fatalf("insert(%d) = %v", n, err)
- }
- }
- const nRuns = 10
- var wg sync.WaitGroup
- for i := 0; i < nRuns; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- for j := 0; j < 10; j++ {
- count := 0
- if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
- t.Errorf("Query: %v", err)
- return
- }
- if _, err := ins.Exec(rand.Intn(100)); err != nil {
- t.Errorf("Insert: %v", err)
- return
- }
- }
- }()
- }
- wg.Wait()
- }
- // Benchmarks need to use panic() since b.Error errors are lost when
- // running via testing.Benchmark() I would like to run these via go
- // test -bench but calling Benchmark() from a benchmark test
- // currently hangs go.
- func BenchmarkExec(b *testing.B) {
- for i := 0; i < b.N; i++ {
- if _, err := db.Exec("select 1"); err != nil {
- panic(err)
- }
- }
- }
- func BenchmarkQuery(b *testing.B) {
- for i := 0; i < b.N; i++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- // var t time.Time
- if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
- panic(err)
- }
- }
- }
- func BenchmarkParams(b *testing.B) {
- for i := 0; i < b.N; i++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- // var t time.Time
- if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
- panic(err)
- }
- }
- }
- func BenchmarkStmt(b *testing.B) {
- st, err := db.Prepare("select ?, ?, ?, ?")
- if err != nil {
- panic(err)
- }
- defer st.Close()
- for n := 0; n < b.N; n++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- // var t time.Time
- if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
- panic(err)
- }
- }
- }
- func BenchmarkRows(b *testing.B) {
- db.once.Do(makeBench)
- for n := 0; n < b.N; n++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- var t time.Time
- r, err := db.Query("select * from bench")
- if err != nil {
- panic(err)
- }
- for r.Next() {
- if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
- panic(err)
- }
- }
- if err = r.Err(); err != nil {
- panic(err)
- }
- }
- }
- func BenchmarkStmtRows(b *testing.B) {
- db.once.Do(makeBench)
- st, err := db.Prepare("select * from bench")
- if err != nil {
- panic(err)
- }
- defer st.Close()
- for n := 0; n < b.N; n++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- var t time.Time
- r, err := st.Query()
- if err != nil {
- panic(err)
- }
- for r.Next() {
- if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
- panic(err)
- }
- }
- if err = r.Err(); err != nil {
- panic(err)
- }
- }
- }
|