oauth.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package controller
  2. import (
  3. "net/http"
  4. "strconv"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/i18n"
  7. "github.com/QuantumNous/new-api/model"
  8. "github.com/QuantumNous/new-api/oauth"
  9. "github.com/gin-contrib/sessions"
  10. "github.com/gin-gonic/gin"
  11. )
  12. // providerParams returns map with Provider key for i18n templates
  13. func providerParams(name string) map[string]any {
  14. return map[string]any{"Provider": name}
  15. }
  16. // GenerateOAuthCode generates a state code for OAuth CSRF protection
  17. func GenerateOAuthCode(c *gin.Context) {
  18. session := sessions.Default(c)
  19. state := common.GetRandomString(12)
  20. affCode := c.Query("aff")
  21. if affCode != "" {
  22. session.Set("aff", affCode)
  23. }
  24. session.Set("oauth_state", state)
  25. err := session.Save()
  26. if err != nil {
  27. common.ApiError(c, err)
  28. return
  29. }
  30. c.JSON(http.StatusOK, gin.H{
  31. "success": true,
  32. "message": "",
  33. "data": state,
  34. })
  35. }
  36. // HandleOAuth handles OAuth callback for all standard OAuth providers
  37. func HandleOAuth(c *gin.Context) {
  38. providerName := c.Param("provider")
  39. provider := oauth.GetProvider(providerName)
  40. if provider == nil {
  41. c.JSON(http.StatusBadRequest, gin.H{
  42. "success": false,
  43. "message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
  44. })
  45. return
  46. }
  47. session := sessions.Default(c)
  48. // 1. Validate state (CSRF protection)
  49. state := c.Query("state")
  50. if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
  51. c.JSON(http.StatusForbidden, gin.H{
  52. "success": false,
  53. "message": i18n.T(c, i18n.MsgOAuthStateInvalid),
  54. })
  55. return
  56. }
  57. // 2. Check if user is already logged in (bind flow)
  58. username := session.Get("username")
  59. if username != nil {
  60. handleOAuthBind(c, provider)
  61. return
  62. }
  63. // 3. Check if provider is enabled
  64. if !provider.IsEnabled() {
  65. common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
  66. return
  67. }
  68. // 4. Handle error from provider
  69. errorCode := c.Query("error")
  70. if errorCode != "" {
  71. errorDescription := c.Query("error_description")
  72. c.JSON(http.StatusOK, gin.H{
  73. "success": false,
  74. "message": errorDescription,
  75. })
  76. return
  77. }
  78. // 5. Exchange code for token
  79. code := c.Query("code")
  80. token, err := provider.ExchangeToken(c.Request.Context(), code, c)
  81. if err != nil {
  82. handleOAuthError(c, err)
  83. return
  84. }
  85. // 6. Get user info
  86. oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
  87. if err != nil {
  88. handleOAuthError(c, err)
  89. return
  90. }
  91. // 7. Find or create user
  92. user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
  93. if err != nil {
  94. switch err.(type) {
  95. case *OAuthUserDeletedError:
  96. common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
  97. case *OAuthRegistrationDisabledError:
  98. common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
  99. default:
  100. common.ApiError(c, err)
  101. }
  102. return
  103. }
  104. // 8. Check user status
  105. if user.Status != common.UserStatusEnabled {
  106. common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
  107. return
  108. }
  109. // 9. Setup login
  110. setupLogin(user, c)
  111. }
  112. // handleOAuthBind handles binding OAuth account to existing user
  113. func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
  114. if !provider.IsEnabled() {
  115. common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
  116. return
  117. }
  118. // Exchange code for token
  119. code := c.Query("code")
  120. token, err := provider.ExchangeToken(c.Request.Context(), code, c)
  121. if err != nil {
  122. handleOAuthError(c, err)
  123. return
  124. }
  125. // Get user info
  126. oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
  127. if err != nil {
  128. handleOAuthError(c, err)
  129. return
  130. }
  131. // Check if this OAuth account is already bound
  132. if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
  133. common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
  134. return
  135. }
  136. // Get current user from session
  137. session := sessions.Default(c)
  138. id := session.Get("id")
  139. user := model.User{Id: id.(int)}
  140. err = user.FillUserById()
  141. if err != nil {
  142. common.ApiError(c, err)
  143. return
  144. }
  145. // Update user with OAuth ID
  146. provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
  147. err = user.Update(false)
  148. if err != nil {
  149. common.ApiError(c, err)
  150. return
  151. }
  152. common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
  153. }
  154. // findOrCreateOAuthUser finds existing user or creates new user
  155. func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
  156. user := &model.User{}
  157. // Check if user already exists
  158. if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
  159. provider.SetProviderUserID(user, oauthUser.ProviderUserID)
  160. err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
  161. if err != nil {
  162. return nil, err
  163. }
  164. // Check if user has been deleted
  165. if user.Id == 0 {
  166. return nil, &OAuthUserDeletedError{}
  167. }
  168. return user, nil
  169. }
  170. // User doesn't exist, create new user if registration is enabled
  171. if !common.RegisterEnabled {
  172. return nil, &OAuthRegistrationDisabledError{}
  173. }
  174. // Set up new user
  175. user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
  176. if oauthUser.DisplayName != "" {
  177. user.DisplayName = oauthUser.DisplayName
  178. } else if oauthUser.Username != "" {
  179. user.DisplayName = oauthUser.Username
  180. } else {
  181. user.DisplayName = provider.GetName() + " User"
  182. }
  183. if oauthUser.Email != "" {
  184. user.Email = oauthUser.Email
  185. }
  186. user.Role = common.RoleCommonUser
  187. user.Status = common.UserStatusEnabled
  188. provider.SetProviderUserID(user, oauthUser.ProviderUserID)
  189. // Handle affiliate code
  190. affCode := session.Get("aff")
  191. inviterId := 0
  192. if affCode != nil {
  193. inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
  194. }
  195. if err := user.Insert(inviterId); err != nil {
  196. return nil, err
  197. }
  198. return user, nil
  199. }
  200. // Error types for OAuth
  201. type OAuthUserDeletedError struct{}
  202. func (e *OAuthUserDeletedError) Error() string {
  203. return "user has been deleted"
  204. }
  205. type OAuthRegistrationDisabledError struct{}
  206. func (e *OAuthRegistrationDisabledError) Error() string {
  207. return "registration is disabled"
  208. }
  209. // handleOAuthError handles OAuth errors and returns translated message
  210. func handleOAuthError(c *gin.Context, err error) {
  211. switch e := err.(type) {
  212. case *oauth.OAuthError:
  213. if e.Params != nil {
  214. common.ApiErrorI18n(c, e.MsgKey, e.Params)
  215. } else {
  216. common.ApiErrorI18n(c, e.MsgKey)
  217. }
  218. case *oauth.TrustLevelError:
  219. common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
  220. default:
  221. common.ApiError(c, err)
  222. }
  223. }