feitianbubu 5 месяцев назад
Родитель
Сommit
aa8d112c58

+ 2 - 0
common/api_type.go

@@ -63,6 +63,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
 		apiType = constant.APITypeXai
 	case constant.ChannelTypeCoze:
 		apiType = constant.APITypeCoze
+	case constant.ChannelTypeJimeng:
+		apiType = constant.APITypeJimeng
 	}
 	if apiType == -1 {
 		return constant.APITypeOpenAI, false

+ 1 - 0
constant/api_type.go

@@ -30,5 +30,6 @@ const (
 	APITypeXinference
 	APITypeXai
 	APITypeCoze
+	APITypeJimeng
 	APITypeDummy // this one is only for count, do not add any channel after this
 )

+ 3 - 0
relay/channel/api_request.go

@@ -203,6 +203,9 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
 	}
 }
 
+func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
+	return doRequest(c, req, info)
+}
 func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
 	var client *http.Client
 	var err error

+ 135 - 0
relay/channel/jimeng/adaptor.go

@@ -0,0 +1,135 @@
+package jimeng
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/channel"
+	"one-api/relay/channel/openai"
+	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
+	return errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return request, nil
+}
+
+type LogoInfo struct {
+	AddLogo         bool    `json:"add_logo,omitempty"`
+	Position        int     `json:"position,omitempty"`
+	Language        int     `json:"language,omitempty"`
+	Opacity         float64 `json:"opacity,omitempty"`
+	LogoTextContent string  `json:"logo_text_content,omitempty"`
+}
+
+type imageRequestPayload struct {
+	ReqKey     string   `json:"req_key"`                      // Service identifier, fixed value: jimeng_high_aes_general_v21_L
+	Prompt     string   `json:"prompt"`                       // Prompt for image generation, supports both Chinese and English
+	Seed       int64    `json:"seed,omitempty"`               // Random seed, default -1 (random)
+	Width      int      `json:"width,omitempty"`              // Image width, default 512, range [256, 768]
+	Height     int      `json:"height,omitempty"`             // Image height, default 512, range [256, 768]
+	UsePreLLM  bool     `json:"use_pre_llm,omitempty"`        // Enable text expansion, default true
+	UseSR      bool     `json:"use_sr,omitempty"`             // Enable super resolution, default true
+	ReturnURL  bool     `json:"return_url,omitempty"`         // Whether to return image URL (valid for 24 hours)
+	LogoInfo   LogoInfo `json:"logo_info,omitempty"`          // Watermark information
+	ImageUrls  []string `json:"image_urls,omitempty"`         // Image URLs for input
+	BinaryData []string `json:"binary_data_base64,omitempty"` // Base64 encoded binary data
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	payload := imageRequestPayload{
+		ReqKey: request.Model,
+		Prompt: request.Prompt,
+	}
+	if request.ResponseFormat == "" || request.ResponseFormat == "url" {
+		payload.ReturnURL = true // Default to returning image URLs
+	}
+
+	if len(request.ExtraFields) > 0 {
+		if err := json.Unmarshal(request.ExtraFields, &payload); err != nil {
+			return nil, fmt.Errorf("failed to unmarshal extra fields: %w", err)
+		}
+	}
+
+	return payload, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+	fullRequestURL, err := a.GetRequestURL(info)
+	if err != nil {
+		return nil, fmt.Errorf("get request url failed: %w", err)
+	}
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	if err != nil {
+		return nil, fmt.Errorf("new request failed: %w", err)
+	}
+	err = Sign(c, req, info.ApiKey)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	resp, err := channel.DoRequest(c, req, info)
+	if err != nil {
+		return nil, fmt.Errorf("do request failed: %w", err)
+	}
+	return resp, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	if info.RelayMode == relayconstant.RelayModeImagesGenerations {
+		err, usage = jimengImageHandler(c, resp, info)
+	} else if info.IsStream {
+		err, usage = openai.OaiStreamHandler(c, resp, info)
+	} else {
+		err, usage = openai.OpenaiHandler(c, resp, info)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 9 - 0
relay/channel/jimeng/constants.go

@@ -0,0 +1,9 @@
+package jimeng
+
+const (
+	ChannelName = "jimeng"
+)
+
+var ModelList = []string{
+	"jimeng_high_aes_general_v21_L",
+}

+ 91 - 0
relay/channel/jimeng/image.go

@@ -0,0 +1,91 @@
+package jimeng
+
+import (
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+)
+
+type ImageResponse struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+	Data    struct {
+		BinaryDataBase64 []string `json:"binary_data_base64"`
+		ImageUrls        []string `json:"image_urls"`
+		RephraseResult   string   `json:"rephraser_result"`
+		RequestID        string   `json:"request_id"`
+		// Other fields are omitted for brevity
+	} `json:"data"`
+	RequestID   string `json:"request_id"`
+	Status      int    `json:"status"`
+	TimeElapsed string `json:"time_elapsed"`
+}
+
+func responseJimeng2OpenAIImage(_ *gin.Context, response *ImageResponse, info *relaycommon.RelayInfo) *dto.ImageResponse {
+	imageResponse := dto.ImageResponse{
+		Created: info.StartTime.Unix(),
+	}
+
+	for _, base64Data := range response.Data.BinaryDataBase64 {
+		imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+			B64Json: base64Data,
+		})
+	}
+	for _, imageUrl := range response.Data.ImageUrls {
+		imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+			Url: imageUrl,
+		})
+	}
+
+	return &imageResponse
+}
+
+// jimengImageHandler handles the Jimeng image generation response
+func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	var jimengResponse ImageResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	common.CloseResponseBodyGracefully(resp)
+
+	err = json.Unmarshal(responseBody, &jimengResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	// Check if the response indicates an error
+	if jimengResponse.Code != 10000 {
+		return &dto.OpenAIErrorWithStatusCode{
+			Error: dto.OpenAIError{
+				Message: jimengResponse.Message,
+				Type:    "jimeng_error",
+				Param:   "",
+				Code:    fmt.Sprintf("%d", jimengResponse.Code),
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+
+	// Convert Jimeng response to OpenAI format
+	fullTextResponse := responseJimeng2OpenAIImage(c, &jimengResponse, info)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	return nil, &dto.Usage{}
+}

+ 176 - 0
relay/channel/jimeng/sign.go

@@ -0,0 +1,176 @@
+package jimeng
+
+import (
+	"bytes"
+	"crypto/hmac"
+	"crypto/sha256"
+	"encoding/hex"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"net/url"
+	"one-api/common"
+	"sort"
+	"strings"
+	"time"
+)
+
+// SignRequestForJimeng 对即梦 API 请求进行签名,支持 http.Request 或 header+url+body 方式
+//func SignRequestForJimeng(req *http.Request, accessKey, secretKey string) error {
+//	var bodyBytes []byte
+//	var err error
+//
+//	if req.Body != nil {
+//		bodyBytes, err = io.ReadAll(req.Body)
+//		if err != nil {
+//			return fmt.Errorf("read request body failed: %w", err)
+//		}
+//		_ = req.Body.Close()
+//		req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // rewind
+//	} else {
+//		bodyBytes = []byte{}
+//	}
+//
+//	return signJimengHeaders(&req.Header, req.Method, req.URL, bodyBytes, accessKey, secretKey)
+//}
+
+const HexPayloadHashKey = "HexPayloadHash"
+
+func SetPayloadHash(c *gin.Context, req any) error {
+	body, err := json.Marshal(req)
+	if err != nil {
+		return err
+	}
+	common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
+	payloadHash := sha256.Sum256(body)
+	hexPayloadHash := hex.EncodeToString(payloadHash[:])
+	c.Set(HexPayloadHashKey, hexPayloadHash)
+	return nil
+}
+func getPayloadHash(c *gin.Context) string {
+	return c.GetString(HexPayloadHashKey)
+}
+
+func Sign(c *gin.Context, req *http.Request, apiKey string) error {
+	header := req.Header
+
+	var bodyBytes []byte
+	var err error
+
+	if req.Body != nil {
+		bodyBytes, err = io.ReadAll(req.Body)
+		if err != nil {
+			return err
+		}
+		_ = req.Body.Close()
+		req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
+	}
+
+	payloadHash := sha256.Sum256(bodyBytes)
+	hexPayloadHash := hex.EncodeToString(payloadHash[:])
+
+	method := c.Request.Method
+	u := req.URL
+	keyParts := strings.Split(apiKey, "|")
+	if len(keyParts) != 2 {
+		return errors.New("invalid api key format for jimeng: expected 'ak|sk'")
+	}
+	accessKey := strings.TrimSpace(keyParts[0])
+	secretKey := strings.TrimSpace(keyParts[1])
+	t := time.Now().UTC()
+	xDate := t.Format("20060102T150405Z")
+	shortDate := t.Format("20060102")
+
+	host := u.Host
+	header.Set("Host", host)
+	header.Set("X-Date", xDate)
+	header.Set("X-Content-Sha256", hexPayloadHash)
+
+	// Sort and encode query parameters to create canonical query string
+	queryParams := u.Query()
+	sortedKeys := make([]string, 0, len(queryParams))
+	for k := range queryParams {
+		sortedKeys = append(sortedKeys, k)
+	}
+	sort.Strings(sortedKeys)
+	var queryParts []string
+	for _, k := range sortedKeys {
+		values := queryParams[k]
+		sort.Strings(values)
+		for _, v := range values {
+			queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
+		}
+	}
+	canonicalQueryString := strings.Join(queryParts, "&")
+
+	headersToSign := map[string]string{
+		"host":             host,
+		"x-date":           xDate,
+		"x-content-sha256": hexPayloadHash,
+	}
+	if header.Get("Content-Type") == "" {
+		header.Set("Content-Type", "application/json")
+	}
+	headersToSign["content-type"] = header.Get("Content-Type")
+
+	var signedHeaderKeys []string
+	for k := range headersToSign {
+		signedHeaderKeys = append(signedHeaderKeys, k)
+	}
+	sort.Strings(signedHeaderKeys)
+
+	var canonicalHeaders strings.Builder
+	for _, k := range signedHeaderKeys {
+		canonicalHeaders.WriteString(k)
+		canonicalHeaders.WriteString(":")
+		canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
+		canonicalHeaders.WriteString("\n")
+	}
+	signedHeaders := strings.Join(signedHeaderKeys, ";")
+
+	canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
+		method,
+		u.Path,
+		canonicalQueryString,
+		canonicalHeaders.String(),
+		signedHeaders,
+		hexPayloadHash,
+	)
+
+	hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
+	hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
+
+	region := "cn-north-1"
+	serviceName := "cv"
+	credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
+	stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
+		xDate,
+		credentialScope,
+		hexHashedCanonicalRequest,
+	)
+
+	kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
+	kRegion := hmacSHA256(kDate, []byte(region))
+	kService := hmacSHA256(kRegion, []byte(serviceName))
+	kSigning := hmacSHA256(kService, []byte("request"))
+	signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
+
+	authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
+		accessKey,
+		credentialScope,
+		signedHeaders,
+		signature,
+	)
+	header.Set("Authorization", authorization)
+	return nil
+}
+
+// hmacSHA256 计算 HMAC-SHA256
+func hmacSHA256(key []byte, data []byte) []byte {
+	h := hmac.New(sha256.New, key)
+	h.Write(data)
+	return h.Sum(nil)
+}

