relay-aiproxy.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strconv"
  11. "strings"
  12. )
  13. // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
  14. type AIProxyLibraryRequest struct {
  15. Model string `json:"model"`
  16. Query string `json:"query"`
  17. LibraryId string `json:"libraryId"`
  18. Stream bool `json:"stream"`
  19. }
  20. type AIProxyLibraryError struct {
  21. ErrCode int `json:"errCode"`
  22. Message string `json:"message"`
  23. }
  24. type AIProxyLibraryDocument struct {
  25. Title string `json:"title"`
  26. URL string `json:"url"`
  27. }
  28. type AIProxyLibraryResponse struct {
  29. Success bool `json:"success"`
  30. Answer string `json:"answer"`
  31. Documents []AIProxyLibraryDocument `json:"documents"`
  32. AIProxyLibraryError
  33. }
  34. type AIProxyLibraryStreamResponse struct {
  35. Content string `json:"content"`
  36. Finish bool `json:"finish"`
  37. Model string `json:"model"`
  38. Documents []AIProxyLibraryDocument `json:"documents"`
  39. }
  40. func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
  41. query := ""
  42. if len(request.Messages) != 0 {
  43. query = request.Messages[len(request.Messages)-1].Content
  44. }
  45. return &AIProxyLibraryRequest{
  46. Model: request.Model,
  47. Stream: request.Stream,
  48. Query: query,
  49. }
  50. }
  51. func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
  52. if len(documents) == 0 {
  53. return ""
  54. }
  55. content := "\n\n参考文档:\n"
  56. for i, document := range documents {
  57. content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
  58. }
  59. return content
  60. }
  61. func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
  62. content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
  63. choice := OpenAITextResponseChoice{
  64. Index: 0,
  65. Message: Message{
  66. Role: "assistant",
  67. Content: content,
  68. },
  69. FinishReason: "stop",
  70. }
  71. fullTextResponse := OpenAITextResponse{
  72. Id: common.GetUUID(),
  73. Object: "chat.completion",
  74. Created: common.GetTimestamp(),
  75. Choices: []OpenAITextResponseChoice{choice},
  76. }
  77. return &fullTextResponse
  78. }
  79. func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
  80. var choice ChatCompletionsStreamResponseChoice
  81. choice.Delta.Content = aiProxyDocuments2Markdown(documents)
  82. choice.FinishReason = &stopFinishReason
  83. return &ChatCompletionsStreamResponse{
  84. Id: common.GetUUID(),
  85. Object: "chat.completion.chunk",
  86. Created: common.GetTimestamp(),
  87. Model: "",
  88. Choices: []ChatCompletionsStreamResponseChoice{choice},
  89. }
  90. }
  91. func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
  92. var choice ChatCompletionsStreamResponseChoice
  93. choice.Delta.Content = response.Content
  94. return &ChatCompletionsStreamResponse{
  95. Id: common.GetUUID(),
  96. Object: "chat.completion.chunk",
  97. Created: common.GetTimestamp(),
  98. Model: response.Model,
  99. Choices: []ChatCompletionsStreamResponseChoice{choice},
  100. }
  101. }
  102. func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  103. var usage Usage
  104. scanner := bufio.NewScanner(resp.Body)
  105. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  106. if atEOF && len(data) == 0 {
  107. return 0, nil, nil
  108. }
  109. if i := strings.Index(string(data), "\n"); i >= 0 {
  110. return i + 1, data[0:i], nil
  111. }
  112. if atEOF {
  113. return len(data), data, nil
  114. }
  115. return 0, nil, nil
  116. })
  117. dataChan := make(chan string)
  118. stopChan := make(chan bool)
  119. go func() {
  120. for scanner.Scan() {
  121. data := scanner.Text()
  122. if len(data) < 5 { // ignore blank line or wrong format
  123. continue
  124. }
  125. if data[:5] != "data:" {
  126. continue
  127. }
  128. data = data[5:]
  129. dataChan <- data
  130. }
  131. stopChan <- true
  132. }()
  133. setEventStreamHeaders(c)
  134. var documents []AIProxyLibraryDocument
  135. c.Stream(func(w io.Writer) bool {
  136. select {
  137. case data := <-dataChan:
  138. var AIProxyLibraryResponse AIProxyLibraryStreamResponse
  139. err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
  140. if err != nil {
  141. common.SysError("error unmarshalling stream response: " + err.Error())
  142. return true
  143. }
  144. if len(AIProxyLibraryResponse.Documents) != 0 {
  145. documents = AIProxyLibraryResponse.Documents
  146. }
  147. response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
  148. jsonResponse, err := json.Marshal(response)
  149. if err != nil {
  150. common.SysError("error marshalling stream response: " + err.Error())
  151. return true
  152. }
  153. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  154. return true
  155. case <-stopChan:
  156. response := documentsAIProxyLibrary(documents)
  157. jsonResponse, err := json.Marshal(response)
  158. if err != nil {
  159. common.SysError("error marshalling stream response: " + err.Error())
  160. return true
  161. }
  162. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  163. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  164. return false
  165. }
  166. })
  167. err := resp.Body.Close()
  168. if err != nil {
  169. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  170. }
  171. return nil, &usage
  172. }
  173. func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  174. var AIProxyLibraryResponse AIProxyLibraryResponse
  175. responseBody, err := io.ReadAll(resp.Body)
  176. if err != nil {
  177. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  178. }
  179. err = resp.Body.Close()
  180. if err != nil {
  181. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  182. }
  183. err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
  184. if err != nil {
  185. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  186. }
  187. if AIProxyLibraryResponse.ErrCode != 0 {
  188. return &OpenAIErrorWithStatusCode{
  189. OpenAIError: OpenAIError{
  190. Message: AIProxyLibraryResponse.Message,
  191. Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
  192. Code: AIProxyLibraryResponse.ErrCode,
  193. },
  194. StatusCode: resp.StatusCode,
  195. }, nil
  196. }
  197. fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
  198. jsonResponse, err := json.Marshal(fullTextResponse)
  199. if err != nil {
  200. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  201. }
  202. c.Writer.Header().Set("Content-Type", "application/json")
  203. c.Writer.WriteHeader(resp.StatusCode)
  204. _, err = c.Writer.Write(jsonResponse)
  205. return nil, &fullTextResponse.Usage
  206. }