Quellcode durchsuchen

Merge pull request #770 from Calcium-Ion/refactor_notify

feat: Add user notification settings and multiple notification methods
Calcium-Ion vor 10 Monaten
Ursprung
Commit
2b7435500c

+ 2 - 0
README.en.md

@@ -89,6 +89,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
 - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
 - `CRYPTO_SECRET`: Encryption key for encrypting database content
 - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
+- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
+- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
 
 ## Deployment
 

+ 3 - 0
README.md

@@ -95,6 +95,9 @@
 - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
 - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
 - `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview`
+- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。
+- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。
+
 ## 部署
 
 > [!TIP]

+ 1 - 1
common/constants.go

@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
 
 var RetryTimes = 0
 
-var RootUserEmail = ""
+//var RootUserEmail = ""
 
 var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
 

+ 8 - 0
common/logger.go

@@ -100,6 +100,14 @@ func LogQuota(quota int) string {
 	}
 }
 
+func FormatQuota(quota int) string {
+	if DisplayInCurrencyEnabled {
+		return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
+	} else {
+		return fmt.Sprintf("%d", quota)
+	}
+}
+
 // LogJson 仅供测试使用 only for test
 func LogJson(ctx context.Context, msg string, obj any) {
 	jsonStr, err := json.Marshal(obj)

+ 17 - 23
common/model-ratio.go

@@ -233,7 +233,11 @@ var (
 	modelRatioMapMutex                    = sync.RWMutex{}
 )
 
-var CompletionRatio map[string]float64 = nil
+var (
+	CompletionRatio      map[string]float64 = nil
+	CompletionRatioMutex                    = sync.RWMutex{}
+)
+
 var defaultCompletionRatio = map[string]float64{
 	"gpt-4-gizmo-*":  2,
 	"gpt-4o-gizmo-*": 3,
@@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
 	return defaultModelRatio
 }
 
-func CompletionRatio2JSONString() string {
+func GetCompletionRatioMap() map[string]float64 {
+	CompletionRatioMutex.Lock()
+	defer CompletionRatioMutex.Unlock()
 	if CompletionRatio == nil {
 		CompletionRatio = defaultCompletionRatio
 	}
+	return CompletionRatio
+}
+
+func CompletionRatio2JSONString() string {
+	GetCompletionRatioMap()
 	jsonBytes, err := json.Marshal(CompletionRatio)
 	if err != nil {
 		SysError("error marshalling completion ratio: " + err.Error())
@@ -346,11 +357,15 @@ func CompletionRatio2JSONString() string {
 }
 
 func UpdateCompletionRatioByJSONString(jsonStr string) error {
+	CompletionRatioMutex.Lock()
+	defer CompletionRatioMutex.Unlock()
 	CompletionRatio = make(map[string]float64)
 	return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
 }
 
 func GetCompletionRatio(name string) float64 {
+	GetCompletionRatioMap()
+
 	if strings.Contains(name, "/") {
 		if ratio, ok := CompletionRatio[name]; ok {
 			return ratio
@@ -476,24 +491,3 @@ func GetAudioCompletionRatio(name string) float64 {
 	}
 	return 2
 }
-
-//func GetAudioPricePerMinute(name string) float64 {
-//	if strings.HasPrefix(name, "gpt-4o-realtime") {
-//		return 0.06
-//	}
-//	return 0.06
-//}
-//
-//func GetAudioCompletionPricePerMinute(name string) float64 {
-//	if strings.HasPrefix(name, "gpt-4o-realtime") {
-//		return 0.24
-//	}
-//	return 0.24
-//}
-
-func GetCompletionRatioMap() map[string]float64 {
-	if CompletionRatio == nil {
-		CompletionRatio = defaultCompletionRatio
-	}
-	return CompletionRatio
-}

+ 4 - 1
constant/env.go

@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
 
 var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
 
+var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
+var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
+
 func InitEnv() {
 	modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
 	if modelVersionMapStr == "" {
@@ -44,5 +47,5 @@ func InitEnv() {
 	}
 }
 
-// 是否生成初始令牌,默认关闭。
+// GenerateDefaultToken 是否生成初始令牌,默认关闭。
 var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

+ 14 - 0
constant/user_setting.go

@@ -0,0 +1,14 @@
+package constant
+
+var (
+	UserSettingNotifyType            = "notify_type"             // QuotaWarningType 额度预警类型
+	UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
+	UserSettingWebhookUrl            = "webhook_url"             // WebhookUrl webhook地址
+	UserSettingWebhookSecret         = "webhook_secret"          // WebhookSecret webhook密钥
+	UserSettingNotificationEmail     = "notification_email"      // NotificationEmail 通知邮箱地址
+)
+
+var (
+	NotifyTypeEmail   = "email"   // Email 邮件
+	NotifyTypeWebhook = "webhook" // Webhook
+)

+ 2 - 7
controller/channel-test.go

@@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex
 var testAllChannelsRunning bool = false
 
 func testAllChannels(notify bool) error {
-	if common.RootUserEmail == "" {
-		common.RootUserEmail = model.GetRootUserEmail()
-	}
+
 	testAllChannelsLock.Lock()
 	if testAllChannelsRunning {
 		testAllChannelsLock.Unlock()
@@ -295,10 +293,7 @@ func testAllChannels(notify bool) error {
 		testAllChannelsRunning = false
 		testAllChannelsLock.Unlock()
 		if notify {
-			err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
-			if err != nil {
-				common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
-			}
+			service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
 		}
 	})
 	return nil

+ 1 - 1
controller/pricing.go

@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
 	}
 	var group string
 	if exists {
-		user, err := model.GetUserById(userId.(int), false)
+		user, err := model.GetUserCache(userId.(int))
 		if err == nil {
 			group = user.Group
 		}

+ 114 - 4
controller/user.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"net/http"
+	"net/url"
 	"one-api/common"
 	"one-api/model"
 	"one-api/setting"
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
 	if err != nil {
 		id = c.GetInt("id")
 	}
-	user, err := model.GetUserById(id, true)
+	user, err := model.GetUserCache(id)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -869,9 +870,6 @@ func EmailBind(c *gin.Context) {
 		})
 		return
 	}
-	if user.Role == common.RoleRootUser {
-		common.RootUserEmail = email
-	}
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"message": "",
@@ -913,3 +911,115 @@ func TopUp(c *gin.Context) {
 	})
 	return
 }
+
+type UpdateUserSettingRequest struct {
+	QuotaWarningType      string `json:"notify_type"`
+	QuotaWarningThreshold int    `json:"quota_warning_threshold"`
+	WebhookUrl            string `json:"webhook_url,omitempty"`
+	WebhookSecret         string `json:"webhook_secret,omitempty"`
+	NotificationEmail     string `json:"notification_email,omitempty"`
+}
+
+func UpdateUserSetting(c *gin.Context) {
+	var req UpdateUserSettingRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "无效的参数",
+		})
+		return
+	}
+
+	// 验证预警类型
+	if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "无效的预警类型",
+		})
+		return
+	}
+
+	// 验证预警阈值
+	if req.QuotaWarningThreshold <= 0 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "预警阈值必须大于0",
+		})
+		return
+	}
+
+	// 如果是webhook类型,验证webhook地址
+	if req.QuotaWarningType == constant.NotifyTypeWebhook {
+		if req.WebhookUrl == "" {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "Webhook地址不能为空",
+			})
+			return
+		}
+		// 验证URL格式
+		if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "无效的Webhook地址",
+			})
+			return
+		}
+	}
+
+	// 如果是邮件类型,验证邮箱地址
+	if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
+		// 验证邮箱格式
+		if !strings.Contains(req.NotificationEmail, "@") {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "无效的邮箱地址",
+			})
+			return
+		}
+	}
+
+	userId := c.GetInt("id")
+	user, err := model.GetUserById(userId, true)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	// 构建设置
+	settings := map[string]interface{}{
+		constant.UserSettingNotifyType:            req.QuotaWarningType,
+		constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
+	}
+
+	// 如果是webhook类型,添加webhook相关设置
+	if req.QuotaWarningType == constant.NotifyTypeWebhook {
+		settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
+		if req.WebhookSecret != "" {
+			settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
+		}
+	}
+
+	// 如果提供了通知邮箱,添加到设置中
+	if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
+		settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
+	}
+
+	// 更新用户设置
+	user.SetSetting(settings)
+	if err := user.Update(false); err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "更新设置失败: " + err.Error(),
+		})
+		return
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "设置已更新",
+	})
+}

