relay-utils.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "github.com/pkoukk/tiktoken-go"
  6. "one-api/common"
  7. )
  8. var stopFinishReason = "stop"
  9. var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
  10. func getTokenEncoder(model string) *tiktoken.Tiktoken {
  11. if tokenEncoder, ok := tokenEncoderMap[model]; ok {
  12. return tokenEncoder
  13. }
  14. tokenEncoder, err := tiktoken.EncodingForModel(model)
  15. if err != nil {
  16. common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
  17. tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
  18. if err != nil {
  19. common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
  20. }
  21. }
  22. tokenEncoderMap[model] = tokenEncoder
  23. return tokenEncoder
  24. }
  25. func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
  26. if common.ApproximateTokenEnabled {
  27. return int(float64(len(text)) * 0.38)
  28. }
  29. return len(tokenEncoder.Encode(text, nil, nil))
  30. }
  31. func countTokenMessages(messages []Message, model string) int {
  32. tokenEncoder := getTokenEncoder(model)
  33. // Reference:
  34. // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  35. // https://github.com/pkoukk/tiktoken-go/issues/6
  36. //
  37. // Every message follows <|start|>{role/name}\n{content}<|end|>\n
  38. var tokensPerMessage int
  39. var tokensPerName int
  40. if model == "gpt-3.5-turbo-0301" {
  41. tokensPerMessage = 4
  42. tokensPerName = -1 // If there's a name, the role is omitted
  43. } else {
  44. tokensPerMessage = 3
  45. tokensPerName = 1
  46. }
  47. tokenNum := 0
  48. for _, message := range messages {
  49. tokenNum += tokensPerMessage
  50. tokenNum += getTokenNum(tokenEncoder, message.Content)
  51. tokenNum += getTokenNum(tokenEncoder, message.Role)
  52. if message.Name != nil {
  53. tokenNum += tokensPerName
  54. tokenNum += getTokenNum(tokenEncoder, *message.Name)
  55. }
  56. }
  57. tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
  58. return tokenNum
  59. }
  60. func countTokenInput(input any, model string) int {
  61. switch input.(type) {
  62. case string:
  63. return countTokenText(input.(string), model)
  64. case []string:
  65. text := ""
  66. for _, s := range input.([]string) {
  67. text += s
  68. }
  69. return countTokenText(text, model)
  70. }
  71. return 0
  72. }
  73. func countTokenText(text string, model string) int {
  74. tokenEncoder := getTokenEncoder(model)
  75. return getTokenNum(tokenEncoder, text)
  76. }
  77. func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
  78. openAIError := OpenAIError{
  79. Message: err.Error(),
  80. Type: "one_api_error",
  81. Code: code,
  82. }
  83. return &OpenAIErrorWithStatusCode{
  84. OpenAIError: openAIError,
  85. StatusCode: statusCode,
  86. }
  87. }
  88. func shouldDisableChannel(err *OpenAIError) bool {
  89. if !common.AutomaticDisableChannelEnabled {
  90. return false
  91. }
  92. if err == nil {
  93. return false
  94. }
  95. if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
  96. return true
  97. }
  98. return false
  99. }
  100. func setEventStreamHeaders(c *gin.Context) {
  101. c.Writer.Header().Set("Content-Type", "text/event-stream")
  102. c.Writer.Header().Set("Cache-Control", "no-cache")
  103. c.Writer.Header().Set("Connection", "keep-alive")
  104. c.Writer.Header().Set("Transfer-Encoding", "chunked")
  105. c.Writer.Header().Set("X-Accel-Buffering", "no")
  106. }