Kaynağa Gözat

feat(adaptor): 新适配百炼多种图片生成模型

- wan2.6系列生图与编辑,适配多图生成计费
- wan2.5系列生图与编辑
- z-image-turbo生图,适配prompt_extend计费
CaIon 1 gün önce
ebeveyn
işleme
48d358faec

+ 36 - 1
controller/token.go

@@ -1,6 +1,7 @@
 package controller
 
 import (
+	"fmt"
 	"net/http"
 	"strconv"
 	"strings"
@@ -149,6 +150,24 @@ func AddToken(c *gin.Context) {
 		})
 		return
 	}
+	// 非无限额度时,检查额度值是否超出有效范围
+	if !token.UnlimitedQuota {
+		if token.RemainQuota < 0 {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "额度值不能为负数",
+			})
+			return
+		}
+		maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
+		if token.RemainQuota > maxQuotaValue {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
+			})
+			return
+		}
+	}
 	key, err := common.GenerateKey()
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
@@ -216,6 +235,23 @@ func UpdateToken(c *gin.Context) {
 		})
 		return
 	}
+	if !token.UnlimitedQuota {
+		if token.RemainQuota < 0 {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "额度值不能为负数",
+			})
+			return
+		}
+		maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
+		if token.RemainQuota > maxQuotaValue {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
+			})
+			return
+		}
+	}
 	cleanToken, err := model.GetTokenByIds(token.Id, userId)
 	if err != nil {
 		common.ApiError(c, err)
@@ -261,7 +297,6 @@ func UpdateToken(c *gin.Context) {
 		"message": "",
 		"data":    cleanToken,
 	})