+ 1 - 1
docker-compose.yml

@@ -24,7 +24,7 @@ services:
       - redis
       - mysql
     healthcheck:
-      test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
+      test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
       interval: 30s
       timeout: 10s
       retries: 3

+ 25 - 0
dto/notify.go

@@ -0,0 +1,25 @@
+package dto
+
+type Notify struct {
+	Type    string        `json:"type"`
+	Title   string        `json:"title"`
+	Content string        `json:"content"`
+	Values  []interface{} `json:"values"`
+}
+
+const ContentValueParam = "{{value}}"
+
+const (
+	NotifyTypeQuotaExceed   = "quota_exceed"
+	NotifyTypeChannelUpdate = "channel_update"
+	NotifyTypeChannelTest   = "channel_test"
+)
+
+func NewNotify(t string, title string, content string, values []interface{}) Notify {
+	return Notify{
+		Type:    t,
+		Title:   title,
+		Content: content,
+		Values:  values,
+	}
+}

+ 1 - 0
dto/openai_request.go

@@ -18,6 +18,7 @@ type GeneralOpenAIRequest struct {
 	Model               string          `json:"model,omitempty"`
 	Messages            []Message       `json:"messages,omitempty"`
 	Prompt              any             `json:"prompt,omitempty"`
+	Prefix              any             `json:"prefix,omitempty"`
 	Suffix              any             `json:"suffix,omitempty"`
 	Stream              bool            `json:"stream,omitempty"`
 	StreamOptions       *StreamOptions  `json:"stream_options,omitempty"`

+ 3 - 82
model/token.go

@@ -3,13 +3,11 @@ package model
 import (
 	"errors"
 	"fmt"
-	"github.com/bytedance/gopkg/util/gopool"
-	"gorm.io/gorm"
 	"one-api/common"
-	relaycommon "one-api/relay/common"
-	"one-api/setting"
-	"strconv"
 	"strings"
+
+	"github.com/bytedance/gopkg/util/gopool"
+	"gorm.io/gorm"
 )
 
 type Token struct {
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
 	).Error
 	return err
 }
-
-func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
-	if quota < 0 {
-		return errors.New("quota 不能为负数!")
-	}
-	if relayInfo.IsPlayground {
-		return nil
-	}
-	//if relayInfo.TokenUnlimited {
-	//	return nil
-	//}
-	token, err := GetTokenById(relayInfo.TokenId)
-	if err != nil {
-		return err
-	}
-	if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
-		return errors.New("令牌额度不足")
-	}
-	err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
-	if err != nil {
-		return err
-	}
-	return nil
-}
-
-func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
-
-	if quota > 0 {
-		err = DecreaseUserQuota(relayInfo.UserId, quota)
-	} else {
-		err = IncreaseUserQuota(relayInfo.UserId, -quota)
-	}
-	if err != nil {
-		return err
-	}
-
-	if !relayInfo.IsPlayground {
-		if quota > 0 {
-			err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
-		} else {
-			err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
-		}
-		if err != nil {
-			return err
-		}
-	}
-
-	if sendEmail {
-		if (quota + preConsumedQuota) != 0 {
-			quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
-			noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
-			if quotaTooLow || noMoreQuota {
-				go func() {
-					email, err := GetUserEmail(relayInfo.UserId)
-					if err != nil {
-						common.SysError("failed to fetch user email: " + err.Error())
-					}
-					prompt := "您的额度即将用尽"
-					if noMoreQuota {
-						prompt = "您的额度已用尽"
-					}
-					if email != "" {
-						topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
-						err = common.SendEmail(prompt, email,
-							fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
-						if err != nil {
-							common.SysError("failed to send email" + err.Error())
-						}
-						common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
-					}
-				}()
-			}
-		}
-	}
-
-	return nil
-}

+ 1 - 1
model/token_cache.go

@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
 func cacheGetTokenByKey(key string) (*Token, error) {
 	hmacKey := common.GenerateHMAC(key)
 	if !common.RedisEnabled {
-		return nil, nil
+		return nil, fmt.Errorf("redis is not enabled")
 	}
 	var token Token
 	err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)

+ 77 - 13
model/user.go

