codex_usage.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/QuantumNous/new-api/relay/channel/codex"
  14. "github.com/QuantumNous/new-api/service"
  15. "github.com/gin-gonic/gin"
  16. )
  17. func GetCodexChannelUsage(c *gin.Context) {
  18. channelId, err := strconv.Atoi(c.Param("id"))
  19. if err != nil {
  20. common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
  21. return
  22. }
  23. ch, err := model.GetChannelById(channelId, true)
  24. if err != nil {
  25. common.ApiError(c, err)
  26. return
  27. }
  28. if ch == nil {
  29. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
  30. return
  31. }
  32. if ch.Type != constant.ChannelTypeCodex {
  33. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
  34. return
  35. }
  36. if ch.ChannelInfo.IsMultiKey {
  37. c.JSON(http.StatusOK, gin.H{"success": false, "message": "multi-key channel is not supported"})
  38. return
  39. }
  40. oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key))
  41. if err != nil {
  42. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  43. return
  44. }
  45. accessToken := strings.TrimSpace(oauthKey.AccessToken)
  46. accountID := strings.TrimSpace(oauthKey.AccountID)
  47. if accessToken == "" {
  48. c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: access_token is required"})
  49. return
  50. }
  51. if accountID == "" {
  52. c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: account_id is required"})
  53. return
  54. }
  55. client, err := service.NewProxyHttpClient(ch.GetSetting().Proxy)
  56. if err != nil {
  57. common.ApiError(c, err)
  58. return
  59. }
  60. ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
  61. defer cancel()
  62. statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID)
  63. if err != nil {
  64. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  65. return
  66. }
  67. if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) && strings.TrimSpace(oauthKey.RefreshToken) != "" {
  68. refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
  69. defer refreshCancel()
  70. res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
  71. if refreshErr == nil {
  72. oauthKey.AccessToken = res.AccessToken
  73. oauthKey.RefreshToken = res.RefreshToken
  74. oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
  75. oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
  76. if strings.TrimSpace(oauthKey.Type) == "" {
  77. oauthKey.Type = "codex"
  78. }
  79. encoded, encErr := common.Marshal(oauthKey)
  80. if encErr == nil {
  81. _ = model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error
  82. model.InitChannelCache()
  83. service.ResetProxyClientCache()
  84. }
  85. ctx2, cancel2 := context.WithTimeout(c.Request.Context(), 15*time.Second)
  86. defer cancel2()
  87. statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID)
  88. if err != nil {
  89. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  90. return
  91. }
  92. }
  93. }
  94. var payload any
  95. if json.Unmarshal(body, &payload) != nil {
  96. payload = string(body)
  97. }
  98. ok := statusCode >= 200 && statusCode < 300
  99. resp := gin.H{
  100. "success": ok,
  101. "message": "",
  102. "upstream_status": statusCode,
  103. "data": payload,
  104. }
  105. if !ok {
  106. resp["message"] = fmt.Sprintf("upstream status: %d", statusCode)
  107. }
  108. c.JSON(http.StatusOK, resp)
  109. }