-	return
 }
 
 type TokenBatch struct {

+ 3 - 3
dto/openai_image.go

@@ -167,9 +167,9 @@ func (i *ImageRequest) SetModelName(modelName string) {
 }
 
 type ImageResponse struct {
-	Data    []ImageData `json:"data"`
-	Created int64       `json:"created"`
-	Extra   any         `json:"extra,omitempty"`
+	Data     []ImageData     `json:"data"`
+	Created  int64           `json:"created"`
+	Metadata json.RawMessage `json:"metadata,omitempty"`
 }
 type ImageData struct {
 	Url           string `json:"url"`

+ 4 - 1
dto/openai_request.go

@@ -23,6 +23,8 @@ type FormatJsonSchema struct {
 	Strict      json.RawMessage `json:"strict,omitempty"`
 }
 
+// GeneralOpenAIRequest represents a general request structure for OpenAI-compatible APIs.
+// 参数增加规范:无引用的参数必须使用json.RawMessage类型,并添加omitempty标签
 type GeneralOpenAIRequest struct {
 	Model               string            `json:"model,omitempty"`
 	Messages            []Message         `json:"messages,omitempty"`
@@ -82,8 +84,9 @@ type GeneralOpenAIRequest struct {
 	Reasoning json.RawMessage `json:"reasoning,omitempty"`
 	// Ali Qwen Params
 	VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
-	EnableThinking         any             `json:"enable_thinking,omitempty"`
+	EnableThinking         json.RawMessage `json:"enable_thinking,omitempty"`
 	ChatTemplateKwargs     json.RawMessage `json:"chat_template_kwargs,omitempty"`
+	EnableSearch           json.RawMessage `json:"enable_search,omitempty"`
 	// ollama Params
 	Think json.RawMessage `json:"think,omitempty"`
 	// baidu v2

+ 2 - 0
main.go

@@ -188,6 +188,7 @@ func InjectUmamiAnalytics() {
 		analyticsInjectBuilder.WriteString(umamiSiteID)
 		analyticsInjectBuilder.WriteString("\"></script>")
 	}
+	analyticsInjectBuilder.WriteString("<!--Umami QuantumNous-->\n")
 	analyticsInject := analyticsInjectBuilder.String()
 	indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
 }
@@ -209,6 +210,7 @@ func InjectGoogleAnalytics() {
 		analyticsInjectBuilder.WriteString("');")
 		analyticsInjectBuilder.WriteString("</script>")
 	}
+	analyticsInjectBuilder.WriteString("<!--Google Analytics QuantumNous-->\n")
 	analyticsInject := analyticsInjectBuilder.String()
 	indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
 }

+ 1 - 1
relay/audio_handler.go

@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 	if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
 		service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
 	} else {
-		postConsumeQuota(c, info, usage.(*dto.Usage), "")
+		postConsumeQuota(c, info, usage.(*dto.Usage))
 	}
 
 	return nil

+ 46 - 14
relay/channel/ali/adaptor.go

@@ -19,6 +19,22 @@ import (
 )
 
 type Adaptor struct {
+	IsSyncImageModel bool
+}
+
+var syncModels = []string{
+	"z-image",
+	"qwen-image",
+	"wan2.6",
+}
+
+func isSyncImageModel(modelName string) bool {
+	for _, m := range syncModels {
+		if strings.Contains(modelName, m) {
+			return true
+		}
+	}
+	return false
 }
 
 func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -45,10 +61,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		case constant.RelayModeRerank:
 			fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
 		case constant.RelayModeImagesGenerations:
-			fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
+			if isSyncImageModel(info.OriginModelName) {
+				fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
+			} else {
+				fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
+			}
 		case constant.RelayModeImagesEdits:
-			if isWanModel(info.OriginModelName) {
+			if isOldWanModel(info.OriginModelName) {
 				fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
+			} else if isWanModel(info.OriginModelName) {
+				fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image-generation/generation", info.ChannelBaseUrl)
 			} else {
 				fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
 			}
@@ -72,7 +94,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
 		req.Set("X-DashScope-Plugin", c.GetString("plugin"))
 	}
 	if info.RelayMode == constant.RelayModeImagesGenerations {
-		req.Set("X-DashScope-Async", "enable")
+		if isSyncImageModel(info.OriginModelName) {
+
+		} else {
+			req.Set("X-DashScope-Async", "enable")
+		}
 	}
 	if info.RelayMode == constant.RelayModeImagesEdits {
 		if isWanModel(info.OriginModelName) {
@@ -108,15 +134,25 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
 	if info.RelayMode == constant.RelayModeImagesGenerations {
-		aliRequest, err := oaiImage2Ali(request)
+		if isSyncImageModel(info.OriginModelName) {
+			a.IsSyncImageModel = true
+		}
+		aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
 		if err != nil {
-			return nil, fmt.Errorf("convert image request failed: %w", err)
+			return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
 		}
 		return aliRequest, nil
 	} else if info.RelayMode == constant.RelayModeImagesEdits {
-		if isWanModel(info.OriginModelName) {
+		if isOldWanModel(info.OriginModelName) {
 			return oaiFormEdit2WanxImageEdit(c, info, request)
 		}
+		if isSyncImageModel(info.OriginModelName) {
+			if isWanModel(info.OriginModelName) {
+				a.IsSyncImageModel = false
+			} else {
+				a.IsSyncImageModel = true
+			}
+		}
 		// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
 		// 如果用户使用表单,则需要解析表单数据
 		if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
@@ -126,9 +162,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 			}
 			return aliRequest, nil
 		} else {
-			aliRequest, err := oaiImage2Ali(request)
+			aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
 			if err != nil {
-				return nil, fmt.Errorf("convert image request failed: %w", err)
+				return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
 			}
 			return aliRequest, nil
 		}
@@ -169,13 +205,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	default:
 		switch info.RelayMode {
 		case constant.RelayModeImagesGenerations:
-			err, usage = aliImageHandler(c, resp, info)
+			err, usage = aliImageHandler(a, c, resp, info)
 		case constant.RelayModeImagesEdits:
-			if isWanModel(info.OriginModelName) {
-				err, usage = aliImageHandler(c, resp, info)
-			} else {
-				err, usage = aliImageEditHandler(c, resp, info)
-			}
+			err, usage = aliImageHandler(a, c, resp, info)
 		case constant.RelayModeRerank:
 			err, usage = RerankHandler(c, resp, info)
 		default:

+ 98 - 18
relay/channel/ali/dto.go

@@ -1,6 +1,13 @@
 package ali
 
-import "github.com/QuantumNous/new-api/dto"
+import (
+	"strings"
+
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/service"
+	"github.com/gin-gonic/gin"
+)
 
 type AliMessage struct {
 	Content any    `json:"content"`
@@ -65,6 +72,7 @@ type AliUsage struct {
 	InputTokens  int `json:"input_tokens"`
 	OutputTokens int `json:"output_tokens"`
 	TotalTokens  int `json:"total_tokens"`
+	ImageCount   int `json:"image_count,omitempty"`
 }
 
 type TaskResult struct {
@@ -75,14 +83,78 @@ type TaskResult struct {
 }
 
 type AliOutput struct {
-	TaskId       string           `json:"task_id,omitempty"`
-	TaskStatus   string           `json:"task_status,omitempty"`
-	Text         string           `json:"text"`
-	FinishReason string           `json:"finish_reason"`
-	Message      string           `json:"message,omitempty"`
-	Code         string           `json:"code,omitempty"`
-	Results      []TaskResult     `json:"results,omitempty"`
-	Choices      []map[string]any `json:"choices,omitempty"`
+	TaskId       string       `json:"task_id,omitempty"`
+	TaskStatus   string       `json:"task_status,omitempty"`
+	Text         string       `json:"text"`
+	FinishReason string       `json:"finish_reason"`
+	Message      string       `json:"message,omitempty"`
+	Code         string       `json:"code,omitempty"`
+	Results      []TaskResult `json:"results,omitempty"`
+	Choices      []struct {
+		FinishReason string `json:"finish_reason,omitempty"`
+		Message      struct {
+			Role             string            `json:"role,omitempty"`
+			Content          []AliMediaContent `json:"content,omitempty"`
+			ReasoningContent string            `json:"reasoning_content,omitempty"`
+		} `json:"message,omitempty"`
+	} `json:"choices,omitempty"`
+}
+
+func (o *AliOutput) ChoicesToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
+	var imageData []dto.ImageData
+	if len(o.Choices) > 0 {
+		for _, choice := range o.Choices {
+			var data dto.ImageData
+			for _, content := range choice.Message.Content {
+				if content.Image != "" {
+					if strings.HasPrefix(content.Image, "http") {
+						var b64Json string
+						if responseFormat == "b64_json" {
+							_, b64, err := service.GetImageFromUrl(content.Image)
+							if err != nil {
+								logger.LogError(c, "get_image_data_failed: "+err.Error())
+								continue
+							}
+							b64Json = b64
+						}
+						data.Url = content.Image
+						data.B64Json = b64Json
+					} else {
+						data.B64Json = content.Image
+					}
+				} else if content.Text != "" {
+					data.RevisedPrompt = content.Text
+				}
+			}
+			imageData = append(imageData, data)
+		}
+	}
+
+	return imageData
+}
+
+func (o *AliOutput) ResultToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
+	var imageData []dto.ImageData
+	for _, data := range o.Results {
+		var b64Json string
+		if responseFormat == "b64_json" {
+			_, b64, err := service.GetImageFromUrl(data.Url)
+			if err != nil {
+				logger.LogError(c, "get_image_data_failed: "+err.Error())
+				continue
+			}
+			b64Json = b64
+		} else {
+			b64Json = data.B64Image
+		}
+
+		imageData = append(imageData, dto.ImageData{
+			Url:           data.Url,
+			B64Json:       b64Json,
+			RevisedPrompt: "",
+		})
+	}
+	return imageData
 }
 
 type AliResponse struct {
@@ -92,18 +164,26 @@ type AliResponse struct {
 }
 
 type AliImageRequest struct {
-	Model          string `json:"model"`
-	Input          any    `json:"input"`
-	Parameters     any    `json:"parameters,omitempty"`
-	ResponseFormat string `json:"response_format,omitempty"`
+	Model          string             `json:"model"`
+	Input          any                `json:"input"`
+	Parameters     AliImageParameters `json:"parameters,omitempty"`
+	ResponseFormat string             `json:"response_format,omitempty"`
 }
 
 type AliImageParameters struct {
-	Size      string `json:"size,omitempty"`
-	N         int    `json:"n,omitempty"`
-	Steps     string `json:"steps,omitempty"`
-	Scale     string `json:"scale,omitempty"`
-	Watermark *bool  `json:"watermark,omitempty"`
+	Size         string `json:"size,omitempty"`
+	N            int    `json:"n,omitempty"`
+	Steps        string `json:"steps,omitempty"`
+	Scale        string `json:"scale,omitempty"`
+	Watermark    *bool  `json:"watermark,omitempty"`
+	PromptExtend *bool  `json:"prompt_extend,omitempty"`
+}
+
+func (p *AliImageParameters) PromptExtendValue() bool {
+	if p != nil && p.PromptExtend != nil {
+		return *p.PromptExtend
+	}
+	return false
 }
 
 type AliImageInput struct {

+ 83 - 84
relay/channel/ali/image.go

@@ -21,17 +21,25 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
+func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
 	var imageRequest AliImageRequest
 	imageRequest.Model = request.Model
 	imageRequest.ResponseFormat = request.ResponseFormat
 	logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
+	logger.LogDebug(context.Background(), "oaiImage2Ali request isSync: "+fmt.Sprintf("%v", isSync))
 	if request.Extra != nil {
 		if val, ok := request.Extra["parameters"]; ok {
 			err := common.Unmarshal(val, &imageRequest.Parameters)
 			if err != nil {
 				return nil, fmt.Errorf("invalid parameters field: %w", err)
 			}
+		} else {
+			// 兼容没有parameters字段的情况,从openai标准字段中提取参数
+			imageRequest.Parameters = AliImageParameters{
+				Size:      strings.Replace(request.Size, "x", "*", -1),
+				N:         int(request.N),
+				Watermark: request.Watermark,
+			}
 		}
 		if val, ok := request.Extra["input"]; ok {
 			err := common.Unmarshal(val, &imageRequest.Input)
@@ -41,23 +49,44 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
 		}
 	}
 
-	if imageRequest.Parameters == nil {
-		imageRequest.Parameters = AliImageParameters{
-			Size:      strings.Replace(request.Size, "x", "*", -1),
-			N:         int(request.N),
-			Watermark: request.Watermark,
+	if strings.Contains(request.Model, "z-image") {
+		// z-image 开启prompt_extend后,按2倍计费
+		if imageRequest.Parameters.PromptExtendValue() {
+			info.PriceData.AddOtherRatio("prompt_extend", 2)
 		}
 	}
 
-	if imageRequest.Input == nil {
-		imageRequest.Input = AliImageInput{
-			Prompt: request.Prompt,
+	// 检查n参数
+	if imageRequest.Parameters.N != 0 {
+		info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
+	}
+
+	// 同步图片模型和异步图片模型请求格式不一样
+	if isSync {
+		if imageRequest.Input == nil {
+			imageRequest.Input = AliImageInput{
+				Messages: []AliMessage{
+					{
+						Role: "user",
+						Content: []AliMediaContent{
+							{
+								Text: request.Prompt,
+							},
+						},
+					},
+				},
+			}
+		}
+	} else {
+		if imageRequest.Input == nil {
+			imageRequest.Input = AliImageInput{
+				Prompt: request.Prompt,
+			}
 		}
 	}
 
 	return &imageRequest, nil
 }
-
 func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
 	mf := c.Request.MultipartForm
 	if mf == nil {
@@ -199,6 +228,8 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (
 	var taskResponse AliResponse
 	var responseBody []byte
 
+	time.Sleep(time.Duration(5) * time.Second)
+
 	for {
 		logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
 		step++
@@ -238,32 +269,17 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody [
 		Created: info.StartTime.Unix(),
 	}
 
-	for _, data := range response.Output.Results {
-		var b64Json string
-		if responseFormat == "b64_json" {
-			_, b64, err := service.GetImageFromUrl(data.Url)
-			if err != nil {
-				logger.LogError(c, "get_image_data_failed: "+err.Error())
-				continue
-			}
-			b64Json = b64
-		} else {
-			b64Json = data.B64Image
-		}
-
-		imageResponse.Data = append(imageResponse.Data, dto.ImageData{
-			Url:           data.Url,
-			B64Json:       b64Json,
-			RevisedPrompt: "",
-		})
+	if len(response.Output.Results) > 0 {
+		imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
+	} else if len(response.Output.Choices) > 0 {
+		imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
 	}
-	var mapResponse map[string]any
-	_ = common.Unmarshal(originBody, &mapResponse)
-	imageResponse.Extra = mapResponse
+
+	imageResponse.Metadata = originBody
 	return &imageResponse
 }
 
-func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
+func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
 	responseFormat := c.GetString("response_format")
 
 	var aliTaskResponse AliResponse
@@ -282,66 +298,49 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 		return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
 	}
 
-	aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponse), nil
-	}
-
-	if aliResponse.Output.TaskStatus != "SUCCEEDED" {
-		return types.WithOpenAIError(types.OpenAIError{
-			Message: aliResponse.Output.Message,
-			Type:    "ali_error",
-			Param:   "",
-			Code:    aliResponse.Output.Code,
-		}, resp.StatusCode), nil
-	}
-
-	fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
-	jsonResponse, err := common.Marshal(fullTextResponse)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
-	}
-	service.IOCopyBytesGracefully(c, resp, jsonResponse)
-	return nil, &dto.Usage{}
-}
-
-func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
-	var aliResponse AliResponse
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
+	var (
+		aliResponse    *AliResponse
+		originRespBody []byte
+	)
+
+	if a.IsSyncImageModel {
+		aliResponse = &aliTaskResponse
+		originRespBody = responseBody
+	} else {
+		// 异步图片模型需要轮询任务结果
+		aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
+		if err != nil {
+			return types.NewError(err, types.ErrorCodeBadResponse), nil
+		}
+		if aliResponse.Output.TaskStatus != "SUCCEEDED" {
+			return types.WithOpenAIError(types.OpenAIError{
+				Message: aliResponse.Output.Message,
+				Type:    "ali_error",
+				Param:   "",
+				Code:    aliResponse.Output.Code,
+			}, resp.StatusCode), nil
+		}
 	}
 
-	service.CloseResponseBodyGracefully(resp)
-	err = common.Unmarshal(responseBody, &aliResponse)
-	if err != nil {
-		return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
+	//logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
+	if a.IsSyncImageModel {
+		logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
+	} else {
+		logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
 	}
 
-	if aliResponse.Message != "" {
-		logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
-		return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
+	imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
+	// 可能生成多张图片,修正计费数量n
+	if aliResponse.Usage.ImageCount != 0 {
+		info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
+	} else if len(imageResponses.Data) != 0 {
+		info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
 	}
-	var fullTextResponse dto.ImageResponse
-	if len(aliResponse.Output.Choices) > 0 {
-		fullTextResponse = dto.ImageResponse{
-			Created: info.StartTime.Unix(),
-			Data: []dto.ImageData{
-				{
-					Url:     aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
-					B64Json: "",
-				},
-			},
-		}
-	}
-
-	var mapResponse map[string]any
-	_ = common.Unmarshal(responseBody, &mapResponse)
-	fullTextResponse.Extra = mapResponse
-	jsonResponse, err := common.Marshal(fullTextResponse)
+	jsonResponse, err := common.Marshal(imageResponses)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 	}
 	service.IOCopyBytesGracefully(c, resp, jsonResponse)
+
 	return nil, &dto.Usage{}
 }

+ 11 - 3
relay/channel/ali/image_wan.go

@@ -26,14 +26,22 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
 	if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
 		return nil, fmt.Errorf("get image base64s from form failed: %w", err)
 	}
-	wanParams := WanImageParameters{
+	//wanParams := WanImageParameters{
+	//	N: int(request.N),
+	//}
+	imageRequest.Input = wanInput
+	imageRequest.Parameters = AliImageParameters{
 		N: int(request.N),
 	}
-	imageRequest.Input = wanInput
-	imageRequest.Parameters = wanParams
+	info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
+
 	return &imageRequest, nil
 }
 
+func isOldWanModel(modelName string) bool {
+	return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
+}
+
 func isWanModel(modelName string) bool {
 	return strings.Contains(modelName, "wan")
 }

+ 26 - 20
relay/compatible_handler.go

@@ -184,19 +184,19 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 	if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
 		service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
 	} else {
-		postConsumeQuota(c, info, usage.(*dto.Usage), "")
+		postConsumeQuota(c, info, usage.(*dto.Usage))
 	}
 	return nil
 }
 
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
 	if usage == nil {
 		usage = &dto.Usage{
 			PromptTokens:     relayInfo.GetEstimatePromptTokens(),
 			CompletionTokens: 0,
 			TotalTokens:      relayInfo.GetEstimatePromptTokens(),
 		}
-		extraContent += "(可能是请求出错)"
+		extraContent = append(extraContent, "上游无计费信息")
 	}
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens
@@ -246,8 +246,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 			dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
 				Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
 				Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
-			extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
-				webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
+			extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
+				webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()))
 		}
 	} else if strings.HasSuffix(modelName, "search-preview") {
 		// search-preview 模型不支持 response api
@@ -258,8 +258,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 		webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
 		dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
 			Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
-		extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
-			searchContextSize, dWebSearchQuota.String())
+		extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
+			searchContextSize, dWebSearchQuota.String()))
 	}
 	// claude web search tool 计费
 	var dClaudeWebSearchQuota decimal.Decimal
