Browse Source

feat: support dify upload image file

[email protected] 10 months ago
parent
commit
dd393cd0d9

+ 9 - 0
controller/image.go

@@ -0,0 +1,9 @@
+package controller
+
+import (
+	"github.com/gin-gonic/gin"
+)
+
+func GetImage(c *gin.Context) {
+
+}

+ 31 - 23
dto/openai_request.go

@@ -113,9 +113,21 @@ type MediaContent struct {
 	InputAudio any    `json:"input_audio,omitempty"`
 }
 
+func (m *MediaContent) GetImageMedia() *MessageImageUrl {
+	if m.ImageUrl != nil {
+		return m.ImageUrl.(*MessageImageUrl)
+	}
+	return nil
+}
+
 type MessageImageUrl struct {
-	Url    string `json:"url"`
-	Detail string `json:"detail"`
+	Url      string `json:"url"`
+	Detail   string `json:"detail"`
+	MimeType string
+}
+
+func (m *MessageImageUrl) IsRemoteImage() bool {
+	return strings.HasPrefix(m.Url, "http")
 }
 
 type MessageInputAudio struct {
@@ -244,43 +256,39 @@ func (m *Message) ParseContent() []MediaContent {
 
 			case ContentTypeImageURL:
 				imageUrl := contentItem["image_url"]
+				temp := &MessageImageUrl{
+					Detail: "high",
+				}
 				switch v := imageUrl.(type) {
 				case string:
-					contentList = append(contentList, MediaContent{
-						Type: ContentTypeImageURL,
-						ImageUrl: MessageImageUrl{
-							Url:    v,
-							Detail: "high",
-						},
-					})
+					temp.Url = v
 				case map[string]interface{}:
 					url, ok1 := v["url"].(string)
 					detail, ok2 := v["detail"].(string)
-					if !ok2 {
-						detail = "high"
+					if ok2 {
+						temp.Detail = detail
 					}
 					if ok1 {
-						contentList = append(contentList, MediaContent{
-							Type: ContentTypeImageURL,
-							ImageUrl: MessageImageUrl{
-								Url:    url,
-								Detail: detail,
-							},
-						})
+						temp.Url = url
 					}
 				}
+				contentList = append(contentList, MediaContent{
+					Type:     ContentTypeImageURL,
+					ImageUrl: temp,
+				})
 
 			case ContentTypeInputAudio:
 				if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
 					data, ok1 := audioData["data"].(string)
 					format, ok2 := audioData["format"].(string)
 					if ok1 && ok2 {
+						temp := &MessageInputAudio{
+							Data:   data,
+							Format: format,
+						}
 						contentList = append(contentList, MediaContent{
-							Type: ContentTypeInputAudio,
-							InputAudio: MessageInputAudio{
-								Data:   data,
-								Format: format,
-							},
+							Type:       ContentTypeInputAudio,
+							InputAudio: temp,
 						})
 					}
 				}

+ 1 - 1
relay/channel/claude/relay-claude.go

@@ -414,7 +414,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
 				// 加密的不管, 只输出明文的推理过程
 				thinkingContent = message.Thinking
 			case "text":
-				responseText = *message.Text
+				responseText = message.GetText()
 			}
 		}
 	}

+ 1 - 1
relay/channel/dify/adaptor.go

@@ -74,7 +74,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	return requestOpenAI2Dify(*request), nil
+	return requestOpenAI2Dify(c, info, *request), nil
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

+ 10 - 0
relay/channel/dify/dto.go

