codex_credential_refresh.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/model"
  11. )
  12. type CodexCredentialRefreshOptions struct {
  13. ResetCaches bool
  14. }
  15. type CodexOAuthKey struct {
  16. IDToken string `json:"id_token,omitempty"`
  17. AccessToken string `json:"access_token,omitempty"`
  18. RefreshToken string `json:"refresh_token,omitempty"`
  19. AccountID string `json:"account_id,omitempty"`
  20. LastRefresh string `json:"last_refresh,omitempty"`
  21. Email string `json:"email,omitempty"`
  22. Type string `json:"type,omitempty"`
  23. Expired string `json:"expired,omitempty"`
  24. }
  25. func parseCodexOAuthKey(raw string) (*CodexOAuthKey, error) {
  26. if strings.TrimSpace(raw) == "" {
  27. return nil, errors.New("codex channel: empty oauth key")
  28. }
  29. var key CodexOAuthKey
  30. if err := common.Unmarshal([]byte(raw), &key); err != nil {
  31. return nil, errors.New("codex channel: invalid oauth key json")
  32. }
  33. return &key, nil
  34. }
  35. func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts CodexCredentialRefreshOptions) (*CodexOAuthKey, *model.Channel, error) {
  36. ch, err := model.GetChannelById(channelID, true)
  37. if err != nil {
  38. return nil, nil, err
  39. }
  40. if ch == nil {
  41. return nil, nil, fmt.Errorf("channel not found")
  42. }
  43. if ch.Type != constant.ChannelTypeCodex {
  44. return nil, nil, fmt.Errorf("channel type is not Codex")
  45. }
  46. oauthKey, err := parseCodexOAuthKey(strings.TrimSpace(ch.Key))
  47. if err != nil {
  48. return nil, nil, err
  49. }
  50. if strings.TrimSpace(oauthKey.RefreshToken) == "" {
  51. return nil, nil, fmt.Errorf("codex channel: refresh_token is required to refresh credential")
  52. }
  53. refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
  54. defer cancel()
  55. res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
  56. if err != nil {
  57. return nil, nil, err
  58. }
  59. oauthKey.AccessToken = res.AccessToken
  60. oauthKey.RefreshToken = res.RefreshToken
  61. oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
  62. oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
  63. if strings.TrimSpace(oauthKey.Type) == "" {
  64. oauthKey.Type = "codex"
  65. }
  66. if strings.TrimSpace(oauthKey.AccountID) == "" {
  67. if accountID, ok := ExtractCodexAccountIDFromJWT(oauthKey.AccessToken); ok {
  68. oauthKey.AccountID = accountID
  69. }
  70. }
  71. if strings.TrimSpace(oauthKey.Email) == "" {
  72. if email, ok := ExtractEmailFromJWT(oauthKey.AccessToken); ok {
  73. oauthKey.Email = email
  74. }
  75. }
  76. encoded, err := common.Marshal(oauthKey)
  77. if err != nil {
  78. return nil, nil, err
  79. }
  80. if err := model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error; err != nil {
  81. return nil, nil, err
  82. }
  83. if opts.ResetCaches {
  84. model.InitChannelCache()
  85. ResetProxyClientCache()
  86. }
  87. return oauthKey, ch, nil
  88. }