discord.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package oauth
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "net/url"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/i18n"
  11. "github.com/QuantumNous/new-api/logger"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/QuantumNous/new-api/setting/system_setting"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func init() {
  17. Register("discord", &DiscordProvider{})
  18. }
  19. // DiscordProvider implements OAuth for Discord
  20. type DiscordProvider struct{}
  21. type discordOAuthResponse struct {
  22. AccessToken string `json:"access_token"`
  23. IDToken string `json:"id_token"`
  24. RefreshToken string `json:"refresh_token"`
  25. TokenType string `json:"token_type"`
  26. ExpiresIn int `json:"expires_in"`
  27. Scope string `json:"scope"`
  28. }
  29. type discordUser struct {
  30. UID string `json:"id"`
  31. ID string `json:"username"`
  32. Name string `json:"global_name"`
  33. }
  34. func (p *DiscordProvider) GetName() string {
  35. return "Discord"
  36. }
  37. func (p *DiscordProvider) IsEnabled() bool {
  38. return system_setting.GetDiscordSettings().Enabled
  39. }
  40. func (p *DiscordProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
  41. if code == "" {
  42. return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
  43. }
  44. logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: code=%s...", code[:min(len(code), 10)])
  45. settings := system_setting.GetDiscordSettings()
  46. redirectUri := fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress)
  47. values := url.Values{}
  48. values.Set("client_id", settings.ClientId)
  49. values.Set("client_secret", settings.ClientSecret)
  50. values.Set("code", code)
  51. values.Set("grant_type", "authorization_code")
  52. values.Set("redirect_uri", redirectUri)
  53. logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: redirect_uri=%s", redirectUri)
  54. req, err := http.NewRequestWithContext(ctx, "POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(values.Encode()))
  55. if err != nil {
  56. return nil, err
  57. }
  58. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  59. req.Header.Set("Accept", "application/json")
  60. client := http.Client{
  61. Timeout: 5 * time.Second,
  62. }
  63. res, err := client.Do(req)
  64. if err != nil {
  65. logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken error: %s", err.Error()))
  66. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
  67. }
  68. defer res.Body.Close()
  69. logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken response status: %d", res.StatusCode)
  70. var discordResponse discordOAuthResponse
  71. err = json.NewDecoder(res.Body).Decode(&discordResponse)
  72. if err != nil {
  73. logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken decode error: %s", err.Error()))
  74. return nil, err
  75. }
  76. if discordResponse.AccessToken == "" {
  77. logger.LogError(ctx, "[OAuth-Discord] ExchangeToken failed: empty access token")
  78. return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Discord"})
  79. }
  80. logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken success: scope=%s", discordResponse.Scope)
  81. return &OAuthToken{
  82. AccessToken: discordResponse.AccessToken,
  83. TokenType: discordResponse.TokenType,
  84. RefreshToken: discordResponse.RefreshToken,
  85. ExpiresIn: discordResponse.ExpiresIn,
  86. Scope: discordResponse.Scope,
  87. IDToken: discordResponse.IDToken,
  88. }, nil
  89. }
  90. func (p *DiscordProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
  91. logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo: fetching user info")
  92. req, err := http.NewRequestWithContext(ctx, "GET", "https://discord.com/api/v10/users/@me", nil)
  93. if err != nil {
  94. return nil, err
  95. }
  96. req.Header.Set("Authorization", "Bearer "+token.AccessToken)
  97. client := http.Client{
  98. Timeout: 5 * time.Second,
  99. }
  100. res, err := client.Do(req)
  101. if err != nil {
  102. logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo error: %s", err.Error()))
  103. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
  104. }
  105. defer res.Body.Close()
  106. logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo response status: %d", res.StatusCode)
  107. if res.StatusCode != http.StatusOK {
  108. logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo failed: status=%d", res.StatusCode))
  109. return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
  110. }
  111. var discordUser discordUser
  112. err = json.NewDecoder(res.Body).Decode(&discordUser)
  113. if err != nil {
  114. logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo decode error: %s", err.Error()))
  115. return nil, err
  116. }
  117. if discordUser.UID == "" || discordUser.ID == "" {
  118. logger.LogError(ctx, "[OAuth-Discord] GetUserInfo failed: empty user fields")
  119. return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Discord"})
  120. }
  121. logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo success: uid=%s, username=%s, name=%s", discordUser.UID, discordUser.ID, discordUser.Name)
  122. return &OAuthUser{
  123. ProviderUserID: discordUser.UID,
  124. Username: discordUser.ID,
  125. DisplayName: discordUser.Name,
  126. }, nil
  127. }
  128. func (p *DiscordProvider) IsUserIDTaken(providerUserID string) bool {
  129. return model.IsDiscordIdAlreadyTaken(providerUserID)
  130. }
  131. func (p *DiscordProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
  132. user.DiscordId = providerUserID
  133. return user.FillUserByDiscordId()
  134. }
  135. func (p *DiscordProvider) SetProviderUserID(user *model.User, providerUserID string) {
  136. user.DiscordId = providerUserID
  137. }
  138. func (p *DiscordProvider) GetProviderPrefix() string {
  139. return "discord_"
  140. }