Просмотр исходного кода

fix(oauth): enhance error handling and transaction management for OAuth user creation and binding

- Improve error handling in DeleteCustomOAuthProvider to log and return errors when fetching binding counts.
- Refactor user creation and OAuth binding logic to use transactions for atomic operations, ensuring data integrity.
- Add unique constraints to UserOAuthBinding model to prevent duplicate bindings.
- Enhance GitHub OAuth provider error logging for non-200 responses.
- Update AccountManagement component to provide clearer error messages on API failures.
CaIon 1 неделя назад
Родитель
Сommit
2567cff6c8

+ 6 - 1
controller/custom_oauth.go

@@ -296,7 +296,12 @@ func DeleteCustomOAuthProvider(c *gin.Context) {
 	}
 
 	// Check if there are any user bindings
-	count, _ := model.GetBindingCountByProviderId(id)
+	count, err := model.GetBindingCountByProviderId(id)
+	if err != nil {
+		common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
+		common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
+		return
+	}
 	if count > 0 {
 		common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
 		return

+ 53 - 17
controller/oauth.go

@@ -11,6 +11,7 @@ import (
 	"github.com/QuantumNous/new-api/oauth"
 	"github.com/gin-contrib/sessions"
 	"github.com/gin-gonic/gin"
+	"gorm.io/gorm"
 )
 
 // providerParams returns map with Provider key for i18n templates
@@ -256,27 +257,62 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
 		inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
 	}
 
-	if err := user.Insert(inviterId); err != nil {
-		return nil, err
-	}
-
-	// For custom providers, create the binding after user is created
+	// Use transaction to ensure user creation and OAuth binding are atomic
 	if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
