basedb.go 7.6 KB


  1. // Copyright (C) 2025 The Syncthing Authors.
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  5. // You can obtain one at https://mozilla.org/MPL/2.0/.
  6. package sqlite
  7. import (
  8. "cmp"
  9. "database/sql"
  10. "embed"
  11. "io/fs"
  12. "net/url"
  13. "path/filepath"
  14. "slices"
  15. "strconv"
  16. "strings"
  17. "sync"
  18. "text/template"
  19. "time"
  20. "github.com/jmoiron/sqlx"
  21. "github.com/syncthing/syncthing/lib/build"
  22. "github.com/syncthing/syncthing/lib/protocol"
  23. )
  24. const currentSchemaVersion = 3
  25. //go:embed sql/**
  26. var embedded embed.FS
  27. type baseDB struct {
  28. path string
  29. baseName string
  30. sql *sqlx.DB
  31. updateLock sync.Mutex
  32. updatePoints int
  33. checkpointsCount int
  34. statementsMut sync.RWMutex
  35. statements map[string]*sqlx.Stmt
  36. tplInput map[string]any
  37. }
  38. //nolint:noctx
  39. func openBase(path string, maxConns int, pragmas, schemaScripts, migrationScripts []string) (*baseDB, error) {
  40. // Open the database with options to enable foreign keys and recursive
  41. // triggers (needed for the delete+insert triggers on row replace).
  42. pathURL := url.URL{
  43. Scheme: "file",
  44. Path: fileToUriPath(path),
  45. RawQuery: commonOptions,
  46. }
  47. sqlDB, err := sqlx.Open(dbDriver, pathURL.String())
  48. if err != nil {
  49. return nil, wrap(err)
  50. }
  51. sqlDB.SetMaxOpenConns(maxConns)
  52. for _, pragma := range pragmas {
  53. if _, err := sqlDB.Exec("PRAGMA " + pragma); err != nil {
  54. return nil, wrap(err, "PRAGMA "+pragma)
  55. }
  56. }
  57. db := &baseDB{
  58. path: path,
  59. baseName: filepath.Base(path),
  60. sql: sqlDB,
  61. statements: make(map[string]*sqlx.Stmt),
  62. tplInput: map[string]any{
  63. "FlagLocalUnsupported": protocol.FlagLocalUnsupported,
  64. "FlagLocalIgnored": protocol.FlagLocalIgnored,
  65. "FlagLocalMustRescan": protocol.FlagLocalMustRescan,
  66. "FlagLocalReceiveOnly": protocol.FlagLocalReceiveOnly,
  67. "FlagLocalGlobal": protocol.FlagLocalGlobal,
  68. "FlagLocalNeeded": protocol.FlagLocalNeeded,
  69. "FlagLocalRemoteInvalid": protocol.FlagLocalRemoteInvalid,
  70. "LocalInvalidFlags": protocol.LocalInvalidFlags,
  71. "SyncthingVersion": build.LongVersion,
  72. },
  73. }
  74. for _, script := range schemaScripts {
  75. if err := db.runScripts(script); err != nil {
  76. return nil, wrap(err)
  77. }
  78. }
  79. ver, _ := db.getAppliedSchemaVersion()
  80. if ver.SchemaVersion > 0 {
  81. type migration struct {
  82. script string
  83. version int
  84. }
  85. migrations := make([]migration, 0, len(migrationScripts))
  86. for _, script := range migrationScripts {
  87. base := filepath.Base(script)
  88. nstr, _, ok := strings.Cut(base, "-")
  89. if !ok {
  90. continue
  91. }
  92. n, err := strconv.ParseInt(nstr, 10, 32)
  93. if err != nil {
  94. continue
  95. }
  96. migrations = append(migrations, migration{
  97. script: script,
  98. version: int(n),
  99. })
  100. }
  101. slices.SortFunc(migrations, func(m1, m2 migration) int { return cmp.Compare(m1.version, m2.version) })
  102. for _, m := range migrations {
  103. if err := db.applyMigration(m.version, m.script); err != nil {
  104. return nil, wrap(err)
  105. }
  106. }
  107. }
  108. // Set the current schema version, if not already set
  109. if err := setAppliedSchemaVersion(currentSchemaVersion, db.sql); err != nil {
  110. return nil, wrap(err)
  111. }
  112. return db, nil
  113. }
  114. func fileToUriPath(path string) string {
  115. path = filepath.ToSlash(path)
  116. if (build.IsWindows && len(path) >= 2 && path[1] == ':') ||
  117. (strings.HasPrefix(path, "//") && !strings.HasPrefix(path, "///")) {
  118. // Add an extra leading slash for Windows drive letter or UNC path
  119. path = "/" + path
  120. }
  121. return path
  122. }
  123. func (s *baseDB) Close() error {
  124. s.updateLock.Lock()
  125. s.statementsMut.Lock()
  126. defer s.updateLock.Unlock()
  127. defer s.statementsMut.Unlock()
  128. for _, stmt := range s.statements {
  129. stmt.Close()
  130. }
  131. return wrap(s.sql.Close())
  132. }
  133. var tplFuncs = template.FuncMap{
  134. "or": func(vs ...int) int {
  135. v := vs[0]
  136. for _, ov := range vs[1:] {
  137. v |= ov
  138. }
  139. return v
  140. },
  141. }
  142. // stmt returns a prepared statement for the given SQL string, after
  143. // applying local template expansions. The statement is cached.
  144. func (s *baseDB) stmt(tpl string) stmt {
  145. tpl = strings.TrimSpace(tpl)
  146. // Fast concurrent lookup of cached statement
  147. s.statementsMut.RLock()
  148. stmt, ok := s.statements[tpl]
  149. s.statementsMut.RUnlock()
  150. if ok {
  151. return stmt
  152. }
  153. // On miss, take the full lock, check again
  154. s.statementsMut.Lock()
  155. defer s.statementsMut.Unlock()
  156. stmt, ok = s.statements[tpl]
  157. if ok {
  158. return stmt
  159. }
  160. // Prepare and cache
  161. stmt, err := s.sql.Preparex(s.expandTemplateVars(tpl))
  162. if err != nil {
  163. return failedStmt{err}
  164. }
  165. s.statements[tpl] = stmt
  166. return stmt
  167. }
  168. // expandTemplateVars just applies template expansions to the template
  169. // string, or dies trying
  170. func (s *baseDB) expandTemplateVars(tpl string) string {
  171. var sb strings.Builder
  172. compTpl := template.Must(template.New("tpl").Funcs(tplFuncs).Parse(tpl))
  173. if err := compTpl.Execute(&sb, s.tplInput); err != nil {
  174. panic("bug: bad template: " + err.Error())
  175. }
  176. return sb.String()
  177. }
  178. type stmt interface {
  179. Exec(args ...any) (sql.Result, error)
  180. Get(dest any, args ...any) error
  181. Queryx(args ...any) (*sqlx.Rows, error)
  182. Select(dest any, args ...any) error
  183. }
  184. type failedStmt struct {
  185. err error
  186. }
  187. func (f failedStmt) Exec(_ ...any) (sql.Result, error) { return nil, f.err }
  188. func (f failedStmt) Get(_ any, _ ...any) error { return f.err }
  189. func (f failedStmt) Queryx(_ ...any) (*sqlx.Rows, error) { return nil, f.err }
  190. func (f failedStmt) Select(_ any, _ ...any) error { return f.err }
  191. //nolint:noctx
  192. func (s *baseDB) runScripts(glob string, filter ...func(s string) bool) error {
  193. scripts, err := fs.Glob(embedded, glob)
  194. if err != nil {
  195. return wrap(err)
  196. }
  197. tx, err := s.sql.Begin()
  198. if err != nil {
  199. return wrap(err)
  200. }
  201. defer tx.Rollback() //nolint:errcheck
  202. nextScript:
  203. for _, scr := range scripts {
  204. for _, fn := range filter {
  205. if !fn(scr) {
  206. continue nextScript
  207. }
  208. }
  209. if err := s.execScript(tx, scr); err != nil {
  210. return wrap(err)
  211. }
  212. }
  213. return wrap(tx.Commit())
  214. }
  215. //nolint:noctx
  216. func (s *baseDB) applyMigration(ver int, script string) error {
  217. tx, err := s.sql.Begin()
  218. if err != nil {
  219. return wrap(err)
  220. }
  221. defer tx.Rollback() //nolint:errcheck
  222. if err := s.execScript(tx, script); err != nil {
  223. return wrap(err)
  224. }
  225. if err := setAppliedSchemaVersion(ver, tx); err != nil {
  226. return wrap(err)
  227. }
  228. return wrap(tx.Commit())
  229. }
  230. //nolint:noctx
  231. func (s *baseDB) execScript(tx *sql.Tx, scr string) error {
  232. bs, err := fs.ReadFile(embedded, scr)
  233. if err != nil {
  234. return wrap(err, scr)
  235. }
  236. // SQLite requires one statement per exec, so we split the init
  237. // files on lines containing only a semicolon and execute them
  238. // separately. We require it on a separate line because there are
  239. // also statement-internal semicolons in the triggers.
  240. for _, stmt := range strings.Split(string(bs), "\n;") {
  241. if _, err := tx.Exec(s.expandTemplateVars(stmt)); err != nil {
  242. return wrap(err, stmt)
  243. }
  244. }
  245. return nil
  246. }
  247. type schemaVersion struct {
  248. SchemaVersion int
  249. AppliedAt int64
  250. SyncthingVersion string
  251. }
  252. func (s *schemaVersion) AppliedTime() time.Time {
  253. return time.Unix(0, s.AppliedAt)
  254. }
  255. func setAppliedSchemaVersion(ver int, execer sqlx.Execer) error {
  256. _, err := execer.Exec(`
  257. INSERT OR IGNORE INTO schemamigrations (schema_version, applied_at, syncthing_version)
  258. VALUES (?, ?, ?)
  259. `, ver, time.Now().UnixNano(), build.LongVersion)
  260. return wrap(err)
  261. }
  262. func (s *baseDB) getAppliedSchemaVersion() (schemaVersion, error) {
  263. var v schemaVersion
  264. err := s.stmt(`
  265. SELECT schema_version as schemaversion, applied_at as appliedat, syncthing_version as syncthingversion FROM schemamigrations
  266. ORDER BY schema_version DESC
  267. LIMIT 1
  268. `).Get(&v)
  269. return v, wrap(err)
  270. }