|
|
@@ -2,10 +2,18 @@ package controller
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
+ "github.com/chai2010/webp"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"github.com/pkoukk/tiktoken-go"
|
|
|
+ "image"
|
|
|
+ _ "image/gif"
|
|
|
+ _ "image/jpeg"
|
|
|
+ _ "image/png"
|
|
|
"io"
|
|
|
+ "log"
|
|
|
+ "math"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"strconv"
|
|
|
@@ -63,6 +71,64 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|
|
return len(tokenEncoder.Encode(text, nil, nil))
|
|
|
}
|
|
|
|
|
|
+func getImageToken(imageUrl MessageImageUrl) (int, error) {
|
|
|
+ if imageUrl.Detail == "low" {
|
|
|
+ return 85, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ response, err := http.Get(imageUrl.Url)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println("Error: Failed to get the URL")
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ defer response.Body.Close()
|
|
|
+
|
|
|
+ // 限制读取的字节数,防止下载整个图片
|
|
|
+ limitReader := io.LimitReader(response.Body, 8192)
|
|
|
+
|
|
|
+ // 读取图片的头部信息来获取图片尺寸
|
|
|
+ config, _, err := image.DecodeConfig(limitReader)
|
|
|
+ if err != nil {
|
|
|
+ common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
|
|
+ config, err = webp.DecodeConfig(limitReader)
|
|
|
+ if err != nil {
|
|
|
+ common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if config.Width == 0 || config.Height == 0 {
|
|
|
+ return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error()))
|
|
|
+ }
|
|
|
+ if config.Width < 512 && config.Height < 512 {
|
|
|
+ if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
|
|
+ return 85, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ shortSide := config.Width
|
|
|
+ otherSide := config.Height
|
|
|
+ log.Printf("width: %d, height: %d", config.Width, config.Height)
|
|
|
+ // 缩放倍数
|
|
|
+ scale := 1.0
|
|
|
+ if config.Height < shortSide {
|
|
|
+ shortSide = config.Height
|
|
|
+ otherSide = config.Width
|
|
|
+ }
|
|
|
+
|
|
|
+ // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768
|
|
|
+ if shortSide > 768 {
|
|
|
+ scale = float64(shortSide) / 768
|
|
|
+ shortSide = 768
|
|
|
+ }
|
|
|
+ // 将另一边按照相同的比例缩小,向上取整
|
|
|
+ otherSide = int(math.Ceil(float64(otherSide) / scale))
|
|
|
+ log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
|
|
|
+ // 计算图片的token数量(边的长度除以512,向上取整)
|
|
|
+ tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
|
|
|
+ log.Printf("tiles: %d", tiles)
|
|
|
+ return tiles*170 + 85, nil
|
|
|
+}
|
|
|
+
|
|
|
func countTokenMessages(messages []Message, model string) (int, error) {
|
|
|
//recover when panic
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
@@ -100,8 +166,12 @@ func countTokenMessages(messages []Message, model string) (int, error) {
|
|
|
} else {
|
|
|
for _, m := range arrayContent {
|
|
|
if m.Type == "image_url" {
|
|
|
- //TODO: getImageToken
|
|
|
- tokenNum += 1000
|
|
|
+ imageTokenNum, err := getImageToken(m.ImageUrl)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ tokenNum += imageTokenNum
|
|
|
+ log.Printf("image token num: %d", imageTokenNum)
|
|
|
} else {
|
|
|
tokenNum += getTokenNum(tokenEncoder, m.Text)
|
|
|
}
|