fake.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. package streamfake
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "slices"
  8. "strconv"
  9. "github.com/bytedance/sonic"
  10. "github.com/bytedance/sonic/ast"
  11. "github.com/gin-gonic/gin"
  12. "github.com/labring/aiproxy/core/common"
  13. "github.com/labring/aiproxy/core/common/conv"
  14. "github.com/labring/aiproxy/core/model"
  15. "github.com/labring/aiproxy/core/relay/adaptor"
  16. "github.com/labring/aiproxy/core/relay/meta"
  17. "github.com/labring/aiproxy/core/relay/mode"
  18. relaymodel "github.com/labring/aiproxy/core/relay/model"
  19. "github.com/labring/aiproxy/core/relay/plugin"
  20. "github.com/labring/aiproxy/core/relay/plugin/noop"
  21. "github.com/labring/aiproxy/core/relay/plugin/patch"
  22. )
  23. var _ plugin.Plugin = (*StreamFake)(nil)
  24. // StreamFake implements the stream fake functionality
  25. type StreamFake struct {
  26. noop.Noop
  27. }
  28. // NewStreamFakePlugin creates a new stream fake plugin instance
  29. func NewStreamFakePlugin() plugin.Plugin {
  30. return &StreamFake{}
  31. }
  32. // Constants for metadata keys
  33. const (
  34. fakeStreamKey = "fake_stream"
  35. )
  36. // getConfig retrieves the plugin configuration
  37. func (p *StreamFake) getConfig(meta *meta.Meta) (*Config, error) {
  38. pluginConfig := &Config{}
  39. if err := meta.ModelConfig.LoadPluginConfig("stream-fake", pluginConfig); err != nil {
  40. return nil, err
  41. }
  42. return pluginConfig, nil
  43. }
  44. // ConvertRequest modifies the request to enable streaming if it's originally non-streaming
  45. func (p *StreamFake) ConvertRequest(
  46. meta *meta.Meta,
  47. store adaptor.Store,
  48. req *http.Request,
  49. do adaptor.ConvertRequest,
  50. ) (adaptor.ConvertResult, error) {
  51. // Only process chat completions
  52. if meta.Mode != mode.ChatCompletions {
  53. return do.ConvertRequest(meta, store, req)
  54. }
  55. // Check if stream fake is enabled
  56. pluginConfig, err := p.getConfig(meta)
  57. if err != nil || !pluginConfig.Enable {
  58. return do.ConvertRequest(meta, store, req)
  59. }
  60. body, err := common.GetRequestBodyReusable(req)
  61. if err != nil {
  62. return adaptor.ConvertResult{}, fmt.Errorf("failed to read request body: %w", err)
  63. }
  64. node, err := sonic.Get(body)
  65. if err != nil {
  66. return do.ConvertRequest(meta, store, req)
  67. }
  68. stream, _ := node.Get("stream").Bool()
  69. if stream {
  70. // Already streaming, no need to fake
  71. return do.ConvertRequest(meta, store, req)
  72. }
  73. patch.AddLazyPatch(meta, patch.PatchOperation{
  74. Op: patch.OpFunction,
  75. Function: func(root *ast.Node) (bool, error) {
  76. _, err := root.Set("stream", ast.NewBool(true))
  77. if err != nil {
  78. return false, err
  79. }
  80. return true, nil
  81. },
  82. })
  83. meta.Set(fakeStreamKey, true)
  84. return do.ConvertRequest(meta, store, req)
  85. }
  86. // DoResponse handles the response processing to collect streaming data and convert back to non-streaming
  87. func (p *StreamFake) DoResponse(
  88. meta *meta.Meta,
  89. store adaptor.Store,
  90. c *gin.Context,
  91. resp *http.Response,
  92. do adaptor.DoResponse,
  93. ) (model.Usage, adaptor.Error) {
  94. // Only process chat completions
  95. if meta.Mode != mode.ChatCompletions {
  96. return do.DoResponse(meta, store, c, resp)
  97. }
  98. // Check if this is a fake stream request
  99. isFakeStream, ok := meta.Get(fakeStreamKey)
  100. if !ok {
  101. return do.DoResponse(meta, store, c, resp)
  102. }
  103. isFakeStreamBool, ok := isFakeStream.(bool)
  104. if !ok || !isFakeStreamBool {
  105. return do.DoResponse(meta, store, c, resp)
  106. }
  107. return p.handleFakeStreamResponse(meta, store, c, resp, do)
  108. }
  109. // handleFakeStreamResponse processes the streaming response and converts it back to non-streaming
  110. func (p *StreamFake) handleFakeStreamResponse(
  111. meta *meta.Meta,
  112. store adaptor.Store,
  113. c *gin.Context,
  114. resp *http.Response,
  115. do adaptor.DoResponse,
  116. ) (model.Usage, adaptor.Error) {
  117. log := common.GetLogger(c)
  118. // Create a custom response writer to collect streaming data
  119. rw := &fakeStreamResponseWriter{
  120. ResponseWriter: c.Writer,
  121. }
  122. c.Writer = rw
  123. defer func() {
  124. c.Writer = rw.ResponseWriter
  125. }()
  126. // Process the streaming response
  127. usage, relayErr := do.DoResponse(meta, store, c, resp)
  128. if relayErr != nil {
  129. return usage, relayErr
  130. }
  131. // Convert collected streaming chunks to non-streaming response
  132. respBody, err := rw.convertToNonStream()
  133. if err != nil {
  134. log.Errorf("failed to convert to non-streaming response: %v", err)
  135. return usage, relayErr
  136. }
  137. // Set appropriate headers for non-streaming response
  138. c.Header("Content-Type", "application/json")
  139. c.Header("Content-Length", strconv.Itoa(len(respBody)))
  140. // Remove streaming-specific headers
  141. c.Header("Cache-Control", "")
  142. c.Header("Connection", "")
  143. c.Header("Transfer-Encoding", "")
  144. c.Header("X-Accel-Buffering", "")
  145. // Write the non-streaming response
  146. _, _ = rw.ResponseWriter.Write(respBody)
  147. return usage, nil
  148. }
  149. // fakeStreamResponseWriter captures streaming response data
  150. type fakeStreamResponseWriter struct {
  151. gin.ResponseWriter
  152. lastChunk *ast.Node
  153. usageNode *ast.Node
  154. contentBuilder bytes.Buffer
  155. reasoningContent bytes.Buffer
  156. finishReason relaymodel.FinishReason
  157. logprobsContent []ast.Node
  158. toolCalls []*relaymodel.ToolCall
  159. }
  160. // ignore flush
  161. func (rw *fakeStreamResponseWriter) Flush() {}
  162. // ignore WriteHeaderNow
  163. func (rw *fakeStreamResponseWriter) WriteHeaderNow() {}
  164. func (rw *fakeStreamResponseWriter) Write(b []byte) (int, error) {
  165. // Parse streaming data
  166. _ = rw.parseStreamingData(b)
  167. return len(b), nil
  168. }
  169. func (rw *fakeStreamResponseWriter) WriteString(s string) (int, error) {
  170. return rw.Write(conv.StringToBytes(s))
  171. }
  172. // parseStreamingData extracts individual chunks from streaming response
  173. func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error {
  174. node, err := sonic.Get(data)
  175. if err != nil {
  176. return err
  177. }
  178. rw.lastChunk = &node
  179. usageNode := node.Get("usage")
  180. if err := usageNode.Check(); err != nil {
  181. if !errors.Is(err, ast.ErrNotExist) {
  182. return err
  183. }
  184. } else {
  185. rw.usageNode = usageNode
  186. }
  187. choicesNode := node.Get("choices")
  188. if err := choicesNode.Check(); err != nil {
  189. return err
  190. }
  191. return choicesNode.ForEach(func(_ ast.Sequence, choiceNode *ast.Node) bool {
  192. deltaNode := choiceNode.Get("delta")
  193. if err := deltaNode.Check(); err != nil {
  194. return true
  195. }
  196. content, err := deltaNode.Get("content").String()
  197. if err == nil {
  198. rw.contentBuilder.WriteString(content)
  199. }
  200. reasoningContent, err := deltaNode.Get("reasoning_content").String()
  201. if err == nil {
  202. rw.reasoningContent.WriteString(reasoningContent)
  203. }
  204. _ = deltaNode.Get("tool_calls").
  205. ForEach(func(_ ast.Sequence, toolCallNode *ast.Node) bool {
  206. toolCallRaw, err := toolCallNode.Raw()
  207. if err != nil {
  208. return true
  209. }
  210. var toolCall relaymodel.ToolCall
  211. if err := sonic.UnmarshalString(toolCallRaw, &toolCall); err != nil {
  212. return true
  213. }
  214. rw.toolCalls = mergeToolCalls(rw.toolCalls, &toolCall)
  215. return true
  216. })
  217. finishReason, err := choiceNode.Get("finish_reason").String()
  218. if err == nil && finishReason != "" {
  219. rw.finishReason = finishReason
  220. }
  221. logprobsContentNode := choiceNode.GetByPath("logprobs", "content")
  222. if err := logprobsContentNode.Check(); err == nil {
  223. l, err := logprobsContentNode.Len()
  224. if err != nil {
  225. return true
  226. }
  227. rw.logprobsContent = slices.Grow(rw.logprobsContent, l)
  228. _ = logprobsContentNode.ForEach(
  229. func(_ ast.Sequence, logprobsContentNode *ast.Node) bool {
  230. rw.logprobsContent = append(rw.logprobsContent, *logprobsContentNode)
  231. return true
  232. },
  233. )
  234. }
  235. return true
  236. })
  237. }
  238. func (rw *fakeStreamResponseWriter) convertToNonStream() ([]byte, error) {
  239. lastChunk := rw.lastChunk
  240. if lastChunk == nil {
  241. return nil, errors.New("last chunk is nil")
  242. }
  243. _, err := lastChunk.Set("object", ast.NewString(relaymodel.ChatCompletionObject))
  244. if err != nil {
  245. return nil, err
  246. }
  247. if rw.usageNode != nil {
  248. _, err = lastChunk.Set("usage", *rw.usageNode)
  249. if err != nil {
  250. return nil, err
  251. }
  252. }
  253. message := map[string]any{
  254. "role": "assistant",
  255. "content": rw.contentBuilder.String(),
  256. }
  257. reasoningContent := rw.reasoningContent.String()
  258. if reasoningContent != "" {
  259. message["reasoning_content"] = reasoningContent
  260. }
  261. if len(rw.toolCalls) > 0 {
  262. slices.SortFunc(rw.toolCalls, func(a, b *relaymodel.ToolCall) int {
  263. return a.Index - b.Index
  264. })
  265. message["tool_calls"] = rw.toolCalls
  266. }
  267. if len(rw.logprobsContent) > 0 {
  268. message["logprobs"] = map[string]any{
  269. "content": rw.logprobsContent,
  270. }
  271. }
  272. _, err = lastChunk.SetAny("choices", []any{
  273. map[string]any{
  274. "index": 0,
  275. "message": message,
  276. "finish_reason": rw.finishReason,
  277. },
  278. })
  279. if err != nil {
  280. return nil, err
  281. }
  282. return lastChunk.MarshalJSON()
  283. }
  284. func mergeToolCalls(
  285. oldToolCalls []*relaymodel.ToolCall,
  286. newToolCall *relaymodel.ToolCall,
  287. ) []*relaymodel.ToolCall {
  288. findedToolCallIndex := slices.IndexFunc(oldToolCalls, func(t *relaymodel.ToolCall) bool {
  289. return t.Index == newToolCall.Index
  290. })
  291. if findedToolCallIndex != -1 {
  292. oldToolCall := oldToolCalls[findedToolCallIndex]
  293. oldToolCalls[findedToolCallIndex] = mergeToolCall(oldToolCall, newToolCall)
  294. } else {
  295. oldToolCalls = append(oldToolCalls, newToolCall)
  296. }
  297. return oldToolCalls
  298. }
  299. func mergeToolCall(oldToolCall, newToolCall *relaymodel.ToolCall) *relaymodel.ToolCall {
  300. if oldToolCall == nil {
  301. return newToolCall
  302. }
  303. if newToolCall == nil {
  304. return oldToolCall
  305. }
  306. merged := &relaymodel.ToolCall{
  307. Index: oldToolCall.Index,
  308. ID: oldToolCall.ID,
  309. Type: oldToolCall.Type,
  310. Function: oldToolCall.Function,
  311. }
  312. merged.Function.Arguments += newToolCall.Function.Arguments
  313. return merged
  314. }