postgres.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package database
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strings"
  8. "sync"
  9. "time"
  10. "github.com/ding113/claude-code-hub/internal/config"
  11. "github.com/ding113/claude-code-hub/internal/pkg/logger"
  12. "github.com/uptrace/bun"
  13. "github.com/uptrace/bun/dialect/pgdialect"
  14. "github.com/uptrace/bun/driver/pgdriver"
  15. )
  16. var (
  17. // 单例模式,与 Node.js 版本保持一致
  18. dbInstance *bun.DB
  19. dbOnce sync.Once
  20. dbErr error
  21. )
  22. // PostgresDB 封装 PostgreSQL 数据库连接
  23. type PostgresDB struct {
  24. DB *bun.DB
  25. cfg config.DatabaseConfig
  26. }
  27. // NewPostgres 创建 PostgreSQL 数据库连接
  28. // 支持两种配置方式:
  29. // 1. DSN 连接字符串(优先)
  30. // 2. 分离的配置字段
  31. func NewPostgres(cfg config.DatabaseConfig) (*bun.DB, error) {
  32. // 获取 DSN
  33. dsn := cfg.DSN
  34. if dsn == "" {
  35. // 如果没有 DSN,则从分离的配置字段构建
  36. dsn = buildDSN(cfg)
  37. }
  38. // 验证 DSN 不为空
  39. if dsn == "" {
  40. return nil, errors.New("DSN environment variable is not set")
  41. }
  42. // 检查是否为占位符模板(与 Node.js 版本保持一致)
  43. if strings.Contains(dsn, "user:password@host:port") {
  44. return nil, errors.New("DSN contains placeholder template, please set a valid DSN")
  45. }
  46. // 创建连接器
  47. connector := pgdriver.NewConnector(
  48. pgdriver.WithDSN(dsn),
  49. pgdriver.WithDialTimeout(cfg.ConnectTimeout),
  50. pgdriver.WithReadTimeout(cfg.IdleTimeout), // 读取超时使用空闲超时
  51. )
  52. // 创建 sql.DB
  53. sqlDB := sql.OpenDB(connector)
  54. // 设置连接池参数
  55. // MaxOpenConns: 最大打开连接数
  56. // - 与 Node.js 版本的 max 参数对应
  57. sqlDB.SetMaxOpenConns(cfg.PoolMax)
  58. // MaxIdleConns: 最大空闲连接数
  59. sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
  60. // ConnMaxLifetime: 连接最大生命周期
  61. sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
  62. // ConnMaxIdleTime: 空闲连接最大存活时间
  63. // - 与 Node.js 版本的 idle_timeout 参数对应
  64. sqlDB.SetConnMaxIdleTime(cfg.IdleTimeout)
  65. // 创建 Bun DB
  66. db := bun.NewDB(sqlDB, pgdialect.New())
  67. // 测试连接
  68. ctx, cancel := context.WithTimeout(context.Background(), cfg.ConnectTimeout)
  69. defer cancel()
  70. if err := db.PingContext(ctx); err != nil {
  71. return nil, fmt.Errorf("failed to ping database: %w", err)
  72. }
  73. // 记录连接信息(隐藏敏感信息)
  74. logDSN := sanitizeDSN(dsn)
  75. logger.Info().
  76. Str("dsn", logDSN).
  77. Int("pool_max", cfg.PoolMax).
  78. Int("max_idle_conns", cfg.MaxIdleConns).
  79. Dur("idle_timeout", cfg.IdleTimeout).
  80. Dur("connect_timeout", cfg.ConnectTimeout).
  81. Dur("conn_max_lifetime", cfg.ConnMaxLifetime).
  82. Msg("PostgreSQL connected")
  83. return db, nil
  84. }
  85. // GetDB 获取数据库单例(懒加载)
  86. // 与 Node.js 版本的 getDb() 函数对应
  87. func GetDB(cfg config.DatabaseConfig) (*bun.DB, error) {
  88. dbOnce.Do(func() {
  89. dbInstance, dbErr = NewPostgres(cfg)
  90. })
  91. return dbInstance, dbErr
  92. }
  93. // ClosePostgres 关闭数据库连接
  94. func ClosePostgres(db *bun.DB) error {
  95. if db != nil {
  96. logger.Info().Msg("Closing PostgreSQL connection")
  97. return db.Close()
  98. }
  99. return nil
  100. }
  101. // HealthCheck 健康检查
  102. // 返回数据库连接状态和统计信息
  103. func HealthCheck(ctx context.Context, db *bun.DB) (*HealthStatus, error) {
  104. if db == nil {
  105. return nil, errors.New("database connection is nil")
  106. }
  107. status := &HealthStatus{
  108. Healthy: false,
  109. Timestamp: time.Now(),
  110. }
  111. // 执行 ping 检查
  112. start := time.Now()
  113. err := db.PingContext(ctx)
  114. status.Latency = time.Since(start)
  115. if err != nil {
  116. status.Error = err.Error()
  117. return status, err
  118. }
  119. status.Healthy = true
  120. // 获取连接池统计信息
  121. sqlDB := db.DB
  122. stats := sqlDB.Stats()
  123. status.Stats = &PoolStats{
  124. MaxOpenConnections: stats.MaxOpenConnections,
  125. OpenConnections: stats.OpenConnections,
  126. InUse: stats.InUse,
  127. Idle: stats.Idle,
  128. WaitCount: stats.WaitCount,
  129. WaitDuration: stats.WaitDuration,
  130. MaxIdleClosed: stats.MaxIdleClosed,
  131. MaxLifetimeClosed: stats.MaxLifetimeClosed,
  132. }
  133. return status, nil
  134. }
  135. // HealthStatus 健康检查状态
  136. type HealthStatus struct {
  137. Healthy bool `json:"healthy"`
  138. Latency time.Duration `json:"latency"`
  139. Error string `json:"error,omitempty"`
  140. Timestamp time.Time `json:"timestamp"`
  141. Stats *PoolStats `json:"stats,omitempty"`
  142. }
  143. // PoolStats 连接池统计信息
  144. type PoolStats struct {
  145. MaxOpenConnections int `json:"max_open_connections"`
  146. OpenConnections int `json:"open_connections"`
  147. InUse int `json:"in_use"`
  148. Idle int `json:"idle"`
  149. WaitCount int64 `json:"wait_count"`
  150. WaitDuration time.Duration `json:"wait_duration"`
  151. MaxIdleClosed int64 `json:"max_idle_closed"`
  152. MaxLifetimeClosed int64 `json:"max_lifetime_closed"`
  153. }
  154. // buildDSN 从分离的配置字段构建 DSN
  155. func buildDSN(cfg config.DatabaseConfig) string {
  156. if cfg.Host == "" {
  157. return ""
  158. }
  159. return fmt.Sprintf(
  160. "postgres://%s:%s@%s:%d/%s?sslmode=%s",
  161. cfg.User,
  162. cfg.Password,
  163. cfg.Host,
  164. cfg.Port,
  165. cfg.DBName,
  166. cfg.SSLMode,
  167. )
  168. }
  169. // sanitizeDSN 清理 DSN 中的敏感信息(用于日志)
  170. func sanitizeDSN(dsn string) string {
  171. // 简单处理:隐藏密码部分
  172. // postgres://user:password@host:port/dbname -> postgres://user:***@host:port/dbname
  173. if !strings.Contains(dsn, "://") {
  174. return dsn
  175. }
  176. parts := strings.SplitN(dsn, "://", 2)
  177. if len(parts) != 2 {
  178. return dsn
  179. }
  180. protocol := parts[0]
  181. rest := parts[1]
  182. // 查找 @ 符号
  183. atIndex := strings.Index(rest, "@")
  184. if atIndex == -1 {
  185. return dsn
  186. }
  187. userPass := rest[:atIndex]
  188. hostAndRest := rest[atIndex:]
  189. // 查找密码部分
  190. colonIndex := strings.Index(userPass, ":")
  191. if colonIndex == -1 {
  192. return dsn
  193. }
  194. user := userPass[:colonIndex]
  195. return fmt.Sprintf("%s://%s:***%s", protocol, user, hostAndRest)
  196. }