| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- package service
- import (
- "context"
- "errors"
- "fmt"
- "strings"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/model"
- )
- type CodexCredentialRefreshOptions struct {
- ResetCaches bool
- }
- type CodexOAuthKey struct {
- IDToken string `json:"id_token,omitempty"`
- AccessToken string `json:"access_token,omitempty"`
- RefreshToken string `json:"refresh_token,omitempty"`
- AccountID string `json:"account_id,omitempty"`
- LastRefresh string `json:"last_refresh,omitempty"`
- Email string `json:"email,omitempty"`
- Type string `json:"type,omitempty"`
- Expired string `json:"expired,omitempty"`
- }
- func parseCodexOAuthKey(raw string) (*CodexOAuthKey, error) {
- if strings.TrimSpace(raw) == "" {
- return nil, errors.New("codex channel: empty oauth key")
- }
- var key CodexOAuthKey
- if err := common.Unmarshal([]byte(raw), &key); err != nil {
- return nil, errors.New("codex channel: invalid oauth key json")
- }
- return &key, nil
- }
- func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts CodexCredentialRefreshOptions) (*CodexOAuthKey, *model.Channel, error) {
- ch, err := model.GetChannelById(channelID, true)
- if err != nil {
- return nil, nil, err
- }
- if ch == nil {
- return nil, nil, fmt.Errorf("channel not found")
- }
- if ch.Type != constant.ChannelTypeCodex {
- return nil, nil, fmt.Errorf("channel type is not Codex")
- }
- oauthKey, err := parseCodexOAuthKey(strings.TrimSpace(ch.Key))
- if err != nil {
- return nil, nil, err
- }
- if strings.TrimSpace(oauthKey.RefreshToken) == "" {
- return nil, nil, fmt.Errorf("codex channel: refresh_token is required to refresh credential")
- }
- refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
- res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
- if err != nil {
- return nil, nil, err
- }
- oauthKey.AccessToken = res.AccessToken
- oauthKey.RefreshToken = res.RefreshToken
- oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
- oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
- if strings.TrimSpace(oauthKey.Type) == "" {
- oauthKey.Type = "codex"
- }
- if strings.TrimSpace(oauthKey.AccountID) == "" {
- if accountID, ok := ExtractCodexAccountIDFromJWT(oauthKey.AccessToken); ok {
- oauthKey.AccountID = accountID
- }
- }
- if strings.TrimSpace(oauthKey.Email) == "" {
- if email, ok := ExtractEmailFromJWT(oauthKey.AccessToken); ok {
- oauthKey.Email = email
- }
- }
- encoded, err := common.Marshal(oauthKey)
- if err != nil {
- return nil, nil, err
- }
- if err := model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error; err != nil {
- return nil, nil, err
- }
- if opts.ResetCaches {
- model.InitChannelCache()
- ResetProxyClientCache()
- }
- return oauthKey, ch, nil
- }
|