fake.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. package streamfake
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "slices"
  8. "strconv"
  9. "github.com/bytedance/sonic"
  10. "github.com/bytedance/sonic/ast"
  11. "github.com/gin-gonic/gin"
  12. "github.com/labring/aiproxy/core/common"
  13. "github.com/labring/aiproxy/core/common/conv"
  14. "github.com/labring/aiproxy/core/model"
  15. "github.com/labring/aiproxy/core/relay/adaptor"
  16. "github.com/labring/aiproxy/core/relay/meta"
  17. "github.com/labring/aiproxy/core/relay/mode"
  18. relaymodel "github.com/labring/aiproxy/core/relay/model"
  19. "github.com/labring/aiproxy/core/relay/plugin"
  20. "github.com/labring/aiproxy/core/relay/plugin/noop"
  21. "github.com/labring/aiproxy/core/relay/plugin/patch"
  22. "github.com/labring/aiproxy/core/relay/render"
  23. )
  24. var _ plugin.Plugin = (*StreamFake)(nil)
  25. // StreamFake implements the stream fake functionality
  26. type StreamFake struct {
  27. noop.Noop
  28. }
  29. // NewStreamFakePlugin creates a new stream fake plugin instance
  30. func NewStreamFakePlugin() plugin.Plugin {
  31. return &StreamFake{}
  32. }
  33. // Constants for metadata keys
  34. const (
  35. fakeStreamKey = "fake_stream"
  36. )
  37. // getConfig retrieves the plugin configuration
  38. func (p *StreamFake) getConfig(meta *meta.Meta) (*Config, error) {
  39. pluginConfig := &Config{}
  40. if err := meta.ModelConfig.LoadPluginConfig("stream-fake", pluginConfig); err != nil {
  41. return nil, err
  42. }
  43. return pluginConfig, nil
  44. }
  45. // ConvertRequest modifies the request to enable streaming if it's originally non-streaming
  46. func (p *StreamFake) ConvertRequest(
  47. meta *meta.Meta,
  48. store adaptor.Store,
  49. req *http.Request,
  50. do adaptor.ConvertRequest,
  51. ) (adaptor.ConvertResult, error) {
  52. // Only process chat completions
  53. if meta.Mode != mode.ChatCompletions {
  54. return do.ConvertRequest(meta, store, req)
  55. }
  56. // Check if stream fake is enabled
  57. pluginConfig, err := p.getConfig(meta)
  58. if err != nil || !pluginConfig.Enable {
  59. return do.ConvertRequest(meta, store, req)
  60. }
  61. body, err := common.GetRequestBodyReusable(req)
  62. if err != nil {
  63. return adaptor.ConvertResult{}, fmt.Errorf("failed to read request body: %w", err)
  64. }
  65. node, err := sonic.Get(body)
  66. if err != nil {
  67. return do.ConvertRequest(meta, store, req)
  68. }
  69. stream, _ := node.Get("stream").Bool()
  70. if stream {
  71. // Already streaming, no need to fake
  72. return do.ConvertRequest(meta, store, req)
  73. }
  74. patch.AddLazyPatch(meta, patch.PatchOperation{
  75. Op: patch.OpFunction,
  76. Function: func(root *ast.Node) (bool, error) {
  77. _, err := root.Set("stream", ast.NewBool(true))
  78. if err != nil {
  79. return false, err
  80. }
  81. return true, nil
  82. },
  83. })
  84. meta.Set(fakeStreamKey, true)
  85. return do.ConvertRequest(meta, store, req)
  86. }
  87. // DoResponse handles the response processing to collect streaming data and convert back to non-streaming
  88. func (p *StreamFake) DoResponse(
  89. meta *meta.Meta,
  90. store adaptor.Store,
  91. c *gin.Context,
  92. resp *http.Response,
  93. do adaptor.DoResponse,
  94. ) (model.Usage, adaptor.Error) {
  95. // Only process chat completions
  96. if meta.Mode != mode.ChatCompletions {
  97. return do.DoResponse(meta, store, c, resp)
  98. }
  99. // Check if this is a fake stream request
  100. isFakeStream, ok := meta.Get(fakeStreamKey)
  101. if !ok {
  102. return do.DoResponse(meta, store, c, resp)
  103. }
  104. isFakeStreamBool, ok := isFakeStream.(bool)
  105. if !ok || !isFakeStreamBool {
  106. return do.DoResponse(meta, store, c, resp)
  107. }
  108. return p.handleFakeStreamResponse(meta, store, c, resp, do)
  109. }
  110. // handleFakeStreamResponse processes the streaming response and converts it back to non-streaming
  111. func (p *StreamFake) handleFakeStreamResponse(
  112. meta *meta.Meta,
  113. store adaptor.Store,
  114. c *gin.Context,
  115. resp *http.Response,
  116. do adaptor.DoResponse,
  117. ) (model.Usage, adaptor.Error) {
  118. log := common.GetLogger(c)
  119. // Create a custom response writer to collect streaming data
  120. rw := &fakeStreamResponseWriter{
  121. ResponseWriter: c.Writer,
  122. }
  123. c.Writer = rw
  124. defer func() {
  125. c.Writer = rw.ResponseWriter
  126. }()
  127. // Process the streaming response
  128. usage, relayErr := do.DoResponse(meta, store, c, resp)
  129. if relayErr != nil {
  130. return usage, relayErr
  131. }
  132. // Convert collected streaming chunks to non-streaming response
  133. respBody, err := rw.convertToNonStream()
  134. if err != nil {
  135. log.Errorf("failed to convert to non-streaming response: %v", err)
  136. return usage, relayErr
  137. }
  138. // Set appropriate headers for non-streaming response
  139. c.Header("Content-Type", "application/json")
  140. c.Header("Content-Length", strconv.Itoa(len(respBody)))
  141. // Remove streaming-specific headers
  142. c.Header("Cache-Control", "")
  143. c.Header("Connection", "")
  144. c.Header("Transfer-Encoding", "")
  145. c.Header("X-Accel-Buffering", "")
  146. // Write the non-streaming response
  147. _, _ = rw.ResponseWriter.Write(respBody)
  148. return usage, nil
  149. }
  150. // fakeStreamResponseWriter captures streaming response data
  151. type fakeStreamResponseWriter struct {
  152. gin.ResponseWriter
  153. lastChunk *ast.Node
  154. usageNode *ast.Node
  155. contentBuilder bytes.Buffer
  156. reasoningContent bytes.Buffer
  157. finishReason relaymodel.FinishReason
  158. logprobsContent []ast.Node
  159. toolCalls []*relaymodel.ToolCall
  160. }
  161. // ignore flush
  162. func (rw *fakeStreamResponseWriter) Flush() {}
  163. // ignore WriteHeaderNow
  164. func (rw *fakeStreamResponseWriter) WriteHeaderNow() {}
  165. func (rw *fakeStreamResponseWriter) Write(b []byte) (int, error) {
  166. // Parse streaming data
  167. _ = rw.parseStreamingData(b)
  168. return len(b), nil
  169. }
  170. func (rw *fakeStreamResponseWriter) WriteString(s string) (int, error) {
  171. return rw.Write(conv.StringToBytes(s))
  172. }
  173. // parseStreamingData extracts individual chunks from streaming response
  174. func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error {
  175. if render.IsValidSSEData(data) {
  176. return nil
  177. }
  178. node, err := sonic.Get(data)
  179. if err != nil {
  180. return err
  181. }
  182. rw.lastChunk = &node
  183. usageNode := node.Get("usage")
  184. if err := usageNode.Check(); err != nil {
  185. if !errors.Is(err, ast.ErrNotExist) {
  186. return err
  187. }
  188. } else {
  189. rw.usageNode = usageNode
  190. }
  191. choicesNode := node.Get("choices")
  192. if err := choicesNode.Check(); err != nil {
  193. return err
  194. }
  195. return choicesNode.ForEach(func(_ ast.Sequence, choiceNode *ast.Node) bool {
  196. deltaNode := choiceNode.Get("delta")
  197. if err := deltaNode.Check(); err != nil {
  198. return true
  199. }
  200. content, err := deltaNode.Get("content").String()
  201. if err == nil {
  202. rw.contentBuilder.WriteString(content)
  203. }
  204. reasoningContent, err := deltaNode.Get("reasoning_content").String()
  205. if err == nil {
  206. rw.reasoningContent.WriteString(reasoningContent)
  207. }
  208. _ = deltaNode.Get("tool_calls").
  209. ForEach(func(_ ast.Sequence, toolCallNode *ast.Node) bool {
  210. if toolCallNode == nil || toolCallNode.TypeSafe() == ast.V_NULL {
  211. return true
  212. }
  213. toolCallRaw, err := toolCallNode.Raw()
  214. if err != nil {
  215. return true
  216. }
  217. var toolCall relaymodel.ToolCall
  218. if err := sonic.UnmarshalString(toolCallRaw, &toolCall); err != nil {
  219. return true
  220. }
  221. rw.toolCalls = mergeToolCalls(rw.toolCalls, &toolCall)
  222. return true
  223. })
  224. finishReason, err := choiceNode.Get("finish_reason").String()
  225. if err == nil && finishReason != "" {
  226. rw.finishReason = finishReason
  227. }
  228. logprobsContentNode := choiceNode.GetByPath("logprobs", "content")
  229. if err := logprobsContentNode.Check(); err == nil {
  230. l, err := logprobsContentNode.Len()
  231. if err != nil {
  232. return true
  233. }
  234. rw.logprobsContent = slices.Grow(rw.logprobsContent, l)
  235. _ = logprobsContentNode.ForEach(
  236. func(_ ast.Sequence, logprobsContentNode *ast.Node) bool {
  237. rw.logprobsContent = append(rw.logprobsContent, *logprobsContentNode)
  238. return true
  239. },
  240. )
  241. }
  242. return true
  243. })
  244. }
  245. func (rw *fakeStreamResponseWriter) convertToNonStream() ([]byte, error) {
  246. lastChunk := rw.lastChunk
  247. if lastChunk == nil {
  248. return nil, errors.New("last chunk is nil")
  249. }
  250. _, err := lastChunk.Set("object", ast.NewString(relaymodel.ChatCompletionObject))
  251. if err != nil {
  252. return nil, err
  253. }
  254. if rw.usageNode != nil {
  255. _, err = lastChunk.Set("usage", *rw.usageNode)
  256. if err != nil {
  257. return nil, err
  258. }
  259. }
  260. message := map[string]any{
  261. "role": "assistant",
  262. "content": rw.contentBuilder.String(),
  263. }
  264. reasoningContent := rw.reasoningContent.String()
  265. if reasoningContent != "" {
  266. message["reasoning_content"] = reasoningContent
  267. }
  268. if len(rw.toolCalls) > 0 {
  269. message["tool_calls"] = rw.buildToolCalls()
  270. }
  271. if len(rw.logprobsContent) > 0 {
  272. message["logprobs"] = map[string]any{
  273. "content": rw.logprobsContent,
  274. }
  275. }
  276. _, err = lastChunk.SetAny("choices", []any{
  277. map[string]any{
  278. "index": 0,
  279. "message": message,
  280. "finish_reason": rw.finishReason,
  281. },
  282. })
  283. if err != nil {
  284. return nil, err
  285. }
  286. return lastChunk.MarshalJSON()
  287. }
  288. func (rw *fakeStreamResponseWriter) buildToolCalls() []*relaymodel.ToolCall {
  289. if len(rw.toolCalls) == 0 {
  290. return nil
  291. }
  292. slices.SortFunc(rw.toolCalls, func(a, b *relaymodel.ToolCall) int {
  293. return a.Index - b.Index
  294. })
  295. if rw.toolCalls[0].Index == 0 {
  296. return rw.toolCalls
  297. }
  298. // fix tool call index start with 0
  299. for i, v := range rw.toolCalls {
  300. v.Index = i
  301. }
  302. return rw.toolCalls
  303. }
  304. func mergeToolCalls(
  305. oldToolCalls []*relaymodel.ToolCall,
  306. newToolCall *relaymodel.ToolCall,
  307. ) []*relaymodel.ToolCall {
  308. findedToolCallIndex := slices.IndexFunc(oldToolCalls, func(t *relaymodel.ToolCall) bool {
  309. return t.Index == newToolCall.Index
  310. })
  311. if findedToolCallIndex != -1 {
  312. oldToolCall := oldToolCalls[findedToolCallIndex]
  313. oldToolCalls[findedToolCallIndex] = mergeToolCall(oldToolCall, newToolCall)
  314. } else {
  315. oldToolCalls = append(oldToolCalls, newToolCall)
  316. }
  317. return oldToolCalls
  318. }
  319. func mergeToolCall(oldToolCall, newToolCall *relaymodel.ToolCall) *relaymodel.ToolCall {
  320. if oldToolCall == nil {
  321. return newToolCall
  322. }
  323. if newToolCall == nil {
  324. return oldToolCall
  325. }
  326. merged := &relaymodel.ToolCall{
  327. Index: oldToolCall.Index,
  328. ID: oldToolCall.ID,
  329. Type: oldToolCall.Type,
  330. Function: oldToolCall.Function,
  331. }
  332. // Update ID if new one is provided and old one is empty
  333. if newToolCall.ID != "" && oldToolCall.ID == "" {
  334. merged.ID = newToolCall.ID
  335. }
  336. // Update function name if new one is provided and old one is empty
  337. if newToolCall.Function.Name != "" && oldToolCall.Function.Name == "" {
  338. merged.Function.Name = newToolCall.Function.Name
  339. }
  340. // Only append arguments if they're not empty
  341. if newToolCall.Function.Arguments != "" {
  342. merged.Function.Arguments += newToolCall.Function.Arguments
  343. }
  344. return merged
  345. }