Ver Fonte

Merge remote-tracking branch 'new-api/main' into gpt-image

# Conflicts:
#	relay/relay-image.go
CaIon há 8 meses atrás
pai
commit
a03c615fa4

+ 4 - 0
common/constants.go

@@ -62,6 +62,10 @@ var EmailDomainWhitelist = []string{
 	"yahoo.com",
 	"foxmail.com",
 }
+var EmailLoginAuthServerList = []string{
+	"smtp.sendcloud.net",
+	"smtp.azurecomm.net",
+}
 
 var DebugEnabled bool
 var MemoryCacheEnabled bool

+ 2 - 1
common/email.go

@@ -5,6 +5,7 @@ import (
 	"encoding/base64"
 	"fmt"
 	"net/smtp"
+	"slices"
 	"strings"
 	"time"
 )
@@ -79,7 +80,7 @@ func SendEmail(subject string, receiver string, content string) error {
 		if err != nil {
 			return err
 		}
-	} else if isOutlookServer(SMTPAccount) || SMTPServer == "smtp.azurecomm.net" {
+	} else if isOutlookServer(SMTPAccount) || slices.Contains(EmailLoginAuthServerList, SMTPServer) {
 		auth = LoginAuth(SMTPAccount, SMTPToken)
 		err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
 	} else {

+ 89 - 0
common/limiter/limiter.go

@@ -0,0 +1,89 @@
+package limiter
+
+import (
+	"context"
+	_ "embed"
+	"fmt"
+	"github.com/go-redis/redis/v8"
+	"one-api/common"
+	"sync"
+)
+
+//go:embed lua/rate_limit.lua
+var rateLimitScript string
+
+type RedisLimiter struct {
+	client         *redis.Client
+	limitScriptSHA string
+}
+
+var (
+	instance *RedisLimiter
+	once     sync.Once
+)
+
+func New(ctx context.Context, r *redis.Client) *RedisLimiter {
+	once.Do(func() {
+		// 预加载脚本
+		limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
+		if err != nil {
+			common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
+		}
+		instance = &RedisLimiter{
+			client:         r,
+			limitScriptSHA: limitSHA,
+		}
+	})
+
+	return instance
+}
+
+func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
+	// 默认配置
+	config := &Config{
+		Capacity:  10,
+		Rate:      1,
+		Requested: 1,
+	}
+
+	// 应用选项模式
+	for _, opt := range opts {
+		opt(config)
+	}
+
+	// 执行限流
+	result, err := rl.client.EvalSha(
+		ctx,
+		rl.limitScriptSHA,
+		[]string{key},
+		config.Requested,
+		config.Rate,
+		config.Capacity,
+	).Int()
+
+	if err != nil {
+		return false, fmt.Errorf("rate limit failed: %w", err)
+	}
+	return result == 1, nil
+}
+
+// Config 配置选项模式
+type Config struct {
+	Capacity  int64
+	Rate      int64
+	Requested int64
+}
+
+type Option func(*Config)
+
+func WithCapacity(c int64) Option {
+	return func(cfg *Config) { cfg.Capacity = c }
+}
+
+func WithRate(r int64) Option {
+	return func(cfg *Config) { cfg.Rate = r }
+}
+
+func WithRequested(n int64) Option {
+	return func(cfg *Config) { cfg.Requested = n }
+}

+ 44 - 0
common/limiter/lua/rate_limit.lua

