Переглянути джерело

feat: Improve image download and validation in GetImageFromUrl

[email protected] 10 місяців тому
батько
коміт
6ecfb81cbc
1 змінених файлів з 44 додано та 11 видалено
  1. 44 11
      service/image.go

+ 44 - 11
service/image.go

@@ -7,7 +7,9 @@ import (
 	"fmt"
 	"image"
 	"io"
+	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"strings"
 
 	"golang.org/x/image/webp"
@@ -23,7 +25,7 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
 	decodedData, err := base64.StdEncoding.DecodeString(base64String)
 	if err != nil {
 		fmt.Println("Error: Failed to decode base64 string")
-		return image.Config{}, "", "", err
+		return image.Config{}, "", "", fmt.Errorf("failed to decode base64 string: %s", err.Error())
 	}
 
 	// 创建一个bytes.Buffer用于存储解码后的数据
@@ -61,20 +63,51 @@ func DecodeBase64FileData(base64String string) (string, string, error) {
 func GetImageFromUrl(url string) (mimeType string, data string, err error) {
 	resp, err := DoDownloadRequest(url)
 	if err != nil {
-		return "", "", err
-	}
-	if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
-		return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
+		return "", "", fmt.Errorf("failed to download image: %w", err)
 	}
 	defer resp.Body.Close()
-	buffer := bytes.NewBuffer(nil)
-	_, err = buffer.ReadFrom(resp.Body)
+
+	// Check HTTP status code
+	if resp.StatusCode != http.StatusOK {
+		return "", "", fmt.Errorf("failed to download image: HTTP %d", resp.StatusCode)
+	}
+
+	contentType := resp.Header.Get("Content-Type")
+	if contentType != "application/octet-stream" && !strings.HasPrefix(contentType, "image/") {
+		return "", "", fmt.Errorf("invalid content type: %s, required image/*", contentType)
+	}
+	maxImageSize := int64(constant.MaxFileDownloadMB * 1024 * 1024)
+
+	// Check Content-Length if available
+	if resp.ContentLength > maxImageSize {
+		return "", "", fmt.Errorf("image size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxImageSize)
+	}
+
+	// Use LimitReader to prevent reading oversized images
+	limitReader := io.LimitReader(resp.Body, maxImageSize)
+	buffer := &bytes.Buffer{}
+
+	written, err := io.Copy(buffer, limitReader)
 	if err != nil {
-		return
+		return "", "", fmt.Errorf("failed to read image data: %w", err)
+	}
+	if written >= maxImageSize {
+		return "", "", fmt.Errorf("image size exceeds maximum allowed size of %d bytes", maxImageSize)
 	}
-	mimeType = resp.Header.Get("Content-Type")
+
 	data = base64.StdEncoding.EncodeToString(buffer.Bytes())
-	return
+	mimeType = contentType
+
+	// Handle application/octet-stream type
+	if mimeType == "application/octet-stream" {
+		_, format, _, err := DecodeBase64ImageData(data)
+		if err != nil {
+			return "", "", err
+		}
+		mimeType = "image/" + format
+	}
+
+	return mimeType, data, nil
 }
 
 func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
@@ -92,7 +125,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
 
 	mimeType := response.Header.Get("Content-Type")
 
-	if !strings.HasPrefix(mimeType, "image/") {
+	if mimeType != "application/octet-stream" && !strings.HasPrefix(mimeType, "image/") {
 		return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
 	}