relay-openai.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package controller
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
  15. var responseTextBuilder strings.Builder
  16. scanner := bufio.NewScanner(resp.Body)
  17. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  18. if atEOF && len(data) == 0 {
  19. return 0, nil, nil
  20. }
  21. if i := strings.Index(string(data), "\n"); i >= 0 {
  22. return i + 1, data[0:i], nil
  23. }
  24. if atEOF {
  25. return len(data), data, nil
  26. }
  27. return 0, nil, nil
  28. })
  29. dataChan := make(chan string, 5)
  30. stopChan := make(chan bool, 2)
  31. defer close(stopChan)
  32. defer close(dataChan)
  33. var wg sync.WaitGroup
  34. go func() {
  35. wg.Add(1)
  36. defer wg.Done()
  37. var streamItems []string
  38. for scanner.Scan() {
  39. data := scanner.Text()
  40. if len(data) < 6 { // ignore blank line or wrong format
  41. continue
  42. }
  43. if data[:6] != "data: " && data[:6] != "[DONE]" {
  44. continue
  45. }
  46. dataChan <- data
  47. data = data[6:]
  48. if !strings.HasPrefix(data, "[DONE]") {
  49. streamItems = append(streamItems, data)
  50. }
  51. }
  52. streamResp := "[" + strings.Join(streamItems, ",") + "]"
  53. switch relayMode {
  54. case RelayModeChatCompletions:
  55. var streamResponses []ChatCompletionsStreamResponseSimple
  56. err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
  57. if err != nil {
  58. common.SysError("error unmarshalling stream response: " + err.Error())
  59. return // just ignore the error
  60. }
  61. for _, streamResponse := range streamResponses {
  62. for _, choice := range streamResponse.Choices {
  63. responseTextBuilder.WriteString(choice.Delta.Content)
  64. }
  65. }
  66. case RelayModeCompletions:
  67. var streamResponses []CompletionsStreamResponse
  68. err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
  69. if err != nil {
  70. common.SysError("error unmarshalling stream response: " + err.Error())
  71. return // just ignore the error
  72. }
  73. for _, streamResponse := range streamResponses {
  74. for _, choice := range streamResponse.Choices {
  75. responseTextBuilder.WriteString(choice.Text)
  76. }
  77. }
  78. }
  79. if len(dataChan) > 0 {
  80. // wait data out
  81. time.Sleep(2 * time.Second)
  82. }
  83. common.SafeSend(stopChan, true)
  84. }()
  85. setEventStreamHeaders(c)
  86. c.Stream(func(w io.Writer) bool {
  87. select {
  88. case data := <-dataChan:
  89. if strings.HasPrefix(data, "data: [DONE]") {
  90. data = data[:12]
  91. }
  92. // some implementations may add \r at the end of data
  93. data = strings.TrimSuffix(data, "\r")
  94. c.Render(-1, common.CustomEvent{Data: data})
  95. return true
  96. case <-stopChan:
  97. return false
  98. }
  99. })
  100. err := resp.Body.Close()
  101. if err != nil {
  102. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
  103. }
  104. wg.Wait()
  105. return nil, responseTextBuilder.String()
  106. }
  107. func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
  108. var textResponse TextResponse
  109. responseBody, err := io.ReadAll(resp.Body)
  110. if err != nil {
  111. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  112. }
  113. err = resp.Body.Close()
  114. if err != nil {
  115. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  116. }
  117. err = json.Unmarshal(responseBody, &textResponse)
  118. if err != nil {
  119. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  120. }
  121. if textResponse.Error.Type != "" {
  122. return &OpenAIErrorWithStatusCode{
  123. OpenAIError: textResponse.Error,
  124. StatusCode: resp.StatusCode,
  125. }, nil
  126. }
  127. // Reset response body
  128. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  129. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  130. // And then we will have to send an error response, but in this case, the header has already been set.
  131. // So the httpClient will be confused by the response.
  132. // For example, Postman will report error, and we cannot check the response at all.
  133. for k, v := range resp.Header {
  134. c.Writer.Header().Set(k, v[0])
  135. }
  136. c.Writer.WriteHeader(resp.StatusCode)
  137. _, err = io.Copy(c.Writer, resp.Body)
  138. if err != nil {
  139. return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  140. }
  141. err = resp.Body.Close()
  142. if err != nil {
  143. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  144. }
  145. if textResponse.Usage.TotalTokens == 0 {
  146. completionTokens := 0
  147. for _, choice := range textResponse.Choices {
  148. completionTokens += countTokenText(string(choice.Message.Content), model)
  149. }
  150. textResponse.Usage = Usage{
  151. PromptTokens: promptTokens,
  152. CompletionTokens: completionTokens,
  153. TotalTokens: promptTokens + completionTokens,
  154. }
  155. }
  156. return nil, &textResponse.Usage
  157. }