codex_usage.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/QuantumNous/new-api/relay/channel/codex"
  14. "github.com/QuantumNous/new-api/service"
  15. "github.com/gin-gonic/gin"
  16. )
  17. func GetCodexChannelUsage(c *gin.Context) {
  18. channelId, err := strconv.Atoi(c.Param("id"))
  19. if err != nil {
  20. common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
  21. return
  22. }
  23. ch, err := model.GetChannelById(channelId, true)
  24. if err != nil {
  25. common.ApiError(c, err)
  26. return
  27. }
  28. if ch == nil {
  29. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
  30. return
  31. }
  32. if ch.Type != constant.ChannelTypeCodex {
  33. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
  34. return
  35. }
  36. if ch.ChannelInfo.IsMultiKey {
  37. c.JSON(http.StatusOK, gin.H{"success": false, "message": "multi-key channel is not supported"})
  38. return
  39. }
  40. oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key))
  41. if err != nil {
  42. common.SysError("failed to parse oauth key: " + err.Error())
  43. c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析凭证失败,请检查渠道配置"})
  44. return
  45. }
  46. accessToken := strings.TrimSpace(oauthKey.AccessToken)
  47. accountID := strings.TrimSpace(oauthKey.AccountID)
  48. if accessToken == "" {
  49. c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: access_token is required"})
  50. return
  51. }
  52. if accountID == "" {
  53. c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: account_id is required"})
  54. return
  55. }
  56. client, err := service.NewProxyHttpClient(ch.GetSetting().Proxy)
  57. if err != nil {
  58. common.ApiError(c, err)
  59. return
  60. }
  61. ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
  62. defer cancel()
  63. statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID)
  64. if err != nil {
  65. common.SysError("failed to fetch codex usage: " + err.Error())
  66. c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"})
  67. return
  68. }
  69. if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) && strings.TrimSpace(oauthKey.RefreshToken) != "" {
  70. refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
  71. defer refreshCancel()
  72. res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
  73. if refreshErr == nil {
  74. oauthKey.AccessToken = res.AccessToken
  75. oauthKey.RefreshToken = res.RefreshToken
  76. oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
  77. oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
  78. if strings.TrimSpace(oauthKey.Type) == "" {
  79. oauthKey.Type = "codex"
  80. }
  81. encoded, encErr := common.Marshal(oauthKey)
  82. if encErr == nil {
  83. _ = model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error
  84. model.InitChannelCache()
  85. service.ResetProxyClientCache()
  86. }
  87. ctx2, cancel2 := context.WithTimeout(c.Request.Context(), 15*time.Second)
  88. defer cancel2()
  89. statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID)
  90. if err != nil {
  91. common.SysError("failed to fetch codex usage after refresh: " + err.Error())
  92. c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"})
  93. return
  94. }
  95. }
  96. }
  97. var payload any
  98. if json.Unmarshal(body, &payload) != nil {
  99. payload = string(body)
  100. }
  101. ok := statusCode >= 200 && statusCode < 300
  102. resp := gin.H{
  103. "success": ok,
  104. "message": "",
  105. "upstream_status": statusCode,
  106. "data": payload,
  107. }
  108. if !ok {
  109. resp["message"] = fmt.Sprintf("upstream status: %d", statusCode)
  110. }
  111. c.JSON(http.StatusOK, resp)
  112. }