| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348 |
- package controller
- import (
- "fmt"
- "net/http"
- "strconv"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/i18n"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/oauth"
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
- "gorm.io/gorm"
- )
- // providerParams returns map with Provider key for i18n templates
- func providerParams(name string) map[string]any {
- return map[string]any{"Provider": name}
- }
- // GenerateOAuthCode generates a state code for OAuth CSRF protection
- func GenerateOAuthCode(c *gin.Context) {
- session := sessions.Default(c)
- state := common.GetRandomString(12)
- affCode := c.Query("aff")
- if affCode != "" {
- session.Set("aff", affCode)
- }
- session.Set("oauth_state", state)
- err := session.Save()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": state,
- })
- }
- // HandleOAuth handles OAuth callback for all standard OAuth providers
- func HandleOAuth(c *gin.Context) {
- providerName := c.Param("provider")
- provider := oauth.GetProvider(providerName)
- if provider == nil {
- c.JSON(http.StatusBadRequest, gin.H{
- "success": false,
- "message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
- })
- return
- }
- session := sessions.Default(c)
- // 1. Validate state (CSRF protection)
- state := c.Query("state")
- if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
- c.JSON(http.StatusForbidden, gin.H{
- "success": false,
- "message": i18n.T(c, i18n.MsgOAuthStateInvalid),
- })
- return
- }
- // 2. Check if user is already logged in (bind flow)
- username := session.Get("username")
- if username != nil {
- handleOAuthBind(c, provider)
- return
- }
- // 3. Check if provider is enabled
- if !provider.IsEnabled() {
- common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
- return
- }
- // 4. Handle error from provider
- errorCode := c.Query("error")
- if errorCode != "" {
- errorDescription := c.Query("error_description")
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": errorDescription,
- })
- return
- }
- // 5. Exchange code for token
- code := c.Query("code")
- token, err := provider.ExchangeToken(c.Request.Context(), code, c)
- if err != nil {
- handleOAuthError(c, err)
- return
- }
- // 6. Get user info
- oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
- if err != nil {
- handleOAuthError(c, err)
- return
- }
- // 7. Find or create user
- user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
- if err != nil {
- switch err.(type) {
- case *OAuthUserDeletedError:
- common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
- case *OAuthRegistrationDisabledError:
- common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
- default:
- common.ApiError(c, err)
- }
- return
- }
- // 8. Check user status
- if user.Status != common.UserStatusEnabled {
- common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
- return
- }
- // 9. Setup login
- setupLogin(user, c)
- }
- // handleOAuthBind handles binding OAuth account to existing user
- func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
- if !provider.IsEnabled() {
- common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
- return
- }
- // Exchange code for token
- code := c.Query("code")
- token, err := provider.ExchangeToken(c.Request.Context(), code, c)
- if err != nil {
- handleOAuthError(c, err)
- return
- }
- // Get user info
- oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
- if err != nil {
- handleOAuthError(c, err)
- return
- }
- // Check if this OAuth account is already bound (check both new ID and legacy ID)
- if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
- common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
- return
- }
- // Also check legacy ID to prevent duplicate bindings during migration period
- if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
- if provider.IsUserIDTaken(legacyID) {
- common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
- return
- }
- }
- // Get current user from session
- session := sessions.Default(c)
- id := session.Get("id")
- user := model.User{Id: id.(int)}
- err = user.FillUserById()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- // Handle binding based on provider type
- if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
- // Custom provider: use user_oauth_bindings table
- err = model.UpdateUserOAuthBinding(user.Id, genericProvider.GetProviderId(), oauthUser.ProviderUserID)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- } else {
- // Built-in provider: update user record directly
- provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
- err = user.Update(false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- }
- common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
- }
- // findOrCreateOAuthUser finds existing user or creates new user
- func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
- user := &model.User{}
- // Check if user already exists with new ID
- if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
- err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
- if err != nil {
- return nil, err
- }
- // Check if user has been deleted
- if user.Id == 0 {
- return nil, &OAuthUserDeletedError{}
- }
- return user, nil
- }
- // Try to find user with legacy ID (for GitHub migration from login to numeric ID)
- if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
- if provider.IsUserIDTaken(legacyID) {
- err := provider.FillUserByProviderID(user, legacyID)
- if err != nil {
- return nil, err
- }
- if user.Id != 0 {
- // Found user with legacy ID, migrate to new ID
- common.SysLog(fmt.Sprintf("[OAuth] Migrating user %d from legacy_id=%s to new_id=%s",
- user.Id, legacyID, oauthUser.ProviderUserID))
- if err := user.UpdateGitHubId(oauthUser.ProviderUserID); err != nil {
- common.SysError(fmt.Sprintf("[OAuth] Failed to migrate user %d: %s", user.Id, err.Error()))
- // Continue with login even if migration fails
- }
- return user, nil
- }
- }
- }
- // User doesn't exist, create new user if registration is enabled
- if !common.RegisterEnabled {
- return nil, &OAuthRegistrationDisabledError{}
- }
- // Set up new user
- user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
- if oauthUser.DisplayName != "" {
- user.DisplayName = oauthUser.DisplayName
- } else if oauthUser.Username != "" {
- user.DisplayName = oauthUser.Username
- } else {
- user.DisplayName = provider.GetName() + " User"
- }
- if oauthUser.Email != "" {
- user.Email = oauthUser.Email
- }
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
- // Handle affiliate code
- affCode := session.Get("aff")
- inviterId := 0
- if affCode != nil {
- inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
- }
- // Use transaction to ensure user creation and OAuth binding are atomic
- if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
- // Custom provider: create user and binding in a transaction
- err := model.DB.Transaction(func(tx *gorm.DB) error {
- // Create user
- if err := user.InsertWithTx(tx, inviterId); err != nil {
- return err
- }
- // Create OAuth binding
- binding := &model.UserOAuthBinding{
- UserId: user.Id,
- ProviderId: genericProvider.GetProviderId(),
- ProviderUserId: oauthUser.ProviderUserID,
- }
- if err := model.CreateUserOAuthBindingWithTx(tx, binding); err != nil {
- return err
- }
- return nil
- })
- if err != nil {
- return nil, err
- }
- // Perform post-transaction tasks (logs, sidebar config, inviter rewards)
- user.FinalizeOAuthUserCreation(inviterId)
- } else {
- // Built-in provider: create user and update provider ID in a transaction
- err := model.DB.Transaction(func(tx *gorm.DB) error {
- // Create user
- if err := user.InsertWithTx(tx, inviterId); err != nil {
- return err
- }
- // Set the provider user ID on the user model and update
- provider.SetProviderUserID(user, oauthUser.ProviderUserID)
- if err := tx.Model(user).Updates(map[string]interface{}{
- "github_id": user.GitHubId,
- "discord_id": user.DiscordId,
- "oidc_id": user.OidcId,
- "linux_do_id": user.LinuxDOId,
- "wechat_id": user.WeChatId,
- "telegram_id": user.TelegramId,
- }).Error; err != nil {
- return err
- }
- return nil
- })
- if err != nil {
- return nil, err
- }
- // Perform post-transaction tasks
- user.FinalizeOAuthUserCreation(inviterId)
- }
- return user, nil
- }
- // Error types for OAuth
- type OAuthUserDeletedError struct{}
- func (e *OAuthUserDeletedError) Error() string {
- return "user has been deleted"
- }
- type OAuthRegistrationDisabledError struct{}
- func (e *OAuthRegistrationDisabledError) Error() string {
- return "registration is disabled"
- }
- // handleOAuthError handles OAuth errors and returns translated message
- func handleOAuthError(c *gin.Context, err error) {
- switch e := err.(type) {
- case *oauth.OAuthError:
- if e.Params != nil {
- common.ApiErrorI18n(c, e.MsgKey, e.Params)
- } else {
- common.ApiErrorI18n(c, e.MsgKey)
- }
- case *oauth.TrustLevelError:
- common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
- default:
- common.ApiError(c, err)
- }
- }
|