@@ -1,6 +1,7 @@
 package model
 
 import (
+	"encoding/json"
 	"errors"
 	"fmt"
 	"one-api/common"
@@ -38,6 +39,20 @@ type User struct {
 	InviterId        int            `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
 	DeletedAt        gorm.DeletedAt `gorm:"index"`
 	LinuxDOId        string         `json:"linux_do_id" gorm:"column:linux_do_id;index"`
+	Setting          string         `json:"setting" gorm:"type:text;column:setting"`
+}
+
+func (user *User) ToBaseUser() *UserBase {
+	cache := &UserBase{
+		Id:       user.Id,
+		Group:    user.Group,
+		Quota:    user.Quota,
+		Status:   user.Status,
+		Username: user.Username,
+		Setting:  user.Setting,
+		Email:    user.Email,
+	}
+	return cache
 }
 
 func (user *User) GetAccessToken() string {
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
 	user.AccessToken = &token
 }
 
+func (user *User) GetSetting() map[string]interface{} {
+	if user.Setting == "" {
+		return nil
+	}
+	return common.StrToMap(user.Setting)
+}
+
+func (user *User) SetSetting(setting map[string]interface{}) {
+	settingBytes, err := json.Marshal(setting)
+	if err != nil {
+		common.SysError("failed to marshal setting: " + err.Error())
+		return
+	}
+	user.Setting = string(settingBytes)
+}
+
 // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
 func CheckUserExistOrDeleted(username string, email string) (bool, error) {
 	var user User
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
 		return err
 	}
 
-	// 更新缓存
-	return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
+	// Update cache
+	return updateUserCache(*user)
 }
 
 func (user *User) Edit(updatePassword bool) error {
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
 		return err
 	}
 
-	// 更新缓存
-	return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
+	// Update cache
+	return updateUserCache(*user)
 }
 
 func (user *User) Delete() error {
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
 // ValidateAndFill check password & user status
 func (user *User) ValidateAndFill() (err error) {
 	// When querying with struct, GORM will only query with non-zero fields,
-	// that means if your fields value is 0, '', false or other zero values,
-	// it wont be used to build query conditions
+	// that means if your field's value is 0, '', false or other zero values,
+	// it won't be used to build query conditions
 	password := user.Password
 	username := strings.TrimSpace(user.Username)
 	if username == "" || password == "" {
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
 			return quota, nil
 		}
 		// Don't return error - fall through to DB
-		//common.SysError("failed to get user quota from cache: " + err.Error())
 	}
 	fromDB = true
 	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
@@ -580,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
 	return group, nil
 }
 
+// GetUserSetting gets setting from Redis first, falls back to DB if needed
+func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
+	var setting string
+	defer func() {
+		// Update Redis cache asynchronously on successful DB read
+		if shouldUpdateRedis(fromDB, err) {
+			gopool.Go(func() {
+				if err := updateUserSettingCache(id, setting); err != nil {
+					common.SysError("failed to update user setting cache: " + err.Error())
+				}
+			})
+		}
+	}()
+	if !fromDB && common.RedisEnabled {
+		setting, err := getUserSettingCache(id)
+		if err == nil {
+			return setting, nil
+		}
+		// Don't return error - fall through to DB
+	}
+	fromDB = true
+	err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
+	if err != nil {
+		return map[string]interface{}{}, err
+	}
+
+	return common.StrToMap(setting), nil
+}
+
 func IncreaseUserQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
@@ -641,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
 	}
 }
 
-func GetRootUserEmail() (email string) {
-	DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
-	return email
+//func GetRootUserEmail() (email string) {
+//	DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
+//	return email
+//}
+
+func GetRootUser() (user *User) {
+	DB.Where("role = ?", common.RoleRootUser).First(&user)
+	return user
 }
 
 func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
 	return !errors.Is(err, gorm.ErrRecordNotFound)
 }
 
-func (u *User) FillUserByLinuxDOId() error {
-	if u.LinuxDOId == "" {
+func (user *User) FillUserByLinuxDOId() error {
+	if user.LinuxDOId == "" {
 		return errors.New("linux do id is empty")
 	}
-	err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
+	err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
 	return err
 }

+ 128 - 121
model/user_cache.go

@@ -1,206 +1,213 @@
 package model
 
 import (
+	"encoding/json"
 	"fmt"
 	"one-api/common"
 	"one-api/constant"
-	"strconv"
 	"time"
+
+	"github.com/bytedance/gopkg/util/gopool"
 )
 
-// Change UserCache struct to userCache
-type userCache struct {
+// UserBase struct remains the same as it represents the cached data structure
+type UserBase struct {
 	Id       int    `json:"id"`
 	Group    string `json:"group"`
+	Email    string `json:"email"`
 	Quota    int    `json:"quota"`
 	Status   int    `json:"status"`
-	Role     int    `json:"role"`
 	Username string `json:"username"`
+	Setting  string `json:"setting"`
 }
 
-// Rename all exported functions to private ones
-// invalidateUserCache clears all user related cache
-func invalidateUserCache(userId int) error {
-	if !common.RedisEnabled {
+func (user *UserBase) GetSetting() map[string]interface{} {
+	if user.Setting == "" {
 		return nil
 	}
-
-	keys := []string{
-		fmt.Sprintf(constant.UserGroupKeyFmt, userId),
-		fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
-		fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
-		fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
-	}
-
-	for _, key := range keys {
-		if err := common.RedisDel(key); err != nil {
-			return fmt.Errorf("failed to delete cache key %s: %w", key, err)
-		}
-	}
-	return nil
+	return common.StrToMap(user.Setting)
 }
 
-// updateUserGroupCache updates user group cache
-func updateUserGroupCache(userId int, group string) error {
-	if !common.RedisEnabled {
-		return nil
+func (user *UserBase) SetSetting(setting map[string]interface{}) {
+	settingBytes, err := json.Marshal(setting)
+	if err != nil {
+		common.SysError("failed to marshal setting: " + err.Error())
+		return
 	}
-	return common.RedisSet(
-		fmt.Sprintf(constant.UserGroupKeyFmt, userId),
-		group,
-		time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
-	)
+	user.Setting = string(settingBytes)
 }
 
-// updateUserQuotaCache updates user quota cache
-func updateUserQuotaCache(userId int, quota int) error {
-	if !common.RedisEnabled {
-		return nil
-	}
-	return common.RedisSet(
-		fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
-		fmt.Sprintf("%d", quota),
-		time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
-	)
+// getUserCacheKey returns the key for user cache
+func getUserCacheKey(userId int) string {
+	return fmt.Sprintf("user:%d", userId)
 }
 
-// updateUserStatusCache updates user status cache
-func updateUserStatusCache(userId int, userEnabled bool) error {
+// invalidateUserCache clears user cache
+func invalidateUserCache(userId int) error {
 	if !common.RedisEnabled {
 		return nil
 	}
-	enabled := "0"
-	if userEnabled {
-		enabled = "1"
-	}
-	return common.RedisSet(
-		fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
-		enabled,
-		time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
-	)
+	return common.RedisHDelObj(getUserCacheKey(userId))
 }
 
-// updateUserNameCache updates username cache
-func updateUserNameCache(userId int, username string) error {
+// updateUserCache updates all user cache fields using hash
+func updateUserCache(user User) error {
 	if !common.RedisEnabled {
 		return nil
 	}
-	return common.RedisSet(
-		fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
-		username,
+
+	return common.RedisHSetObj(
+		getUserCacheKey(user.Id),
+		user.ToBaseUser(),
 		time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
 	)
 }
 
-// updateUserCache updates all user cache fields
-func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
-	if !common.RedisEnabled {
-		return nil
+// GetUserCache gets complete user cache from hash
+func GetUserCache(userId int) (userCache *UserBase, err error) {
+	var user *User
+	var fromDB bool
+	defer func() {
+		// Update Redis cache asynchronously on successful DB read
+		if shouldUpdateRedis(fromDB, err) && user != nil {
+			gopool.Go(func() {
+				if err := updateUserCache(*user); err != nil {
+					common.SysError("failed to update user status cache: " + err.Error())
+				}
+			})
+		}
+	}()
+
+	// Try getting from Redis first
+	userCache, err = cacheGetUserBase(userId)
+	if err == nil {
+		return userCache, nil
 	}
 
-	if err := updateUserGroupCache(userId, userGroup); err != nil {
-		return fmt.Errorf("update group cache: %w", err)
+	// If Redis fails, get from DB
+	fromDB = true
+	user, err = GetUserById(userId, false)
+	if err != nil {
+		return nil, err // Return nil and error if DB lookup fails
 	}
 
-	if err := updateUserQuotaCache(userId, quota); err != nil {
-		return fmt.Errorf("update quota cache: %w", err)
+	// Create cache object from user data
+	userCache = &UserBase{
+		Id:       user.Id,
+		Group:    user.Group,
+		Quota:    user.Quota,
+		Status:   user.Status,
+		Username: user.Username,
+		Setting:  user.Setting,
+		Email:    user.Email,
 	}
 
-	if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
-		return fmt.Errorf("update status cache: %w", err)
+	return userCache, nil
+}
+
+func cacheGetUserBase(userId int) (*UserBase, error) {
+	if !common.RedisEnabled {
+		return nil, fmt.Errorf("redis is not enabled")
 	}
+	var userCache UserBase
+	// Try getting from Redis first
+	err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
+	if err != nil {
+		return nil, err
+	}
+	return &userCache, nil
+}
 
-	if err := updateUserNameCache(userId, username); err != nil {
-		return fmt.Errorf("update username cache: %w", err)
+// Add atomic quota operations using hash fields
+func cacheIncrUserQuota(userId int, delta int64) error {
+	if !common.RedisEnabled {
+		return nil
 	}
+	return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
+}
 
-	return nil
+func cacheDecrUserQuota(userId int, delta int64) error {
+	return cacheIncrUserQuota(userId, -delta)
 }
 
-// getUserGroupCache gets user group from cache
+// Helper functions to get individual fields if needed
 func getUserGroupCache(userId int) (string, error) {
-	if !common.RedisEnabled {
-		return "", nil
+	cache, err := GetUserCache(userId)
+	if err != nil {
+		return "", err
 	}
-	return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
+	return cache.Group, nil
 }
 
-// getUserQuotaCache gets user quota from cache
 func getUserQuotaCache(userId int) (int, error) {
-	if !common.RedisEnabled {
-		return 0, nil
-	}
-	quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
+	cache, err := GetUserCache(userId)
 	if err != nil {
 		return 0, err
 	}
-	return strconv.Atoi(quotaStr)
+	return cache.Quota, nil
 }
 
-// getUserStatusCache gets user status from cache
 func getUserStatusCache(userId int) (int, error) {
-	if !common.RedisEnabled {
-		return 0, nil
-	}
-	statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
+	cache, err := GetUserCache(userId)
 	if err != nil {
 		return 0, err
 	}
-	return strconv.Atoi(statusStr)
+	return cache.Status, nil
 }
 
-// getUserNameCache gets username from cache
 func getUserNameCache(userId int) (string, error) {
-	if !common.RedisEnabled {
-		return "", nil
+	cache, err := GetUserCache(userId)
+	if err != nil {
+		return "", err
 	}
-	return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
+	return cache.Username, nil
 }
 
-// getUserCache gets complete user cache
-func getUserCache(userId int) (*userCache, error) {
-	if !common.RedisEnabled {
-		return nil, nil
-	}
-
-	group, err := getUserGroupCache(userId)
+func getUserSettingCache(userId int) (map[string]interface{}, error) {
+	setting := make(map[string]interface{})
+	cache, err := GetUserCache(userId)
 	if err != nil {
-		return nil, fmt.Errorf("get group cache: %w", err)
+		return setting, err
 	}
+	return cache.GetSetting(), nil
+}
 
-	quota, err := getUserQuotaCache(userId)
-	if err != nil {
-		return nil, fmt.Errorf("get quota cache: %w", err)
+// New functions for individual field updates
+func updateUserStatusCache(userId int, status bool) error {
+	if !common.RedisEnabled {
+		return nil
 	}
-
-	status, err := getUserStatusCache(userId)
-	if err != nil {
-		return nil, fmt.Errorf("get status cache: %w", err)
+	statusInt := common.UserStatusEnabled
+	if !status {
+		statusInt = common.UserStatusDisabled
 	}
+	return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
+}
 
-	username, err := getUserNameCache(userId)
-	if err != nil {
-		return nil, fmt.Errorf("get username cache: %w", err)
+func updateUserQuotaCache(userId int, quota int) error {
+	if !common.RedisEnabled {
+		return nil
 	}
+	return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
+}
 
-	return &userCache{
-		Id:       userId,
-		Group:    group,
-		Quota:    quota,
-		Status:   status,
-		Username: username,
-	}, nil
+func updateUserGroupCache(userId int, group string) error {
+	if !common.RedisEnabled {
+		return nil
+	}
+	return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
 }
 
-// Add atomic quota operations
-func cacheIncrUserQuota(userId int, delta int64) error {
+func updateUserNameCache(userId int, username string) error {
 	if !common.RedisEnabled {
 		return nil
 	}
-	key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
-	return common.RedisIncr(key, delta)
+	return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
 }
 
-func cacheDecrUserQuota(userId int, delta int64) error {
-	return cacheIncrUserQuota(userId, -delta)
+func updateUserSettingCache(userId int, setting string) error {
+	if !common.RedisEnabled {
+		return nil
+	}
+	return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
 }

+ 2 - 1
relay/channel/cloudflare/adaptor.go

@@ -4,13 +4,14 @@ import (
 	"bytes"
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
+
+	"github.com/gin-gonic/gin"
 )
 
 type Adaptor struct {

+ 2 - 2
relay/relay-mj.go

@@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 	}
 	defer func(ctx context.Context) {
 		if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
-			err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
+			err := service.PostConsumeQuota(relayInfo, quota, 0, true)
 			if err != nil {
 				common.SysError("error consuming token remain quota: " + err.Error())
 			}
@@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 
 	defer func(ctx context.Context) {
 		if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
-			err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
+			err := service.PostConsumeQuota(relayInfo, quota, 0, true)
 			if err != nil {
 				common.SysError("error consuming token remain quota: " + err.Error())
 			}

+ 6 - 6
relay/relay-text.go

@@ -272,7 +272,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 		return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
 	if userQuota-preConsumedQuota < 0 {
-		return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
+		return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %d", common.FormatQuota(userQuota), preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
 	}
 	if userQuota > 100*preConsumedQuota {
 		// 用户额度充足,判断令牌额度是否充足
@@ -282,18 +282,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 			if tokenQuota > 100*preConsumedQuota {
 				// 令牌额度充足,信任令牌
 				preConsumedQuota = 0
-				common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
+				common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
 			}
 		} else {
 			// in this case, we do not pre-consume quota
 			// because the user has enough quota
 			preConsumedQuota = 0
-			common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
+			common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
 		}
 	}
 
 	if preConsumedQuota > 0 {
-		err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
+		err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
 		if err != nil {
 			return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
@@ -310,7 +310,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
 		go func() {
 			relayInfoCopy := *relayInfo
 
-			err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
+			err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
 			if err != nil {
 				common.SysError("error return pre-consumed quota: " + err.Error())
 			}
@@ -368,7 +368,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
 		//}
 		quotaDelta := quota - preConsumedQuota
 		if quotaDelta != 0 {
-			err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
+			err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
 			if err != nil {
 				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 			}

+ 1 - 1
relay/relay_task.go

@@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
 		// release quota
 		if relayInfo.ConsumeQuota && taskErr == nil {
 
-			err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
+			err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
 			if err != nil {
 				common.SysError("error consuming token remain quota: " + err.Error())
 			}

+ 1 - 0
router/api-router.go

@@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) {
 				selfRoute.POST("/pay", controller.RequestEpay)
 				selfRoute.POST("/amount", controller.RequestAmount)
 				selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
+				selfRoute.PUT("/setting", controller.UpdateUserSetting)
 			}
 
 			adminRoute := userRoute.Group("/")

+ 37 - 9
service/cf_worker.go

@@ -2,6 +2,7 @@ package service
 
 import (
 	"bytes"
+	"encoding/json"
 	"fmt"
 	"net/http"
 	"one-api/common"
@@ -9,19 +10,46 @@ import (
 	"strings"
 )
 
+// WorkerRequest Worker请求的数据结构
+type WorkerRequest struct {
+	URL     string            `json:"url"`
+	Key     string            `json:"key"`
+	Method  string            `json:"method,omitempty"`
+	Headers map[string]string `json:"headers,omitempty"`
+	Body    json.RawMessage   `json:"body,omitempty"`
+}
+
+// DoWorkerRequest 通过Worker发送请求
+func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
+	if !setting.EnableWorker() {
+		return nil, fmt.Errorf("worker not enabled")
+	}
+	if !strings.HasPrefix(req.URL, "https") {
+		return nil, fmt.Errorf("only support https url")
+	}
+
+	workerUrl := setting.WorkerUrl
+	if !strings.HasSuffix(workerUrl, "/") {
+		workerUrl += "/"
+	}
+
+	// 序列化worker请求数据
+	workerPayload, err := json.Marshal(req)
+	if err != nil {
+		return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
+	}
+
+	return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
+}
+
 func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
 	if setting.EnableWorker() {
 		common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
-		if !strings.HasPrefix(originUrl, "https") {
-			return nil, fmt.Errorf("only support https url")
-		}
-		workerUrl := setting.WorkerUrl
-		if !strings.HasSuffix(workerUrl, "/") {
-			workerUrl += "/"
+		req := &WorkerRequest{
+			URL: originUrl,
+			Key: setting.WorkerValidKey,
 		}
-		// post request to worker
-		data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
-		return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
+		return DoWorkerRequest(req)
 	} else {
 		common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
 		return http.Get(originUrl)

+ 5 - 5
service/channel.go

@@ -4,7 +4,7 @@ import (
 	"fmt"
 	"net/http"
 	"one-api/common"
-	relaymodel "one-api/dto"
+	"one-api/dto"
 	"one-api/model"
 	"one-api/setting"
 	"strings"
@@ -15,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) {
 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
-	notifyRootUser(subject, content)
+	NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
 }
 
 func EnableChannel(channelId int, channelName string) {
 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
-	notifyRootUser(subject, content)
+	NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
 }
 
-func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
+func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
 	if !common.AutomaticDisableChannelEnabled {
 		return false
 	}
@@ -75,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
 	return false
 }
 
-func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
+func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
 	if !common.AutomaticEnableChannelEnabled {
 		return false
 	}

+ 116 - 0
service/notify-limit.go

@@ -0,0 +1,116 @@
+package service
+
+import (
+	"fmt"
+	"one-api/common"
+	"one-api/constant"
+	"strconv"
+	"sync"
+	"time"
+)
+
+// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
+var (
+	notifyLimitStore sync.Map
+	cleanupOnce      sync.Once
+)
+
+type limitCount struct {
+	Count     int
+	Timestamp time.Time
+}
+
+func getDuration() time.Duration {
+	minute := constant.NotificationLimitDurationMinute
+	return time.Duration(minute) * time.Minute
+}
+
+// startCleanupTask starts a background task to clean up expired entries
+func startCleanupTask() {
+	go func() {
+		for {
+			time.Sleep(time.Hour)
+			now := time.Now()
+			notifyLimitStore.Range(func(key, value interface{}) bool {
+				if limit, ok := value.(limitCount); ok {
+					if now.Sub(limit.Timestamp) >= getDuration() {
+						notifyLimitStore.Delete(key)
+					}
+				}
+				return true
+			})
+		}
+	}()
+}
+
+// CheckNotificationLimit checks if the user has exceeded their notification limit
+// Returns true if the user can send notification, false if limit exceeded
+func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
+	if common.RedisEnabled {
+		return checkRedisLimit(userId, notifyType)
+	}
+	return checkMemoryLimit(userId, notifyType)
+}
+
+func checkRedisLimit(userId int, notifyType string) (bool, error) {
+	key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+
+	// Get current count
+	count, err := common.RedisGet(key)
+	if err != nil && err.Error() != "redis: nil" {
+		return false, fmt.Errorf("failed to get notification count: %w", err)
+	}
+
+	// If key doesn't exist, initialize it
+	if count == "" {
+		err = common.RedisSet(key, "1", getDuration())
+		return true, err
+	}
+
+	currentCount, _ := strconv.Atoi(count)
+	limit := constant.NotifyLimitCount
+
+	// Check if limit is already reached
+	if currentCount >= limit {
+		return false, nil
+	}
+
+	// Only increment if under limit
+	err = common.RedisIncr(key, 1)
+	if err != nil {
+		return false, fmt.Errorf("failed to increment notification count: %w", err)
+	}
+
+	return true, nil
+}
+
+func checkMemoryLimit(userId int, notifyType string) (bool, error) {
+	// Ensure cleanup task is started
+	cleanupOnce.Do(startCleanupTask)
+
+	key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+	now := time.Now()
+
+	// Get current limit count or initialize new one
+	var currentLimit limitCount
+	if value, ok := notifyLimitStore.Load(key); ok {
+		currentLimit = value.(limitCount)
+		// Check if the entry has expired
+		if now.Sub(currentLimit.Timestamp) >= getDuration() {
+			currentLimit = limitCount{Count: 0, Timestamp: now}
+		}
+	} else {
+		currentLimit = limitCount{Count: 0, Timestamp: now}
+	}
+
+	// Increment count
+	currentLimit.Count++
+
+	// Check against limits
+	limit := constant.NotifyLimitCount
+
+	// Store updated count
+	notifyLimitStore.Store(key, currentLimit)
+
+	return currentLimit.Count <= limit, nil
+}

+ 89 - 2
service/quota.go

@@ -3,8 +3,10 @@ package service
 import (
 	"errors"
 	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
 	"math"
 	"one-api/common"
+	constant2 "one-api/constant"
 	"one-api/dto"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
@@ -99,7 +101,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
 		return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
 	}
 
-	err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
+	err = PostConsumeQuota(relayInfo, quota, 0, false)
 	if err != nil {
 		return err
 	}
@@ -222,7 +224,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	} else {
 		quotaDelta := quota - preConsumedQuota
 		if quotaDelta != 0 {
-			err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
+			err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
 			if err != nil {
 				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 			}
@@ -239,3 +241,88 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
 		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
 }
+
+func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
+	if quota < 0 {
+		return errors.New("quota 不能为负数!")
+	}
+	if relayInfo.IsPlayground {
+		return nil
+	}
+	//if relayInfo.TokenUnlimited {
+	//	return nil
+	//}
+	token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
+	if err != nil {
+		return err
+	}
+	if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
+		return errors.New("令牌额度不足")
+	}
+	err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
+
+	if quota > 0 {
+		err = model.DecreaseUserQuota(relayInfo.UserId, quota)
+	} else {
+		err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
+	}
+	if err != nil {
+		return err
+	}
+
+	if !relayInfo.IsPlayground {
+		if quota > 0 {
+			err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
+		} else {
+			err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
+		}
+		if err != nil {
+			return err
+		}
+	}
+
+	if sendEmail {
+		if (quota + preConsumedQuota) != 0 {
+			checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
+		}
+	}
+
+	return nil
+}
+
+func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
+	gopool.Go(func() {
+		userCache, err := model.GetUserCache(userId)
+		if err != nil {
+			common.SysError("failed to get user cache: " + err.Error())
+		}
+		userSetting := userCache.GetSetting()
+		threshold := common.QuotaRemindThreshold
+		if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
+			threshold = int(userCustomThreshold.(float64))
+		}
+
+		//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
+		quotaTooLow := false
+		consumeQuota := quota + preConsumedQuota
+		if userCache.Quota-consumeQuota < threshold {
+			quotaTooLow = true
+		}
+		if quotaTooLow {
+			prompt := "您的额度即将用尽"
+			topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
+			content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
+			err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
+			if err != nil {
+				common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
+			}
+		}
+	})
+}

+ 65 - 5
service/user_notify.go

@@ -3,15 +3,75 @@ package service
 import (
 	"fmt"
 	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
 	"one-api/model"
+	"strings"
 )
 
-func notifyRootUser(subject string, content string) {
-	if common.RootUserEmail == "" {
-		common.RootUserEmail = model.GetRootUserEmail()
+func NotifyRootUser(t string, subject string, content string) {
+	user := model.GetRootUser().ToBaseUser()
+	_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
+}
+
+func NotifyUser(user *model.UserBase, data dto.Notify) error {
+	userSetting := user.GetSetting()
+	notifyType, ok := userSetting[constant.UserSettingNotifyType]
+	if !ok {
+		notifyType = constant.NotifyTypeEmail
 	}
-	err := common.SendEmail(subject, common.RootUserEmail, content)
+
+	// Check notification limit
+	canSend, err := CheckNotificationLimit(user.Id, data.Type)
 	if err != nil {
-		common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+		common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
+		return err
+	}
+	if !canSend {
+		return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
+	}
+
+	switch notifyType {
+	case constant.NotifyTypeEmail:
+		userEmail := user.Email
+		// check setting email
+		if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
+			userEmail = settingEmail.(string)
+		}
+		if userEmail == "" {
+			common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
+			return nil
+		}
+		return sendEmailNotify(userEmail, data)
+	case constant.NotifyTypeWebhook:
+		webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
+		if !ok {
+			common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
+			return nil
+		}
+		webhookURLStr, ok := webhookURL.(string)
+		if !ok {
+			common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
+			return nil
+		}
+
+		// 获取 webhook secret
+		var webhookSecret string
+		if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
+			webhookSecret, _ = secret.(string)
+		}
+
+		return SendWebhookNotify(webhookURLStr, webhookSecret, data)
+	}
+	return nil
+}
+
+func sendEmailNotify(userEmail string, data dto.Notify) error {
+	// make email content
+	content := data.Content
+	// 处理占位符
+	for _, value := range data.Values {
+		content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
 	}
+	return common.SendEmail(data.Title, userEmail, content)
 }

+ 118 - 0
service/webhook.go

@@ -0,0 +1,118 @@
+package service
+
+import (
+	"bytes"
+	"crypto/hmac"
+	"crypto/sha256"
+	"encoding/hex"
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"one-api/dto"
+	"one-api/setting"
+	"time"
+)
+
+// WebhookPayload webhook 通知的负载数据
+type WebhookPayload struct {
+	Type      string        `json:"type"`
+	Title     string        `json:"title"`
+	Content   string        `json:"content"`
+	Values    []interface{} `json:"values,omitempty"`
+	Timestamp int64         `json:"timestamp"`
+}
+
+// generateSignature 生成 webhook 签名
+func generateSignature(secret string, payload []byte) string {
+	h := hmac.New(sha256.New, []byte(secret))
+	h.Write(payload)
+	return hex.EncodeToString(h.Sum(nil))
+}
+
+// SendWebhookNotify 发送 webhook 通知
+func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error {
+	// 处理占位符
+	content := data.Content
+	for _, value := range data.Values {
+		content = fmt.Sprintf(content, value)
+	}
+
+	// 构建 webhook 负载
+	payload := WebhookPayload{
+		Type:      data.Type,
+		Title:     data.Title,
+		Content:   content,
+		Values:    data.Values,
+		Timestamp: time.Now().Unix(),
+	}
+
+	// 序列化负载
+	payloadBytes, err := json.Marshal(payload)
+	if err != nil {
+		return fmt.Errorf("failed to marshal webhook payload: %v", err)
+	}
+
+	// 创建 HTTP 请求
+	var req *http.Request
+	var resp *http.Response
+
+	if setting.EnableWorker() {
+		// 构建worker请求数据
+		workerReq := &WorkerRequest{
+			URL:    webhookURL,
+			Key:    setting.WorkerValidKey,
+			Method: http.MethodPost,
+			Headers: map[string]string{
+				"Content-Type": "application/json",
+			},
+			Body: payloadBytes,
+		}
+
+		// 如果有secret,添加签名到headers
+		if secret != "" {
+			signature := generateSignature(secret, payloadBytes)
+			workerReq.Headers["X-Webhook-Signature"] = signature
+			workerReq.Headers["Authorization"] = "Bearer " + secret
+		}
+
+		resp, err = DoWorkerRequest(workerReq)
+		if err != nil {
+			return fmt.Errorf("failed to send webhook request through worker: %v", err)
+		}
+		defer resp.Body.Close()
+
+		// 检查响应状态
+		if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+			return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
+		}
+	} else {
+		req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
+		if err != nil {
+			return fmt.Errorf("failed to create webhook request: %v", err)
+		}
+
+		// 设置请求头
+		req.Header.Set("Content-Type", "application/json")
+
+		// 如果有 secret,生成签名
+		if secret != "" {
+			signature := generateSignature(secret, payloadBytes)
+			req.Header.Set("X-Webhook-Signature", signature)
+		}
+
+		// 发送请求
+		client := GetImpatientHttpClient()
+		resp, err = client.Do(req)
+		if err != nil {
+			return fmt.Errorf("failed to send webhook request: %v", err)
+		}
+		defer resp.Body.Close()
+
+		// 检查响应状态
+		if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+			return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
+		}
+	}
+
+	return nil
+}

+ 0 - 0
setting/system-setting.go → setting/system_setting.go


+ 177 - 25
web/src/components/PersonalSetting.js

@@ -26,6 +26,10 @@ import {
     Tag,
     Typography,
     Collapsible,
+    Select,
+    Radio,
+    RadioGroup,
+    AutoComplete,
 } from '@douyinfe/semi-ui';
 import {
     getQuotaPerUnit,
@@ -67,14 +71,16 @@ const PersonalSetting = () => {
     const [transferAmount, setTransferAmount] = useState(0);
     const [isModelsExpanded, setIsModelsExpanded] = useState(false);
     const MODELS_DISPLAY_COUNT = 10;  // 默认显示的模型数量
+    const [notificationSettings, setNotificationSettings] = useState({
+        warningType: 'email',
+        warningThreshold: 100000,
+        webhookUrl: '',
+        webhookSecret: '',
+        notificationEmail: ''
+    });
+    const [showWebhookDocs, setShowWebhookDocs] = useState(false);
 
     useEffect(() => {
-        // let user = localStorage.getItem('user');
-        // if (user) {
-        //   userDispatch({ type: 'login', payload: user });
-        // }
-        // console.log(localStorage.getItem('user'))
-
         let status = localStorage.getItem('status');
         if (status) {
             status = JSON.parse(status);
@@ -105,6 +111,19 @@ const PersonalSetting = () => {
         return () => clearInterval(countdownInterval); // Clean up on unmount
     }, [disableButton, countdown]);
 
+    useEffect(() => {
+        if (userState?.user?.setting) {
+            const settings = JSON.parse(userState.user.setting);
+            setNotificationSettings({
+                warningType: settings.notify_type || 'email',
+                warningThreshold: settings.quota_warning_threshold || 500000,
+                webhookUrl: settings.webhook_url || '',
+                webhookSecret: settings.webhook_secret || '',
+                notificationEmail: settings.notification_email || ''
+            });
+        }
+    }, [userState?.user?.setting]);
+
     const handleInputChange = (name, value) => {
         setInputs((inputs) => ({...inputs, [name]: value}));
     };
@@ -300,7 +319,36 @@ const PersonalSetting = () => {
         }
     };
 
+    const handleNotificationSettingChange = (type, value) => {
+        setNotificationSettings(prev => ({
+            ...prev,
+            [type]: value.target ? value.target.value : value  // 处理 Radio 事件对象
+        }));
+    };
+
+    const saveNotificationSettings = async () => {
+        try {
+            const res = await API.put('/api/user/setting', {
+                notify_type: notificationSettings.warningType,
+                quota_warning_threshold: notificationSettings.warningThreshold,
+                webhook_url: notificationSettings.webhookUrl,
+                webhook_secret: notificationSettings.webhookSecret,
+                notification_email: notificationSettings.notificationEmail
+            });
+            
+            if (res.data.success) {
+                showSuccess(t('通知设置已更新'));
+                await getUserData();
+            } else {
+                showError(res.data.message);
+            }
+        } catch (error) {
+            showError(t('更新通知设置失败'));
+        }
+    };
+
     return (
+
         <div>
             <Layout>
                 <Layout.Content>
@@ -526,9 +574,7 @@ const PersonalSetting = () => {
                             </div>
                             <div style={{marginTop: 10}}>
                                 <Typography.Text strong>{t('微信')}</Typography.Text>
-                                <div
-                                    style={{display: 'flex', justifyContent: 'space-between'}}
-                                >
+                                <div style={{display: 'flex', justifyContent: 'space-between'}}>
                                     <div>
                                         <Input
                                             value={
@@ -541,12 +587,16 @@ const PersonalSetting = () => {
                                     </div>
                                     <div>
                                         <Button
-                                            disabled={
-                                                (userState.user && userState.user.wechat_id !== '') ||
-                                                !status.wechat_login
-                                            }
+                                            disabled={!status.wechat_login}
+                                            onClick={() => {
+                                                setShowWeChatBindModal(true);
+                                            }}
                                         >
-                                            {status.wechat_login ? t('绑定') : t('未启用')}
+                                            {userState.user && userState.user.wechat_id !== ''
+                                                ? t('修改绑定')
+                                                : status.wechat_login 
+                                                    ? t('绑定') 
+                                                    : t('未启用')}
                                         </Button>
                                     </div>
                                 </div>
@@ -672,18 +722,8 @@ const PersonalSetting = () => {
                                         style={{marginTop: '10px'}}
                                     />
                                 )}
-                                {status.wechat_login && (
-                                    <Button
-                                        onClick={() => {
-                                            setShowWeChatBindModal(true);
-                                        }}
-                                    >
-                                        {t('绑定微信账号')}
-                                    </Button>
-                                )}
                                 <Modal
                                     onCancel={() => setShowWeChatBindModal(false)}
-                                    // onOpen={() => setShowWeChatBindModal(true)}
                                     visible={showWeChatBindModal}
                                     size={'small'}
                                 >
@@ -707,9 +747,121 @@ const PersonalSetting = () => {
                                 </Modal>
                             </div>
                         </Card>
+                        <Card style={{marginTop: 10}}>
+                            <Typography.Title heading={6}>{t('通知设置')}</Typography.Title>
+                            <div style={{marginTop: 20}}>
+                                <Typography.Text strong>{t('通知方式')}</Typography.Text>
+                                <div style={{marginTop: 10}}>
+                                    <RadioGroup
+                                        value={notificationSettings.warningType}
+                                        onChange={value => handleNotificationSettingChange('warningType', value)}
+                                    >
+                                        <Radio value="email">{t('邮件通知')}</Radio>
+                                        <Radio value="webhook">{t('Webhook通知')}</Radio>
+                                    </RadioGroup>
+                                </div>
+                            </div>
+                            {notificationSettings.warningType === 'webhook' && (
+                                <>
+                                    <div style={{marginTop: 20}}>
+                                        <Typography.Text strong>{t('Webhook地址')}</Typography.Text>
+                                        <div style={{marginTop: 10}}>
+                                            <Input
+                                                value={notificationSettings.webhookUrl}
+                                                onChange={val => handleNotificationSettingChange('webhookUrl', val)}
+                                                placeholder={t('请输入Webhook地址,例如: https://example.com/webhook')}
+                                            />
+                                            <Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
+                                                {t('只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')}
+                                            </Typography.Text>
+                                            <Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
+                                                <div style={{cursor: 'pointer'}} onClick={() => setShowWebhookDocs(!showWebhookDocs)}>
+                                                    {t('Webhook请求结构')} {showWebhookDocs ? '▼' : '▶'}
+                                                </div>
+                                                <Collapsible isOpen={showWebhookDocs}>
+                                                    <pre style={{marginTop: 4, background: 'var(--semi-color-fill-0)', padding: 8, borderRadius: 4}}>
+{`{
+    "type": "quota_exceed",      // 通知类型
+    "title": "标题",             // 通知标题
+    "content": "通知内容",       // 通知内容,支持 {{value}} 变量占位符
+    "values": ["值1", "值2"],    // 按顺序替换content中的 {{value}} 占位符
+    "timestamp": 1739950503      // 时间戳
+}
+
+示例:
+{
+    "type": "quota_exceed",
+    "title": "额度预警通知",
+    "content": "您的额度即将用尽,当前剩余额度为 {{value}}",
+    "values": ["$0.99"],
+    "timestamp": 1739950503
+}`}
+                                                    </pre>
+                                                </Collapsible>
+                                            </Typography.Text>
+                                        </div>
+                                    </div>
+                                    <div style={{marginTop: 20}}>
+                                        <Typography.Text strong>{t('接口凭证(可选)')}</Typography.Text>
+                                        <div style={{marginTop: 10}}>
+                                            <Input
+                                                value={notificationSettings.webhookSecret}
+                                                onChange={val => handleNotificationSettingChange('webhookSecret', val)}
+                                                placeholder={t('请输入密钥')}
+                                            />
+                                            <Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
+                                                {t('密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性')}
+                                            </Typography.Text>
+                                            <Typography.Text type="secondary" style={{marginTop: 4, display: 'block'}}>
+                                                {t('Authorization: Bearer your-secret-key')}
+                                            </Typography.Text>
+                                        </div>
+                                    </div>
+                                </>
+                            )}
+                            {notificationSettings.warningType === 'email' && (
+                                <div style={{marginTop: 20}}>
+                                    <Typography.Text strong>{t('通知邮箱')}</Typography.Text>
+                                    <div style={{marginTop: 10}}>
+                                        <Input
+                                            value={notificationSettings.notificationEmail}
+                                            onChange={val => handleNotificationSettingChange('notificationEmail', val)}
+                                            placeholder={t('留空则使用账号绑定的邮箱')}
+                                        />
+                                        <Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
+                                            {t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')}
+                                        </Typography.Text>
+                                    </div>
+                                </div>
+                            )}
+                            <div style={{marginTop: 20}}>
+                                <Typography.Text strong>{t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)}</Typography.Text>
+                                <div style={{marginTop: 10}}>
+                                    <AutoComplete
+                                        value={notificationSettings.warningThreshold}
+                                        onChange={val => handleNotificationSettingChange('warningThreshold', val)}
+                                        style={{width: 200}}
+                                        placeholder={t('请输入预警额度')}
+                                        data={[
+                                            { value: 100000, label: '0.2$' },
+                                            { value: 500000, label: '1$' },
+                                            { value: 1000000, label: '5$' },
+                                            { value: 5000000, label: '10$' }
+                                        ]}
+                                    />
+                                </div>
+                                <Typography.Text type="secondary" style={{marginTop: 10, display: 'block'}}>
+                                    {t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')}
+                                </Typography.Text>
+                            </div>
+                            <div style={{marginTop: 20}}>
+                                <Button type="primary" onClick={saveNotificationSettings}>
+                                    {t('保存设置')}
+                                </Button>
+                            </div>
+                        </Card>
                         <Modal
                             onCancel={() => setShowEmailBindModal(false)}
-                            // onOpen={() => setShowEmailBindModal(true)}
                             onOk={bindEmail}
                             visible={showEmailBindModal}
                             size={'small'}

+ 11 - 0
web/src/components/SystemSetting.js

@@ -368,6 +368,17 @@ const SystemSetting = () => {
             </a>
           </Header>
+          <Message info>
+            注意:代理功能仅对图片请求和 Webhook 请求生效,不会影响其他 API 请求。如需配置 API 请求代理,请参考
+            <a 
+              href='https://github.com/Calcium-Ion/new-api/blob/main/docs/channel/other_setting.md'
+              target='_blank'
+              rel='noreferrer'
+            >
+              {' '}API 代理设置文档
+            </a>
+            。
+          </Message>
           <Form.Group widths='equal'>
             <Form.Input
               label='Worker地址,不填写则不启用代理'

+ 1 - 1
web/src/helpers/render.js

@@ -386,7 +386,7 @@ export function renderQuotaWithPrompt(quota, digits) {
   let displayInCurrency = localStorage.getItem('display_in_currency');
   displayInCurrency = displayInCurrency === 'true';
   if (displayInCurrency) {
-    return '|' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
+    return ' | ' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
   }
   return '';
 }