package controller import ( "fmt" "net/http" "strconv" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/i18n" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/oauth" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" ) // providerParams returns map with Provider key for i18n templates func providerParams(name string) map[string]any { return map[string]any{"Provider": name} } // GenerateOAuthCode generates a state code for OAuth CSRF protection func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) state := common.GetRandomString(12) affCode := c.Query("aff") if affCode != "" { session.Set("aff", affCode) } session.Set("oauth_state", state) err := session.Save() if err != nil { common.ApiError(c, err) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": state, }) } // HandleOAuth handles OAuth callback for all standard OAuth providers func HandleOAuth(c *gin.Context) { providerName := c.Param("provider") provider := oauth.GetProvider(providerName) if provider == nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": i18n.T(c, i18n.MsgOAuthUnknownProvider), }) return } session := sessions.Default(c) // 1. Validate state (CSRF protection) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": i18n.T(c, i18n.MsgOAuthStateInvalid), }) return } // 2. Check if user is already logged in (bind flow) username := session.Get("username") if username != nil { handleOAuthBind(c, provider) return } // 3. Check if provider is enabled if !provider.IsEnabled() { common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName())) return } // 4. Handle error from provider errorCode := c.Query("error") if errorCode != "" { errorDescription := c.Query("error_description") c.JSON(http.StatusOK, gin.H{ "success": false, "message": errorDescription, }) return } // 5. Exchange code for token code := c.Query("code") token, err := provider.ExchangeToken(c.Request.Context(), code, c) if err != nil { handleOAuthError(c, err) return } // 6. Get user info oauthUser, err := provider.GetUserInfo(c.Request.Context(), token) if err != nil { handleOAuthError(c, err) return } // 7. Find or create user user, err := findOrCreateOAuthUser(c, provider, oauthUser, session) if err != nil { switch err.(type) { case *OAuthUserDeletedError: common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted) case *OAuthRegistrationDisabledError: common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled) default: common.ApiError(c, err) } return } // 8. Check user status if user.Status != common.UserStatusEnabled { common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned) return } // 9. Setup login setupLogin(user, c) } // handleOAuthBind handles binding OAuth account to existing user func handleOAuthBind(c *gin.Context, provider oauth.Provider) { if !provider.IsEnabled() { common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName())) return } // Exchange code for token code := c.Query("code") token, err := provider.ExchangeToken(c.Request.Context(), code, c) if err != nil { handleOAuthError(c, err) return } // Get user info oauthUser, err := provider.GetUserInfo(c.Request.Context(), token) if err != nil { handleOAuthError(c, err) return } // Check if this OAuth account is already bound (check both new ID and legacy ID) if provider.IsUserIDTaken(oauthUser.ProviderUserID) { common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName())) return } // Also check legacy ID to prevent duplicate bindings during migration period if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" { if provider.IsUserIDTaken(legacyID) { common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName())) return } } // Get current user from session session := sessions.Default(c) id := session.Get("id") user := model.User{Id: id.(int)} err = user.FillUserById() if err != nil { common.ApiError(c, err) return } // Handle binding based on provider type if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok { // Custom provider: use user_oauth_bindings table err = model.UpdateUserOAuthBinding(user.Id, genericProvider.GetProviderId(), oauthUser.ProviderUserID) if err != nil { common.ApiError(c, err) return } } else { // Built-in provider: update user record directly provider.SetProviderUserID(&user, oauthUser.ProviderUserID) err = user.Update(false) if err != nil { common.ApiError(c, err) return } } common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil) } // findOrCreateOAuthUser finds existing user or creates new user func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) { user := &model.User{} // Check if user already exists with new ID if provider.IsUserIDTaken(oauthUser.ProviderUserID) { err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID) if err != nil { return nil, err } // Check if user has been deleted if user.Id == 0 { return nil, &OAuthUserDeletedError{} } return user, nil } // Try to find user with legacy ID (for GitHub migration from login to numeric ID) if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" { if provider.IsUserIDTaken(legacyID) { err := provider.FillUserByProviderID(user, legacyID) if err != nil { return nil, err } if user.Id != 0 { // Found user with legacy ID, migrate to new ID common.SysLog(fmt.Sprintf("[OAuth] Migrating user %d from legacy_id=%s to new_id=%s", user.Id, legacyID, oauthUser.ProviderUserID)) if err := user.UpdateGitHubId(oauthUser.ProviderUserID); err != nil { common.SysError(fmt.Sprintf("[OAuth] Failed to migrate user %d: %s", user.Id, err.Error())) // Continue with login even if migration fails } return user, nil } } } // User doesn't exist, create new user if registration is enabled if !common.RegisterEnabled { return nil, &OAuthRegistrationDisabledError{} } // Set up new user user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1) if oauthUser.DisplayName != "" { user.DisplayName = oauthUser.DisplayName } else if oauthUser.Username != "" { user.DisplayName = oauthUser.Username } else { user.DisplayName = provider.GetName() + " User" } if oauthUser.Email != "" { user.Email = oauthUser.Email } user.Role = common.RoleCommonUser user.Status = common.UserStatusEnabled // Handle affiliate code affCode := session.Get("aff") inviterId := 0 if affCode != nil { inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) } if err := user.Insert(inviterId); err != nil { return nil, err } // For custom providers, create the binding after user is created if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok { binding := &model.UserOAuthBinding{ UserId: user.Id, ProviderId: genericProvider.GetProviderId(), ProviderUserId: oauthUser.ProviderUserID, } if err := model.CreateUserOAuthBinding(binding); err != nil { common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error())) // Don't fail the registration, just log the error } } else { // Built-in provider: set the provider user ID on the user model provider.SetProviderUserID(user, oauthUser.ProviderUserID) if err := user.Update(false); err != nil { common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error())) } } return user, nil } // Error types for OAuth type OAuthUserDeletedError struct{} func (e *OAuthUserDeletedError) Error() string { return "user has been deleted" } type OAuthRegistrationDisabledError struct{} func (e *OAuthRegistrationDisabledError) Error() string { return "registration is disabled" } // handleOAuthError handles OAuth errors and returns translated message func handleOAuthError(c *gin.Context, err error) { switch e := err.(type) { case *oauth.OAuthError: if e.Params != nil { common.ApiErrorI18n(c, e.MsgKey, e.Params) } else { common.ApiErrorI18n(c, e.MsgKey) } case *oauth.TrustLevelError: common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow) default: common.ApiError(c, err) } }