Просмотр исходного кода

feat: zhipu v4 image generations

Seefs 1 месяц назад
Родитель
Сommit
2e37347851
3 измененных файлов с 138 добавлено и 4 удалено
  1. 5 2
      dto/openai_image.go
  2. 6 2
      relay/channel/zhipu_4v/adaptor.go
  3. 127 0
      relay/channel/zhipu_4v/image.go

+ 5 - 2
dto/openai_image.go

@@ -27,8 +27,11 @@ type ImageRequest struct {
 	OutputCompression json.RawMessage `json:"output_compression,omitempty"`
 	PartialImages     json.RawMessage `json:"partial_images,omitempty"`
 	// Stream            bool            `json:"stream,omitempty"`
-	Watermark *bool           `json:"watermark,omitempty"`
-	Image     json.RawMessage `json:"image,omitempty"`
+	Watermark *bool `json:"watermark,omitempty"`
+	// zhipu 4v
+	WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"`
+	UserId           json.RawMessage `json:"user_id,omitempty"`
+	Image            json.RawMessage `json:"image,omitempty"`
 	// 用匿名参数接收额外参数
 	Extra map[string]json.RawMessage `json:"-"`
 }

+ 6 - 2
relay/channel/zhipu_4v/adaptor.go

@@ -36,8 +36,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	return request, nil
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -63,6 +62,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 				return fmt.Sprintf("%s/embeddings", specialPlan.OpenAIBaseURL), nil
 			}
 			return fmt.Sprintf("%s/api/paas/v4/embeddings", baseURL), nil
+		case relayconstant.RelayModeImagesGenerations:
+			return fmt.Sprintf("%s/api/paas/v4/images/generations", baseURL), nil
 		default:
 			if hasSpecialPlan && specialPlan.OpenAIBaseURL != "" {
 				return fmt.Sprintf("%s/chat/completions", specialPlan.OpenAIBaseURL), nil
@@ -114,6 +115,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 			return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
 		}
 	default:
+		if info.RelayMode == relayconstant.RelayModeImagesGenerations {
+			return zhipu4vImageHandler(c, resp, info)
+		}
 		adaptor := openai.Adaptor{}
 		return adaptor.DoResponse(c, resp, info)
 	}

+ 127 - 0
relay/channel/zhipu_4v/image.go

@@ -0,0 +1,127 @@
+package zhipu_4v
+
+import (
+	"io"
+	"net/http"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/logger"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/service"
+	"github.com/QuantumNous/new-api/types"
+
+	"github.com/gin-gonic/gin"
+)
+
+type zhipuImageRequest struct {
+	Model            string `json:"model"`
+	Prompt           string `json:"prompt"`
+	Quality          string `json:"quality,omitempty"`
+	Size             string `json:"size,omitempty"`
+	WatermarkEnabled *bool  `json:"watermark_enabled,omitempty"`
+	UserID           string `json:"user_id,omitempty"`
+}
+
+type zhipuImageResponse struct {
+	Created       *int64            `json:"created,omitempty"`
+	Data          []zhipuImageData  `json:"data,omitempty"`
+	ContentFilter any               `json:"content_filter,omitempty"`
+	Usage         *dto.Usage        `json:"usage,omitempty"`
+	Error         *zhipuImageError  `json:"error,omitempty"`
+	RequestID     string            `json:"request_id,omitempty"`
+	ExtendParam   map[string]string `json:"extendParam,omitempty"`
+}
+
+type zhipuImageError struct {
+	Code    string `json:"code"`
+	Message string `json:"message"`
+}
+
+type zhipuImageData struct {
+	Url      string `json:"url,omitempty"`
+	ImageUrl string `json:"image_url,omitempty"`
+	B64Json  string `json:"b64_json,omitempty"`
+	B64Image string `json:"b64_image,omitempty"`
+}
+
+type openAIImagePayload struct {
+	Created int64             `json:"created"`
+	Data    []openAIImageData `json:"data"`
+}
+
+type openAIImageData struct {
+	B64Json string `json:"b64_json"`
+}
+
+func zhipu4vImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
+	}
+	service.CloseResponseBodyGracefully(resp)
+
+	var zhipuResp zhipuImageResponse
+	if err := common.Unmarshal(responseBody, &zhipuResp); err != nil {
+		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+	}
+
+	if zhipuResp.Error != nil && zhipuResp.Error.Message != "" {
+		return nil, types.WithOpenAIError(types.OpenAIError{
+			Message: zhipuResp.Error.Message,
+			Type:    "zhipu_image_error",
+			Code:    zhipuResp.Error.Code,
+		}, resp.StatusCode)
+	}
+
+	payload := openAIImagePayload{}
+	if zhipuResp.Created != nil && *zhipuResp.Created != 0 {
+		payload.Created = *zhipuResp.Created
+	} else {
+		payload.Created = info.StartTime.Unix()
+	}
+	for _, data := range zhipuResp.Data {
+		url := data.Url
+		if url == "" {
+			url = data.ImageUrl
+		}
+		if url == "" {
+			logger.LogWarn(c, "zhipu_image_missing_url")
+			continue
+		}
+
+		var b64 string
+		switch {
+		case data.B64Json != "":
+			b64 = data.B64Json
+		case data.B64Image != "":
+			b64 = data.B64Image
+		default:
+			_, downloaded, err := service.GetImageFromUrl(url)
+			if err != nil {
+				logger.LogError(c, "zhipu_image_get_b64_failed: "+err.Error())
+				continue
+			}
+			b64 = downloaded
+		}
+
+		if b64 == "" {
+			logger.LogWarn(c, "zhipu_image_empty_b64")
+			continue
+		}
+
+		imageData := openAIImageData{
+			B64Json: b64,
+		}
+		payload.Data = append(payload.Data, imageData)
+	}
+
+	jsonResp, err := common.Marshal(payload)
+	if err != nil {
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+	}
+
+	service.IOCopyBytesGracefully(c, resp, jsonResp)
+
+	return &dto.Usage{}, nil
+}