CalciumIon пре 1 година
родитељ
комит
c9100b219f
5 измењених фајлова са 256 додато и 74 уклоњено
  1. 7 5
      dto/dalle.go
  2. 31 23
      relay/channel/ali/adaptor.go
  3. 30 3
      relay/channel/ali/dto.go
  4. 177 0
      relay/channel/ali/image.go
  5. 11 43
      relay/channel/ali/text.go

+ 7 - 5
dto/dalle.go

@@ -12,9 +12,11 @@ type ImageRequest struct {
 }
 
 type ImageResponse struct {
-	Created int `json:"created"`
-	Data    []struct {
-		Url     string `json:"url"`
-		B64Json string `json:"b64_json"`
-	}
+	Data    []ImageData `json:"data"`
+	Created int64       `json:"created"`
+}
+type ImageData struct {
+	Url           string `json:"url"`
+	B64Json       string `json:"b64_json"`
+	RevisedPrompt string `json:"revised_prompt"`
 }

+ 31 - 23
relay/channel/ali/adaptor.go

@@ -8,6 +8,7 @@ import (
 	"net/http"
 	"one-api/dto"
 	"one-api/relay/channel"
+	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
 )
@@ -15,23 +16,18 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
-
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
-	if info.RelayMode == constant.RelayModeEmbeddings {
+	var fullRequestURL string
+	switch info.RelayMode {
+	case constant.RelayModeEmbeddings:
 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
+	case constant.RelayModeImagesGenerations:
+		fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
+	default:
+		fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
 	}
 	return fullRequestURL, nil
 }
@@ -57,13 +53,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 		baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
 		return baiduEmbeddingRequest, nil
 	default:
