package model import ( "encoding/json" "errors" "fmt" "one-api/common" "one-api/dto" "strconv" "strings" "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" ) // User if you add sensitive fields, don't forget to clean them in setupLogin function. // Otherwise, the sensitive information will be saved on local storage in plain text! type User struct { Id int `json:"id"` Username string `json:"username" gorm:"unique;index" validate:"max=12"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database! DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` Role int `json:"role" gorm:"type:int;default:1"` // admin, common Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int `json:"quota" gorm:"type:int;default:0"` UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number Group string `json:"group" gorm:"type:varchar(64);default:'default'"` AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"` AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度 AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度 InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` DeletedAt gorm.DeletedAt `gorm:"index"` LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"` Setting string `json:"setting" gorm:"type:text;column:setting"` Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"` StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"` } func (user *User) ToBaseUser() *UserBase { cache := &UserBase{ Id: user.Id, Group: user.Group, Quota: user.Quota, Status: user.Status, Username: user.Username, Setting: user.Setting, Email: user.Email, } return cache } func (user *User) GetAccessToken() string { if user.AccessToken == nil { return "" } return *user.AccessToken } func (user *User) SetAccessToken(token string) { user.AccessToken = &token } func (user *User) GetSetting() dto.UserSetting { setting := dto.UserSetting{} if user.Setting != "" { err := json.Unmarshal([]byte(user.Setting), &setting) if err != nil { common.SysError("failed to unmarshal setting: " + err.Error()) } } return setting } func (user *User) SetSetting(setting dto.UserSetting) { settingBytes, err := json.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) return } user.Setting = string(settingBytes) } // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil func CheckUserExistOrDeleted(username string, email string) (bool, error) { var user User // err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error // check email if empty var err error if email == "" { err = DB.Unscoped().First(&user, "username = ?", username).Error } else { err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error } if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // not exist, return false, nil return false, nil } // other error, return false, err return false, err } // exist, return true, nil return true, nil } func GetMaxUserId() int { var user User DB.Unscoped().Last(&user) return user.Id } func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) { // Start transaction tx := DB.Begin() if tx.Error != nil { return nil, 0, tx.Error } defer func() { if r := recover(); r != nil { tx.Rollback() } }() // Get total count within transaction err = tx.Unscoped().Model(&User{}).Count(&total).Error if err != nil { tx.Rollback() return nil, 0, err } // Get paginated users within same transaction err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error if err != nil { tx.Rollback() return nil, 0, err } // Commit transaction if err = tx.Commit().Error; err != nil { return nil, 0, err } return users, total, nil } func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) { var users []*User var total int64 var err error // 开始事务 tx := DB.Begin() if tx.Error != nil { return nil, 0, tx.Error } defer func() { if r := recover(); r != nil { tx.Rollback() } }() // 构建基础查询 query := tx.Unscoped().Model(&User{}) // 构建搜索条件 likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?" // 尝试将关键字转换为整数ID keywordInt, err := strconv.Atoi(keyword) if err == nil { // 如果是数字,同时搜索ID和其他字段 likeCondition = "id = ? OR " + likeCondition if group != "" { query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?", keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) } else { query = query.Where(likeCondition, keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%") } } else { // 非数字关键字,只搜索字符串字段 if group != "" { query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) } else { query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%") } } // 获取总数 err = query.Count(&total).Error if err != nil { tx.Rollback() return nil, 0, err } // 获取分页数据 err = query.Omit("password").Order("id desc").Limit(num).Offset(startIdx).Find(&users).Error if err != nil { tx.Rollback() return nil, 0, err } // 提交事务 if err = tx.Commit().Error; err != nil { return nil, 0, err } return users, total, nil } func GetUserById(id int, selectAll bool) (*User, error) { if id == 0 { return nil, errors.New("id 为空!") } user := User{Id: id} var err error = nil if selectAll { err = DB.First(&user, "id = ?", id).Error } else { err = DB.Omit("password").First(&user, "id = ?", id).Error } return &user, err } func GetUserIdByAffCode(affCode string) (int, error) { if affCode == "" { return 0, errors.New("affCode 为空!") } var user User err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error return user.Id, err } func DeleteUserById(id int) (err error) { if id == 0 { return errors.New("id 为空!") } user := User{Id: id} return user.Delete() } func HardDeleteUserById(id int) error { if id == 0 { return errors.New("id 为空!") } err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error return err } func inviteUser(inviterId int) (err error) { user, err := GetUserById(inviterId, true) if err != nil { return err } user.AffCount++ user.AffQuota += common.QuotaForInviter user.AffHistoryQuota += common.QuotaForInviter return DB.Save(user).Error } func (user *User) TransferAffQuotaToQuota(quota int) error { // 检查quota是否小于最小额度 if float64(quota) < common.QuotaPerUnit { return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit))) } // 开始数据库事务 tx := DB.Begin() if tx.Error != nil { return tx.Error } defer tx.Rollback() // 确保在函数退出时事务能回滚 // 加锁查询用户以确保数据一致性 err := tx.Set("gorm:query_option", "FOR UPDATE").First(&user, user.Id).Error if err != nil { return err } // 再次检查用户的AffQuota是否足够 if user.AffQuota < quota { return errors.New("邀请额度不足!") } // 更新用户额度 user.AffQuota -= quota user.Quota += quota // 保存用户状态 if err := tx.Save(user).Error; err != nil { return err } // 提交事务 return tx.Commit().Error } func (user *User) Insert(inviterId int) error { var err error if user.Password != "" { user.Password, err = common.Password2Hash(user.Password) if err != nil { return err } } user.Quota = common.QuotaForNewUser //user.SetAccessToken(common.GetUUID()) user.AffCode = common.GetRandomString(4) result := DB.Create(user) if result.Error != nil { return result.Error } if common.QuotaForNewUser > 0 { RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) } if inviterId != 0 { if common.QuotaForInvitee > 0 { _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) } if common.QuotaForInviter > 0 { //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) _ = inviteUser(inviterId) } } return nil } func (user *User) Update(updatePassword bool) error { var err error if updatePassword { user.Password, err = common.Password2Hash(user.Password) if err != nil { return err } } newUser := *user DB.First(&user, user.Id) if err = DB.Model(user).Updates(newUser).Error; err != nil { return err } // Update cache return updateUserCache(*user) } func (user *User) Edit(updatePassword bool) error { var err error if updatePassword { user.Password, err = common.Password2Hash(user.Password) if err != nil { return err } } newUser := *user updates := map[string]interface{}{ "username": newUser.Username, "display_name": newUser.DisplayName, "group": newUser.Group, "quota": newUser.Quota, "remark": newUser.Remark, } if updatePassword { updates["password"] = newUser.Password } DB.First(&user, user.Id) if err = DB.Model(user).Updates(updates).Error; err != nil { return err } // Update cache return updateUserCache(*user) } func (user *User) Delete() error { if user.Id == 0 { return errors.New("id 为空!") } if err := DB.Delete(user).Error; err != nil { return err } // 清除缓存 return invalidateUserCache(user.Id) } func (user *User) HardDelete() error { if user.Id == 0 { return errors.New("id 为空!") } err := DB.Unscoped().Delete(user).Error return err } // ValidateAndFill check password & user status func (user *User) ValidateAndFill() (err error) { // When querying with struct, GORM will only query with non-zero fields, // that means if your field's value is 0, '', false or other zero values, // it won't be used to build query conditions password := user.Password username := strings.TrimSpace(user.Username) if username == "" || password == "" { return errors.New("用户名或密码为空") } // find buy username or email DB.Where("username = ? OR email = ?", username, username).First(user) okay := common.ValidatePasswordAndHash(password, user.Password) if !okay || user.Status != common.UserStatusEnabled { return errors.New("用户名或密码错误,或用户已被封禁") } return nil } func (user *User) FillUserById() error { if user.Id == 0 { return errors.New("id 为空!") } DB.Where(User{Id: user.Id}).First(user) return nil } func (user *User) FillUserByEmail() error { if user.Email == "" { return errors.New("email 为空!") } DB.Where(User{Email: user.Email}).First(user) return nil } func (user *User) FillUserByGitHubId() error { if user.GitHubId == "" { return errors.New("GitHub id 为空!") } DB.Where(User{GitHubId: user.GitHubId}).First(user) return nil } func (user *User) FillUserByOidcId() error { if user.OidcId == "" { return errors.New("oidc id 为空!") } DB.Where(User{OidcId: user.OidcId}).First(user) return nil } func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") } DB.Where(User{WeChatId: user.WeChatId}).First(user) return nil } func (user *User) FillUserByTelegramId() error { if user.TelegramId == "" { return errors.New("Telegram id 为空!") } err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("该 Telegram 账户未绑定") } return nil } func IsEmailAlreadyTaken(email string) bool { return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1 } func IsWeChatIdAlreadyTaken(wechatId string) bool { return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 } func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } func IsOidcIdAlreadyTaken(oidcId string) bool { return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 } func IsTelegramIdAlreadyTaken(telegramId string) bool { return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1 } func ResetUserPasswordByEmail(email string, password string) error { if email == "" || password == "" { return errors.New("邮箱地址或密码为空!") } hashedPassword, err := common.Password2Hash(password) if err != nil { return err } err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error return err } func IsAdmin(userId int) bool { if userId == 0 { return false } var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { common.SysError("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser } //// IsUserEnabled checks user status from Redis first, falls back to DB if needed //func IsUserEnabled(id int, fromDB bool) (status bool, err error) { // defer func() { // // Update Redis cache asynchronously on successful DB read // if shouldUpdateRedis(fromDB, err) { // gopool.Go(func() { // if err := updateUserStatusCache(id, status); err != nil { // common.SysError("failed to update user status cache: " + err.Error()) // } // }) // } // }() // if !fromDB && common.RedisEnabled { // // Try Redis first // status, err := getUserStatusCache(id) // if err == nil { // return status == common.UserStatusEnabled, nil // } // // Don't return error - fall through to DB // } // fromDB = true // var user User // err = DB.Where("id = ?", id).Select("status").Find(&user).Error // if err != nil { // return false, err // } // // return user.Status == common.UserStatusEnabled, nil //} func ValidateAccessToken(token string) (user *User) { if token == "" { return nil } token = strings.Replace(token, "Bearer ", "", 1) user = &User{} if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 { return user } return nil } // GetUserQuota gets quota from Redis first, falls back to DB if needed func GetUserQuota(id int, fromDB bool) (quota int, err error) { defer func() { // Update Redis cache asynchronously on successful DB read if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserQuotaCache(id, quota); err != nil { common.SysError("failed to update user quota cache: " + err.Error()) } }) } }() if !fromDB && common.RedisEnabled { quota, err := getUserQuotaCache(id) if err == nil { return quota, nil } // Don't return error - fall through to DB } fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error if err != nil { return 0, err } return quota, nil } func GetUserUsedQuota(id int) (quota int, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error return quota, err } func GetUserEmail(id int) (email string, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error return email, err } // GetUserGroup gets group from Redis first, falls back to DB if needed func GetUserGroup(id int, fromDB bool) (group string, err error) { defer func() { // Update Redis cache asynchronously on successful DB read if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserGroupCache(id, group); err != nil { common.SysError("failed to update user group cache: " + err.Error()) } }) } }() if !fromDB && common.RedisEnabled { group, err := getUserGroupCache(id) if err == nil { return group, nil } // Don't return error - fall through to DB } fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error if err != nil { return "", err } return group, nil } // GetUserSetting gets setting from Redis first, falls back to DB if needed func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) { var setting string defer func() { // Update Redis cache asynchronously on successful DB read if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserSettingCache(id, setting); err != nil { common.SysError("failed to update user setting cache: " + err.Error()) } }) } }() if !fromDB && common.RedisEnabled { setting, err := getUserSettingCache(id) if err == nil { return setting, nil } // Don't return error - fall through to DB } fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error if err != nil { return settingMap, err } userBase := &UserBase{ Setting: setting, } return userBase.GetSetting(), nil } func IncreaseUserQuota(id int, quota int, db bool) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } gopool.Go(func() { err := cacheIncrUserQuota(id, int64(quota)) if err != nil { common.SysError("failed to increase user quota: " + err.Error()) } }) if !db && common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, quota) return nil } return increaseUserQuota(id, quota) } func increaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error if err != nil { return err } return err } func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } gopool.Go(func() { err := cacheDecrUserQuota(id, int64(quota)) if err != nil { common.SysError("failed to decrease user quota: " + err.Error()) } }) if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, -quota) return nil } return decreaseUserQuota(id, quota) } func decreaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error if err != nil { return err } return err } func DeltaUpdateUserQuota(id int, delta int) (err error) { if delta == 0 { return nil } if delta > 0 { return IncreaseUserQuota(id, delta, false) } else { return DecreaseUserQuota(id, -delta) } } //func GetRootUserEmail() (email string) { // DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) // return email //} func GetRootUser() (user *User) { DB.Where("role = ?", common.RoleRootUser).First(&user) return user } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeRequestCount, id, 1) return } updateUserUsedQuotaAndRequestCount(id, quota, 1) } func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), "request_count": gorm.Expr("request_count + ?", count), }, ).Error if err != nil { common.SysError("failed to update user used quota and request count: " + err.Error()) return } //// 更新缓存 //if err := invalidateUserCache(id); err != nil { // common.SysError("failed to invalidate user cache: " + err.Error()) //} } func updateUserUsedQuota(id int, quota int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), }, ).Error if err != nil { common.SysError("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { common.SysError("failed to update user request count: " + err.Error()) } } // GetUsernameById gets username from Redis first, falls back to DB if needed func GetUsernameById(id int, fromDB bool) (username string, err error) { defer func() { // Update Redis cache asynchronously on successful DB read if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserNameCache(id, username); err != nil { common.SysError("failed to update user name cache: " + err.Error()) } }) } }() if !fromDB && common.RedisEnabled { username, err := getUserNameCache(id) if err == nil { return username, nil } // Don't return error - fall through to DB } fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error if err != nil { return "", err } return username, nil } func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool { var user User err := DB.Unscoped().Where("linux_do_id = ?", linuxDOId).First(&user).Error return !errors.Is(err, gorm.ErrRecordNotFound) } func (user *User) FillUserByLinuxDOId() error { if user.LinuxDOId == "" { return errors.New("linux do id is empty") } err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error return err } func RootUserExists() bool { var user User err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error if err != nil { return false } return true }