generic.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. package oauth
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "strings"
  11. "time"
  12. "github.com/QuantumNous/new-api/i18n"
  13. "github.com/QuantumNous/new-api/logger"
  14. "github.com/QuantumNous/new-api/model"
  15. "github.com/QuantumNous/new-api/setting/system_setting"
  16. "github.com/gin-gonic/gin"
  17. "github.com/tidwall/gjson"
  18. )
  19. // AuthStyle defines how to send client credentials
  20. const (
  21. AuthStyleAutoDetect = 0 // Auto-detect based on server response
  22. AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
  23. AuthStyleInHeader = 2 // Send as Basic Auth header
  24. )
  25. // GenericOAuthProvider implements OAuth for custom/generic OAuth providers
  26. type GenericOAuthProvider struct {
  27. config *model.CustomOAuthProvider
  28. }
  29. // NewGenericOAuthProvider creates a new generic OAuth provider from config
  30. func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
  31. return &GenericOAuthProvider{config: config}
  32. }
  33. func (p *GenericOAuthProvider) GetName() string {
  34. return p.config.Name
  35. }
  36. func (p *GenericOAuthProvider) IsEnabled() bool {
  37. return p.config.Enabled
  38. }
  39. func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
  40. return p.config
  41. }
  42. func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
  43. if code == "" {
  44. return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
  45. }
  46. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
  47. redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
  48. values := url.Values{}
  49. values.Set("grant_type", "authorization_code")
  50. values.Set("code", code)
  51. values.Set("redirect_uri", redirectUri)
  52. // Determine auth style
  53. authStyle := p.config.AuthStyle
  54. if authStyle == AuthStyleAutoDetect {
  55. // Default to params style for most OAuth servers
  56. authStyle = AuthStyleInParams
  57. }
  58. var req *http.Request
  59. var err error
  60. if authStyle == AuthStyleInParams {
  61. values.Set("client_id", p.config.ClientId)
  62. values.Set("client_secret", p.config.ClientSecret)
  63. }
  64. req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
  65. if err != nil {
  66. return nil, err
  67. }
  68. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  69. req.Header.Set("Accept", "application/json")
  70. if authStyle == AuthStyleInHeader {
  71. // Basic Auth
  72. credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
  73. req.Header.Set("Authorization", "Basic "+credentials)
  74. }
  75. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
  76. p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
  77. client := http.Client{
  78. Timeout: 20 * time.Second,
  79. }
  80. res, err := client.Do(req)
  81. if err != nil {
  82. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
  83. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
  84. }
  85. defer res.Body.Close()
  86. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
  87. body, err := io.ReadAll(res.Body)
  88. if err != nil {
  89. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
  90. return nil, err
  91. }
  92. bodyStr := string(body)
  93. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
  94. // Try to parse as JSON first
  95. var tokenResponse struct {
  96. AccessToken string `json:"access_token"`
  97. TokenType string `json:"token_type"`
  98. RefreshToken string `json:"refresh_token"`
  99. ExpiresIn int `json:"expires_in"`
  100. Scope string `json:"scope"`
  101. IDToken string `json:"id_token"`
  102. Error string `json:"error"`
  103. ErrorDesc string `json:"error_description"`
  104. }
  105. if err := json.Unmarshal(body, &tokenResponse); err != nil {
  106. // Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
  107. parsedValues, parseErr := url.ParseQuery(bodyStr)
  108. if parseErr != nil {
  109. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
  110. return nil, err
  111. }
  112. tokenResponse.AccessToken = parsedValues.Get("access_token")
  113. tokenResponse.TokenType = parsedValues.Get("token_type")
  114. tokenResponse.Scope = parsedValues.Get("scope")
  115. }
  116. if tokenResponse.Error != "" {
  117. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
  118. p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
  119. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
  120. }
  121. if tokenResponse.AccessToken == "" {
  122. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
  123. return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
  124. }
  125. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
  126. return &OAuthToken{
  127. AccessToken: tokenResponse.AccessToken,
  128. TokenType: tokenResponse.TokenType,
  129. RefreshToken: tokenResponse.RefreshToken,
  130. ExpiresIn: tokenResponse.ExpiresIn,
  131. Scope: tokenResponse.Scope,
  132. IDToken: tokenResponse.IDToken,
  133. }, nil
  134. }
  135. func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
  136. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
  137. req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
  138. if err != nil {
  139. return nil, err
  140. }
  141. // Set authorization header
  142. tokenType := token.TokenType
  143. if tokenType == "" {
  144. tokenType = "Bearer"
  145. }
  146. req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
  147. req.Header.Set("Accept", "application/json")
  148. client := http.Client{
  149. Timeout: 20 * time.Second,
  150. }
  151. res, err := client.Do(req)
  152. if err != nil {
  153. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
  154. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
  155. }
  156. defer res.Body.Close()
  157. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
  158. if res.StatusCode != http.StatusOK {
  159. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
  160. return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
  161. }
  162. body, err := io.ReadAll(res.Body)
  163. if err != nil {
  164. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
  165. return nil, err
  166. }
  167. bodyStr := string(body)
  168. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
  169. // Extract fields using gjson (supports JSONPath-like syntax)
  170. userId := gjson.Get(bodyStr, p.config.UserIdField).String()
  171. username := gjson.Get(bodyStr, p.config.UsernameField).String()
  172. displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
  173. email := gjson.Get(bodyStr, p.config.EmailField).String()
  174. // If user ID field returns a number, convert it
  175. if userId == "" {
  176. // Try to get as number
  177. userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
  178. if userIdNum.Exists() {
  179. userId = userIdNum.Raw
  180. // Remove quotes if present
  181. userId = strings.Trim(userId, "\"")
  182. }
  183. }
  184. if userId == "" {
  185. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
  186. return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
  187. }
  188. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
  189. p.config.Slug, userId, username, displayName, email)
  190. return &OAuthUser{
  191. ProviderUserID: userId,
  192. Username: username,
  193. DisplayName: displayName,
  194. Email: email,
  195. }, nil
  196. }
  197. func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
  198. return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
  199. }
  200. func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
  201. foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
  202. if err != nil {
  203. return err
  204. }
  205. *user = *foundUser
  206. return nil
  207. }
  208. func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
  209. // For generic providers, we store the binding in user_oauth_bindings table
  210. // This is handled separately in the OAuth controller
  211. }
  212. func (p *GenericOAuthProvider) GetProviderPrefix() string {
  213. return p.config.Slug + "_"
  214. }
  215. // GetProviderId returns the provider ID for binding purposes
  216. func (p *GenericOAuthProvider) GetProviderId() int {
  217. return p.config.Id
  218. }
  219. // IsGenericProvider returns true for generic providers
  220. func (p *GenericOAuthProvider) IsGenericProvider() bool {
  221. return true
  222. }