Browse Source

fix(db): apply all migrations and schema in one transaction

Jakob Borg 1 month ago
parent
commit
e41d6b9c1e
4 changed files with 42 additions and 21 deletions
  1. 6 0
      .golangci.yml
  2. 29 19
      internal/db/sqlite/basedb.go
  3. 6 1
      internal/db/sqlite/db_test.go
  4. 1 1
      internal/db/sqlite/sql/README.md

+ 6 - 0
.golangci.yml

@@ -64,6 +64,12 @@ linters:
       # relax the slog rules for debug lines, for now
       - linters: [sloglint]
         source: Debug
+      # contexts are irrelevant for SQLite
+      - linters: [noctx]
+        text: database/sql
+      # Rollback errors can be ignored
+      - linters: [errcheck]
+        source: Rollback
   settings:
     sloglint:
       context: "scope"

+ 29 - 19
internal/db/sqlite/basedb.go

@@ -10,6 +10,7 @@ import (
 	"database/sql"
 	"embed"
 	"io/fs"
+	"log/slog"
 	"net/url"
 	"path/filepath"
 	"strconv"
@@ -19,6 +20,7 @@ import (
 	"time"
 
 	"github.com/jmoiron/sqlx"
+	"github.com/syncthing/syncthing/internal/slogutil"
 	"github.com/syncthing/syncthing/lib/build"
 	"github.com/syncthing/syncthing/lib/protocol"
 )
@@ -81,13 +83,19 @@ func openBase(path string, maxConns int, pragmas, schemaScripts, migrationScript
 		},
 	}
 
+	tx, err := db.sql.Beginx()
+	if err != nil {
+		return nil, wrap(err)
+	}
+	defer tx.Rollback()
+
 	for _, script := range schemaScripts {
-		if err := db.runScripts(script); err != nil {
+		if err := db.runScripts(tx, script); err != nil {
 			return nil, wrap(err)
 		}
 	}
 
-	ver, _ := db.getAppliedSchemaVersion()
+	ver, _ := db.getAppliedSchemaVersion(tx)
 	shouldVacuum := false
 	if ver.SchemaVersion > 0 {
 		filter := func(scr string) bool {
@@ -100,10 +108,14 @@ func openBase(path string, maxConns int, pragmas, schemaScripts, migrationScript
 			if err != nil {
 				return false
 			}
-			return int(n) > ver.SchemaVersion
+			if int(n) > ver.SchemaVersion {
+				slog.Info("Applying database migration", slogutil.FilePath(db.baseName), slog.String("script", scr))
+				return true
+			}
+			return false
 		}
 		for _, script := range migrationScripts {
-			if err := db.runScripts(script, filter); err != nil {
+			if err := db.runScripts(tx, script, filter); err != nil {
 				return nil, wrap(err)
 			}
 			shouldVacuum = true
@@ -111,7 +123,11 @@ func openBase(path string, maxConns int, pragmas, schemaScripts, migrationScript
 	}
 
 	// Set the current schema version, if not already set
-	if err := db.setAppliedSchemaVersion(currentSchemaVersion); err != nil {
+	if err := db.setAppliedSchemaVersion(tx, currentSchemaVersion); err != nil {
+		return nil, wrap(err)
+	}
+
+	if err := tx.Commit(); err != nil {
 		return nil, wrap(err)
 	}
 
@@ -228,18 +244,12 @@ func (f failedStmt) Get(_ any, _ ...any) error           { return f.err }
 func (f failedStmt) Queryx(_ ...any) (*sqlx.Rows, error) { return nil, f.err }
 func (f failedStmt) Select(_ any, _ ...any) error        { return f.err }
 
-func (s *baseDB) runScripts(glob string, filter ...func(s string) bool) error {
+func (s *baseDB) runScripts(tx *sqlx.Tx, glob string, filter ...func(s string) bool) error {
 	scripts, err := fs.Glob(embedded, glob)
 	if err != nil {
 		return wrap(err)
 	}
 
-	tx, err := s.sql.Begin()
-	if err != nil {
-		return wrap(err)
-	}
-	defer tx.Rollback() //nolint:errcheck
-
 nextScript:
 	for _, scr := range scripts {
 		for _, fn := range filter {
@@ -262,7 +272,7 @@ nextScript:
 		}
 	}
 
-	return wrap(tx.Commit())
+	return nil
 }
 
 type schemaVersion struct {
@@ -275,20 +285,20 @@ func (s *schemaVersion) AppliedTime() time.Time {
 	return time.Unix(0, s.AppliedAt)
 }
 
-func (s *baseDB) setAppliedSchemaVersion(ver int) error {
-	_, err := s.stmt(`
+func (s *baseDB) setAppliedSchemaVersion(tx *sqlx.Tx, ver int) error {
+	_, err := tx.Exec(`
 		INSERT OR IGNORE INTO schemamigrations (schema_version, applied_at, syncthing_version)
 		VALUES (?, ?, ?)
-	`).Exec(ver, time.Now().UnixNano(), build.LongVersion)
+	`, ver, time.Now().UnixNano(), build.LongVersion)
 	return wrap(err)
 }
 
-func (s *baseDB) getAppliedSchemaVersion() (schemaVersion, error) {
+func (s *baseDB) getAppliedSchemaVersion(tx *sqlx.Tx) (schemaVersion, error) {
 	var v schemaVersion
-	err := s.stmt(`
+	err := tx.Get(&v, `
 		SELECT schema_version as schemaversion, applied_at as appliedat, syncthing_version as syncthingversion FROM schemamigrations
 		ORDER BY schema_version DESC
 		LIMIT 1
-	`).Get(&v)
+	`)
 	return v, wrap(err)
 }

+ 6 - 1
internal/db/sqlite/db_test.go

@@ -81,7 +81,12 @@ func TestBasics(t *testing.T) {
 	)
 
 	t.Run("SchemaVersion", func(t *testing.T) {
-		ver, err := sdb.getAppliedSchemaVersion()
+		tx, err := sdb.sql.Beginx()
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer tx.Rollback()
+		ver, err := sdb.getAppliedSchemaVersion(tx)
 		if err != nil {
 			t.Fatal(err)
 		}

+ 1 - 1
internal/db/sqlite/sql/README.md

@@ -2,7 +2,7 @@ These SQL scripts are embedded in the binary.
 
 Scripts in `schema/` are run at every startup, in alphanumerical order.
 
-Scripts in `migrations/` are run when a migration is needed; the must begin
+Scripts in `migrations/` are run when a migration is needed; they must begin
 with a number that equals the schema version that results from that
 migration. Migrations are not run on initial database creation, so the
 scripts in `schema/` should create the latest version.