+ 5 - 2
relay/relay_adaptor.go

@@ -15,6 +15,7 @@ import (
 	"one-api/relay/channel/deepseek"
 	"one-api/relay/channel/dify"
 	"one-api/relay/channel/gemini"
+	"one-api/relay/channel/jimeng"
 	"one-api/relay/channel/jina"
 	"one-api/relay/channel/mistral"
 	"one-api/relay/channel/mokaai"
@@ -23,7 +24,7 @@ import (
 	"one-api/relay/channel/palm"
 	"one-api/relay/channel/perplexity"
 	"one-api/relay/channel/siliconflow"
-	"one-api/relay/channel/task/jimeng"
+	taskjimeng "one-api/relay/channel/task/jimeng"
 	"one-api/relay/channel/task/kling"
 	"one-api/relay/channel/task/suno"
 	"one-api/relay/channel/tencent"
@@ -93,6 +94,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
 		return &xai.Adaptor{}
 	case constant.APITypeCoze:
 		return &coze.Adaptor{}
+	case constant.APITypeJimeng:
+		return &jimeng.Adaptor{}
 	}
 	return nil
 }
@@ -106,7 +109,7 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
 	case commonconstant.TaskPlatformKling:
 		return &kling.TaskAdaptor{}
 	case commonconstant.TaskPlatformJimeng:
-		return &jimeng.TaskAdaptor{}
+		return &taskjimeng.TaskAdaptor{}
 	}
 	return nil
 }