-		binding := &model.UserOAuthBinding{
-			UserId:         user.Id,
-			ProviderId:     genericProvider.GetProviderId(),
-			ProviderUserId: oauthUser.ProviderUserID,
-		}
-		if err := model.CreateUserOAuthBinding(binding); err != nil {
-			common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
-			// Don't fail the registration, just log the error
+		// Custom provider: create user and binding in a transaction
+		err := model.DB.Transaction(func(tx *gorm.DB) error {
+			// Create user
+			if err := user.InsertWithTx(tx, inviterId); err != nil {
+				return err
+			}
+
+			// Create OAuth binding
+			binding := &model.UserOAuthBinding{
+				UserId:         user.Id,
+				ProviderId:     genericProvider.GetProviderId(),
+				ProviderUserId: oauthUser.ProviderUserID,
+			}
+			if err := model.CreateUserOAuthBindingWithTx(tx, binding); err != nil {
+				return err
+			}
+
+			return nil
+		})
+		if err != nil {
+			return nil, err
 		}
+
+		// Perform post-transaction tasks (logs, sidebar config, inviter rewards)
+		user.FinalizeOAuthUserCreation(inviterId)
 	} else {
-		// Built-in provider: set the provider user ID on the user model
-		provider.SetProviderUserID(user, oauthUser.ProviderUserID)
-		if err := user.Update(false); err != nil {
-			common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
+		// Built-in provider: create user and update provider ID in a transaction
+		err := model.DB.Transaction(func(tx *gorm.DB) error {
+			// Create user
+			if err := user.InsertWithTx(tx, inviterId); err != nil {
+				return err
+			}
+
+			// Set the provider user ID on the user model and update
+			provider.SetProviderUserID(user, oauthUser.ProviderUserID)
+			if err := tx.Model(user).Updates(map[string]interface{}{
+				"github_id":    user.GitHubId,
+				"discord_id":   user.DiscordId,
+				"oidc_id":      user.OidcId,
+				"linux_do_id":  user.LinuxDOId,
+				"wechat_id":    user.WeChatId,
+				"telegram_id":  user.TelegramId,
+			}).Error; err != nil {
+				return err
+			}
+
+			return nil
+		})
+		if err != nil {
+			return nil, err
 		}
+
+		// Perform post-transaction tasks
+		user.FinalizeOAuthUserCreation(inviterId)
 	}
 
 	return user, nil

+ 6 - 1
model/custom_oauth_provider.go

@@ -97,13 +97,18 @@ func DeleteCustomOAuthProvider(id int) error {
 }
 
 // IsSlugTaken checks if a slug is already taken by another provider
+// Returns true on DB errors (fail-closed) to prevent slug conflicts
 func IsSlugTaken(slug string, excludeId int) bool {
 	var count int64
 	query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
 	if excludeId > 0 {
 		query = query.Where("id != ?", excludeId)
 	}
-	query.Count(&count)
+	res := query.Count(&count)
+	if res.Error != nil {
+		// Fail-closed: treat DB errors as slug being taken to prevent conflicts
+		return true
+	}
 	return count > 0
 }
 

+ 59 - 0
model/user.go

@@ -429,6 +429,65 @@ func (user *User) Insert(inviterId int) error {
 	return nil
 }
 
+// InsertWithTx inserts a new user within an existing transaction.
+// This is used for OAuth registration where user creation and binding need to be atomic.
+// Post-creation tasks (sidebar config, logs, inviter rewards) are handled after the transaction commits.
+func (user *User) InsertWithTx(tx *gorm.DB, inviterId int) error {
+	var err error
+	if user.Password != "" {
+		user.Password, err = common.Password2Hash(user.Password)
+		if err != nil {
+			return err
+		}
+	}
+	user.Quota = common.QuotaForNewUser
+	user.AffCode = common.GetRandomString(4)
+
+	// 初始化用户设置
+	if user.Setting == "" {
+		defaultSetting := dto.UserSetting{}
+		user.SetSetting(defaultSetting)
+	}
+
+	result := tx.Create(user)
+	if result.Error != nil {
+		return result.Error
+	}
+
+	return nil
+}
+
+// FinalizeOAuthUserCreation performs post-transaction tasks for OAuth user creation.
+// This should be called after the transaction commits successfully.
+func (user *User) FinalizeOAuthUserCreation(inviterId int) {
+	// 用户创建成功后,根据角色初始化边栏配置
+	var createdUser User
+	if err := DB.Where("id = ?", user.Id).First(&createdUser).Error; err == nil {
+		defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
+		if defaultSidebarConfig != "" {
+			currentSetting := createdUser.GetSetting()
+			currentSetting.SidebarModules = defaultSidebarConfig
+			createdUser.SetSetting(currentSetting)
+			createdUser.Update(false)
+			common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
+		}
+	}
+
+	if common.QuotaForNewUser > 0 {
+		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
+	}
+	if inviterId != 0 {
+		if common.QuotaForInvitee > 0 {
+			_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
+			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
+		}
+		if common.QuotaForInviter > 0 {
+			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
+			_ = inviteUser(inviterId)
+		}
+	}
+}
+
 func (user *User) Update(updatePassword bool) error {
 	var err error
 	if updatePassword {

+ 28 - 6
model/user_oauth_binding.go

@@ -3,18 +3,17 @@ package model
 import (
 	"errors"
 	"time"
+
+	"gorm.io/gorm"
 )
 
 // UserOAuthBinding stores the binding relationship between users and custom OAuth providers
 type UserOAuthBinding struct {
 	Id             int       `json:"id" gorm:"primaryKey"`
-	UserId         int       `json:"user_id" gorm:"index;not null"`                                               // User ID
-	ProviderId     int       `json:"provider_id" gorm:"index;not null"`                                           // Custom OAuth provider ID
-	ProviderUserId string    `json:"provider_user_id" gorm:"type:varchar(256);not null"`                          // User ID from OAuth provider
+	UserId         int       `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"`                                        // User ID - one binding per user per provider
+	ProviderId     int       `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"`     // Custom OAuth provider ID
+	ProviderUserId string    `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"`           // User ID from OAuth provider - one OAuth account per provider
 	CreatedAt      time.Time `json:"created_at"`
-
-	// Composite unique index to prevent duplicate bindings
-	// One OAuth account can only be bound to one user
 }
 
 func (UserOAuthBinding) TableName() string {
@@ -82,6 +81,29 @@ func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
 	return DB.Create(binding).Error
 }
 
+// CreateUserOAuthBindingWithTx creates a new OAuth binding within a transaction
+func CreateUserOAuthBindingWithTx(tx *gorm.DB, binding *UserOAuthBinding) error {
+	if binding.UserId == 0 {
+		return errors.New("user ID is required")
+	}
+	if binding.ProviderId == 0 {
+		return errors.New("provider ID is required")
+	}
+	if binding.ProviderUserId == "" {
+		return errors.New("provider user ID is required")
+	}
+
+	// Check if this provider user ID is already taken (use tx to check within the same transaction)
+	var count int64
+	tx.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", binding.ProviderId, binding.ProviderUserId).Count(&count)
+	if count > 0 {
+		return errors.New("this OAuth account is already bound to another user")
+	}
+
+	binding.CreatedAt = time.Now()
+	return tx.Create(binding).Error
+}
+
 // UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
 func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
 	// Check if the new provider user ID is already taken by another user

+ 12 - 0
oauth/github.go

@@ -5,6 +5,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"io"
 	"net/http"
 	"strconv"
 	"time"
@@ -122,6 +123,17 @@ func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*O
 
 	logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
 
+	// Check for non-200 status codes before attempting to decode
+	if res.StatusCode != http.StatusOK {
+		body, _ := io.ReadAll(res.Body)
+		bodyStr := string(body)
+		if len(bodyStr) > 500 {
+			bodyStr = bodyStr[:500] + "..."
+		}
+		logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo failed: status=%d, body=%s", res.StatusCode, bodyStr))
+		return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, map[string]any{"Provider": "GitHub"}, fmt.Sprintf("status %d", res.StatusCode))
+	}
+
 	var githubUser gitHubUser
 	err = json.NewDecoder(res.Body).Decode(&githubUser)
 	if err != nil {

+ 4 - 2
web/src/components/settings/personal/cards/AccountManagement.jsx

@@ -107,9 +107,11 @@ const AccountManagement = ({
       const res = await API.get('/api/user/oauth/bindings');
       if (res.data.success) {
         setCustomOAuthBindings(res.data.data || []);
+      } else {
+        showError(res.data.message || t('获取绑定信息失败'));
       }
     } catch (error) {
-      // ignore
+      showError(error.response?.data?.message || error.message || t('获取绑定信息失败'));
     }
   };
 
@@ -131,7 +133,7 @@ const AccountManagement = ({
             showError(res.data.message);
           }
         } catch (error) {
-          showError(t('操作失败'));
+          showError(error.response?.data?.message || error.message || t('操作失败'));
         } finally {
           setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: false }));
         }