소스 검색

fix image token calculate

CaIon 2 년 전
부모
커밋
e5c2524f15
3개의 변경된 파일75개의 추가작업 그리고 2개의 파일을 삭제
  1. 72 2
      controller/relay-utils.go
  2. 1 0
      go.mod
  3. 2 0
      go.sum

+ 72 - 2
controller/relay-utils.go

@@ -2,10 +2,18 @@ package controller
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
+	"github.com/chai2010/webp"
 	"github.com/gin-gonic/gin"
 	"github.com/pkoukk/tiktoken-go"
+	"image"
+	_ "image/gif"
+	_ "image/jpeg"
+	_ "image/png"
 	"io"
+	"log"
+	"math"
 	"net/http"
 	"one-api/common"
 	"strconv"
@@ -63,6 +71,64 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 	return len(tokenEncoder.Encode(text, nil, nil))
 }
 
+func getImageToken(imageUrl MessageImageUrl) (int, error) {
+	if imageUrl.Detail == "low" {
+		return 85, nil
+	}
+
+	response, err := http.Get(imageUrl.Url)
+	if err != nil {
+		fmt.Println("Error: Failed to get the URL")
+		return 0, err
+	}
+
+	defer response.Body.Close()
+
+	// 限制读取的字节数,防止下载整个图片
+	limitReader := io.LimitReader(response.Body, 8192)
+
+	// 读取图片的头部信息来获取图片尺寸
+	config, _, err := image.DecodeConfig(limitReader)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
+		config, err = webp.DecodeConfig(limitReader)
+		if err != nil {
+			common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
+		}
+	}
+	if config.Width == 0 || config.Height == 0 {
+		return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error()))
+	}
+	if config.Width < 512 && config.Height < 512 {
+		if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
+			return 85, nil
+		}
+	}
+
+	shortSide := config.Width
+	otherSide := config.Height
+	log.Printf("width: %d, height: %d", config.Width, config.Height)
+	// 缩放倍数
+	scale := 1.0
+	if config.Height < shortSide {
+		shortSide = config.Height
+		otherSide = config.Width
+	}
+
+	// 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768
+	if shortSide > 768 {
+		scale = float64(shortSide) / 768
+		shortSide = 768
+	}
+	// 将另一边按照相同的比例缩小,向上取整
+	otherSide = int(math.Ceil(float64(otherSide) / scale))
+	log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
+	// 计算图片的token数量(边的长度除以512,向上取整)
+	tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
+	log.Printf("tiles: %d", tiles)
+	return tiles*170 + 85, nil
+}
+
 func countTokenMessages(messages []Message, model string) (int, error) {
 	//recover when panic
 	tokenEncoder := getTokenEncoder(model)
@@ -100,8 +166,12 @@ func countTokenMessages(messages []Message, model string) (int, error) {
 		} else {
 			for _, m := range arrayContent {
 				if m.Type == "image_url" {
-					//TODO: getImageToken
-					tokenNum += 1000
+					imageTokenNum, err := getImageToken(m.ImageUrl)
+					if err != nil {
+						return 0, err
+					}
+					tokenNum += imageTokenNum
+					log.Printf("image token num: %d", imageTokenNum)
 				} else {
 					tokenNum += getTokenNum(tokenEncoder, m.Text)
 				}

+ 1 - 0
go.mod

@@ -4,6 +4,7 @@ module one-api
 go 1.18
 
 require (
+	github.com/chai2010/webp v1.1.1
 	github.com/gin-contrib/cors v1.4.0
 	github.com/gin-contrib/gzip v0.0.6
 	github.com/gin-contrib/sessions v0.0.5

+ 2 - 0
go.sum

@@ -3,6 +3,8 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s
 github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
 github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
+github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
 github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
 github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
 github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=