Browse Source

Merge branch 'alpha' into refactor_error

# Conflicts:
#	controller/channel.go
#	middleware/distributor.go
#	model/channel.go
#	model/user.go
#	model/user_cache.go
#	relay/common/relay_info.go
CaIon 5 months ago
parent
commit
cb16bf552e

+ 10 - 0
common/gin.go

@@ -76,3 +76,13 @@ func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]
 func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
 	return c.GetTime(string(key))
 }
+
+func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
+	if value, ok := c.Get(string(key)); ok {
+		if v, ok := value.(T); ok {
+			return v, true
+		}
+	}
+	var t T
+	return t, false
+}

+ 13 - 0
common/str.go

@@ -1,6 +1,7 @@
 package common
 
 import (
+	"encoding/base64"
 	"encoding/json"
 	"math/rand"
 	"strconv"
@@ -82,3 +83,15 @@ func StringToByteSlice(s string) []byte {
 	tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
 	return *(*[]byte)(unsafe.Pointer(&tmp2))
 }
+
+func EncodeBase64(str string) string {
+	return base64.StdEncoding.EncodeToString([]byte(str))
+}
+
+func GetJsonString(data any) string {
+	if data == nil {
+		return ""
+	}
+	b, _ := json.Marshal(data)
+	return string(b)
+}

+ 0 - 7
constant/channel_setting.go

@@ -1,7 +0,0 @@
-package constant
-
-var (
-	ForceFormat                     = "force_format"        // ForceFormat 强制格式化为OpenAI格式
-	ChanelSettingProxy              = "proxy"               // Proxy 代理
-	ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
-)

+ 0 - 16
constant/user_setting.go

@@ -1,16 +0,0 @@
-package constant
-
-var (
-	UserSettingNotifyType            = "notify_type"                    // QuotaWarningType 额度预警类型
-	UserSettingQuotaWarningThreshold = "quota_warning_threshold"        // QuotaWarningThreshold 额度预警阈值
-	UserSettingWebhookUrl            = "webhook_url"                    // WebhookUrl webhook地址
-	UserSettingWebhookSecret         = "webhook_secret"                 // WebhookSecret webhook密钥
-	UserSettingNotificationEmail     = "notification_email"             // NotificationEmail 通知邮箱地址
-	UserAcceptUnsetRatioModel        = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
-	UserSettingRecordIpLog          = "record_ip_log"                   // 是否记录请求和错误日志IP
-)
-
-var (
-	NotifyTypeEmail   = "email"   // Email 邮件
-	NotifyTypeWebhook = "webhook" // Webhook
-)

+ 17 - 2
controller/channel-test.go

@@ -174,8 +174,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
 	consumedTime := float64(milliseconds) / 1000.0
 	other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
 		usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
-	model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
-		quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other)
+	model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
+		ChannelId:        channel.Id,
+		PromptTokens:     usage.PromptTokens,
+		CompletionTokens: usage.CompletionTokens,
+		ModelName:        info.OriginModelName,
+		TokenName:        "模型测试",
+		Quota:            quota,
+		Content:          "模型测试",
+		UseTimeSeconds:   int(consumedTime),
+		IsStream:         false,
+		Group:            info.UsingGroup,
+		Other:            other,
+	})
 	common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
 	return nil, nil
 }
@@ -342,6 +353,10 @@ func TestAllChannels(c *gin.Context) {
 }
 
 func AutomaticallyTestChannels(frequency int) {
+	if frequency <= 0 {
+		common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
+		return
+	}
 	for {
 		time.Sleep(time.Duration(frequency) * time.Minute)
 		common.SysLog("testing all channels")

+ 23 - 2
controller/channel.go

@@ -228,7 +228,7 @@ func FetchUpstreamModels(c *gin.Context) {
 }
 
 func FixChannelsAbilities(c *gin.Context) {
-	count, err := model.FixAbility()
+	success, fails, err := model.FixAbility()
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -239,7 +239,10 @@ func FixChannelsAbilities(c *gin.Context) {
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"message": "",
-		"data":    count,
+		"data": gin.H{
+			"success": success,
+			"fails":   fails,
+		},
 	})
 }
 
@@ -425,6 +428,16 @@ func AddChannel(c *gin.Context) {
 		})
 		return
 	}
+
+	err = addChannelRequest.Channel.ValidateSettings()
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "channel setting 格式错误:" + err.Error(),
+		})
+		return
+	}
+
 	if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -713,6 +726,14 @@ func UpdateChannel(c *gin.Context) {
 		})
 		return
 	}
