| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379 |
- package streamfake
- import (
- "bytes"
- "errors"
- "fmt"
- "net/http"
- "slices"
- "strconv"
- "github.com/bytedance/sonic"
- "github.com/bytedance/sonic/ast"
- "github.com/gin-gonic/gin"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/common/conv"
- "github.com/labring/aiproxy/core/model"
- "github.com/labring/aiproxy/core/relay/adaptor"
- "github.com/labring/aiproxy/core/relay/meta"
- "github.com/labring/aiproxy/core/relay/mode"
- relaymodel "github.com/labring/aiproxy/core/relay/model"
- "github.com/labring/aiproxy/core/relay/plugin"
- "github.com/labring/aiproxy/core/relay/plugin/noop"
- "github.com/labring/aiproxy/core/relay/plugin/patch"
- )
- var _ plugin.Plugin = (*StreamFake)(nil)
- // StreamFake implements the stream fake functionality
- type StreamFake struct {
- noop.Noop
- }
- // NewStreamFakePlugin creates a new stream fake plugin instance
- func NewStreamFakePlugin() plugin.Plugin {
- return &StreamFake{}
- }
- // Constants for metadata keys
- const (
- fakeStreamKey = "fake_stream"
- )
- // getConfig retrieves the plugin configuration
- func (p *StreamFake) getConfig(meta *meta.Meta) (*Config, error) {
- pluginConfig := &Config{}
- if err := meta.ModelConfig.LoadPluginConfig("stream-fake", pluginConfig); err != nil {
- return nil, err
- }
- return pluginConfig, nil
- }
- // ConvertRequest modifies the request to enable streaming if it's originally non-streaming
- func (p *StreamFake) ConvertRequest(
- meta *meta.Meta,
- store adaptor.Store,
- req *http.Request,
- do adaptor.ConvertRequest,
- ) (adaptor.ConvertResult, error) {
- // Only process chat completions
- if meta.Mode != mode.ChatCompletions {
- return do.ConvertRequest(meta, store, req)
- }
- // Check if stream fake is enabled
- pluginConfig, err := p.getConfig(meta)
- if err != nil || !pluginConfig.Enable {
- return do.ConvertRequest(meta, store, req)
- }
- body, err := common.GetRequestBodyReusable(req)
- if err != nil {
- return adaptor.ConvertResult{}, fmt.Errorf("failed to read request body: %w", err)
- }
- node, err := sonic.Get(body)
- if err != nil {
- return do.ConvertRequest(meta, store, req)
- }
- stream, _ := node.Get("stream").Bool()
- if stream {
- // Already streaming, no need to fake
- return do.ConvertRequest(meta, store, req)
- }
- patch.AddLazyPatch(meta, patch.PatchOperation{
- Op: patch.OpFunction,
- Function: func(root *ast.Node) (bool, error) {
- _, err := root.Set("stream", ast.NewBool(true))
- if err != nil {
- return false, err
- }
- return true, nil
- },
- })
- meta.Set(fakeStreamKey, true)
- return do.ConvertRequest(meta, store, req)
- }
- // DoResponse handles the response processing to collect streaming data and convert back to non-streaming
- func (p *StreamFake) DoResponse(
- meta *meta.Meta,
- store adaptor.Store,
- c *gin.Context,
- resp *http.Response,
- do adaptor.DoResponse,
- ) (model.Usage, adaptor.Error) {
- // Only process chat completions
- if meta.Mode != mode.ChatCompletions {
- return do.DoResponse(meta, store, c, resp)
- }
- // Check if this is a fake stream request
- isFakeStream, ok := meta.Get(fakeStreamKey)
- if !ok {
- return do.DoResponse(meta, store, c, resp)
- }
- isFakeStreamBool, ok := isFakeStream.(bool)
- if !ok || !isFakeStreamBool {
- return do.DoResponse(meta, store, c, resp)
- }
- return p.handleFakeStreamResponse(meta, store, c, resp, do)
- }
- // handleFakeStreamResponse processes the streaming response and converts it back to non-streaming
- func (p *StreamFake) handleFakeStreamResponse(
- meta *meta.Meta,
- store adaptor.Store,
- c *gin.Context,
- resp *http.Response,
- do adaptor.DoResponse,
- ) (model.Usage, adaptor.Error) {
- log := common.GetLogger(c)
- // Create a custom response writer to collect streaming data
- rw := &fakeStreamResponseWriter{
- ResponseWriter: c.Writer,
- }
- c.Writer = rw
- defer func() {
- c.Writer = rw.ResponseWriter
- }()
- // Process the streaming response
- usage, relayErr := do.DoResponse(meta, store, c, resp)
- if relayErr != nil {
- return usage, relayErr
- }
- // Convert collected streaming chunks to non-streaming response
- respBody, err := rw.convertToNonStream()
- if err != nil {
- log.Errorf("failed to convert to non-streaming response: %v", err)
- return usage, relayErr
- }
- // Set appropriate headers for non-streaming response
- c.Header("Content-Type", "application/json")
- c.Header("Content-Length", strconv.Itoa(len(respBody)))
- // Remove streaming-specific headers
- c.Header("Cache-Control", "")
- c.Header("Connection", "")
- c.Header("Transfer-Encoding", "")
- c.Header("X-Accel-Buffering", "")
- // Write the non-streaming response
- _, _ = rw.ResponseWriter.Write(respBody)
- return usage, nil
- }
- // fakeStreamResponseWriter captures streaming response data
- type fakeStreamResponseWriter struct {
- gin.ResponseWriter
- lastChunk *ast.Node
- usageNode *ast.Node
- contentBuilder bytes.Buffer
- reasoningContent bytes.Buffer
- finishReason relaymodel.FinishReason
- logprobsContent []ast.Node
- toolCalls []*relaymodel.ToolCall
- }
- // ignore flush
- func (rw *fakeStreamResponseWriter) Flush() {}
- // ignore WriteHeaderNow
- func (rw *fakeStreamResponseWriter) WriteHeaderNow() {}
- func (rw *fakeStreamResponseWriter) Write(b []byte) (int, error) {
- // Parse streaming data
- _ = rw.parseStreamingData(b)
- return len(b), nil
- }
- func (rw *fakeStreamResponseWriter) WriteString(s string) (int, error) {
- return rw.Write(conv.StringToBytes(s))
- }
- // parseStreamingData extracts individual chunks from streaming response
- func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error {
- node, err := sonic.Get(data)
- if err != nil {
- return err
- }
- rw.lastChunk = &node
- usageNode := node.Get("usage")
- if err := usageNode.Check(); err != nil {
- if !errors.Is(err, ast.ErrNotExist) {
- return err
- }
- } else {
- rw.usageNode = usageNode
- }
- choicesNode := node.Get("choices")
- if err := choicesNode.Check(); err != nil {
- return err
- }
- return choicesNode.ForEach(func(_ ast.Sequence, choiceNode *ast.Node) bool {
- deltaNode := choiceNode.Get("delta")
- if err := deltaNode.Check(); err != nil {
- return true
- }
- content, err := deltaNode.Get("content").String()
- if err == nil {
- rw.contentBuilder.WriteString(content)
- }
- reasoningContent, err := deltaNode.Get("reasoning_content").String()
- if err == nil {
- rw.reasoningContent.WriteString(reasoningContent)
- }
- _ = deltaNode.Get("tool_calls").
- ForEach(func(_ ast.Sequence, toolCallNode *ast.Node) bool {
- toolCallRaw, err := toolCallNode.Raw()
- if err != nil {
- return true
- }
- var toolCall relaymodel.ToolCall
- if err := sonic.UnmarshalString(toolCallRaw, &toolCall); err != nil {
- return true
- }
- rw.toolCalls = mergeToolCalls(rw.toolCalls, &toolCall)
- return true
- })
- finishReason, err := choiceNode.Get("finish_reason").String()
- if err == nil && finishReason != "" {
- rw.finishReason = finishReason
- }
- logprobsContentNode := choiceNode.GetByPath("logprobs", "content")
- if err := logprobsContentNode.Check(); err == nil {
- l, err := logprobsContentNode.Len()
- if err != nil {
- return true
- }
- rw.logprobsContent = slices.Grow(rw.logprobsContent, l)
- _ = logprobsContentNode.ForEach(
- func(_ ast.Sequence, logprobsContentNode *ast.Node) bool {
- rw.logprobsContent = append(rw.logprobsContent, *logprobsContentNode)
- return true
- },
- )
- }
- return true
- })
- }
- func (rw *fakeStreamResponseWriter) convertToNonStream() ([]byte, error) {
- lastChunk := rw.lastChunk
- if lastChunk == nil {
- return nil, errors.New("last chunk is nil")
- }
- _, err := lastChunk.Set("object", ast.NewString(relaymodel.ChatCompletionObject))
- if err != nil {
- return nil, err
- }
- if rw.usageNode != nil {
- _, err = lastChunk.Set("usage", *rw.usageNode)
- if err != nil {
- return nil, err
- }
- }
- message := map[string]any{
- "role": "assistant",
- "content": rw.contentBuilder.String(),
- }
- reasoningContent := rw.reasoningContent.String()
- if reasoningContent != "" {
- message["reasoning_content"] = reasoningContent
- }
- if len(rw.toolCalls) > 0 {
- slices.SortFunc(rw.toolCalls, func(a, b *relaymodel.ToolCall) int {
- return a.Index - b.Index
- })
- message["tool_calls"] = rw.toolCalls
- }
- if len(rw.logprobsContent) > 0 {
- message["logprobs"] = map[string]any{
- "content": rw.logprobsContent,
- }
- }
- _, err = lastChunk.SetAny("choices", []any{
- map[string]any{
- "index": 0,
- "message": message,
- "finish_reason": rw.finishReason,
- },
- })
- if err != nil {
- return nil, err
- }
- return lastChunk.MarshalJSON()
- }
- func mergeToolCalls(
- oldToolCalls []*relaymodel.ToolCall,
- newToolCall *relaymodel.ToolCall,
- ) []*relaymodel.ToolCall {
- findedToolCallIndex := slices.IndexFunc(oldToolCalls, func(t *relaymodel.ToolCall) bool {
- return t.Index == newToolCall.Index
- })
- if findedToolCallIndex != -1 {
- oldToolCall := oldToolCalls[findedToolCallIndex]
- oldToolCalls[findedToolCallIndex] = mergeToolCall(oldToolCall, newToolCall)
- } else {
- oldToolCalls = append(oldToolCalls, newToolCall)
- }
- return oldToolCalls
- }
- func mergeToolCall(oldToolCall, newToolCall *relaymodel.ToolCall) *relaymodel.ToolCall {
- if oldToolCall == nil {
- return newToolCall
- }
- if newToolCall == nil {
- return oldToolCall
- }
- merged := &relaymodel.ToolCall{
- Index: oldToolCall.Index,
- ID: oldToolCall.ID,
- Type: oldToolCall.Type,
- Function: oldToolCall.Function,
- }
- merged.Function.Arguments += newToolCall.Function.Arguments
- return merged
- }
|