| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- package controller
- import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "strconv"
- "strings"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/relay/channel/codex"
- "github.com/QuantumNous/new-api/service"
- "github.com/gin-gonic/gin"
- )
- func GetCodexChannelUsage(c *gin.Context) {
- channelId, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
- return
- }
- ch, err := model.GetChannelById(channelId, true)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if ch == nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
- return
- }
- if ch.Type != constant.ChannelTypeCodex {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
- return
- }
- if ch.ChannelInfo.IsMultiKey {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "multi-key channel is not supported"})
- return
- }
- oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key))
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- accessToken := strings.TrimSpace(oauthKey.AccessToken)
- accountID := strings.TrimSpace(oauthKey.AccountID)
- if accessToken == "" {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: access_token is required"})
- return
- }
- if accountID == "" {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: account_id is required"})
- return
- }
- client, err := service.NewProxyHttpClient(ch.GetSetting().Proxy)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
- defer cancel()
- statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) && strings.TrimSpace(oauthKey.RefreshToken) != "" {
- refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
- defer refreshCancel()
- res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
- if refreshErr == nil {
- oauthKey.AccessToken = res.AccessToken
- oauthKey.RefreshToken = res.RefreshToken
- oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
- oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
- if strings.TrimSpace(oauthKey.Type) == "" {
- oauthKey.Type = "codex"
- }
- encoded, encErr := common.Marshal(oauthKey)
- if encErr == nil {
- _ = model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error
- model.InitChannelCache()
- service.ResetProxyClientCache()
- }
- ctx2, cancel2 := context.WithTimeout(c.Request.Context(), 15*time.Second)
- defer cancel2()
- statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- }
- }
- var payload any
- if json.Unmarshal(body, &payload) != nil {
- payload = string(body)
- }
- ok := statusCode >= 200 && statusCode < 300
- resp := gin.H{
- "success": ok,
- "message": "",
- "upstream_status": statusCode,
- "data": payload,
- }
- if !ok {
- resp["message"] = fmt.Sprintf("upstream status: %d", statusCode)
- }
- c.JSON(http.StatusOK, resp)
- }
|