@@ -0,0 +1,44 @@
+-- 令牌桶限流器
+-- KEYS[1]: 限流器唯一标识
+-- ARGV[1]: 请求令牌数 (通常为1)
+-- ARGV[2]: 令牌生成速率 (每秒)
+-- ARGV[3]: 桶容量
+
+local key = KEYS[1]
+local requested = tonumber(ARGV[1])
+local rate = tonumber(ARGV[2])
+local capacity = tonumber(ARGV[3])
+
+-- 获取当前时间(Redis服务器时间)
+local now = redis.call('TIME')
+local nowInSeconds = tonumber(now[1])
+
+-- 获取桶状态
+local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
+local tokens = tonumber(bucket[1])
+local last_time = tonumber(bucket[2])
+
+-- 初始化桶(首次请求或过期)
+if not tokens or not last_time then
+    tokens = capacity
+    last_time = nowInSeconds
+else
+    -- 计算新增令牌
+    local elapsed = nowInSeconds - last_time
+    local add_tokens = elapsed * rate
+    tokens = math.min(capacity, tokens + add_tokens)
+    last_time = nowInSeconds
+end
+
+-- 判断是否允许请求
+local allowed = false
+if tokens >= requested then
+    tokens = tokens - requested
+    allowed = true
+end
+
+---- 更新桶状态并设置过期时间
+redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
+--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
+
+return allowed and 1 or 0

+ 1 - 1
common/utils.go

@@ -7,7 +7,6 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"fmt"
-	"github.com/pkg/errors"
 	"html/template"
 	"io"
 	"log"
@@ -22,6 +21,7 @@ import (
 	"time"
 
 	"github.com/google/uuid"
+	"github.com/pkg/errors"
 )
 
 func OpenBrowser(url string) {

+ 4 - 1
controller/channel-test.go

@@ -103,7 +103,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	}
 
 	request := buildTestRequest(testModel)
