Procházet zdrojové kódy

feat: support gemini tool calling (close #368)

CalciumIon před 1 rokem
rodič
revize
f96291a25a

+ 10 - 1
dto/text_request.go

@@ -148,7 +148,7 @@ func (m Message) ParseContent() []MediaMessage {
 					if ok {
 						subObj["detail"] = detail.(string)
 					} else {
-						subObj["detail"] = "auto"
+						subObj["detail"] = "high"
 					}
 					contentList = append(contentList, MediaMessage{
 						Type: ContentTypeImageURL,
@@ -157,7 +157,16 @@ func (m Message) ParseContent() []MediaMessage {
 							Detail: subObj["detail"].(string),
 						},
 					})
+				} else if url, ok := contentMap["image_url"].(string); ok {
+					contentList = append(contentList, MediaMessage{
+						Type: ContentTypeImageURL,
+						ImageUrl: MessageImageUrl{
+							Url:    url,
+							Detail: "high",
+						},
+					})
 				}
+
 			}
 		}
 		return contentList

+ 1 - 1
relay/channel/gemini/adaptor.go

@@ -47,7 +47,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 
 	action := "generateContent"
 	if info.IsStream {
-		action = "streamGenerateContent"
+		action = "streamGenerateContent?alt=sse"
 	}
 	return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
 }

+ 8 - 2
relay/channel/gemini/dto.go

@@ -12,9 +12,15 @@ type GeminiInlineData struct {
 	Data     string `json:"data"`
 }
 
+type FunctionCall struct {
+	FunctionName string `json:"name"`
+	Arguments    any    `json:"args"`
+}
+
 type GeminiPart struct {
-	Text       string            `json:"text,omitempty"`
-	InlineData *GeminiInlineData `json:"inlineData,omitempty"`
+	Text         string            `json:"text,omitempty"`
+	InlineData   *GeminiInlineData `json:"inlineData,omitempty"`
+	FunctionCall *FunctionCall     `json:"functionCall,omitempty"`
 }
 
 type GeminiChatContent struct {

+ 84 - 89
relay/channel/gemini/relay-gemini.go

@@ -4,18 +4,14 @@ import (
 	"bufio"
 	"encoding/json"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"io"
-	"log"
 	"net/http"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
-	"time"
-
-	"github.com/gin-gonic/gin"
 )
 
 // Setting safety to the lowest possible values since Gemini is already powerless enough
@@ -46,7 +42,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
 			MaxOutputTokens: textRequest.MaxTokens,
 		},
 	}
