oauth.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/i18n"
  8. "github.com/QuantumNous/new-api/model"
  9. "github.com/QuantumNous/new-api/oauth"
  10. "github.com/gin-contrib/sessions"
  11. "github.com/gin-gonic/gin"
  12. )
  13. // providerParams returns map with Provider key for i18n templates
  14. func providerParams(name string) map[string]any {
  15. return map[string]any{"Provider": name}
  16. }
  17. // GenerateOAuthCode generates a state code for OAuth CSRF protection
  18. func GenerateOAuthCode(c *gin.Context) {
  19. session := sessions.Default(c)
  20. state := common.GetRandomString(12)
  21. affCode := c.Query("aff")
  22. if affCode != "" {
  23. session.Set("aff", affCode)
  24. }
  25. session.Set("oauth_state", state)
  26. err := session.Save()
  27. if err != nil {
  28. common.ApiError(c, err)
  29. return
  30. }
  31. c.JSON(http.StatusOK, gin.H{
  32. "success": true,
  33. "message": "",
  34. "data": state,
  35. })
  36. }
  37. // HandleOAuth handles OAuth callback for all standard OAuth providers
  38. func HandleOAuth(c *gin.Context) {
  39. providerName := c.Param("provider")
  40. provider := oauth.GetProvider(providerName)
  41. if provider == nil {
  42. c.JSON(http.StatusBadRequest, gin.H{
  43. "success": false,
  44. "message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
  45. })
  46. return
  47. }
  48. session := sessions.Default(c)
  49. // 1. Validate state (CSRF protection)
  50. state := c.Query("state")
  51. if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
  52. c.JSON(http.StatusForbidden, gin.H{
  53. "success": false,
  54. "message": i18n.T(c, i18n.MsgOAuthStateInvalid),
  55. })
  56. return
  57. }
  58. // 2. Check if user is already logged in (bind flow)
  59. username := session.Get("username")
  60. if username != nil {
  61. handleOAuthBind(c, provider)
  62. return
  63. }
  64. // 3. Check if provider is enabled
  65. if !provider.IsEnabled() {
  66. common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
  67. return
  68. }
  69. // 4. Handle error from provider
  70. errorCode := c.Query("error")
  71. if errorCode != "" {
  72. errorDescription := c.Query("error_description")
  73. c.JSON(http.StatusOK, gin.H{
  74. "success": false,
  75. "message": errorDescription,
  76. })
  77. return
  78. }
  79. // 5. Exchange code for token
  80. code := c.Query("code")
  81. token, err := provider.ExchangeToken(c.Request.Context(), code, c)
  82. if err != nil {
  83. handleOAuthError(c, err)
  84. return
  85. }
  86. // 6. Get user info
  87. oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
  88. if err != nil {
  89. handleOAuthError(c, err)
  90. return
  91. }
  92. // 7. Find or create user
  93. user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
  94. if err != nil {
  95. switch err.(type) {
  96. case *OAuthUserDeletedError:
  97. common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
  98. case *OAuthRegistrationDisabledError:
  99. common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
  100. default:
  101. common.ApiError(c, err)
  102. }
  103. return
  104. }
  105. // 8. Check user status
  106. if user.Status != common.UserStatusEnabled {
  107. common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
  108. return
  109. }
  110. // 9. Setup login
  111. setupLogin(user, c)
  112. }
  113. // handleOAuthBind handles binding OAuth account to existing user
  114. func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
  115. if !provider.IsEnabled() {
  116. common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
  117. return
  118. }
  119. // Exchange code for token
  120. code := c.Query("code")
  121. token, err := provider.ExchangeToken(c.Request.Context(), code, c)
  122. if err != nil {
  123. handleOAuthError(c, err)
  124. return
  125. }
  126. // Get user info
  127. oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
  128. if err != nil {
  129. handleOAuthError(c, err)
  130. return
  131. }
  132. // Check if this OAuth account is already bound (check both new ID and legacy ID)
  133. if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
  134. common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
  135. return
  136. }
  137. // Also check legacy ID to prevent duplicate bindings during migration period
  138. if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
  139. if provider.IsUserIDTaken(legacyID) {
  140. common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
  141. return
  142. }
  143. }
  144. // Get current user from session
  145. session := sessions.Default(c)
  146. id := session.Get("id")
  147. user := model.User{Id: id.(int)}
  148. err = user.FillUserById()
  149. if err != nil {
  150. common.ApiError(c, err)
  151. return
  152. }
  153. // Handle binding based on provider type
  154. if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
  155. // Custom provider: use user_oauth_bindings table
  156. err = model.UpdateUserOAuthBinding(user.Id, genericProvider.GetProviderId(), oauthUser.ProviderUserID)
  157. if err != nil {
  158. common.ApiError(c, err)
  159. return
  160. }
  161. } else {
  162. // Built-in provider: update user record directly
  163. provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
  164. err = user.Update(false)
  165. if err != nil {
  166. common.ApiError(c, err)
  167. return
  168. }
  169. }
  170. common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
  171. }
  172. // findOrCreateOAuthUser finds existing user or creates new user
  173. func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
  174. user := &model.User{}
  175. // Check if user already exists with new ID
  176. if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
  177. err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
  178. if err != nil {
  179. return nil, err
  180. }
  181. // Check if user has been deleted
  182. if user.Id == 0 {
  183. return nil, &OAuthUserDeletedError{}
  184. }
  185. return user, nil
  186. }
  187. // Try to find user with legacy ID (for GitHub migration from login to numeric ID)
  188. if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
  189. if provider.IsUserIDTaken(legacyID) {
  190. err := provider.FillUserByProviderID(user, legacyID)
  191. if err != nil {
  192. return nil, err
  193. }
  194. if user.Id != 0 {
  195. // Found user with legacy ID, migrate to new ID
  196. common.SysLog(fmt.Sprintf("[OAuth] Migrating user %d from legacy_id=%s to new_id=%s",
  197. user.Id, legacyID, oauthUser.ProviderUserID))
  198. if err := user.UpdateGitHubId(oauthUser.ProviderUserID); err != nil {
  199. common.SysError(fmt.Sprintf("[OAuth] Failed to migrate user %d: %s", user.Id, err.Error()))
  200. // Continue with login even if migration fails
  201. }
  202. return user, nil
  203. }
  204. }
  205. }
  206. // User doesn't exist, create new user if registration is enabled
  207. if !common.RegisterEnabled {
  208. return nil, &OAuthRegistrationDisabledError{}
  209. }
  210. // Set up new user
  211. user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
  212. if oauthUser.DisplayName != "" {
  213. user.DisplayName = oauthUser.DisplayName
  214. } else if oauthUser.Username != "" {
  215. user.DisplayName = oauthUser.Username
  216. } else {
  217. user.DisplayName = provider.GetName() + " User"
  218. }
  219. if oauthUser.Email != "" {
  220. user.Email = oauthUser.Email
  221. }
  222. user.Role = common.RoleCommonUser
  223. user.Status = common.UserStatusEnabled
  224. // Handle affiliate code
  225. affCode := session.Get("aff")
  226. inviterId := 0
  227. if affCode != nil {
  228. inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
  229. }
  230. if err := user.Insert(inviterId); err != nil {
  231. return nil, err
  232. }
  233. // For custom providers, create the binding after user is created
  234. if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
  235. binding := &model.UserOAuthBinding{
  236. UserId: user.Id,
  237. ProviderId: genericProvider.GetProviderId(),
  238. ProviderUserId: oauthUser.ProviderUserID,
  239. }
  240. if err := model.CreateUserOAuthBinding(binding); err != nil {
  241. common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
  242. // Don't fail the registration, just log the error
  243. }
  244. } else {
  245. // Built-in provider: set the provider user ID on the user model
  246. provider.SetProviderUserID(user, oauthUser.ProviderUserID)
  247. if err := user.Update(false); err != nil {
  248. common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
  249. }
  250. }
  251. return user, nil
  252. }
  253. // Error types for OAuth
  254. type OAuthUserDeletedError struct{}
  255. func (e *OAuthUserDeletedError) Error() string {
  256. return "user has been deleted"
  257. }
  258. type OAuthRegistrationDisabledError struct{}
  259. func (e *OAuthRegistrationDisabledError) Error() string {
  260. return "registration is disabled"
  261. }
  262. // handleOAuthError handles OAuth errors and returns translated message
  263. func handleOAuthError(c *gin.Context, err error) {
  264. switch e := err.(type) {
  265. case *oauth.OAuthError:
  266. if e.Params != nil {
  267. common.ApiErrorI18n(c, e.MsgKey, e.Params)
  268. } else {
  269. common.ApiErrorI18n(c, e.MsgKey)
  270. }
  271. case *oauth.TrustLevelError:
  272. common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
  273. default:
  274. common.ApiError(c, err)
  275. }
  276. }