@@ -8,6 +8,14 @@ type DifyChatRequest struct {
 	ResponseMode     string                 `json:"response_mode"`
 	User             string                 `json:"user"`
 	AutoGenerateName bool                   `json:"auto_generate_name"`
+	Files            []DifyFile             `json:"files"`
+}
+
+type DifyFile struct {
+	Type         string `json:"type"`
+	TransferMode string `json:"transfer_mode"`
+	URL          string `json:"url,omitempty"`
+	UploadFileId string `json:"upload_file_id,omitempty"`
 }
 
 type DifyMetaData struct {
@@ -17,6 +25,8 @@ type DifyMetaData struct {
 type DifyData struct {
 	WorkflowId string `json:"workflow_id"`
 	NodeId     string `json:"node_id"`
+	NodeType   string `json:"node_type"`
+	Status     string `json:"status"`
 }
 
 type DifyChatCompletionResponse struct {

+ 175 - 31
relay/channel/dify/relay-dify.go

@@ -2,9 +2,12 @@ package dify
 
 import (
 	"bufio"
+	"bytes"
+	"encoding/base64"
 	"encoding/json"
-	"github.com/gin-gonic/gin"
+	"fmt"
 	"io"
+	"mime/multipart"
 	"net/http"
 	"one-api/common"
 	"one-api/constant"
@@ -12,35 +15,163 @@ import (
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/service"
+	"os"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
-func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
-	content := ""
+func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
+	uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
+	switch media.Type {
+	case dto.ContentTypeImageURL:
+		// Decode base64 data
+		imageMedia := media.GetImageMedia()
+		base64Data := imageMedia.Url
+		// Remove base64 prefix if exists (e.g., "data:image/jpeg;base64,")
+		if idx := strings.Index(base64Data, ","); idx != -1 {
+			base64Data = base64Data[idx+1:]
+		}
+
+		// Decode base64 string
+		decodedData, err := base64.StdEncoding.DecodeString(base64Data)
+		if err != nil {
+			common.SysError("failed to decode base64: " + err.Error())
+			return nil
+		}
+
+		// Create temporary file
+		tempFile, err := os.CreateTemp("", "dify-upload-*")
+		if err != nil {
+			common.SysError("failed to create temp file: " + err.Error())
+			return nil
+		}
+		defer tempFile.Close()
+		defer os.Remove(tempFile.Name())
+
+		// Write decoded data to temp file
+		if _, err := tempFile.Write(decodedData); err != nil {
+			common.SysError("failed to write to temp file: " + err.Error())
+			return nil
+		}
+
+		// Create multipart form
+		body := &bytes.Buffer{}
+		writer := multipart.NewWriter(body)
+
+		// Add user field
+		if err := writer.WriteField("user", user); err != nil {
+			common.SysError("failed to add user field: " + err.Error())
+			return nil
+		}
+
+		// Create form file with proper mime type
+		mimeType := imageMedia.MimeType
+		if mimeType == "" {
+			mimeType = "image/jpeg" // default mime type
+		}
+
+		// Create form file
+		part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
+		if err != nil {
+			common.SysError("failed to create form file: " + err.Error())
+			return nil
+		}
+
+		// Copy file content to form
+		if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
+			common.SysError("failed to copy file content: " + err.Error())
+			return nil
+		}
+		writer.Close()
+
+		// Create HTTP request
+		req, err := http.NewRequest("POST", uploadUrl, body)
+		if err != nil {
+			common.SysError("failed to create request: " + err.Error())
+			return nil
+		}
+
+		req.Header.Set("Content-Type", writer.FormDataContentType())
+		req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+
+		// Send request
+		client := service.GetImpatientHttpClient()
+		resp, err := client.Do(req)
+		if err != nil {
+			common.SysError("failed to send request: " + err.Error())
+			return nil
+		}
+		defer resp.Body.Close()
+
+		// Parse response
+		var result struct {
+			Id string `json:"id"`
+		}
+		if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+			common.SysError("failed to decode response: " + err.Error())
+			return nil
+		}
+
+		return &DifyFile{
+			UploadFileId: result.Id,
+			Type:         "image",
+			TransferMode: "local_file",
+		}
+	}
+	return nil
+}
+
+func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) *DifyChatRequest {
+	difyReq := DifyChatRequest{
+		Inputs:           make(map[string]interface{}),
+		AutoGenerateName: false,
+	}
+
+	user := request.User
+	if user == "" {
+		user = helper.GetResponseID(c)
+	}
+	difyReq.User = user
+
+	files := make([]DifyFile, 0)
+	var content strings.Builder
 	for _, message := range request.Messages {
 		if message.Role == "system" {
-			content += "SYSTEM: \n" + message.StringContent() + "\n"
+			content.WriteString("SYSTEM: \n" + message.StringContent() + "\n")
 		} else if message.Role == "assistant" {
-			content += "ASSISTANT: \n" + message.StringContent() + "\n"
+			content.WriteString("ASSISTANT: \n" + message.StringContent() + "\n")
 		} else {
-			content += "USER: \n" + message.StringContent() + "\n"
+			parseContent := message.ParseContent()
+			for _, mediaContent := range parseContent {
+				switch mediaContent.Type {
+				case dto.ContentTypeText:
+					content.WriteString("USER: \n" + mediaContent.Text + "\n")
+				case dto.ContentTypeImageURL:
+					media := mediaContent.GetImageMedia()
+					var file *DifyFile
+					if media.IsRemoteImage() {
+						file.Type = media.MimeType
+						file.TransferMode = "remote_url"
+						file.URL = media.Url
+					} else {
+						file = uploadDifyFile(c, info, difyReq.User, mediaContent)
+					}
+					if file != nil {
+						files = append(files, *file)
+					}
+				}
+			}
 		}
 	}
+	difyReq.Query = content.String()
+	difyReq.Files = files
 	mode := "blocking"
 	if request.Stream {
 		mode = "streaming"
 	}
-	user := request.User
-	if user == "" {
-		user = "api-user"
-	}
-	return &DifyChatRequest{
-		Inputs:           make(map[string]interface{}),
-		Query:            content,
-		ResponseMode:     mode,
-		User:             user,
-		AutoGenerateName: false,
-	}
+	difyReq.ResponseMode = mode
+	return &difyReq
 }
 
 func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
@@ -50,10 +181,22 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
 		Model:   "dify",
 	}
 	var choice dto.ChatCompletionsStreamResponseChoice
