sqlite.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. // Package sqlite3 wraps the C SQLite API.
  2. package sqlite3
  3. import (
  4. "context"
  5. "math/bits"
  6. "os"
  7. "sync"
  8. "unsafe"
  9. "github.com/tetratelabs/wazero"
  10. "github.com/tetratelabs/wazero/api"
  11. "github.com/ncruces/go-sqlite3/internal/util"
  12. "github.com/ncruces/go-sqlite3/vfs"
  13. )
  14. // Configure SQLite Wasm.
  15. //
  16. // Importing package embed initializes [Binary]
  17. // with an appropriate build of SQLite:
  18. //
  19. // import _ "github.com/ncruces/go-sqlite3/embed"
  20. var (
  21. Binary []byte // Wasm binary to load.
  22. Path string // Path to load the binary from.
  23. RuntimeConfig wazero.RuntimeConfig
  24. )
  25. // Initialize decodes and compiles the SQLite Wasm binary.
  26. // This is called implicitly when the first connection is openned,
  27. // but is potentially slow, so you may want to call it at a more convenient time.
  28. func Initialize() error {
  29. instance.once.Do(compileSQLite)
  30. return instance.err
  31. }
  32. var instance struct {
  33. runtime wazero.Runtime
  34. compiled wazero.CompiledModule
  35. err error
  36. once sync.Once
  37. }
  38. func compileSQLite() {
  39. ctx := context.Background()
  40. cfg := RuntimeConfig
  41. if cfg == nil {
  42. cfg = wazero.NewRuntimeConfig()
  43. if bits.UintSize < 64 {
  44. cfg = cfg.WithMemoryLimitPages(512) // 32MB
  45. } else {
  46. cfg = cfg.WithMemoryLimitPages(4096) // 256MB
  47. }
  48. cfg = cfg.WithCoreFeatures(api.CoreFeaturesV2)
  49. }
  50. instance.runtime = wazero.NewRuntimeWithConfig(ctx, cfg)
  51. env := instance.runtime.NewHostModuleBuilder("env")
  52. env = vfs.ExportHostFunctions(env)
  53. env = exportCallbacks(env)
  54. _, instance.err = env.Instantiate(ctx)
  55. if instance.err != nil {
  56. return
  57. }
  58. bin := Binary
  59. if bin == nil && Path != "" {
  60. bin, instance.err = os.ReadFile(Path)
  61. if instance.err != nil {
  62. return
  63. }
  64. }
  65. if bin == nil {
  66. instance.err = util.NoBinaryErr
  67. return
  68. }
  69. instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
  70. }
  71. type sqlite struct {
  72. ctx context.Context
  73. mod api.Module
  74. funcs struct {
  75. fn [32]api.Function
  76. id [32]*byte
  77. mask uint32
  78. }
  79. stack [9]stk_t
  80. }
  81. func instantiateSQLite() (sqlt *sqlite, err error) {
  82. if err := Initialize(); err != nil {
  83. return nil, err
  84. }
  85. sqlt = new(sqlite)
  86. sqlt.ctx = util.NewContext(context.Background())
  87. sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
  88. instance.compiled, wazero.NewModuleConfig().WithName(""))
  89. if err != nil {
  90. return nil, err
  91. }
  92. if sqlt.getfn("sqlite3_progress_handler_go") == nil {
  93. return nil, util.BadBinaryErr
  94. }
  95. return sqlt, nil
  96. }
  97. func (sqlt *sqlite) close() error {
  98. return sqlt.mod.Close(sqlt.ctx)
  99. }
  100. func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
  101. if rc == _OK {
  102. return nil
  103. }
  104. if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM {
  105. panic(util.OOMErr)
  106. }
  107. if handle != 0 {
  108. var msg, query string
  109. if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
  110. msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
  111. switch {
  112. case msg == "not an error":
  113. msg = ""
  114. case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]:
  115. msg = ""
  116. }
  117. }
  118. if len(sql) != 0 {
  119. if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 {
  120. query = sql[0][i:]
  121. }
  122. }
  123. if msg != "" || query != "" {
  124. return &Error{code: rc, msg: msg, sql: query}
  125. }
  126. }
  127. return xErrorCode(rc)
  128. }
  129. func (sqlt *sqlite) getfn(name string) api.Function {
  130. c := &sqlt.funcs
  131. p := unsafe.StringData(name)
  132. for i := range c.id {
  133. if c.id[i] == p {
  134. c.id[i] = nil
  135. c.mask &^= uint32(1) << i
  136. return c.fn[i]
  137. }
  138. }
  139. return sqlt.mod.ExportedFunction(name)
  140. }
  141. func (sqlt *sqlite) putfn(name string, fn api.Function) {
  142. c := &sqlt.funcs
  143. p := unsafe.StringData(name)
  144. i := bits.TrailingZeros32(^c.mask)
  145. if i < 32 {
  146. c.id[i] = p
  147. c.fn[i] = fn
  148. c.mask |= uint32(1) << i
  149. } else {
  150. c.id[0] = p
  151. c.fn[0] = fn
  152. c.mask = uint32(1)
  153. }
  154. }
  155. func (sqlt *sqlite) call(name string, params ...stk_t) stk_t {
  156. copy(sqlt.stack[:], params)
  157. fn := sqlt.getfn(name)
  158. err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
  159. if err != nil {
  160. panic(err)
  161. }
  162. sqlt.putfn(name, fn)
  163. return stk_t(sqlt.stack[0])
  164. }
  165. func (sqlt *sqlite) free(ptr ptr_t) {
  166. if ptr == 0 {
  167. return
  168. }
  169. sqlt.call("sqlite3_free", stk_t(ptr))
  170. }
  171. func (sqlt *sqlite) new(size int64) ptr_t {
  172. ptr := ptr_t(sqlt.call("sqlite3_malloc64", stk_t(size)))
  173. if ptr == 0 && size != 0 {
  174. panic(util.OOMErr)
  175. }
  176. return ptr
  177. }
  178. func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t {
  179. ptr = ptr_t(sqlt.call("sqlite3_realloc64", stk_t(ptr), stk_t(size)))
  180. if ptr == 0 && size != 0 {
  181. panic(util.OOMErr)
  182. }
  183. return ptr
  184. }
  185. func (sqlt *sqlite) newBytes(b []byte) ptr_t {
  186. if len(b) == 0 {
  187. return 0
  188. }
  189. ptr := sqlt.new(int64(len(b)))
  190. util.WriteBytes(sqlt.mod, ptr, b)
  191. return ptr
  192. }
  193. func (sqlt *sqlite) newString(s string) ptr_t {
  194. ptr := sqlt.new(int64(len(s)) + 1)
  195. util.WriteString(sqlt.mod, ptr, s)
  196. return ptr
  197. }
  198. const arenaSize = 4096
  199. func (sqlt *sqlite) newArena() arena {
  200. return arena{
  201. sqlt: sqlt,
  202. base: sqlt.new(arenaSize),
  203. }
  204. }
  205. type arena struct {
  206. sqlt *sqlite
  207. ptrs []ptr_t
  208. base ptr_t
  209. next int32
  210. }
  211. func (a *arena) free() {
  212. if a.sqlt == nil {
  213. return
  214. }
  215. for _, ptr := range a.ptrs {
  216. a.sqlt.free(ptr)
  217. }
  218. a.sqlt.free(a.base)
  219. a.sqlt = nil
  220. }
  221. func (a *arena) mark() (reset func()) {
  222. ptrs := len(a.ptrs)
  223. next := a.next
  224. return func() {
  225. rest := a.ptrs[ptrs:]
  226. for _, ptr := range a.ptrs[:ptrs] {
  227. a.sqlt.free(ptr)
  228. }
  229. a.ptrs = rest
  230. a.next = next
  231. }
  232. }
  233. func (a *arena) new(size int64) ptr_t {
  234. // Align the next address, to 4 or 8 bytes.
  235. if size&7 != 0 {
  236. a.next = (a.next + 3) &^ 3
  237. } else {
  238. a.next = (a.next + 7) &^ 7
  239. }
  240. if size <= arenaSize-int64(a.next) {
  241. ptr := a.base + ptr_t(a.next)
  242. a.next += int32(size)
  243. return ptr_t(ptr)
  244. }
  245. ptr := a.sqlt.new(size)
  246. a.ptrs = append(a.ptrs, ptr)
  247. return ptr_t(ptr)
  248. }
  249. func (a *arena) bytes(b []byte) ptr_t {
  250. if len(b) == 0 {
  251. return 0
  252. }
  253. ptr := a.new(int64(len(b)))
  254. util.WriteBytes(a.sqlt.mod, ptr, b)
  255. return ptr
  256. }
  257. func (a *arena) string(s string) ptr_t {
  258. ptr := a.new(int64(len(s)) + 1)
  259. util.WriteString(a.sqlt.mod, ptr, s)
  260. return ptr
  261. }
  262. func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
  263. util.ExportFuncII(env, "go_progress_handler", progressCallback)
  264. util.ExportFuncIII(env, "go_busy_timeout", timeoutCallback)
  265. util.ExportFuncIII(env, "go_busy_handler", busyCallback)
  266. util.ExportFuncII(env, "go_commit_hook", commitCallback)
  267. util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)
  268. util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback)
  269. util.ExportFuncIIIII(env, "go_wal_hook", walCallback)
  270. util.ExportFuncIIIII(env, "go_trace", traceCallback)
  271. util.ExportFuncIIIIII(env, "go_autovacuum_pages", autoVacuumCallback)
  272. util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback)
  273. util.ExportFuncVIII(env, "go_log", logCallback)
  274. util.ExportFuncVI(env, "go_destroy", destroyCallback)
  275. util.ExportFuncVIIII(env, "go_func", funcCallback)
  276. util.ExportFuncVIIIII(env, "go_step", stepCallback)
  277. util.ExportFuncVIIII(env, "go_value", valueCallback)
  278. util.ExportFuncVIIII(env, "go_inverse", inverseCallback)
  279. util.ExportFuncVIIII(env, "go_collation_needed", collationCallback)
  280. util.ExportFuncIIIIII(env, "go_compare", compareCallback)
  281. util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(xCreate))
  282. util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(xConnect))
  283. util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback)
  284. util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback)
  285. util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback)
  286. util.ExportFuncIIIII(env, "go_vtab_update", vtabUpdateCallback)
  287. util.ExportFuncIII(env, "go_vtab_rename", vtabRenameCallback)
  288. util.ExportFuncIIIII(env, "go_vtab_find_function", vtabFindFuncCallback)
  289. util.ExportFuncII(env, "go_vtab_begin", vtabBeginCallback)
  290. util.ExportFuncII(env, "go_vtab_sync", vtabSyncCallback)
  291. util.ExportFuncII(env, "go_vtab_commit", vtabCommitCallback)
  292. util.ExportFuncII(env, "go_vtab_rollback", vtabRollbackCallback)
  293. util.ExportFuncIII(env, "go_vtab_savepoint", vtabSavepointCallback)
  294. util.ExportFuncIII(env, "go_vtab_release", vtabReleaseCallback)
  295. util.ExportFuncIII(env, "go_vtab_rollback_to", vtabRollbackToCallback)
  296. util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback)
  297. util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback)
  298. util.ExportFuncII(env, "go_cur_close", cursorCloseCallback)
  299. util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback)
  300. util.ExportFuncII(env, "go_cur_next", cursorNextCallback)
  301. util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback)
  302. util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback)
  303. util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback)
  304. return env
  305. }