@@ -269,8 +269,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 		claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
 		dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
 			Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
-		extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
-			claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
+		extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
+			claudeWebSearchCallCount, dClaudeWebSearchQuota.String()))
 	}
 	// file search tool 计费
 	var dFileSearchQuota decimal.Decimal
@@ -281,8 +281,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 			dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
 				Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
 				Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
-			extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
-				fileSearchTool.CallCount, dFileSearchQuota.String())
+			extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
+				fileSearchTool.CallCount, dFileSearchQuota.String()))
 		}
 	}
 	var dImageGenerationCallQuota decimal.Decimal
@@ -290,7 +290,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	if ctx.GetBool("image_generation_call") {
 		imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
 		dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
-		extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
+		extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()))
 	}
 
 	var quotaCalculateDecimal decimal.Decimal
@@ -331,7 +331,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 				// 重新计算 base tokens
 				baseTokens = baseTokens.Sub(dAudioTokens)
 				audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
-				extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
+				extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
 			}
 		}
 		promptQuota := baseTokens.Add(cachedTokensWithRatio).
@@ -356,17 +356,25 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	// 添加 image generation call 计费
 	quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
 
+	if len(relayInfo.PriceData.OtherRatios) > 0 {
+		for key, otherRatio := range relayInfo.PriceData.OtherRatios {
+			dOtherRatio := decimal.NewFromFloat(otherRatio)
+			quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
+			extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
+		}
+	}
+
 	quota := int(quotaCalculateDecimal.Round(0).IntPart())
 	totalTokens := promptTokens + completionTokens
 
