Procházet zdrojové kódy

fix: mutli splitter heads (#143)

* fix: mutli splitter heads

* fix: ci
zijiren před 9 měsíci
rodič
revize
8815df9ddf

+ 129 - 60
core/common/splitter/splitter.go

@@ -1,25 +1,41 @@
 package splitter
 
-import "bytes"
+import (
+	"bytes"
+)
 
 type Splitter struct {
-	head           []byte
-	tail           []byte
-	headLen        int
-	tailLen        int
-	buffer         []byte
-	state          int
-	partialTailPos int
-	kmpNext        []int
+	heads           [][]byte
+	tails           [][]byte
+	buffer          []byte
+	state           int
+	partialTailPos  []int
+	kmpNexts        [][]int
+	impossibleHeads []bool // Track which heads cannot match (inverted logic)
+	longestHeadLen  int    // Cache the longest head length
 }
 
-func NewSplitter(head, tail []byte) *Splitter {
+func NewSplitter(heads, tails [][]byte) *Splitter {
+	kmpNexts := make([][]int, len(tails))
+	for i, tail := range tails {
+		kmpNexts[i] = computeKMPNext(tail)
+	}
+
+	// Find the longest head pattern length
+	longestHeadLen := 0
+	for _, head := range heads {
+		if len(head) > longestHeadLen {
+			longestHeadLen = len(head)
+		}
+	}
+
 	return &Splitter{
-		head:    head,
-		tail:    tail,
-		headLen: len(head),
-		tailLen: len(tail),
-		kmpNext: computeKMPNext(tail),
+		heads:           heads,
+		tails:           tails,
+		kmpNexts:        kmpNexts,
+		partialTailPos:  make([]int, len(tails)),
+		impossibleHeads: make([]bool, len(heads)), // Defaults to false (all heads initially possible)
+		longestHeadLen:  longestHeadLen,
 	}
 }
 
@@ -47,36 +63,71 @@ func (s *Splitter) Process(data []byte) ([]byte, []byte) {
 	if len(data) == 0 {
 		return nil, nil
 	}
+
 	switch s.state {
 	case 0:
 		s.buffer = append(s.buffer, data...)
-		bufferLen := len(s.buffer)
-		minLen := bufferLen
-		if minLen > s.headLen {
-			minLen = s.headLen
-		}
-		if minLen > 0 {
-			if !bytes.Equal(s.buffer[:minLen], s.head[:minLen]) {
-				s.state = 2
-				remaining := s.buffer
-				s.buffer = nil
-				return nil, remaining
+		bufLen := len(s.buffer)
+
+		headMatched := false
+		headMatchLen := 0
+		anyPossibleHead := false
+
+		// Check all heads in a single pass
+		for i, head := range s.heads {
+			// Skip if this head has already been ruled out
+			if s.impossibleHeads[i] {
+				continue
+			}
+
+			headLen := len(head)
+
+			// Check for complete match
+			if bufLen >= headLen {
+				if bytes.Equal(s.buffer[:headLen], head) {
+					headMatched = true
+					headMatchLen = headLen
+					break
+				}
+				// Mark this head as impossible to match
+				s.impossibleHeads[i] = true
+			} else {
+				// Check for partial match (potential match)
+				matchLen := bufLen
+				if bytes.Equal(s.buffer[:matchLen], head[:matchLen]) {
+					anyPossibleHead = true
+				} else {
+					// Mark this head as impossible to match
+					s.impossibleHeads[i] = true
+				}
 			}
 		}
 
-		if bufferLen < s.headLen {
-			return nil, nil
+		if headMatched {
+			// Head found, move to seeking tail
+			s.state = 1
+			s.buffer = s.buffer[headMatchLen:]
+			if len(s.buffer) == 0 {
+				return nil, nil
+			}
+			return s.processSeekTail()
 		}
 
-		s.state = 1
-		s.buffer = s.buffer[s.headLen:]
-		if len(s.buffer) == 0 {
+		if anyPossibleHead {
+			// Need more data to determine if a head matches
 			return nil, nil
 		}
-		return s.processSeekTail()
+
+		// No head matches and no partial match possible, move to done state
+		s.state = 2
+		remaining := s.buffer
+		s.buffer = nil
+		return nil, remaining
+
 	case 1:
 		s.buffer = append(s.buffer, data...)
 		return s.processSeekTail()
+
 	default:
 		return nil, data
 	}
@@ -84,38 +135,56 @@ func (s *Splitter) Process(data []byte) ([]byte, []byte) {
 
 func (s *Splitter) processSeekTail() ([]byte, []byte) {
 	data := s.buffer
-	j := s.partialTailPos
-	tail := s.tail
-	tailLen := s.tailLen
-	kmpNext := s.kmpNext
-
-	for i := range data {
-		for j > 0 && data[i] != tail[j] {
-			j = kmpNext[j-1]
-		}
-		if data[i] == tail[j] {
-			j++
-			if j == tailLen {
-				end := i - tailLen + 1
-				if end < 0 {
-					end = 0
+
+	// Check for each tail pattern
+	for tailIdx, tail := range s.tails {
+		j := s.partialTailPos[tailIdx]
+		tailLen := len(tail)
+		kmpNext := s.kmpNexts[tailIdx]
+
+		for i := range data {
+			for j > 0 && data[i] != tail[j] {
+				j = kmpNext[j-1]
+			}
+			if data[i] == tail[j] {
+				j++
+				if j == tailLen {
+					end := i - tailLen + 1
+					if end < 0 {
+						end = 0
+					}
+					result := data[:end]
+					remaining := data[i+1:]
+					s.buffer = nil
+					s.state = 2
+					return result, remaining
 				}
-				result := data[:end]
-				remaining := data[i+1:]
-				s.buffer = nil
-				s.state = 2
-				s.partialTailPos = 0
-				return result, remaining
+			}
+		}
+
+		// Update partial match position for this tail
+		s.partialTailPos[tailIdx] = j
+	}
+
+	// Determine how much of the buffer we can safely return
+	minSafePos := len(data)
+	for _, pos := range s.partialTailPos {
+		if pos > 0 {
+			// We have a partial match for this tail
+			tailMatchLen := pos
+			safePos := len(data) - tailMatchLen
+			if safePos < minSafePos {
+				minSafePos = safePos
 			}
 		}
 	}
-	splitAt := len(data) - j
-	if splitAt < 0 {
-		splitAt = 0
+
+	if minSafePos <= 0 {
+		// We can't safely return anything yet
+		return nil, nil
 	}
-	result := data[:splitAt]
-	remainingPart := data[splitAt:]
-	s.partialTailPos = j
-	s.buffer = remainingPart
+
+	result := data[:minSafePos]
+	s.buffer = data[minSafePos:]
 	return result, nil
 }

+ 7 - 5
core/common/splitter/think.go

@@ -3,15 +3,17 @@ package splitter
 import "github.com/labring/aiproxy/core/common/conv"
 
 const (
-	ThinkHead = "<think>\n"
-	ThinkTail = "</think>\n"
+	NThinkHead = "\n<think>\n"
+	ThinkHead  = "<think>\n"
+	ThinkTail  = "</think>\n"
 )
 
 var (
-	thinkHeadBytes = conv.StringToBytes(ThinkHead)
-	thinkTailBytes = conv.StringToBytes(ThinkTail)
+	nthinkHeadBytes = conv.StringToBytes(NThinkHead)
+	thinkHeadBytes  = conv.StringToBytes(ThinkHead)
+	thinkTailBytes  = conv.StringToBytes(ThinkTail)
 )
 
 func NewThinkSplitter() *Splitter {
-	return NewSplitter(thinkHeadBytes, thinkTailBytes)
+	return NewSplitter([][]byte{nthinkHeadBytes, thinkHeadBytes}, [][]byte{thinkTailBytes})
 }

+ 2 - 0
core/relay/adaptor/openai/main.go

@@ -417,6 +417,8 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response, preHandler Pr
 			return usage.ToModelUsage(), ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 		}
 		SplitThink(respMap)
+		c.JSON(http.StatusOK, respMap)
+		return usage.ToModelUsage(), nil
 	}
 
 	newData, err := sonic.Marshal(&node)