Ver Fonte

fix(relay): improve error handling for unsupported MIME types by sanitizing URLs

CaIon há 6 meses atrás
pai
commit
0199896d9a
3 ficheiros alterados com 18 adições e 7 exclusões
  1. 4 1
      relay/channel/gemini/relay-gemini.go
  2. 10 6
      service/error.go
  3. 4 0
      service/file_decoder.go

+ 4 - 1
relay/channel/gemini/relay-gemini.go

@@ -324,7 +324,10 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 
 					// 校验 MimeType 是否在 Gemini 支持的白名单中
 					if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
-						return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
+						url := part.GetImageMedia().Url
+						url = strings.TrimPrefix(url, "http://")
+						url = strings.TrimPrefix(url, "https://")
+						return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
 					}
 
 					parts = append(parts, GeminiPart{

+ 10 - 6
service/error.go

@@ -29,9 +29,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
 func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
 	text := err.Error()
 	lowerText := strings.ToLower(text)
-	if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
-		common.SysLog(fmt.Sprintf("error: %s", text))
-		text = "请求上游地址失败"
+	if !strings.HasPrefix(lowerText, "get file base64 from url") {
+		if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+			common.SysLog(fmt.Sprintf("error: %s", text))
+			text = "请求上游地址失败"
+		}
 	}
 	openAIError := dto.OpenAIError{
 		Message: text,
@@ -53,9 +55,11 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
 func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
 	text := err.Error()
 	lowerText := strings.ToLower(text)
-	if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
-		common.SysLog(fmt.Sprintf("error: %s", text))
-		text = "请求上游地址失败"
+	if !strings.HasPrefix(lowerText, "get file base64 from url") {
+		if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+			common.SysLog(fmt.Sprintf("error: %s", text))
+			text = "请求上游地址失败"
+		}
 	}
 	claudeError := dto.ClaudeError{
 		Message: text,

+ 4 - 0
service/file_decoder.go

@@ -33,6 +33,10 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
 	base64Data := base64.StdEncoding.EncodeToString(fileBytes)
 
 	mimeType := resp.Header.Get("Content-Type")
+	if len(strings.Split(mimeType, ";")) > 1 {
+		// If Content-Type has parameters, take the first part
+		mimeType = strings.Split(mimeType, ";")[0]
+	}
 	if mimeType == "application/octet-stream" {
 		if common.DebugEnabled {
 			println("MIME type is application/octet-stream, trying to guess from URL or filename")