oauth.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/i18n"
  8. "github.com/QuantumNous/new-api/model"
  9. "github.com/QuantumNous/new-api/oauth"
  10. "github.com/gin-contrib/sessions"
  11. "github.com/gin-gonic/gin"
  12. "gorm.io/gorm"
  13. )
  14. // providerParams returns map with Provider key for i18n templates
  15. func providerParams(name string) map[string]any {
  16. return map[string]any{"Provider": name}
  17. }
  18. // GenerateOAuthCode generates a state code for OAuth CSRF protection
  19. func GenerateOAuthCode(c *gin.Context) {
  20. session := sessions.Default(c)
  21. state := common.GetRandomString(12)
  22. affCode := c.Query("aff")
  23. if affCode != "" {
  24. session.Set("aff", affCode)
  25. }
  26. session.Set("oauth_state", state)
  27. err := session.Save()
  28. if err != nil {
  29. common.ApiError(c, err)
  30. return
  31. }
  32. c.JSON(http.StatusOK, gin.H{
  33. "success": true,
  34. "message": "",
  35. "data": state,
  36. })
  37. }
  38. // HandleOAuth handles OAuth callback for all standard OAuth providers
  39. func HandleOAuth(c *gin.Context) {
  40. providerName := c.Param("provider")
  41. provider := oauth.GetProvider(providerName)
  42. if provider == nil {
  43. c.JSON(http.StatusBadRequest, gin.H{
  44. "success": false,
  45. "message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
  46. })
  47. return
  48. }
  49. session := sessions.Default(c)
  50. // 1. Validate state (CSRF protection)
  51. state := c.Query("state")
  52. if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
  53. c.JSON(http.StatusForbidden, gin.H{
  54. "success": false,
  55. "message": i18n.T(c, i18n.MsgOAuthStateInvalid),
  56. })
  57. return
  58. }
  59. // 2. Check if user is already logged in (bind flow)
  60. username := session.Get("username")
  61. if username != nil {
  62. handleOAuthBind(c, provider)
  63. return
  64. }
  65. // 3. Check if provider is enabled
  66. if !provider.IsEnabled() {
  67. common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
  68. return
  69. }
  70. // 4. Handle error from provider
  71. errorCode := c.Query("error")
  72. if errorCode != "" {
  73. errorDescription := c.Query("error_description")
  74. c.JSON(http.StatusOK, gin.H{
  75. "success": false,
  76. "message": errorDescription,
  77. })
  78. return
  79. }
  80. // 5. Exchange code for token
  81. code := c.Query("code")
  82. token, err := provider.ExchangeToken(c.Request.Context(), code, c)
  83. if err != nil {
  84. handleOAuthError(c, err)
  85. return
  86. }
  87. // 6. Get user info
  88. oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
  89. if err != nil {
  90. handleOAuthError(c, err)
  91. return
  92. }
  93. // 7. Find or create user
  94. user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
  95. if err != nil {
  96. switch err.(type) {
  97. case *OAuthUserDeletedError:
  98. common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
  99. case *OAuthRegistrationDisabledError:
  100. common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
  101. default:
  102. common.ApiError(c, err)
  103. }
  104. return
  105. }
  106. // 8. Check user status
  107. if user.Status != common.UserStatusEnabled {
  108. common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
  109. return
  110. }
  111. // 9. Setup login
  112. setupLogin(user, c)
  113. }
  114. // handleOAuthBind handles binding OAuth account to existing user
  115. func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
  116. if !provider.IsEnabled() {
  117. common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
  118. return
  119. }
  120. // Exchange code for token
  121. code := c.Query("code")
  122. token, err := provider.ExchangeToken(c.Request.Context(), code, c)
  123. if err != nil {
  124. handleOAuthError(c, err)
  125. return
  126. }
  127. // Get user info
  128. oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
  129. if err != nil {
  130. handleOAuthError(c, err)
  131. return
  132. }
  133. // Check if this OAuth account is already bound (check both new ID and legacy ID)
  134. if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
  135. common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
  136. return
  137. }
  138. // Also check legacy ID to prevent duplicate bindings during migration period
  139. if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
  140. if provider.IsUserIDTaken(legacyID) {
  141. common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
  142. return
  143. }
  144. }
  145. // Get current user from session
  146. session := sessions.Default(c)
  147. id := session.Get("id")
  148. user := model.User{Id: id.(int)}
  149. err = user.FillUserById()
  150. if err != nil {
  151. common.ApiError(c, err)
  152. return
  153. }
  154. // Handle binding based on provider type
  155. if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
  156. // Custom provider: use user_oauth_bindings table
  157. err = model.UpdateUserOAuthBinding(user.Id, genericProvider.GetProviderId(), oauthUser.ProviderUserID)
  158. if err != nil {
  159. common.ApiError(c, err)
  160. return
  161. }
  162. } else {
  163. // Built-in provider: update user record directly
  164. provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
  165. err = user.Update(false)
  166. if err != nil {
  167. common.ApiError(c, err)
  168. return
  169. }
  170. }
  171. common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
  172. }
  173. // findOrCreateOAuthUser finds existing user or creates new user
  174. func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
  175. user := &model.User{}
  176. // Check if user already exists with new ID
  177. if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
  178. err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
  179. if err != nil {
  180. return nil, err
  181. }
  182. // Check if user has been deleted
  183. if user.Id == 0 {
  184. return nil, &OAuthUserDeletedError{}
  185. }
  186. return user, nil
  187. }
  188. // Try to find user with legacy ID (for GitHub migration from login to numeric ID)
  189. if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
  190. if provider.IsUserIDTaken(legacyID) {
  191. err := provider.FillUserByProviderID(user, legacyID)
  192. if err != nil {
  193. return nil, err
  194. }
  195. if user.Id != 0 {
  196. // Found user with legacy ID, migrate to new ID
  197. common.SysLog(fmt.Sprintf("[OAuth] Migrating user %d from legacy_id=%s to new_id=%s",
  198. user.Id, legacyID, oauthUser.ProviderUserID))
  199. if err := user.UpdateGitHubId(oauthUser.ProviderUserID); err != nil {
  200. common.SysError(fmt.Sprintf("[OAuth] Failed to migrate user %d: %s", user.Id, err.Error()))
  201. // Continue with login even if migration fails
  202. }
  203. return user, nil
  204. }
  205. }
  206. }
  207. // User doesn't exist, create new user if registration is enabled
  208. if !common.RegisterEnabled {
  209. return nil, &OAuthRegistrationDisabledError{}
  210. }
  211. // Set up new user
  212. user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
  213. if oauthUser.DisplayName != "" {
  214. user.DisplayName = oauthUser.DisplayName
  215. } else if oauthUser.Username != "" {
  216. user.DisplayName = oauthUser.Username
  217. } else {
  218. user.DisplayName = provider.GetName() + " User"
  219. }
  220. if oauthUser.Email != "" {
  221. user.Email = oauthUser.Email
  222. }
  223. user.Role = common.RoleCommonUser
  224. user.Status = common.UserStatusEnabled
  225. // Handle affiliate code
  226. affCode := session.Get("aff")
  227. inviterId := 0
  228. if affCode != nil {
  229. inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
  230. }
  231. // Use transaction to ensure user creation and OAuth binding are atomic
  232. if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
  233. // Custom provider: create user and binding in a transaction
  234. err := model.DB.Transaction(func(tx *gorm.DB) error {
  235. // Create user
  236. if err := user.InsertWithTx(tx, inviterId); err != nil {
  237. return err
  238. }
  239. // Create OAuth binding
  240. binding := &model.UserOAuthBinding{
  241. UserId: user.Id,
  242. ProviderId: genericProvider.GetProviderId(),
  243. ProviderUserId: oauthUser.ProviderUserID,
  244. }
  245. if err := model.CreateUserOAuthBindingWithTx(tx, binding); err != nil {
  246. return err
  247. }
  248. return nil
  249. })
  250. if err != nil {
  251. return nil, err
  252. }
  253. // Perform post-transaction tasks (logs, sidebar config, inviter rewards)
  254. user.FinalizeOAuthUserCreation(inviterId)
  255. } else {
  256. // Built-in provider: create user and update provider ID in a transaction
  257. err := model.DB.Transaction(func(tx *gorm.DB) error {
  258. // Create user
  259. if err := user.InsertWithTx(tx, inviterId); err != nil {
  260. return err
  261. }
  262. // Set the provider user ID on the user model and update
  263. provider.SetProviderUserID(user, oauthUser.ProviderUserID)
  264. if err := tx.Model(user).Updates(map[string]interface{}{
  265. "github_id": user.GitHubId,
  266. "discord_id": user.DiscordId,
  267. "oidc_id": user.OidcId,
  268. "linux_do_id": user.LinuxDOId,
  269. "wechat_id": user.WeChatId,
  270. "telegram_id": user.TelegramId,
  271. }).Error; err != nil {
  272. return err
  273. }
  274. return nil
  275. })
  276. if err != nil {
  277. return nil, err
  278. }
  279. // Perform post-transaction tasks
  280. user.FinalizeOAuthUserCreation(inviterId)
  281. }
  282. return user, nil
  283. }
  284. // Error types for OAuth
  285. type OAuthUserDeletedError struct{}
  286. func (e *OAuthUserDeletedError) Error() string {
  287. return "user has been deleted"
  288. }
  289. type OAuthRegistrationDisabledError struct{}
  290. func (e *OAuthRegistrationDisabledError) Error() string {
  291. return "registration is disabled"
  292. }
  293. // handleOAuthError handles OAuth errors and returns translated message
  294. func handleOAuthError(c *gin.Context, err error) {
  295. switch e := err.(type) {
  296. case *oauth.OAuthError:
  297. if e.Params != nil {
  298. common.ApiErrorI18n(c, e.MsgKey, e.Params)
  299. } else {
  300. common.ApiErrorI18n(c, e.MsgKey)
  301. }
  302. case *oauth.TrustLevelError:
  303. common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
  304. default:
  305. common.ApiError(c, err)
  306. }
  307. }