-	var logContent string
+	//var logContent string
 
 	// record all the consume log even if quota is 0
 	if totalTokens == 0 {
 		// in this case, must be some error happened
 		// we cannot just return, because we may have to return the pre-consumed quota
 		quota = 0
-		logContent += fmt.Sprintf("(可能是上游超时)")
+		extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
 		logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
 			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
 	} else {
@@ -405,15 +413,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	logModel := modelName
 	if strings.HasPrefix(logModel, "gpt-4-gizmo") {
 		logModel = "gpt-4-gizmo-*"
-		logContent += fmt.Sprintf(",模型 %s", modelName)
+		extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
 	}
 	if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
 		logModel = "gpt-4o-gizmo-*"
-		logContent += fmt.Sprintf(",模型 %s", modelName)
-	}
-	if extraContent != "" {
-		logContent += ", " + extraContent
+		extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
 	}
+	logContent := strings.Join(extraContent, ", ")
 	other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
 	if imageTokens != 0 {
 		other["image"] = true

+ 1 - 1
relay/embedding_handler.go

@@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		return newAPIError
 	}
-	postConsumeQuota(c, info, usage.(*dto.Usage), "")
+	postConsumeQuota(c, info, usage.(*dto.Usage))
 	return nil
 }

+ 2 - 2
relay/gemini_handler.go