-	common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
+	// 创建一个用于日志的 info 副本,移除 ApiKey
+	logInfo := *info
+	logInfo.ApiKey = ""
+	common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
 
 	priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
 	if err != nil {

+ 22 - 14
middleware/model-rate-limit.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"net/http"
 	"one-api/common"
+	"one-api/common/limiter"
 	"one-api/setting"
 	"strconv"
 	"time"
@@ -78,34 +79,41 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
 		ctx := context.Background()
 		rdb := common.RDB
 
-		// 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
-		totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
-		allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
+		// 1. 检查成功请求数限制
+		successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
+		allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
 		if err != nil {
-			fmt.Println("检查请求数限制失败:", err.Error())
+			fmt.Println("检查成功请求数限制失败:", err.Error())
 			abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
 			return
 		}
 		if !allowed {
-			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
+			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
+			return
 		}
 
-		// 2. 检查成功请求数限制
-		successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
-		allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
+		//2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
+		totalKey := fmt.Sprintf("rateLimit:%s", userId)
+		// 初始化
+		tb := limiter.New(ctx, rdb)
+		allowed, err = tb.Allow(
+			ctx,
+			totalKey,
+			limiter.WithCapacity(int64(totalMaxCount)*duration),
+			limiter.WithRate(int64(totalMaxCount)),
+			limiter.WithRequested(duration),
+		)
+
 		if err != nil {
-			fmt.Println("检查成功请求数限制失败:", err.Error())
+			fmt.Println("检查请求数限制失败:", err.Error())
 			abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
 			return
 		}
+
 		if !allowed {
-			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
-			return
+			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
 		}
 
-		// 3. 记录总请求(当totalMaxCount为0时会自动跳过)
-		recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
-
 		// 4. 处理请求
 		c.Next()
 

+ 1 - 1
model/user.go

@@ -108,7 +108,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
 
 func GetMaxUserId() int {
 	var user User
-	DB.Last(&user)
+	DB.Unscoped().Last(&user)
 	return user.Id
 }
 

+ 10 - 4
relay/channel/claude/relay-claude.go

@@ -246,17 +246,23 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 					} else {
 						imageUrl := mediaMessage.GetImageMedia()
 						claudeMediaMessage.Type = "image"
-						claudeMediaMessage.Source = &dto.ClaudeMessageSource{}
+						claudeMediaMessage.Source = &dto.ClaudeMessageSource{
+							Type: "base64",
+						}
 						// 判断是否是url
 						if strings.HasPrefix(imageUrl.Url, "http") {
-							claudeMediaMessage.Source.Type = "url"
-							claudeMediaMessage.Source.Url = imageUrl.Url
+							// 是url,获取图片的类型和base64编码的数据
+							fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+							if err != nil {
+								return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
+							}
+							claudeMediaMessage.Source.MediaType = fileData.MimeType
+							claudeMediaMessage.Source.Data = fileData.Base64Data
 						} else {
 							_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
 							if err != nil {
 								return nil, err
 							}
-							claudeMediaMessage.Source.Type = "base64"
 							claudeMediaMessage.Source.MediaType = "image/" + format
 							claudeMediaMessage.Source.Data = base64String
 						}

+ 6 - 1
relay/channel/deepseek/adaptor.go

@@ -11,6 +11,7 @@ import (
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
+	"strings"
 )
 
 type Adaptor struct {
@@ -36,9 +37,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	fimBaseUrl := info.BaseUrl
+	if !strings.HasSuffix(info.BaseUrl, "/beta") {
+		fimBaseUrl += "/beta"
+	}
 	switch info.RelayMode {
 	case constant.RelayModeCompletions:
-		return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
+		return fmt.Sprintf("%s/completions", fimBaseUrl), nil
 	default:
 		return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
 	}

+ 46 - 1
relay/relay-image.go

@@ -5,7 +5,6 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/common"
@@ -17,6 +16,8 @@ import (
 	"one-api/service"
 	"one-api/setting"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
 func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
@@ -81,6 +82,50 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 		imageRequest.Size = "1024x1024"
 	}
 
+	err := common.UnmarshalBodyReusable(c, imageRequest)
+	if err != nil {
+		return nil, err
+	}
+	if imageRequest.Prompt == "" {
+		return nil, errors.New("prompt is required")
+	}
+	if strings.Contains(imageRequest.Size, "×") {
+		return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
+	}
+	if imageRequest.N == 0 {
+		imageRequest.N = 1
+	}
+	if imageRequest.Size == "" {
+		imageRequest.Size = "1024x1024"
+	}
+	if imageRequest.Model == "" {
+		imageRequest.Model = "dall-e-2"
+	}
+	// x.ai grok-2-image not support size, quality or style
+	if imageRequest.Size == "empty" {
+		imageRequest.Size = ""
+	}
+
+	// Not "256x256", "512x512", or "1024x1024"
+	if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
+		if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
+			return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
+		}
+	} else if imageRequest.Model == "dall-e-3" {
+		if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
+			return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
+		}
+		if imageRequest.Quality == "" {
+			imageRequest.Quality = "standard"
+		}
+		//if imageRequest.N != 1 {
+		//	return nil, errors.New("n must be 1")
+		//}
+	}
+	// N should between 1 and 10
+	//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
+	//	return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
+	//}
 	if setting.ShouldCheckPromptSensitive() {
 		words, err := service.CheckSensitiveInput(imageRequest.Prompt)
 		if err != nil {

+ 8 - 3
web/src/components/ModelPricing.js

@@ -81,7 +81,7 @@ const ModelPricing = () => {
   }
 
   function renderAvailable(available) {
-    return (
+    return available ? (
       <Popover
         content={
           <div style={{ padding: 8 }}>{t('您的分组可以使用该模型')}</div>
@@ -98,7 +98,7 @@ const ModelPricing = () => {
       >
         <IconVerify style={{ color: 'green' }} size='large' />
       </Popover>
-    );
+    ) : null;
   }
 
   const columns = [
@@ -109,7 +109,12 @@ const ModelPricing = () => {
         // if record.enable_groups contains selectedGroup, then available is true
         return renderAvailable(record.enable_groups.includes(selectedGroup));
       },
-      sorter: (a, b) => a.available - b.available,
+      sorter: (a, b) => {
+        const aAvailable = a.enable_groups.includes(selectedGroup);
+        const bAvailable = b.enable_groups.includes(selectedGroup);
+        return Number(aAvailable) - Number(bAvailable);
+      },
+      defaultSortOrder: 'descend',
     },
     {
       title: t('模型名称'),