1
0

linuxdo.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package controller
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "net/url"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "strings"
  13. "time"
  14. "github.com/gin-contrib/sessions"
  15. "github.com/gin-gonic/gin"
  16. )
  17. type LinuxdoUser struct {
  18. Id int `json:"id"`
  19. Username string `json:"username"`
  20. Name string `json:"name"`
  21. Active bool `json:"active"`
  22. TrustLevel int `json:"trust_level"`
  23. Silenced bool `json:"silenced"`
  24. }
  25. func LinuxDoBind(c *gin.Context) {
  26. if !common.LinuxDOOAuthEnabled {
  27. c.JSON(http.StatusOK, gin.H{
  28. "success": false,
  29. "message": "管理员未开启通过 Linux DO 登录以及注册",
  30. })
  31. return
  32. }
  33. code := c.Query("code")
  34. linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
  35. if err != nil {
  36. common.ApiError(c, err)
  37. return
  38. }
  39. user := model.User{
  40. LinuxDOId: strconv.Itoa(linuxdoUser.Id),
  41. }
  42. if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
  43. c.JSON(http.StatusOK, gin.H{
  44. "success": false,
  45. "message": "该 Linux DO 账户已被绑定",
  46. })
  47. return
  48. }
  49. session := sessions.Default(c)
  50. id := session.Get("id")
  51. user.Id = id.(int)
  52. err = user.FillUserById()
  53. if err != nil {
  54. common.ApiError(c, err)
  55. return
  56. }
  57. user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
  58. err = user.Update(false)
  59. if err != nil {
  60. common.ApiError(c, err)
  61. return
  62. }
  63. c.JSON(http.StatusOK, gin.H{
  64. "success": true,
  65. "message": "bind",
  66. })
  67. }
  68. func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
  69. if code == "" {
  70. return nil, errors.New("invalid code")
  71. }
  72. // Get access token using Basic auth
  73. tokenEndpoint := "https://connect.linux.do/oauth2/token"
  74. credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
  75. basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
  76. // Get redirect URI from request
  77. scheme := "http"
  78. if c.Request.TLS != nil {
  79. scheme = "https"
  80. }
  81. redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
  82. data := url.Values{}
  83. data.Set("grant_type", "authorization_code")
  84. data.Set("code", code)
  85. data.Set("redirect_uri", redirectURI)
  86. req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
  87. if err != nil {
  88. return nil, err
  89. }
  90. req.Header.Set("Authorization", basicAuth)
  91. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  92. req.Header.Set("Accept", "application/json")
  93. client := http.Client{Timeout: 5 * time.Second}
  94. res, err := client.Do(req)
  95. if err != nil {
  96. return nil, errors.New("failed to connect to Linux DO server")
  97. }
  98. defer res.Body.Close()
  99. var tokenRes struct {
  100. AccessToken string `json:"access_token"`
  101. Message string `json:"message"`
  102. }
  103. if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
  104. return nil, err
  105. }
  106. if tokenRes.AccessToken == "" {
  107. return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
  108. }
  109. // Get user info
  110. userEndpoint := "https://connect.linux.do/api/user"
  111. req, err = http.NewRequest("GET", userEndpoint, nil)
  112. if err != nil {
  113. return nil, err
  114. }
  115. req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
  116. req.Header.Set("Accept", "application/json")
  117. res2, err := client.Do(req)
  118. if err != nil {
  119. return nil, errors.New("failed to get user info from Linux DO")
  120. }
  121. defer res2.Body.Close()
  122. var linuxdoUser LinuxdoUser
  123. if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
  124. return nil, err
  125. }
  126. if linuxdoUser.Id == 0 {
  127. return nil, errors.New("invalid user info returned")
  128. }
  129. return &linuxdoUser, nil
  130. }
  131. func LinuxdoOAuth(c *gin.Context) {
  132. session := sessions.Default(c)
  133. errorCode := c.Query("error")
  134. if errorCode != "" {
  135. errorDescription := c.Query("error_description")
  136. c.JSON(http.StatusOK, gin.H{
  137. "success": false,
  138. "message": errorDescription,
  139. })
  140. return
  141. }
  142. state := c.Query("state")
  143. if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
  144. c.JSON(http.StatusForbidden, gin.H{
  145. "success": false,
  146. "message": "state is empty or not same",
  147. })
  148. return
  149. }
  150. username := session.Get("username")
  151. if username != nil {
  152. LinuxDoBind(c)
  153. return
  154. }
  155. if !common.LinuxDOOAuthEnabled {
  156. c.JSON(http.StatusOK, gin.H{
  157. "success": false,
  158. "message": "管理员未开启通过 Linux DO 登录以及注册",
  159. })
  160. return
  161. }
  162. code := c.Query("code")
  163. linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
  164. if err != nil {
  165. common.ApiError(c, err)
  166. return
  167. }
  168. user := model.User{
  169. LinuxDOId: strconv.Itoa(linuxdoUser.Id),
  170. }
  171. // Check if user exists
  172. if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
  173. err := user.FillUserByLinuxDOId()
  174. if err != nil {
  175. c.JSON(http.StatusOK, gin.H{
  176. "success": false,
  177. "message": err.Error(),
  178. })
  179. return
  180. }
  181. if user.Id == 0 {
  182. c.JSON(http.StatusOK, gin.H{
  183. "success": false,
  184. "message": "用户已注销",
  185. })
  186. return
  187. }
  188. } else {
  189. if common.RegisterEnabled {
  190. user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
  191. user.DisplayName = linuxdoUser.Name
  192. user.Role = common.RoleCommonUser
  193. user.Status = common.UserStatusEnabled
  194. affCode := session.Get("aff")
  195. inviterId := 0
  196. if affCode != nil {
  197. inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
  198. }
  199. if err := user.Insert(inviterId); err != nil {
  200. c.JSON(http.StatusOK, gin.H{
  201. "success": false,
  202. "message": err.Error(),
  203. })
  204. return
  205. }
  206. } else {
  207. c.JSON(http.StatusOK, gin.H{
  208. "success": false,
  209. "message": "管理员关闭了新用户注册",
  210. })
  211. return
  212. }
  213. }
  214. if user.Status != common.UserStatusEnabled {
  215. c.JSON(http.StatusOK, gin.H{
  216. "message": "用户已被封禁",
  217. "success": false,
  218. })
  219. return
  220. }
  221. setupLogin(&user, c)
  222. }