token_counter.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "image"
  6. _ "image/gif"
  7. _ "image/jpeg"
  8. _ "image/png"
  9. "log"
  10. "math"
  11. "path/filepath"
  12. "strings"
  13. "unicode/utf8"
  14. "github.com/QuantumNous/new-api/common"
  15. "github.com/QuantumNous/new-api/constant"
  16. "github.com/QuantumNous/new-api/dto"
  17. relaycommon "github.com/QuantumNous/new-api/relay/common"
  18. constant2 "github.com/QuantumNous/new-api/relay/constant"
  19. "github.com/QuantumNous/new-api/types"
  20. "github.com/gin-gonic/gin"
  21. )
  22. func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
  23. if fileMeta == nil {
  24. return 0, fmt.Errorf("image_url_is_nil")
  25. }
  26. // Defaults for 4o/4.1/4.5 family unless overridden below
  27. baseTokens := 85
  28. tileTokens := 170
  29. // Model classification
  30. lowerModel := strings.ToLower(model)
  31. // Special cases from existing behavior
  32. if strings.HasPrefix(lowerModel, "glm-4") {
  33. return 1047, nil
  34. }
  35. // Patch-based models (32x32 patches, capped at 1536, with multiplier)
  36. isPatchBased := false
  37. multiplier := 1.0
  38. switch {
  39. case strings.Contains(lowerModel, "gpt-4.1-mini"):
  40. isPatchBased = true
  41. multiplier = 1.62
  42. case strings.Contains(lowerModel, "gpt-4.1-nano"):
  43. isPatchBased = true
  44. multiplier = 2.46
  45. case strings.HasPrefix(lowerModel, "o4-mini"):
  46. isPatchBased = true
  47. multiplier = 1.72
  48. case strings.HasPrefix(lowerModel, "gpt-5-mini"):
  49. isPatchBased = true
  50. multiplier = 1.62
  51. case strings.HasPrefix(lowerModel, "gpt-5-nano"):
  52. isPatchBased = true
  53. multiplier = 2.46
  54. }
  55. // Tile-based model tokens and bases per doc
  56. if !isPatchBased {
  57. if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
  58. baseTokens = 2833
  59. tileTokens = 5667
  60. } else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
  61. baseTokens = 70
  62. tileTokens = 140
  63. } else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
  64. baseTokens = 75
  65. tileTokens = 150
  66. } else if strings.Contains(lowerModel, "computer-use-preview") {
  67. baseTokens = 65
  68. tileTokens = 129
  69. } else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
  70. baseTokens = 85
  71. tileTokens = 170
  72. }
  73. }
  74. // Respect existing feature flags/short-circuits
  75. if fileMeta.Detail == "low" && !isPatchBased {
  76. return baseTokens, nil
  77. }
  78. // Whether to count image tokens at all
  79. if !constant.GetMediaToken {
  80. return 3 * baseTokens, nil
  81. }
  82. if !constant.GetMediaTokenNotStream && !stream {
  83. return 3 * baseTokens, nil
  84. }
  85. // Normalize detail
  86. if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
  87. fileMeta.Detail = "high"
  88. }
  89. // Decode image to get dimensions
  90. var config image.Config
  91. var err error
  92. var format string
  93. var b64str string
  94. if fileMeta.ParsedData != nil {
  95. config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
  96. } else {
  97. if strings.HasPrefix(fileMeta.OriginData, "http") {
  98. config, format, err = DecodeUrlImageData(fileMeta.OriginData)
  99. } else {
  100. common.SysLog(fmt.Sprintf("decoding image"))
  101. config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
  102. }
  103. fileMeta.MimeType = format
  104. }
  105. if err != nil {
  106. return 0, err
  107. }
  108. if config.Width == 0 || config.Height == 0 {
  109. // not an image
  110. if format != "" && b64str != "" {
  111. // file type
  112. return 3 * baseTokens, nil
  113. }
  114. return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
  115. }
  116. width := config.Width
  117. height := config.Height
  118. log.Printf("format: %s, width: %d, height: %d", format, width, height)
  119. if isPatchBased {
  120. // 32x32 patch-based calculation with 1536 cap and model multiplier
  121. ceilDiv := func(a, b int) int { return (a + b - 1) / b }
  122. rawPatchesW := ceilDiv(width, 32)
  123. rawPatchesH := ceilDiv(height, 32)
  124. rawPatches := rawPatchesW * rawPatchesH
  125. if rawPatches > 1536 {
  126. // scale down
  127. area := float64(width * height)
  128. r := math.Sqrt(float64(32*32*1536) / area)
  129. wScaled := float64(width) * r
  130. hScaled := float64(height) * r
  131. // adjust to fit whole number of patches after scaling
  132. adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
  133. adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
  134. adj := math.Min(adjW, adjH)
  135. if !math.IsNaN(adj) && adj > 0 {
  136. r = r * adj
  137. }
  138. wScaled = float64(width) * r
  139. hScaled = float64(height) * r
  140. patchesW := math.Ceil(wScaled / 32.0)
  141. patchesH := math.Ceil(hScaled / 32.0)
  142. imageTokens := int(patchesW * patchesH)
  143. if imageTokens > 1536 {
  144. imageTokens = 1536
  145. }
  146. return int(math.Round(float64(imageTokens) * multiplier)), nil
  147. }
  148. // below cap
  149. imageTokens := rawPatches
  150. return int(math.Round(float64(imageTokens) * multiplier)), nil
  151. }
  152. // Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
  153. // Step 1: fit within 2048x2048 square
  154. maxSide := math.Max(float64(width), float64(height))
  155. fitScale := 1.0
  156. if maxSide > 2048 {
  157. fitScale = maxSide / 2048.0
  158. }
  159. fitW := int(math.Round(float64(width) / fitScale))
  160. fitH := int(math.Round(float64(height) / fitScale))
  161. // Step 2: scale so that shortest side is exactly 768
  162. minSide := math.Min(float64(fitW), float64(fitH))
  163. if minSide == 0 {
  164. return baseTokens, nil
  165. }
  166. shortScale := 768.0 / minSide
  167. finalW := int(math.Round(float64(fitW) * shortScale))
  168. finalH := int(math.Round(float64(fitH) * shortScale))
  169. // Count 512px tiles
  170. tilesW := (finalW + 512 - 1) / 512
  171. tilesH := (finalH + 512 - 1) / 512
  172. tiles := tilesW * tilesH
  173. if common.DebugEnabled {
  174. log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
  175. }
  176. return tiles*tileTokens + baseTokens, nil
  177. }
  178. func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
  179. // 是否统计token
  180. if !constant.CountToken {
  181. return 0, nil
  182. }
  183. if meta == nil {
  184. return 0, errors.New("token count meta is nil")
  185. }
  186. if info.RelayFormat == types.RelayFormatOpenAIRealtime {
  187. return 0, nil
  188. }
  189. if info.RelayMode == constant2.RelayModeAudioTranscription || info.RelayMode == constant2.RelayModeAudioTranslation {
  190. multiForm, err := common.ParseMultipartFormReusable(c)
  191. if err != nil {
  192. return 0, fmt.Errorf("error parsing multipart form: %v", err)
  193. }
  194. fileHeaders := multiForm.File["file"]
  195. totalAudioToken := 0
  196. for _, fileHeader := range fileHeaders {
  197. file, err := fileHeader.Open()
  198. if err != nil {
  199. return 0, fmt.Errorf("error opening audio file: %v", err)
  200. }
  201. defer file.Close()
  202. // get ext and io.seeker
  203. ext := filepath.Ext(fileHeader.Filename)
  204. duration, err := common.GetAudioDuration(c.Request.Context(), file, ext)
  205. if err != nil {
  206. return 0, fmt.Errorf("error getting audio duration: %v", err)
  207. }
  208. // 一分钟 1000 token,与 $price / minute 对齐
  209. totalAudioToken += int(math.Round(math.Ceil(duration) / 60.0 * 1000))
  210. }
  211. return totalAudioToken, nil
  212. }
  213. model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
  214. tkm := 0
  215. if meta.TokenType == types.TokenTypeTextNumber {
  216. tkm += utf8.RuneCountInString(meta.CombineText)
  217. } else {
  218. tkm += CountTextToken(meta.CombineText, model)
  219. }
  220. if info.RelayFormat == types.RelayFormatOpenAI {
  221. tkm += meta.ToolsCount * 8
  222. tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
  223. tkm += meta.NameCount * 3
  224. tkm += 3
  225. }
  226. shouldFetchFiles := true
  227. if info.RelayFormat == types.RelayFormatGemini {
  228. shouldFetchFiles = false
  229. }
  230. // 是否本地计算媒体token数量
  231. if !constant.GetMediaToken {
  232. shouldFetchFiles = false
  233. }
  234. // 是否在非流模式下本地计算媒体token数量
  235. if !constant.GetMediaTokenNotStream && !info.IsStream {
  236. shouldFetchFiles = false
  237. }
  238. for _, file := range meta.Files {
  239. if strings.HasPrefix(file.OriginData, "http") {
  240. if shouldFetchFiles {
  241. mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
  242. if err != nil {
  243. return 0, fmt.Errorf("error getting file base64 from url: %v", err)
  244. }
  245. if strings.HasPrefix(mineType, "image/") {
  246. file.FileType = types.FileTypeImage
  247. } else if strings.HasPrefix(mineType, "video/") {
  248. file.FileType = types.FileTypeVideo
  249. } else if strings.HasPrefix(mineType, "audio/") {
  250. file.FileType = types.FileTypeAudio
  251. } else {
  252. file.FileType = types.FileTypeFile
  253. }
  254. file.MimeType = mineType
  255. }
  256. } else if strings.HasPrefix(file.OriginData, "data:") {
  257. // get mime type from base64 header
  258. parts := strings.SplitN(file.OriginData, ",", 2)
  259. if len(parts) >= 1 {
  260. header := parts[0]
  261. // Extract mime type from "data:mime/type;base64" format
  262. if strings.Contains(header, ":") && strings.Contains(header, ";") {
  263. mimeStart := strings.Index(header, ":") + 1
  264. mimeEnd := strings.Index(header, ";")
  265. if mimeStart < mimeEnd {
  266. mineType := header[mimeStart:mimeEnd]
  267. if strings.HasPrefix(mineType, "image/") {
  268. file.FileType = types.FileTypeImage
  269. } else if strings.HasPrefix(mineType, "video/") {
  270. file.FileType = types.FileTypeVideo
  271. } else if strings.HasPrefix(mineType, "audio/") {
  272. file.FileType = types.FileTypeAudio
  273. } else {
  274. file.FileType = types.FileTypeFile
  275. }
  276. file.MimeType = mineType
  277. }
  278. }
  279. }
  280. }
  281. }
  282. for i, file := range meta.Files {
  283. switch file.FileType {
  284. case types.FileTypeImage:
  285. if common.IsOpenAITextModel(model) {
  286. token, err := getImageToken(file, model, info.IsStream)
  287. if err != nil {
  288. return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
  289. }
  290. tkm += token
  291. } else {
  292. tkm += 520
  293. }
  294. case types.FileTypeAudio:
  295. tkm += 256
  296. case types.FileTypeVideo:
  297. tkm += 4096 * 2
  298. case types.FileTypeFile:
  299. tkm += 4096
  300. default:
  301. tkm += 4096 // Default case for unknown file types
  302. }
  303. }
  304. common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
  305. return tkm, nil
  306. }
  307. func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
  308. audioToken := 0
  309. textToken := 0
  310. switch request.Type {
  311. case dto.RealtimeEventTypeSessionUpdate:
  312. if request.Session != nil {
  313. msgTokens := CountTextToken(request.Session.Instructions, model)
  314. textToken += msgTokens
  315. }
  316. case dto.RealtimeEventResponseAudioDelta:
  317. // count audio token
  318. atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
  319. if err != nil {
  320. return 0, 0, fmt.Errorf("error counting audio token: %v", err)
  321. }
  322. audioToken += atk
  323. case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
  324. // count text token
  325. tkm := CountTextToken(request.Delta, model)
  326. textToken += tkm
  327. case dto.RealtimeEventInputAudioBufferAppend:
  328. // count audio token
  329. atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
  330. if err != nil {
  331. return 0, 0, fmt.Errorf("error counting audio token: %v", err)
  332. }
  333. audioToken += atk
  334. case dto.RealtimeEventConversationItemCreated:
  335. if request.Item != nil {
  336. switch request.Item.Type {
  337. case "message":
  338. for _, content := range request.Item.Content {
  339. if content.Type == "input_text" {
  340. tokens := CountTextToken(content.Text, model)
  341. textToken += tokens
  342. }
  343. }
  344. }
  345. }
  346. case dto.RealtimeEventTypeResponseDone:
  347. // count tools token
  348. if !info.IsFirstRequest {
  349. if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
  350. for _, tool := range info.RealtimeTools {
  351. toolTokens := CountTokenInput(tool, model)
  352. textToken += 8
  353. textToken += toolTokens
  354. }
  355. }
  356. }
  357. }
  358. return textToken, audioToken, nil
  359. }
  360. func CountTokenInput(input any, model string) int {
  361. switch v := input.(type) {
  362. case string:
  363. return CountTextToken(v, model)
  364. case []string:
  365. text := ""
  366. for _, s := range v {
  367. text += s
  368. }
  369. return CountTextToken(text, model)
  370. case []interface{}:
  371. text := ""
  372. for _, item := range v {
  373. text += fmt.Sprintf("%v", item)
  374. }
  375. return CountTextToken(text, model)
  376. }
  377. return CountTokenInput(fmt.Sprintf("%v", input), model)
  378. }
  379. func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
  380. if audioBase64 == "" {
  381. return 0, nil
  382. }
  383. duration, err := parseAudio(audioBase64, audioFormat)
  384. if err != nil {
  385. return 0, err
  386. }
  387. return int(duration / 60 * 100 / 0.06), nil
  388. }
  389. func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
  390. if audioBase64 == "" {
  391. return 0, nil
  392. }
  393. duration, err := parseAudio(audioBase64, audioFormat)
  394. if err != nil {
  395. return 0, err
  396. }
  397. return int(duration / 60 * 200 / 0.24), nil
  398. }
  399. // CountTextToken 统计文本的token数量,仅OpenAI模型使用tokenizer,其余模型使用估算
  400. func CountTextToken(text string, model string) int {
  401. if text == "" {
  402. return 0
  403. }
  404. if common.IsOpenAITextModel(model) {
  405. tokenEncoder := getTokenEncoder(model)
  406. return getTokenNum(tokenEncoder, text)
  407. } else {
  408. // 非openai模型,使用tiktoken-go计算没有意义,使用估算节省资源
  409. return EstimateTokenByModel(model, text)
  410. }
  411. }