-	if constant.DifyDebug && difyResponse.Event == "workflow_started" {
-		choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
-	} else if constant.DifyDebug && difyResponse.Event == "node_started" {
-		choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
+	if strings.HasPrefix(difyResponse.Event, "workflow_") {
+		if constant.DifyDebug {
+			text := "Workflow: " + difyResponse.Data.WorkflowId
+			if difyResponse.Event == "workflow_finished" {
+				text += " " + difyResponse.Data.Status
+			}
+			choice.Delta.SetReasoningContent(text + "\n")
+		}
+	} else if strings.HasPrefix(difyResponse.Event, "node_") {
+		if constant.DifyDebug {
+			text := "Node: " + difyResponse.Data.NodeType
+			if difyResponse.Event == "node_finished" {
+				text += " " + difyResponse.Data.Status
+			}
+			choice.Delta.SetReasoningContent(text + "\n")
+		}
 	} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
 		choice.Delta.SetContentString(difyResponse.Answer)
 	}
@@ -66,38 +209,38 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	usage := &dto.Usage{}
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(bufio.ScanLines)
+	var nodeToken int
 
 	helper.SetEventStreamHeaders(c)
 
-	for scanner.Scan() {
-		data := scanner.Text()
-		if len(data) < 5 || !strings.HasPrefix(data, "data:") {
-			continue
-		}
-		data = strings.TrimPrefix(data, "data:")
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var difyResponse DifyChunkChatCompletionResponse
 		err := json.Unmarshal([]byte(data), &difyResponse)
 		if err != nil {
 			common.SysError("error unmarshalling stream response: " + err.Error())
-			continue
+			return true
 		}
 		var openaiResponse dto.ChatCompletionsStreamResponse
 		if difyResponse.Event == "message_end" {
 			usage = &difyResponse.MetaData.Usage
-			break
+			return false
 		} else if difyResponse.Event == "error" {
-			break
+			return false
 		} else {
 			openaiResponse = *streamResponseDify2OpenAI(difyResponse)
 			if len(openaiResponse.Choices) != 0 {
 				responseText += openaiResponse.Choices[0].Delta.GetContentString()
+				if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
+					nodeToken += 1
+				}
 			}
 		}
 		err = helper.ObjectData(c, openaiResponse)
 		if err != nil {
 			common.SysError(err.Error())
 		}
-	}
+		return true
+	})
 	if err := scanner.Err(); err != nil {
 		common.SysError("error reading stream: " + err.Error())
 	}
@@ -112,6 +255,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 		usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
 		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	}
+	usage.CompletionTokens += nodeToken
 	return nil, usage
 }
 

+ 8 - 12
service/token_counter.go

@@ -86,6 +86,9 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 }
 
 func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
+	if imageUrl == nil {
+		return 0, fmt.Errorf("image_url_is_nil")
+	}
 	baseTokens := 85
 	if model == "glm-4v" {
 		return 1047, nil
@@ -93,10 +96,10 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
 	if imageUrl.Detail == "low" {
 		return baseTokens, nil
 	}
-	// TODO: 非流模式下不计算图片token数量
 	if !constant.GetMediaTokenNotStream && !stream {
-		return 256, nil
+		return 3 * baseTokens, nil
 	}
+
 	// 同步One API的图片计费逻辑
 	if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
 		imageUrl.Detail = "high"
@@ -126,18 +129,11 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
 	if err != nil {
 		return 0, err
 	}
+	imageUrl.MimeType = format
 
 	if config.Width == 0 || config.Height == 0 {
 		return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url))
 	}
-	//// TODO: 适配官方auto计费
-	//if config.Width < 512 && config.Height < 512 {
-	//	if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
-	//		// 如果图片尺寸小于512,强制使用low
-	//		imageUrl.Detail = "low"
-	//		return 85, nil
-	//	}
-	//}
 
 	shortSide := config.Width
 	otherSide := config.Height
@@ -392,8 +388,8 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
 			arrayContent := message.ParseContent()
 			for _, m := range arrayContent {
 				if m.Type == dto.ContentTypeImageURL {
-					imageUrl := m.ImageUrl.(dto.MessageImageUrl)
-					imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
+					imageUrl := m.GetImageMedia()
+					imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
 					if err != nil {
 						return 0, err
 					}