+	err = channel.ValidateSettings()
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "channel setting 格式错误:" + err.Error(),
+		})
+		return
+	}
 	if channel.Type == constant.ChannelTypeVertexAi {
 		if channel.Other == "" {
 			c.JSON(http.StatusOK, gin.H{

+ 14 - 13
controller/user.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"net/url"
 	"one-api/common"
+	"one-api/dto"
 	"one-api/model"
 	"one-api/setting"
 	"strconv"
@@ -961,7 +962,7 @@ func UpdateUserSetting(c *gin.Context) {
 	}
 
 	// 验证预警类型
-	if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
+	if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": "无效的预警类型",
@@ -979,7 +980,7 @@ func UpdateUserSetting(c *gin.Context) {
 	}
 
 	// 如果是webhook类型,验证webhook地址
-	if req.QuotaWarningType == constant.NotifyTypeWebhook {
+	if req.QuotaWarningType == dto.NotifyTypeWebhook {
 		if req.WebhookUrl == "" {
 			c.JSON(http.StatusOK, gin.H{
 				"success": false,
@@ -998,7 +999,7 @@ func UpdateUserSetting(c *gin.Context) {
 	}
 
 	// 如果是邮件类型,验证邮箱地址
-	if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
+	if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
 		// 验证邮箱格式
 		if !strings.Contains(req.NotificationEmail, "@") {
 			c.JSON(http.StatusOK, gin.H{
@@ -1020,24 +1021,24 @@ func UpdateUserSetting(c *gin.Context) {
 	}
 
 	// 构建设置
-	settings := map[string]interface{}{
-		constant.UserSettingNotifyType:            req.QuotaWarningType,
-		constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
-		"accept_unset_model_ratio_model":          req.AcceptUnsetModelRatioModel,
-		constant.UserSettingRecordIpLog:           req.RecordIpLog,
+	settings := dto.UserSetting{
+		NotifyType:            req.QuotaWarningType,
+		QuotaWarningThreshold: req.QuotaWarningThreshold,
+		AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
+		RecordIpLog:           req.RecordIpLog,
 	}
 
 	// 如果是webhook类型,添加webhook相关设置
-	if req.QuotaWarningType == constant.NotifyTypeWebhook {
-		settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
+	if req.QuotaWarningType == dto.NotifyTypeWebhook {
+		settings.WebhookUrl = req.WebhookUrl
 		if req.WebhookSecret != "" {
-			settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
+			settings.WebhookSecret = req.WebhookSecret
 		}
 	}
 
 	// 如果提供了通知邮箱,添加到设置中
-	if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
-		settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
+	if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
+		settings.NotificationEmail = req.NotificationEmail
 	}
 
 	// 更新用户设置

+ 7 - 0
dto/channel_settings.go

@@ -0,0 +1,7 @@
+package dto
+
+type ChannelSettings struct {
+	ForceFormat       bool   `json:"force_format,omitempty"`
+	ThinkingToContent bool   `json:"thinking_to_content,omitempty"`
+	Proxy             string `json:"proxy"`
+}

+ 16 - 0
dto/user_settings.go

@@ -0,0 +1,16 @@
+package dto
+
+type UserSetting struct {
+	NotifyType            string  `json:"notify_type,omitempty"`                    // QuotaWarningType 额度预警类型
+	QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"`        // QuotaWarningThreshold 额度预警阈值
+	WebhookUrl            string  `json:"webhook_url,omitempty"`                    // WebhookUrl webhook地址
+	WebhookSecret         string  `json:"webhook_secret,omitempty"`                 // WebhookSecret webhook密钥
+	NotificationEmail     string  `json:"notification_email,omitempty"`             // NotificationEmail 通知邮箱地址
+	AcceptUnsetRatioModel bool    `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
+	RecordIpLog           bool    `json:"record_ip_log,omitempty"`                  // 是否记录请求和错误日志IP
+}
+
+var (
+	NotifyTypeEmail   = "email"   // Email 邮件
+	NotifyTypeWebhook = "webhook" // Webhook
+)

+ 4 - 3
main.go

@@ -39,7 +39,6 @@ func main() {
 		return
 	}
 
-	common.SetupLogger()
 	common.SysLog("New API " + common.Version + " started")
 	if os.Getenv("GIN_MODE") != "debug" {
 		gin.SetMode(gin.ReleaseMode)
@@ -69,9 +68,9 @@ func main() {
 				if r := recover(); r != nil {
 					common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
 					// Retry once
-					_, fixErr := model.FixAbility()
+					_, _, fixErr := model.FixAbility()
 					if fixErr != nil {
-						common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
+						common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
 					}
 				}
 			}()
@@ -169,6 +168,8 @@ func InitResources() error {
 		common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
 	}
 
+	common.SetupLogger()
+
 	// 加载环境变量
 	common.InitEnv()
 

+ 33 - 61
model/ability.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"one-api/common"
 	"strings"
+	"sync"
 
 	"github.com/samber/lo"
 	"gorm.io/gorm"
@@ -272,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
 	return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
 }
 
-func FixAbility() (int, error) {
-	var channelIds []int
-	count := 0
-	// Find all channel ids from channel table
-	err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
-	if err != nil {
-		common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
-		return 0, err
-	}
+var fixLock = sync.Mutex{}
 
-	// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
-	if len(channelIds) > 0 {
-		// Process deletion in chunks to avoid "too many placeholders" error
-		for _, chunk := range lo.Chunk(channelIds, 100) {
-			err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
-			if err != nil {
-				common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
-				return 0, err
-			}
-		}
-	} else {
-		// If no channels exist, delete all abilities
-		err = DB.Delete(&Ability{}).Error
-		if err != nil {
-			common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
-			return 0, err
-		}
-		common.SysLog("Delete all abilities successfully")
-		return 0, nil
+func FixAbility() (int, int, error) {
+	lock := fixLock.TryLock()
+	if !lock {
+		return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
 	}
-
-	common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
-	count += len(channelIds)
-
-	// Use channelIds to find channel not in abilities table
-	var abilityChannelIds []int
-	err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
+	defer fixLock.Unlock()
+	var channels []*Channel
+	// Find all channels
+	err := DB.Model(&Channel{}).Find(&channels).Error
 	if err != nil {
-		common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
-		return count, err
+		return 0, 0, err
 	}
-
-	var channels []Channel
-	if len(abilityChannelIds) == 0 {
-		err = DB.Find(&channels).Error
-	} else {
-		// Process query in chunks to avoid "too many placeholders" error
-		err = nil
-		for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
-			var channelsChunk []Channel
-			err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
-			if err != nil {
-				common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
-				return count, err
-			}
-			channels = append(channels, channelsChunk...)
-		}
+	if len(channels) == 0 {
+		return 0, 0, nil
 	}
-
-	for _, channel := range channels {
-		err := channel.UpdateAbilities(nil)
+	successCount := 0
+	failCount := 0
+	for _, chunk := range lo.Chunk(channels, 50) {
+		ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
+		// Delete all abilities of this channel
+		err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
 		if err != nil {
-			common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
-		} else {
-			common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
-			count++
+			common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
+			failCount += len(chunk)
+			continue
+		}
+		// Then add new abilities
+		for _, channel := range chunk {
+			err = channel.AddAbilities()
+			if err != nil {
+				common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
+				failCount++
+			} else {
+				successCount++
+			}
 		}
 	}
 	InitChannelCache()
-	return count, nil
+	return successCount, failCount, nil
 }

+ 15 - 3
model/channel.go

@@ -7,6 +7,7 @@ import (
 	"math/rand"
 	"one-api/common"
 	"one-api/constant"
+	"one-api/dto"
 	"strings"
 	"sync"
 
@@ -610,8 +611,19 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
 	return tags, nil
 }
 
-func (channel *Channel) GetSetting() map[string]interface{} {
-	setting := make(map[string]interface{})
+func (channel *Channel) ValidateSettings() error {
+	channelParams := &dto.ChannelSettings{}
+	if channel.Setting != nil && *channel.Setting != "" {
+		err := json.Unmarshal([]byte(*channel.Setting), channelParams)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (channel *Channel) GetSetting() dto.ChannelSettings {
+	setting := dto.ChannelSettings{}
 	if channel.Setting != nil && *channel.Setting != "" {
 		err := json.Unmarshal([]byte(*channel.Setting), &setting)
 		if err != nil {
@@ -621,7 +633,7 @@ func (channel *Channel) GetSetting() map[string]interface{} {
 	return setting
 }
 
-func (channel *Channel) SetSetting(setting map[string]interface{}) {
+func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
 	settingBytes, err := json.Marshal(setting)
 	if err != nil {
 		common.SysError("failed to marshal setting: " + err.Error())

+ 35 - 26
model/log.go

@@ -4,7 +4,6 @@ import (
 	"context"
 	"fmt"
 	"one-api/common"
-	"one-api/constant"
 	"os"
 	"strings"
 	"time"
@@ -100,10 +99,8 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
 	// 判断是否需要记录 IP
 	needRecordIp := false
 	if settingMap, err := GetUserSetting(userId, false); err == nil {
-		if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
-			if vb, ok := v.(bool); ok && vb {
-				needRecordIp = true
-			}
+		if settingMap.RecordIpLog {
+			needRecordIp = true
 		}
 	}
 	log := &Log{
@@ -136,22 +133,34 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
 	}
 }
 
-func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
-	modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
-	isStream bool, group string, other map[string]interface{}) {
-	common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
+type RecordConsumeLogParams struct {
+	ChannelId        int                    `json:"channel_id"`
+	PromptTokens     int                    `json:"prompt_tokens"`
+	CompletionTokens int                    `json:"completion_tokens"`
+	ModelName        string                 `json:"model_name"`
+	TokenName        string                 `json:"token_name"`
+	Quota            int                    `json:"quota"`
+	Content          string                 `json:"content"`
+	TokenId          int                    `json:"token_id"`
+	UserQuota        int                    `json:"user_quota"`
+	UseTimeSeconds   int                    `json:"use_time_seconds"`
+	IsStream         bool                   `json:"is_stream"`
+	Group            string                 `json:"group"`
+	Other            map[string]interface{} `json:"other"`
+}
+
+func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
+	common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
 	if !common.LogConsumeEnabled {
 		return
 	}
 	username := c.GetString("username")
-	otherStr := common.MapToJsonStr(other)
+	otherStr := common.MapToJsonStr(params.Other)
 	// 判断是否需要记录 IP
 	needRecordIp := false
 	if settingMap, err := GetUserSetting(userId, false); err == nil {
-		if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
-			if vb, ok := v.(bool); ok && vb {
-				needRecordIp = true
-			}
+		if settingMap.RecordIpLog {
+			needRecordIp = true
 		}
 	}
 	log := &Log{
@@ -159,17 +168,17 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
 		Username:         username,
 		CreatedAt:        common.GetTimestamp(),
 		Type:             LogTypeConsume,
-		Content:          content,
-		PromptTokens:     promptTokens,
-		CompletionTokens: completionTokens,
-		TokenName:        tokenName,
-		ModelName:        modelName,
-		Quota:            quota,
-		ChannelId:        channelId,
-		TokenId:          tokenId,
-		UseTime:          useTimeSeconds,
-		IsStream:         isStream,
-		Group:            group,
+		Content:          params.Content,
+		PromptTokens:     params.PromptTokens,
+		CompletionTokens: params.CompletionTokens,
+		TokenName:        params.TokenName,
+		ModelName:        params.ModelName,
+		Quota:            params.Quota,
+		ChannelId:        params.ChannelId,
+		TokenId:          params.TokenId,
+		UseTime:          params.UseTimeSeconds,
+		IsStream:         params.IsStream,
+		Group:            params.Group,
 		Ip: func() string {
 			if needRecordIp {
 				return c.ClientIP()
@@ -184,7 +193,7 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
 	}
 	if common.DataExportEnabled {
 		gopool.Go(func() {
-			LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
+			LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
 		})
 	}
 }

+ 15 - 18
model/user.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"one-api/common"
+	"one-api/dto"
 	"strconv"
 	"strings"
 
@@ -68,19 +69,18 @@ func (user *User) SetAccessToken(token string) {
 	user.AccessToken = &token
 }
 
-func (user *User) GetSetting() (map[string]interface{}, error) {
-	if user.Setting == "" {
-		return map[string]interface{}{}, nil
-	}
-	toMap, err := common.StrToMap(user.Setting)
-	if err != nil {
-		common.SysError("failed to convert setting to map: " + err.Error())
-		return nil, fmt.Errorf("failed to convert setting to map")
+func (user *User) GetSetting() dto.UserSetting {
+	setting := dto.UserSetting{}
+	if user.Setting != "" {
+		err := json.Unmarshal([]byte(user.Setting), &setting)
+		if err != nil {
+			common.SysError("failed to unmarshal setting: " + err.Error())
+		}
 	}
-	return toMap, nil
+	return setting
 }
 
-func (user *User) SetSetting(setting map[string]interface{}) {
+func (user *User) SetSetting(setting dto.UserSetting) {
 	settingBytes, err := json.Marshal(setting)
 	if err != nil {
 		common.SysError("failed to marshal setting: " + err.Error())
@@ -631,7 +631,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
 }
 
 // GetUserSetting gets setting from Redis first, falls back to DB if needed
-func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
+func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
 	var setting string
 	defer func() {
 		// Update Redis cache asynchronously on successful DB read
@@ -653,15 +653,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
 	fromDB = true
 	err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
 	if err != nil {
-		return map[string]interface{}{}, err
+		return settingMap, err
 	}
-
-	toMap, err := common.StrToMap(setting)
-	if err != nil {
-		common.SysError("failed to convert setting to map: " + err.Error())
-		return nil, fmt.Errorf("failed to convert setting to map")
+	userBase := &UserBase{
+		Setting: setting,
 	}
-	return toMap, nil
+	return userBase.GetSetting(), nil
 }
 
 func IncreaseUserQuota(id int, quota int, db bool) (err error) {

+ 11 - 21
model/user_cache.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"one-api/common"
 	"one-api/constant"
+	"one-api/dto"
 	"time"
 
 	"github.com/gin-gonic/gin"
@@ -32,25 +33,15 @@ func (user *UserBase) WriteContext(c *gin.Context) {
 	common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
 }
 
-func (user *UserBase) GetSetting() map[string]interface{} {
-	if user.Setting == "" {
-		return nil
-	}
-	toMap, err := common.StrToMap(user.Setting)
-	if err != nil {
-		common.SysError("failed to convert user setting to map: " + err.Error())
-		return nil
-	}
-	return toMap
-}
-
-func (user *UserBase) SetSetting(setting map[string]interface{}) {
-	settingBytes, err := json.Marshal(setting)
-	if err != nil {
-		common.SysError("failed to marshal setting: " + err.Error())
-		return
+func (user *UserBase) GetSetting() dto.UserSetting {
+	setting := dto.UserSetting{}
+	if user.Setting != "" {
+		err := common.Unmarshal([]byte(user.Setting), &setting)
+		if err != nil {
+			common.SysError("failed to unmarshal setting: " + err.Error())
+		}
 	}
-	user.Setting = string(settingBytes)
+	return setting
 }
 
 // getUserCacheKey returns the key for user cache
@@ -179,11 +170,10 @@ func getUserNameCache(userId int) (string, error) {
 	return cache.Username, nil
 }
 
-func getUserSettingCache(userId int) (map[string]interface{}, error) {
-	setting := make(map[string]interface{})
+func getUserSettingCache(userId int) (dto.UserSetting, error) {
 	cache, err := GetUserCache(userId)
 	if err != nil {
-		return setting, err
+		return dto.UserSetting{}, err
 	}
 	return cache.GetSetting(), nil
 }

+ 2 - 2
relay/channel/api_request.go

@@ -206,8 +206,8 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
 func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
 	var client *http.Client
 	var err error
-	if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
-		client, err = service.NewProxyHttpClient(proxyURL.(string))
+	if info.ChannelSetting.Proxy != "" {
+		client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
 		if err != nil {
 			return nil, fmt.Errorf("new proxy http client failed: %w", err)
 		}

+ 10 - 1
relay/channel/baidu_v2/adaptor.go

@@ -43,7 +43,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Set("Authorization", "Bearer "+info.ApiKey)
+        keyParts := strings.Split(info.ApiKey, "|")
+	if len(keyParts) == 0 || keyParts[0] == "" {
+        	return errors.New("invalid API key: authorization token is required")
+        }
+	 if len(keyParts) > 1 {
+               if keyParts[1] != "" {
+                       req.Set("appid", keyParts[1])
+               }
+        }
+	req.Set("Authorization", "Bearer "+keyParts[0])
 	return nil
 }
 

+ 2 - 2
relay/channel/coze/relay-coze.go

@@ -279,8 +279,8 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht
 func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
 	var client *http.Client
 	var err error // 声明 err 变量
-	if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
-		client, err = service.NewProxyHttpClient(proxyURL.(string))
+	if info.ChannelSetting.Proxy != "" {
+		client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
 		if err != nil {
 			return nil, fmt.Errorf("new proxy http client failed: %w", err)
 		}

+ 2 - 2
relay/channel/openai/adaptor.go

@@ -54,7 +54,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 	a.ChannelType = info.ChannelType
 
 	// initialize ThinkingContentInfo when thinking_to_content is enabled
-	if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
+	if info.ChannelSetting.ThinkingToContent {
 		info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
 			IsFirstThinkingContent:  true,
 			SendLastThinkingContent: false,
@@ -146,7 +146,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
 		header.Set("Authorization", "Bearer "+info.ApiKey)
 	}
 	if info.ChannelType == constant.ChannelTypeOpenRouter {
-		header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
+		header.Set("HTTP-Referer", "https://www.newapi.ai")
 		header.Set("X-Title", "New API")
 	}
 	return nil

+ 6 - 6
relay/channel/openai/relay-openai.go

@@ -126,12 +126,12 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 	var forceFormat bool
 	var thinkToContent bool
 
-	if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
-		forceFormat = forceFmt
+	if info.ChannelSetting.ForceFormat {
+		forceFormat = true
 	}
 
-	if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
-		thinkToContent = think2Content
+	if info.ChannelSetting.ThinkingToContent {
+		thinkToContent = true
 	}
 
 	var (
@@ -199,8 +199,8 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	}
 
 	forceFormat := false
-	if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
-		forceFormat = forceFmt
+	if info.ChannelSetting.ForceFormat {
+		forceFormat = true
 	}
 
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {

+ 2 - 2
relay/channel/vertex/service_account.go

@@ -106,8 +106,8 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
 
 	var client *http.Client
 	var err error
-	if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
-		client, err = service.NewProxyHttpClient(proxyURL.(string))
+	if info.ChannelSetting.Proxy != "" {
+		client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
 		if err != nil {
 			return "", fmt.Errorf("new proxy http client failed: %w", err)
 		}

+ 18 - 10
relay/common/relay_info.go

@@ -97,9 +97,9 @@ type RelayInfo struct {
 	IsFirstRequest       bool
 	AudioUsage           bool
 	ReasoningEffort      string
-	ChannelSetting       map[string]interface{}
+	ChannelSetting       dto.ChannelSettings
 	ParamOverride        map[string]interface{}
-	UserSetting          map[string]interface{}
+	UserSetting          dto.UserSetting
 	UserEmail            string
 	UserQuota            int
 	RelayFormat          string
@@ -213,7 +213,6 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
 func GenRelayInfo(c *gin.Context) *RelayInfo {
 	channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
 	channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
-	channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
 	paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
 
 	tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
@@ -227,7 +226,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 
 	info := &RelayInfo{
 		UserQuota:         common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
-		UserSetting:       common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting),
 		UserEmail:         common.GetContextKeyString(c, constant.ContextKeyUserEmail),
 		isFirstResponse:   true,
 		RelayMode:         relayconstant.Path2RelayMode(c.Request.URL.Path),
@@ -246,12 +244,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 		OriginModelName:   common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
 		UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
 		//RecodeModelName:   c.GetString("original_model"),
-		IsModelMapped:     false,
-		ApiType:           apiType,
-		ApiVersion:        c.GetString("api_version"),
-		ApiKey:            strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
-		Organization:      c.GetString("channel_organization"),
-		ChannelSetting:    channelSetting,
+		IsModelMapped: false,
+		ApiType:       apiType,
+		ApiVersion:    c.GetString("api_version"),
+		ApiKey:        strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+		Organization:  c.GetString("channel_organization"),
+
 		ChannelCreateTime: c.GetInt64("channel_create_time"),
 		ParamOverride:     paramOverride,
 		RelayFormat:       RelayFormatOpenAI,
@@ -277,6 +275,16 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	if streamSupportedChannels[info.ChannelType] {
 		info.SupportStreamOptions = true
 	}
+
+	channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
+	if ok {
+		info.ChannelSetting = channelSetting
+	}
+	userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
+	if ok {
+		info.UserSetting = userSetting
+	}
+
 	return info
 }
 

+ 2 - 6
relay/helper/price.go

@@ -3,7 +3,6 @@ package helper
 import (
 	"fmt"
 	"one-api/common"
-	constant2 "one-api/constant"
 	relaycommon "one-api/relay/common"
 	"one-api/setting/ratio_setting"
 
@@ -83,11 +82,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 		modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
 		if !success {
 			acceptUnsetRatio := false
-			if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
-				b, ok := accept.(bool)
-				if ok {
-					acceptUnsetRatio = b
-				}
+			if info.UserSetting.AcceptUnsetRatioModel {
+				acceptUnsetRatio = true
 			}
 			if !acceptUnsetRatio {
 				return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)

+ 31 - 14
relay/relay-mj.go

@@ -34,14 +34,13 @@ func RelayMidjourneyImage(c *gin.Context) {
 	}
 	var httpClient *http.Client
 	if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
-		if proxy, ok := channel.GetSetting()["proxy"]; ok {
-			if proxyURL, ok := proxy.(string); ok && proxyURL != "" {
-				if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil {
-					c.JSON(400, gin.H{
-						"error": "proxy_url_invalid",
-					})
-					return
-				}
+		proxy := channel.GetSetting().Proxy
+		if proxy != "" {
+			if httpClient, err = service.NewProxyHttpClient(proxy); err != nil {
+				c.JSON(400, gin.H{
+					"error": "proxy_url_invalid",
+				})
+				return
 			}
 		}
 	}
@@ -175,7 +174,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 	startTime := time.Now().UnixNano() / int64(time.Millisecond)
 	tokenId := c.GetInt("token_id")
 	userId := c.GetInt("id")
-	group := c.GetString("group")
+	//group := c.GetString("group")
 	channelId := c.GetInt("channel_id")
 	relayInfo := relaycommon.GenRelayInfo(c)
 	var swapFaceRequest dto.SwapFaceRequest
@@ -221,8 +220,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 			tokenName := c.GetString("token_name")
 			logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
 			other := service.GenerateMjOtherInfo(priceData)
-			model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
-				priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
+			model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
+				ChannelId: channelId,
+				ModelName: modelName,
+				TokenName: tokenName,
+				Quota:     priceData.Quota,
+				Content:   logContent,
+				TokenId:   tokenId,
+				UserQuota: userQuota,
+				Group:     relayInfo.UsingGroup,
+				Other:     other,
+			})
 			model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
 			model.UpdateChannelUsedQuota(channelId, priceData.Quota)
 		}
@@ -363,7 +371,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
 
 func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
 
-	tokenId := c.GetInt("token_id")
+	//tokenId := c.GetInt("token_id")
 	//channelType := c.GetInt("channel")
 	userId := c.GetInt("id")
 	group := c.GetString("group")
@@ -518,8 +526,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 			tokenName := c.GetString("token_name")
 			logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
 			other := service.GenerateMjOtherInfo(priceData)
-			model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
-				priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
+			model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
+				ChannelId: channelId,
+				ModelName: modelName,
+				TokenName: tokenName,
+				Quota:     priceData.Quota,
+				Content:   logContent,
+				TokenId:   relayInfo.TokenId,
+				UserQuota: userQuota,
+				Group:     group,
+				Other:     other,
+			})
 			model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
 			model.UpdateChannelUsedQuota(channelId, priceData.Quota)
 		}

+ 15 - 2
relay/relay-text.go

@@ -537,6 +537,19 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		other["audio_input_token_count"] = audioTokens
 		other["audio_input_price"] = audioInputPrice
 	}
-	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
-		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
+	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+		ChannelId:        relayInfo.ChannelId,
+		PromptTokens:     promptTokens,
+		CompletionTokens: completionTokens,
+		ModelName:        logModel,
+		TokenName:        tokenName,
+		Quota:            quota,
+		Content:          logContent,
+		TokenId:          relayInfo.TokenId,
+		UserQuota:        userQuota,
+		UseTimeSeconds:   int(useTimeSeconds),
+		IsStream:         relayInfo.IsStream,
+		Group:            relayInfo.UsingGroup,
+		Other:            other,
+	})
 }

+ 11 - 2
relay/relay_task.go

@@ -139,8 +139,17 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
 				if hasUserGroupRatio {
 					other["user_group_ratio"] = userGroupRatio
 				}
-				model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
-					modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other)
+				model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
+					ChannelId: relayInfo.ChannelId,
+					ModelName: modelName,
+					TokenName: tokenName,
+					Quota:     quota,
+					Content:   logContent,
+					TokenId:   relayInfo.TokenId,
+					UserQuota: userQuota,
+					Group:     relayInfo.UsingGroup,
+					Other:     other,
+				})
 				model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 				model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 			}

+ 48 - 8
service/quota.go

@@ -209,8 +209,21 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	}
 	other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
 		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
-	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
-		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
+	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+		ChannelId:        relayInfo.ChannelId,
+		PromptTokens:     usage.InputTokens,
+		CompletionTokens: usage.OutputTokens,
+		ModelName:        logModel,
+		TokenName:        tokenName,
+		Quota:            quota,
+		Content:          logContent,
+		TokenId:          relayInfo.TokenId,
+		UserQuota:        userQuota,
+		UseTimeSeconds:   int(useTimeSeconds),
+		IsStream:         relayInfo.IsStream,
+		Group:            relayInfo.UsingGroup,
+		Other:            other,
+	})
 }
 
 func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
@@ -286,8 +299,22 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 
 	other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
 		cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
-	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
-		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
+	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+		ChannelId:        relayInfo.ChannelId,
+		PromptTokens:     promptTokens,
+		CompletionTokens: completionTokens,
+		ModelName:        modelName,
+		TokenName:        tokenName,
+		Quota:            quota,
+		Content:          logContent,
+		TokenId:          relayInfo.TokenId,
+		UserQuota:        userQuota,
+		UseTimeSeconds:   int(useTimeSeconds),
+		IsStream:         relayInfo.IsStream,
+		Group:            relayInfo.UsingGroup,
+		Other:            other,
+	})
+
 }
 
 func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
@@ -384,8 +411,21 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	}
 	other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
 		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
-	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
-		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
+	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+		ChannelId:        relayInfo.ChannelId,
+		PromptTokens:     usage.PromptTokens,
+		CompletionTokens: usage.CompletionTokens,
+		ModelName:        logModel,
+		TokenName:        tokenName,
+		Quota:            quota,
+		Content:          logContent,
+		TokenId:          relayInfo.TokenId,
+		UserQuota:        userQuota,
+		UseTimeSeconds:   int(useTimeSeconds),
+		IsStream:         relayInfo.IsStream,
+		Group:            relayInfo.UsingGroup,
+		Other:            other,
+	})
 }
 
 func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
@@ -447,8 +487,8 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
 	gopool.Go(func() {
 		userSetting := relayInfo.UserSetting
 		threshold := common.QuotaRemindThreshold
-		if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok {
-			threshold = int(userCustomThreshold.(float64))
+		if userSetting.QuotaWarningThreshold != 0 {
+			threshold = int(userSetting.QuotaWarningThreshold)
 		}
 
 		//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0

+ 10 - 22
service/user_notify.go

@@ -3,7 +3,6 @@ package service
 import (
 	"fmt"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	"one-api/model"
 	"strings"
@@ -17,10 +16,10 @@ func NotifyRootUser(t string, subject string, content string) {
 	}
 }
 
-func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
-	notifyType, ok := userSetting[constant.UserSettingNotifyType]
-	if !ok {
-		notifyType = constant.NotifyTypeEmail
+func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error {
+	notifyType := userSetting.NotifyType
+	if notifyType == "" {
+		notifyType = dto.NotifyTypeEmail
 	}
 
 	// Check notification limit
@@ -34,34 +33,23 @@ func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}
 	}
 
 	switch notifyType {
-	case constant.NotifyTypeEmail:
+	case dto.NotifyTypeEmail:
 		// check setting email
-		if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
-			userEmail = settingEmail.(string)
-		}
+		userEmail = userSetting.NotificationEmail
 		if userEmail == "" {
 			common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
 			return nil
 		}
 		return sendEmailNotify(userEmail, data)
-	case constant.NotifyTypeWebhook:
-		webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
-		if !ok {
+	case dto.NotifyTypeWebhook:
+		webhookURLStr := userSetting.WebhookUrl
+		if webhookURLStr == "" {
 			common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
 			return nil
 		}
-		webhookURLStr, ok := webhookURL.(string)
-		if !ok {
-			common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
-			return nil
-		}
 
 		// 获取 webhook secret
-		var webhookSecret string
-		if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
-			webhookSecret, _ = secret.(string)
-		}
-
+		webhookSecret := userSetting.WebhookSecret
 		return SendWebhookNotify(webhookURLStr, webhookSecret, data)
 	}
 	return nil

+ 4 - 1
setting/chat.go

@@ -6,8 +6,11 @@ import (
 )
 
 var Chats = []map[string]string{
+	//{
+	//	"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
+	//},
 	{
-		"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
+		"Cherry Studio": "cherrystudio://providers/api-keys?v=1&data={cherryConfig}",
 	},
 	{
 		"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",

+ 2 - 2
web/src/components/table/ChannelsTable.js

@@ -1461,9 +1461,9 @@ const ChannelsTable = () => {
 
   const fixChannelsAbilities = async () => {
     const res = await API.post(`/api/channel/fix`);
-    const { success, message, data } = res.data;
+    const { success, message, data  } = res.data;
     if (success) {
-      showSuccess(t('已修复 ${data} 个通道!').replace('${data}', data));
+      showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails));
       await refresh();
     } else {
       showError(message);

+ 16 - 3
web/src/components/table/TokensTable.js

@@ -432,9 +432,22 @@ const TokensTable = () => {
     if (serverAddress === '') {
       serverAddress = window.location.origin;
     }
-    let encodedServerAddress = encodeURIComponent(serverAddress);
-    url = url.replaceAll('{address}', encodedServerAddress);
-    url = url.replaceAll('{key}', 'sk-' + record.key);
+    if (url.includes('{cherryConfig}') === true) {
+      let cherryConfig = {
+        id: 'new-api',
+        baseUrl: serverAddress,
+        apiKey: 'sk-' + record.key,
+      }
+      // 替换 {cherryConfig} 为base64编码的JSON字符串
+      let encodedConfig = encodeURIComponent(
+        btoa(JSON.stringify(cherryConfig))
+      );
+      url = url.replaceAll('{cherryConfig}', encodedConfig);
+    } else {
+      let encodedServerAddress = encodeURIComponent(serverAddress);
+      url = url.replaceAll('{address}', encodedServerAddress);
+      url = url.replaceAll('{key}', 'sk-' + record.key);
+    }
 
     window.open(url, '_blank');
   };

+ 1 - 1
web/src/pages/Channel/EditChannel.js

@@ -261,7 +261,7 @@ const EditChannel = (props) => {
     if (isEdit) {
       // 如果是编辑模式,使用已有的channel id获取模型列表
       const res = await API.get('/api/channel/fetch_models/' + channelId);
-      if (res.data && res.data?.success) {
+      if (res.data && res.data.success) {
         models.push(...res.data.data);
       } else {
         err = true;