discord.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. package controller
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/url"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/QuantumNous/new-api/setting/system_setting"
  14. "github.com/gin-contrib/sessions"
  15. "github.com/gin-gonic/gin"
  16. )
  17. type DiscordResponse struct {
  18. AccessToken string `json:"access_token"`
  19. IDToken string `json:"id_token"`
  20. RefreshToken string `json:"refresh_token"`
  21. TokenType string `json:"token_type"`
  22. ExpiresIn int `json:"expires_in"`
  23. Scope string `json:"scope"`
  24. }
  25. type DiscordUser struct {
  26. UID string `json:"id"`
  27. ID string `json:"username"`
  28. Name string `json:"global_name"`
  29. }
  30. func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
  31. if code == "" {
  32. return nil, errors.New("无效的参数")
  33. }
  34. values := url.Values{}
  35. values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
  36. values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
  37. values.Set("code", code)
  38. values.Set("grant_type", "authorization_code")
  39. values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
  40. formData := values.Encode()
  41. req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
  42. if err != nil {
  43. return nil, err
  44. }
  45. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  46. req.Header.Set("Accept", "application/json")
  47. client := http.Client{
  48. Timeout: 5 * time.Second,
  49. }
  50. res, err := client.Do(req)
  51. if err != nil {
  52. common.SysLog(err.Error())
  53. return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
  54. }
  55. defer res.Body.Close()
  56. var discordResponse DiscordResponse
  57. err = json.NewDecoder(res.Body).Decode(&discordResponse)
  58. if err != nil {
  59. return nil, err
  60. }
  61. if discordResponse.AccessToken == "" {
  62. common.SysError("Discord 获取 Token 失败,请检查设置!")
  63. return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
  64. }
  65. req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
  66. if err != nil {
  67. return nil, err
  68. }
  69. req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
  70. res2, err := client.Do(req)
  71. if err != nil {
  72. common.SysLog(err.Error())
  73. return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
  74. }
  75. defer res2.Body.Close()
  76. if res2.StatusCode != http.StatusOK {
  77. common.SysError("Discord 获取用户信息失败!请检查设置!")
  78. return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
  79. }
  80. var discordUser DiscordUser
  81. err = json.NewDecoder(res2.Body).Decode(&discordUser)
  82. if err != nil {
  83. return nil, err
  84. }
  85. if discordUser.UID == "" || discordUser.ID == "" {
  86. common.SysError("Discord 获取用户信息为空!请检查设置!")
  87. return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
  88. }
  89. return &discordUser, nil
  90. }
  91. func DiscordOAuth(c *gin.Context) {
  92. session := sessions.Default(c)
  93. state := c.Query("state")
  94. if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
  95. c.JSON(http.StatusForbidden, gin.H{
  96. "success": false,
  97. "message": "state is empty or not same",
  98. })
  99. return
  100. }
  101. username := session.Get("username")
  102. if username != nil {
  103. DiscordBind(c)
  104. return
  105. }
  106. if !system_setting.GetDiscordSettings().Enabled {
  107. c.JSON(http.StatusOK, gin.H{
  108. "success": false,
  109. "message": "管理员未开启通过 Discord 登录以及注册",
  110. })
  111. return
  112. }
  113. code := c.Query("code")
  114. discordUser, err := getDiscordUserInfoByCode(code)
  115. if err != nil {
  116. common.ApiError(c, err)
  117. return
  118. }
  119. user := model.User{
  120. DiscordId: discordUser.UID,
  121. }
  122. if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
  123. err := user.FillUserByDiscordId()
  124. if err != nil {
  125. c.JSON(http.StatusOK, gin.H{
  126. "success": false,
  127. "message": err.Error(),
  128. })
  129. return
  130. }
  131. } else {
  132. if common.RegisterEnabled {
  133. if discordUser.ID != "" {
  134. user.Username = discordUser.ID
  135. } else {
  136. user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
  137. }
  138. if discordUser.Name != "" {
  139. user.DisplayName = discordUser.Name
  140. } else {
  141. user.DisplayName = "Discord User"
  142. }
  143. err := user.Insert(0)
  144. if err != nil {
  145. c.JSON(http.StatusOK, gin.H{
  146. "success": false,
  147. "message": err.Error(),
  148. })
  149. return
  150. }
  151. } else {
  152. c.JSON(http.StatusOK, gin.H{
  153. "success": false,
  154. "message": "管理员关闭了新用户注册",
  155. })
  156. return
  157. }
  158. }
  159. if user.Status != common.UserStatusEnabled {
  160. c.JSON(http.StatusOK, gin.H{
  161. "message": "用户已被封禁",
  162. "success": false,
  163. })
  164. return
  165. }
  166. setupLogin(&user, c)
  167. }
  168. func DiscordBind(c *gin.Context) {
  169. if !system_setting.GetDiscordSettings().Enabled {
  170. c.JSON(http.StatusOK, gin.H{
  171. "success": false,
  172. "message": "管理员未开启通过 Discord 登录以及注册",
  173. })
  174. return
  175. }
  176. code := c.Query("code")
  177. discordUser, err := getDiscordUserInfoByCode(code)
  178. if err != nil {
  179. common.ApiError(c, err)
  180. return
  181. }
  182. user := model.User{
  183. DiscordId: discordUser.UID,
  184. }
  185. if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
  186. c.JSON(http.StatusOK, gin.H{
  187. "success": false,
  188. "message": "该 Discord 账户已被绑定",
  189. })
  190. return
  191. }
  192. session := sessions.Default(c)
  193. id := session.Get("id")
  194. user.Id = id.(int)
  195. err = user.FillUserById()
  196. if err != nil {
  197. common.ApiError(c, err)
  198. return
  199. }
  200. user.DiscordId = discordUser.UID
  201. err = user.Update(false)
  202. if err != nil {
  203. common.ApiError(c, err)
  204. return
  205. }
  206. c.JSON(http.StatusOK, gin.H{
  207. "success": true,
  208. "message": "bind",
  209. })
  210. }