@@ -193,7 +193,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		return openaiErr
 	}
 
-	postConsumeQuota(c, info, usage.(*dto.Usage), "")
+	postConsumeQuota(c, info, usage.(*dto.Usage))
 	return nil
 }
 
@@ -292,6 +292,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
 		return openaiErr
 	}
 
-	postConsumeQuota(c, info, usage.(*dto.Usage), "")
+	postConsumeQuota(c, info, usage.(*dto.Usage))
 	return nil
 }

+ 9 - 3
relay/image_handler.go

@@ -124,12 +124,18 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 		quality = "hd"
 	}
 
-	var logContent string
+	var logContent []string
 
 	if len(request.Size) > 0 {
-		logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
+		logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
+	}
+	if len(quality) > 0 {
+		logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
+	}
+	if request.N > 0 {
+		logContent = append(logContent, fmt.Sprintf("生成数量 %d", request.N))
 	}
 
-	postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
+	postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
 	return nil
 }

+ 1 - 1
relay/rerank_handler.go

@@ -95,6 +95,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		return newAPIError
 	}
-	postConsumeQuota(c, info, usage.(*dto.Usage), "")
+	postConsumeQuota(c, info, usage.(*dto.Usage))
 	return nil
 }

+ 1 - 1
relay/responses_handler.go

@@ -107,7 +107,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 	if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
 		service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
 	} else {
-		postConsumeQuota(c, info, usage.(*dto.Usage), "")
+		postConsumeQuota(c, info, usage.(*dto.Usage))
 	}
 	return nil
 }

+ 11 - 1
types/price_data.go

@@ -26,12 +26,22 @@ type PriceData struct {
 	GroupRatioInfo       GroupRatioInfo
 }
 
+func (p *PriceData) AddOtherRatio(key string, ratio float64) {
+	if p.OtherRatios == nil {
+		p.OtherRatios = make(map[string]float64)
+	}
+	if ratio <= 0 {
+		return
+	}
+	p.OtherRatios[key] = ratio
+}
+
 type PerCallPriceData struct {
 	ModelPrice     float64
 	Quota          int
 	GroupRatioInfo GroupRatioInfo
 }
 
-func (p PriceData) ToSetting() string {
+func (p *PriceData) ToSetting() string {
 	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
 }