-	if textRequest.Functions != nil {
+	if textRequest.Tools != nil {
+		functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
+		for _, tool := range textRequest.Tools {
+			functions = append(functions, tool.Function)
+		}
+		geminiRequest.Tools = []GeminiChatTools{
+			{
+				FunctionDeclarations: functions,
+			},
+		}
+	} else if textRequest.Functions != nil {
 		geminiRequest.Tools = []GeminiChatTools{
 			{
 				FunctionDeclarations: textRequest.Functions,
@@ -126,6 +132,30 @@ func (g *GeminiChatResponse) GetResponseText() string {
 	return ""
 }
 
+func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
+	var toolCalls []dto.ToolCall
+
+	item := candidate.Content.Parts[0]
+	if item.FunctionCall == nil {
+		return toolCalls
+	}
+	argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
+	if err != nil {
+		//common.SysError("getToolCalls failed: " + err.Error())
+		return toolCalls
+	}
+	toolCall := dto.ToolCall{
+		ID:   fmt.Sprintf("call_%s", common.GetUUID()),
+		Type: "function",
+		Function: dto.FunctionCall{
+			Arguments: string(argsBytes),
+			Name:      item.FunctionCall.FunctionName,
+		},
+	}
+	toolCalls = append(toolCalls, toolCall)
+	return toolCalls
+}
+
 func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
 	fullTextResponse := dto.OpenAITextResponse{
 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -144,8 +174,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 			FinishReason: relaycommon.StopFinishReason,
 		}
 		if len(candidate.Content.Parts) > 0 {
-			content, _ = json.Marshal(candidate.Content.Parts[0].Text)
-			choice.Message.Content = content
+			if candidate.Content.Parts[0].FunctionCall != nil {
+				choice.Message.ToolCalls = getToolCalls(&candidate)
+			} else {
+				choice.Message.SetStringContent(candidate.Content.Parts[0].Text)
+			}
 		}
 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
 	}
@@ -154,7 +187,17 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 
 func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
 	var choice dto.ChatCompletionsStreamResponseChoice
-	choice.Delta.SetContentString(geminiResponse.GetResponseText())
+	//choice.Delta.SetContentString(geminiResponse.GetResponseText())
+	if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
+		respFirst := geminiResponse.Candidates[0].Content.Parts[0]
+		if respFirst.FunctionCall != nil {
+			// function response
+			choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
+		} else {
+			// text response
+			choice.Delta.SetContentString(respFirst.Text)
+		}
+	}
 	choice.FinishReason = &relaycommon.StopFinishReason
 	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
@@ -165,92 +208,47 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
 
 func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseText := ""
-	responseJson := ""
 	id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 	createAt := common.GetTimestamp()
 	var usage = &dto.Usage{}
-	dataChan := make(chan string, 5)
-	stopChan := make(chan bool, 2)
 	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-		if atEOF && len(data) == 0 {
-			return 0, nil, nil
-		}
-		if i := strings.Index(string(data), "\n"); i >= 0 {
-			return i + 1, data[0:i], nil
+	scanner.Split(bufio.ScanLines)
+
+	service.SetEventStreamHeaders(c)
+	for scanner.Scan() {
+		data := scanner.Text()
+		info.SetFirstResponseTime()
+		data = strings.TrimSpace(data)
+		if !strings.HasPrefix(data, "data: ") {
+			continue
 		}
-		if atEOF {
-			return len(data), data, nil
+		data = strings.TrimPrefix(data, "data: ")
+		data = strings.TrimSuffix(data, "\"")
+		var geminiResponse GeminiChatResponse
+		err := json.Unmarshal([]byte(data), &geminiResponse)
+		if err != nil {
+			common.LogError(c, "error unmarshalling stream response: "+err.Error())
+			continue
 		}
-		return 0, nil, nil
-	})
-	go func() {
-		for scanner.Scan() {
-			data := scanner.Text()
-			responseJson += data
-			data = strings.TrimSpace(data)
-			if !strings.HasPrefix(data, "\"text\": \"") {
-				continue
-			}
-			data = strings.TrimPrefix(data, "\"text\": \"")
-			data = strings.TrimSuffix(data, "\"")
-			if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
-				// send data timeout, stop the stream
-				common.LogError(c, "send data timeout, stop the stream")
-				break
-			}
+
+		response := streamResponseGeminiChat2OpenAI(&geminiResponse)
+		if response == nil {
+			continue
 		}
-		stopChan <- true
-	}()
-	isFirst := true
-	service.SetEventStreamHeaders(c)
-	c.Stream(func(w io.Writer) bool {
-		select {
-		case data := <-dataChan:
-			if isFirst {
-				isFirst = false
-				info.FirstResponseTime = time.Now()
-			}
-			// this is used to prevent annoying \ related format bug
-			data = fmt.Sprintf("{\"content\": \"%s\"}", data)
-			type dummyStruct struct {
-				Content string `json:"content"`
-			}
-			var dummy dummyStruct
-			err := json.Unmarshal([]byte(data), &dummy)
-			responseText += dummy.Content
-			var choice dto.ChatCompletionsStreamResponseChoice
-			choice.Delta.SetContentString(dummy.Content)
-			response := dto.ChatCompletionsStreamResponse{
-				Id:      id,
-				Object:  "chat.completion.chunk",
-				Created: createAt,
-				Model:   info.UpstreamModelName,
-				Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
-			}
-			jsonResponse, err := json.Marshal(response)
-			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
-				return true
-			}
-			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
-			return true
-		case <-stopChan:
-			return false
+		response.Id = id
+		response.Created = createAt
+		responseText += response.Choices[0].Delta.GetContentString()
+		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
+			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
+			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
 		}
-	})
-	var geminiChatResponses []GeminiChatResponse
-	err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
-	if err != nil {
-		log.Printf("cannot get gemini usage: %s", err.Error())
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
-	} else {
-		for _, response := range geminiChatResponses {
-			usage.PromptTokens = response.UsageMetadata.PromptTokenCount
-			usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
+		err = service.ObjectData(c, response)
+		if err != nil {
+			common.LogError(c, err.Error())
 		}
-		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	}
+	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+
 	if info.ShouldIncludeUsage {
 		response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
 		err := service.ObjectData(c, response)
@@ -259,10 +257,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 		}
 	}
 	service.Done(c)
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
-	}
+	resp.Body.Close()
 	return nil, usage
 }