oidc.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. package oauth
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "net/url"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/i18n"
  11. "github.com/QuantumNous/new-api/logger"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/QuantumNous/new-api/setting/system_setting"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func init() {
  17. Register("oidc", &OIDCProvider{})
  18. }
  19. // OIDCProvider implements OAuth for OIDC
  20. type OIDCProvider struct{}
  21. type oidcOAuthResponse struct {
  22. AccessToken string `json:"access_token"`
  23. IDToken string `json:"id_token"`
  24. RefreshToken string `json:"refresh_token"`
  25. TokenType string `json:"token_type"`
  26. ExpiresIn int `json:"expires_in"`
  27. Scope string `json:"scope"`
  28. }
  29. type oidcUser struct {
  30. OpenID string `json:"sub"`
  31. Email string `json:"email"`
  32. Name string `json:"name"`
  33. PreferredUsername string `json:"preferred_username"`
  34. Picture string `json:"picture"`
  35. }
  36. func (p *OIDCProvider) GetName() string {
  37. return "OIDC"
  38. }
  39. func (p *OIDCProvider) IsEnabled() bool {
  40. return system_setting.GetOIDCSettings().Enabled
  41. }
  42. func (p *OIDCProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
  43. if code == "" {
  44. return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
  45. }
  46. logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: code=%s...", code[:min(len(code), 10)])
  47. settings := system_setting.GetOIDCSettings()
  48. redirectUri := fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress)
  49. values := url.Values{}
  50. values.Set("client_id", settings.ClientId)
  51. values.Set("client_secret", settings.ClientSecret)
  52. values.Set("code", code)
  53. values.Set("grant_type", "authorization_code")
  54. values.Set("redirect_uri", redirectUri)
  55. logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: token_endpoint=%s, redirect_uri=%s", settings.TokenEndpoint, redirectUri)
  56. req, err := http.NewRequestWithContext(ctx, "POST", settings.TokenEndpoint, strings.NewReader(values.Encode()))
  57. if err != nil {
  58. return nil, err
  59. }
  60. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  61. req.Header.Set("Accept", "application/json")
  62. client := http.Client{
  63. Timeout: 5 * time.Second,
  64. }
  65. res, err := client.Do(req)
  66. if err != nil {
  67. logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken error: %s", err.Error()))
  68. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
  69. }
  70. defer res.Body.Close()
  71. logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken response status: %d", res.StatusCode)
  72. var oidcResponse oidcOAuthResponse
  73. err = json.NewDecoder(res.Body).Decode(&oidcResponse)
  74. if err != nil {
  75. logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken decode error: %s", err.Error()))
  76. return nil, err
  77. }
  78. if oidcResponse.AccessToken == "" {
  79. logger.LogError(ctx, "[OAuth-OIDC] ExchangeToken failed: empty access token")
  80. return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "OIDC"})
  81. }
  82. logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken success: scope=%s", oidcResponse.Scope)
  83. return &OAuthToken{
  84. AccessToken: oidcResponse.AccessToken,
  85. TokenType: oidcResponse.TokenType,
  86. RefreshToken: oidcResponse.RefreshToken,
  87. ExpiresIn: oidcResponse.ExpiresIn,
  88. Scope: oidcResponse.Scope,
  89. IDToken: oidcResponse.IDToken,
  90. }, nil
  91. }
  92. func (p *OIDCProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
  93. settings := system_setting.GetOIDCSettings()
  94. logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo: userinfo_endpoint=%s", settings.UserInfoEndpoint)
  95. req, err := http.NewRequestWithContext(ctx, "GET", settings.UserInfoEndpoint, nil)
  96. if err != nil {
  97. return nil, err
  98. }
  99. req.Header.Set("Authorization", "Bearer "+token.AccessToken)
  100. client := http.Client{
  101. Timeout: 5 * time.Second,
  102. }
  103. res, err := client.Do(req)
  104. if err != nil {
  105. logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo error: %s", err.Error()))
  106. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
  107. }
  108. defer res.Body.Close()
  109. logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo response status: %d", res.StatusCode)
  110. if res.StatusCode != http.StatusOK {
  111. logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: status=%d", res.StatusCode))
  112. return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
  113. }
  114. var oidcUser oidcUser
  115. err = json.NewDecoder(res.Body).Decode(&oidcUser)
  116. if err != nil {
  117. logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo decode error: %s", err.Error()))
  118. return nil, err
  119. }
  120. if oidcUser.OpenID == "" || oidcUser.Email == "" {
  121. logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: empty fields (sub=%s, email=%s)", oidcUser.OpenID, oidcUser.Email))
  122. return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "OIDC"})
  123. }
  124. logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo success: sub=%s, username=%s, name=%s, email=%s", oidcUser.OpenID, oidcUser.PreferredUsername, oidcUser.Name, oidcUser.Email)
  125. return &OAuthUser{
  126. ProviderUserID: oidcUser.OpenID,
  127. Username: oidcUser.PreferredUsername,
  128. DisplayName: oidcUser.Name,
  129. Email: oidcUser.Email,
  130. }, nil
  131. }
  132. func (p *OIDCProvider) IsUserIDTaken(providerUserID string) bool {
  133. return model.IsOidcIdAlreadyTaken(providerUserID)
  134. }
  135. func (p *OIDCProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
  136. user.OidcId = providerUserID
  137. return user.FillUserByOidcId()
  138. }
  139. func (p *OIDCProvider) SetProviderUserID(user *model.User, providerUserID string) {
  140. user.OidcId = providerUserID
  141. }
  142. func (p *OIDCProvider) GetProviderPrefix() string {
  143. return "oidc_"
  144. }