linuxdo.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. package controller
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "net/url"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "strings"
  13. "time"
  14. "github.com/gin-contrib/sessions"
  15. "github.com/gin-gonic/gin"
  16. )
  17. type LinuxdoUser struct {
  18. Id int `json:"id"`
  19. Username string `json:"username"`
  20. Name string `json:"name"`
  21. Active bool `json:"active"`
  22. TrustLevel int `json:"trust_level"`
  23. Silenced bool `json:"silenced"`
  24. }
  25. func LinuxDoBind(c *gin.Context) {
  26. if !common.LinuxDOOAuthEnabled {
  27. c.JSON(http.StatusOK, gin.H{
  28. "success": false,
  29. "message": "管理员未开启通过 Linux DO 登录以及注册",
  30. })
  31. return
  32. }
  33. code := c.Query("code")
  34. linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
  35. if err != nil {
  36. c.JSON(http.StatusOK, gin.H{
  37. "success": false,
  38. "message": err.Error(),
  39. })
  40. return
  41. }
  42. user := model.User{
  43. LinuxDOId: strconv.Itoa(linuxdoUser.Id),
  44. }
  45. if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
  46. c.JSON(http.StatusOK, gin.H{
  47. "success": false,
  48. "message": "该 Linux DO 账户已被绑定",
  49. })
  50. return
  51. }
  52. session := sessions.Default(c)
  53. id := session.Get("id")
  54. user.Id = id.(int)
  55. err = user.FillUserById()
  56. if err != nil {
  57. c.JSON(http.StatusOK, gin.H{
  58. "success": false,
  59. "message": err.Error(),
  60. })
  61. return
  62. }
  63. user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
  64. err = user.Update(false)
  65. if err != nil {
  66. c.JSON(http.StatusOK, gin.H{
  67. "success": false,
  68. "message": err.Error(),
  69. })
  70. return
  71. }
  72. c.JSON(http.StatusOK, gin.H{
  73. "success": true,
  74. "message": "bind",
  75. })
  76. }
  77. func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
  78. if code == "" {
  79. return nil, errors.New("invalid code")
  80. }
  81. // Get access token using Basic auth
  82. tokenEndpoint := "https://connect.linux.do/oauth2/token"
  83. credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
  84. basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
  85. // Get redirect URI from request
  86. scheme := "http"
  87. if c.Request.TLS != nil {
  88. scheme = "https"
  89. }
  90. redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
  91. data := url.Values{}
  92. data.Set("grant_type", "authorization_code")
  93. data.Set("code", code)
  94. data.Set("redirect_uri", redirectURI)
  95. req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
  96. if err != nil {
  97. return nil, err
  98. }
  99. req.Header.Set("Authorization", basicAuth)
  100. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  101. req.Header.Set("Accept", "application/json")
  102. client := http.Client{Timeout: 5 * time.Second}
  103. res, err := client.Do(req)
  104. if err != nil {
  105. return nil, errors.New("failed to connect to Linux DO server")
  106. }
  107. defer res.Body.Close()
  108. var tokenRes struct {
  109. AccessToken string `json:"access_token"`
  110. Message string `json:"message"`
  111. }
  112. if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
  113. return nil, err
  114. }
  115. if tokenRes.AccessToken == "" {
  116. return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
  117. }
  118. // Get user info
  119. userEndpoint := "https://connect.linux.do/api/user"
  120. req, err = http.NewRequest("GET", userEndpoint, nil)
  121. if err != nil {
  122. return nil, err
  123. }
  124. req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
  125. req.Header.Set("Accept", "application/json")
  126. res2, err := client.Do(req)
  127. if err != nil {
  128. return nil, errors.New("failed to get user info from Linux DO")
  129. }
  130. defer res2.Body.Close()
  131. var linuxdoUser LinuxdoUser
  132. if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
  133. return nil, err
  134. }
  135. if linuxdoUser.Id == 0 {
  136. return nil, errors.New("invalid user info returned")
  137. }
  138. return &linuxdoUser, nil
  139. }
  140. func LinuxdoOAuth(c *gin.Context) {
  141. session := sessions.Default(c)
  142. errorCode := c.Query("error")
  143. if errorCode != "" {
  144. errorDescription := c.Query("error_description")
  145. c.JSON(http.StatusOK, gin.H{
  146. "success": false,
  147. "message": errorDescription,
  148. })
  149. return
  150. }
  151. state := c.Query("state")
  152. if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
  153. c.JSON(http.StatusForbidden, gin.H{
  154. "success": false,
  155. "message": "state is empty or not same",
  156. })
  157. return
  158. }
  159. username := session.Get("username")
  160. if username != nil {
  161. LinuxDoBind(c)
  162. return
  163. }
  164. if !common.LinuxDOOAuthEnabled {
  165. c.JSON(http.StatusOK, gin.H{
  166. "success": false,
  167. "message": "管理员未开启通过 Linux DO 登录以及注册",
  168. })
  169. return
  170. }
  171. code := c.Query("code")
  172. linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
  173. if err != nil {
  174. c.JSON(http.StatusOK, gin.H{
  175. "success": false,
  176. "message": err.Error(),
  177. })
  178. return
  179. }
  180. user := model.User{
  181. LinuxDOId: strconv.Itoa(linuxdoUser.Id),
  182. }
  183. // Check if user exists
  184. if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
  185. err := user.FillUserByLinuxDOId()
  186. if err != nil {
  187. c.JSON(http.StatusOK, gin.H{
  188. "success": false,
  189. "message": err.Error(),
  190. })
  191. return
  192. }
  193. if user.Id == 0 {
  194. c.JSON(http.StatusOK, gin.H{
  195. "success": false,
  196. "message": "用户已注销",
  197. })
  198. return
  199. }
  200. } else {
  201. if common.RegisterEnabled {
  202. user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
  203. user.DisplayName = linuxdoUser.Name
  204. user.Role = common.RoleCommonUser
  205. user.Status = common.UserStatusEnabled
  206. affCode := session.Get("aff")
  207. inviterId := 0
  208. if affCode != nil {
  209. inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
  210. }
  211. if err := user.Insert(inviterId); err != nil {
  212. c.JSON(http.StatusOK, gin.H{
  213. "success": false,
  214. "message": err.Error(),
  215. })
  216. return
  217. }
  218. } else {
  219. c.JSON(http.StatusOK, gin.H{
  220. "success": false,
  221. "message": "管理员关闭了新用户注册",
  222. })
  223. return
  224. }
  225. }
  226. if user.Status != common.UserStatusEnabled {
  227. c.JSON(http.StatusOK, gin.H{
  228. "message": "用户已被封禁",
  229. "success": false,
  230. })
  231. return
  232. }
  233. setupLogin(&user, c)
  234. }