| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- package oauth
- import (
- "context"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "strings"
- "time"
- "github.com/QuantumNous/new-api/i18n"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/setting/system_setting"
- "github.com/gin-gonic/gin"
- "github.com/tidwall/gjson"
- )
- // AuthStyle defines how to send client credentials
- const (
- AuthStyleAutoDetect = 0 // Auto-detect based on server response
- AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
- AuthStyleInHeader = 2 // Send as Basic Auth header
- )
- // GenericOAuthProvider implements OAuth for custom/generic OAuth providers
- type GenericOAuthProvider struct {
- config *model.CustomOAuthProvider
- }
- // NewGenericOAuthProvider creates a new generic OAuth provider from config
- func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
- return &GenericOAuthProvider{config: config}
- }
- func (p *GenericOAuthProvider) GetName() string {
- return p.config.Name
- }
- func (p *GenericOAuthProvider) IsEnabled() bool {
- return p.config.Enabled
- }
- func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
- return p.config
- }
- func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
- if code == "" {
- return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
- }
- logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
- redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
- values := url.Values{}
- values.Set("grant_type", "authorization_code")
- values.Set("code", code)
- values.Set("redirect_uri", redirectUri)
- // Determine auth style
- authStyle := p.config.AuthStyle
- if authStyle == AuthStyleAutoDetect {
- // Default to params style for most OAuth servers
- authStyle = AuthStyleInParams
- }
- var req *http.Request
- var err error
- if authStyle == AuthStyleInParams {
- values.Set("client_id", p.config.ClientId)
- values.Set("client_secret", p.config.ClientSecret)
- }
- req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.Header.Set("Accept", "application/json")
- if authStyle == AuthStyleInHeader {
- // Basic Auth
- credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
- req.Header.Set("Authorization", "Basic "+credentials)
- }
- logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
- p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
- client := http.Client{
- Timeout: 20 * time.Second,
- }
- res, err := client.Do(req)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
- return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
- }
- defer res.Body.Close()
- logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
- body, err := io.ReadAll(res.Body)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
- return nil, err
- }
- bodyStr := string(body)
- logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
- // Try to parse as JSON first
- var tokenResponse struct {
- AccessToken string `json:"access_token"`
- TokenType string `json:"token_type"`
- RefreshToken string `json:"refresh_token"`
- ExpiresIn int `json:"expires_in"`
- Scope string `json:"scope"`
- IDToken string `json:"id_token"`
- Error string `json:"error"`
- ErrorDesc string `json:"error_description"`
- }
- if err := json.Unmarshal(body, &tokenResponse); err != nil {
- // Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
- parsedValues, parseErr := url.ParseQuery(bodyStr)
- if parseErr != nil {
- logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
- return nil, err
- }
- tokenResponse.AccessToken = parsedValues.Get("access_token")
- tokenResponse.TokenType = parsedValues.Get("token_type")
- tokenResponse.Scope = parsedValues.Get("scope")
- }
- if tokenResponse.Error != "" {
- logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
- p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
- return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
- }
- if tokenResponse.AccessToken == "" {
- logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
- return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
- }
- logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
- return &OAuthToken{
- AccessToken: tokenResponse.AccessToken,
- TokenType: tokenResponse.TokenType,
- RefreshToken: tokenResponse.RefreshToken,
- ExpiresIn: tokenResponse.ExpiresIn,
- Scope: tokenResponse.Scope,
- IDToken: tokenResponse.IDToken,
- }, nil
- }
- func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
- logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
- req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
- if err != nil {
- return nil, err
- }
- // Set authorization header
- tokenType := token.TokenType
- if tokenType == "" {
- tokenType = "Bearer"
- }
- req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
- req.Header.Set("Accept", "application/json")
- client := http.Client{
- Timeout: 20 * time.Second,
- }
- res, err := client.Do(req)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
- return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
- }
- defer res.Body.Close()
- logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
- if res.StatusCode != http.StatusOK {
- logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
- return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
- }
- body, err := io.ReadAll(res.Body)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
- return nil, err
- }
- bodyStr := string(body)
- logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
- // Extract fields using gjson (supports JSONPath-like syntax)
- userId := gjson.Get(bodyStr, p.config.UserIdField).String()
- username := gjson.Get(bodyStr, p.config.UsernameField).String()
- displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
- email := gjson.Get(bodyStr, p.config.EmailField).String()
- // If user ID field returns a number, convert it
- if userId == "" {
- // Try to get as number
- userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
- if userIdNum.Exists() {
- userId = userIdNum.Raw
- // Remove quotes if present
- userId = strings.Trim(userId, "\"")
- }
- }
- if userId == "" {
- logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
- return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
- }
- logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
- p.config.Slug, userId, username, displayName, email)
- return &OAuthUser{
- ProviderUserID: userId,
- Username: username,
- DisplayName: displayName,
- Email: email,
- }, nil
- }
- func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
- return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
- }
- func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
- foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
- if err != nil {
- return err
- }
- *user = *foundUser
- return nil
- }
- func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
- // For generic providers, we store the binding in user_oauth_bindings table
- // This is handled separately in the OAuth controller
- }
- func (p *GenericOAuthProvider) GetProviderPrefix() string {
- return p.config.Slug + "_"
- }
- // GetProviderId returns the provider ID for binding purposes
- func (p *GenericOAuthProvider) GetProviderId() int {
- return p.config.Id
- }
- // IsGenericProvider returns true for generic providers
- func (p *GenericOAuthProvider) IsGenericProvider() bool {
- return true
- }
|