-		baiduRequest := requestOpenAI2Ali(*request)
-		return baiduRequest, nil
+		aliReq := requestOpenAI2Ali(*request)
+		return aliReq, nil
 	}
 }
 
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	aliRequest := oaiImage2Ali(request)
+	return aliRequest, nil
+}
+
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
-	return nil, nil
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@@ -71,14 +77,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
-	if info.IsStream {
-		err, usage = aliStreamHandler(c, resp)
-	} else {
-		switch info.RelayMode {
-		case constant.RelayModeEmbeddings:
-			err, usage = aliEmbeddingHandler(c, resp)
-		default:
-			err, usage = aliHandler(c, resp)
+	switch info.RelayMode {
+	case constant.RelayModeImagesGenerations:
+		err, usage = aliImageHandler(c, resp, info)
+	case constant.RelayModeEmbeddings:
+		err, usage = aliEmbeddingHandler(c, resp)
+	default:
+		if info.IsStream {
+			err, usage = openai.OpenaiStreamHandler(c, resp, info)
+		} else {
+			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 		}
 	}
 	return

+ 30 - 3
relay/channel/ali/dto.go

@@ -60,13 +60,40 @@ type AliUsage struct {
 	TotalTokens  int `json:"total_tokens"`
 }
 
+type TaskResult struct {
+	B64Image string `json:"b64_image,omitempty"`
+	Url      string `json:"url,omitempty"`
+	Code     string `json:"code,omitempty"`
+	Message  string `json:"message,omitempty"`
+}
+
 type AliOutput struct {
-	Text         string `json:"text"`
-	FinishReason string `json:"finish_reason"`
+	TaskId       string       `json:"task_id,omitempty"`
+	TaskStatus   string       `json:"task_status,omitempty"`
+	Text         string       `json:"text"`
+	FinishReason string       `json:"finish_reason"`
+	Message      string       `json:"message,omitempty"`
+	Code         string       `json:"code,omitempty"`
+	Results      []TaskResult `json:"results,omitempty"`
 }
 
-type AliChatResponse struct {
+type AliResponse struct {
 	Output AliOutput `json:"output"`
 	Usage  AliUsage  `json:"usage"`
 	AliError
 }
+
+type AliImageRequest struct {
+	Model string `json:"model"`
+	Input struct {
+		Prompt         string `json:"prompt"`
+		NegativePrompt string `json:"negative_prompt,omitempty"`
+	} `json:"input"`
+	Parameters struct {
+		Size  string `json:"size,omitempty"`
+		N     int    `json:"n,omitempty"`
+		Steps string `json:"steps,omitempty"`
+		Scale string `json:"scale,omitempty"`
+	} `json:"parameters,omitempty"`
+	ResponseFormat string `json:"response_format,omitempty"`
+}

+ 177 - 0
relay/channel/ali/image.go

@@ -0,0 +1,177 @@
+package ali
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+	"strings"
+	"time"
+)
+
+func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
+	var imageRequest AliImageRequest
+	imageRequest.Input.Prompt = request.Prompt
+	imageRequest.Model = request.Model
+	imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
+	imageRequest.Parameters.N = request.N
+	imageRequest.ResponseFormat = request.ResponseFormat
+
+	return &imageRequest
+}
+
+func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) {
+	url := fmt.Sprintf("/api/v1/tasks/%s", taskID)
+
+	var aliResponse AliResponse
+
+	req, err := http.NewRequest("GET", url, nil)
+	if err != nil {
+		return &aliResponse, err, nil
+	}
+
+	req.Header.Set("Authorization", "Bearer "+key)
+
+	client := &http.Client{}
+	resp, err := client.Do(req)
+	if err != nil {
+		common.SysError("updateTask client.Do err: " + err.Error())
+		return &aliResponse, err, nil
+	}
+	defer resp.Body.Close()
+
+	responseBody, err := io.ReadAll(resp.Body)
+
+	var response AliResponse
+	err = json.Unmarshal(responseBody, &response)
+	if err != nil {
+		common.SysError("updateTask NewDecoder err: " + err.Error())
+		return &aliResponse, err, nil
+	}
+
+	return &response, nil, responseBody
+}
+
+func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) {
+	waitSeconds := 3
+	step := 0
+	maxStep := 20
+
+	var taskResponse AliResponse
+	var responseBody []byte
+
+	for {
+		step++
+		rsp, err, body := updateTask(info, taskID, key)
+		responseBody = body
+		if err != nil {
+			return &taskResponse, responseBody, err
+		}
+
+		if rsp.Output.TaskStatus == "" {
+			return &taskResponse, responseBody, nil
+		}
+
+		switch rsp.Output.TaskStatus {
+		case "FAILED":
+			fallthrough
+		case "CANCELED":
+			fallthrough
+		case "SUCCEEDED":
+			fallthrough
+		case "UNKNOWN":
+			return rsp, responseBody, nil
+		}
+		if step >= maxStep {
+			break
+		}
+		time.Sleep(time.Duration(waitSeconds) * time.Second)
+	}
+
+	return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
+}
+
+func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
+	imageResponse := dto.ImageResponse{
+		Created: info.StartTime.Unix(),
+	}
+
+	for _, data := range response.Output.Results {
+		var b64Json string
+		if responseFormat == "b64_json" {
+			_, b64, err := service.GetImageFromUrl(data.Url)
+			if err != nil {
+				common.LogError(c, "get_image_data_failed: "+err.Error())
+				continue
+			}
+			b64Json = b64
+		} else {
+			b64Json = data.B64Image
+		}
+
+		imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+			Url:           data.Url,
+			B64Json:       b64Json,
+			RevisedPrompt: "",
+		})
+	}
+	return &imageResponse
+}
+
+func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	apiKey := c.Request.Header.Get("Authorization")
+	apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+	responseFormat := c.GetString("response_format")
+
+	var aliTaskResponse AliResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &aliTaskResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	if aliTaskResponse.Message != "" {
+		common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
+		return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
+	}
+
+	aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
+	}
+
+	if aliResponse.Output.TaskStatus != "SUCCEEDED" {
+		return &dto.OpenAIErrorWithStatusCode{
+			Error: dto.OpenAIError{
+				Message: aliResponse.Output.Message,
+				Type:    "ali_error",
+				Param:   "",
+				Code:    aliResponse.Output.Code,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+
+	fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, nil
+}

+ 11 - 43
relay/channel/ali/relay-ali.go → relay/channel/ali/text.go

@@ -16,34 +16,13 @@ import (
 
 const EnableSearchModelSuffix = "-internet"
 
-func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
-	messages := make([]AliMessage, 0, len(request.Messages))
-	//prompt := ""
-	for i := 0; i < len(request.Messages); i++ {
-		message := request.Messages[i]
-		messages = append(messages, AliMessage{
-			Content: message.StringContent(),
-			Role:    strings.ToLower(message.Role),
-		})
-	}
-	enableSearch := false
-	aliModel := request.Model
-	if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
-		enableSearch = true
-		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
-	}
-	return &AliChatRequest{
-		Model: request.Model,
-		Input: AliInput{
-			//Prompt:  prompt,
-			Messages: messages,
-		},
-		Parameters: AliParameters{
-			IncrementalOutput: request.Stream,
-			Seed:              uint64(request.Seed),
-			EnableSearch:      enableSearch,
-		},
+func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
+	if request.TopP >= 1 {
+		request.TopP = 0.999
+	} else if request.TopP <= 0 {
+		request.TopP = 0.001
 	}
+	return &request
 }
 
 func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
@@ -110,7 +89,7 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
 	return &openAIEmbeddingResponse
 }
 
-func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
+func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
 	content, _ := json.Marshal(response.Output.Text)
 	choice := dto.OpenAITextResponseChoice{
 		Index: 0,
@@ -134,7 +113,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
 	return &fullTextResponse
 }
 
-func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
+func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
 	var choice dto.ChatCompletionsStreamResponseChoice
 	choice.Delta.SetContentString(aliResponse.Output.Text)
 	if aliResponse.Output.FinishReason != "null" {
@@ -154,18 +133,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletions
 func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var usage dto.Usage
 	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-		if atEOF && len(data) == 0 {
-			return 0, nil, nil
-		}
-		if i := strings.Index(string(data), "\n"); i >= 0 {
-			return i + 1, data[0:i], nil
-		}
-		if atEOF {
-			return len(data), data, nil
-		}
-		return 0, nil, nil
-	})
+	scanner.Split(bufio.ScanLines)
 	dataChan := make(chan string)
 	stopChan := make(chan bool)
 	go func() {
@@ -187,7 +155,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
-			var aliResponse AliChatResponse
+			var aliResponse AliResponse
 			err := json.Unmarshal([]byte(data), &aliResponse)
 			if err != nil {
 				common.SysError("error unmarshalling stream response: " + err.Error())
@@ -221,7 +189,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
 }
 
 func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	var aliResponse AliChatResponse
+	var aliResponse AliResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil