split.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. package thinksplit
  2. import (
  3. "net/http"
  4. "strconv"
  5. "github.com/bytedance/sonic"
  6. "github.com/gin-gonic/gin"
  7. "github.com/labring/aiproxy/core/common/conv"
  8. "github.com/labring/aiproxy/core/model"
  9. "github.com/labring/aiproxy/core/relay/adaptor"
  10. "github.com/labring/aiproxy/core/relay/meta"
  11. "github.com/labring/aiproxy/core/relay/mode"
  12. "github.com/labring/aiproxy/core/relay/plugin"
  13. "github.com/labring/aiproxy/core/relay/plugin/noop"
  14. "github.com/labring/aiproxy/core/relay/plugin/thinksplit/splitter"
  15. "github.com/labring/aiproxy/core/relay/utils"
  16. )
  17. var _ plugin.Plugin = (*ThinkPlugin)(nil)
  18. // ThinkPlugin implements the think content splitting functionality
  19. type ThinkPlugin struct {
  20. noop.Noop
  21. }
  22. // NewThinkPlugin creates a new think plugin instance
  23. func NewThinkPlugin() plugin.Plugin {
  24. return &ThinkPlugin{}
  25. }
  26. // getConfig retrieves the plugin configuration
  27. func (p *ThinkPlugin) getConfig(meta *meta.Meta) (*Config, error) {
  28. pluginConfig := &Config{}
  29. if err := meta.ModelConfig.LoadPluginConfig("think-split", pluginConfig); err != nil {
  30. return nil, err
  31. }
  32. return pluginConfig, nil
  33. }
  34. // DoResponse handles the response processing to split think content
  35. func (p *ThinkPlugin) DoResponse(
  36. meta *meta.Meta,
  37. store adaptor.Store,
  38. c *gin.Context,
  39. resp *http.Response,
  40. do adaptor.DoResponse,
  41. ) (model.Usage, adaptor.Error) {
  42. // Only process chat completions
  43. if meta.Mode != mode.ChatCompletions {
  44. return do.DoResponse(meta, store, c, resp)
  45. }
  46. // Check if think splitting is enabled
  47. pluginConfig, err := p.getConfig(meta)
  48. if err != nil || !pluginConfig.Enable {
  49. return do.DoResponse(meta, store, c, resp)
  50. }
  51. return p.handleResponse(meta, store, c, resp, do)
  52. }
  53. // handleResponse processes streaming responses
  54. func (p *ThinkPlugin) handleResponse(
  55. meta *meta.Meta,
  56. store adaptor.Store,
  57. c *gin.Context,
  58. resp *http.Response,
  59. do adaptor.DoResponse,
  60. ) (model.Usage, adaptor.Error) {
  61. // Create a custom response writer
  62. rw := &thinkResponseWriter{
  63. ResponseWriter: c.Writer,
  64. }
  65. c.Writer = rw
  66. defer func() {
  67. c.Writer = rw.ResponseWriter
  68. }()
  69. return do.DoResponse(meta, store, c, resp)
  70. }
  71. // thinkResponseWriter wraps the response writer for streaming responses
  72. type thinkResponseWriter struct {
  73. gin.ResponseWriter
  74. thinkSplitter *splitter.Splitter
  75. isStream bool
  76. done bool
  77. }
  78. func (rw *thinkResponseWriter) getThinkSplitter() *splitter.Splitter {
  79. if rw.thinkSplitter == nil {
  80. rw.thinkSplitter = splitter.NewThinkSplitter()
  81. }
  82. return rw.thinkSplitter
  83. }
  84. func (rw *thinkResponseWriter) Write(b []byte) (int, error) {
  85. if rw.done {
  86. return rw.ResponseWriter.Write(b)
  87. }
  88. // For streaming responses, process each chunk
  89. node, err := sonic.Get(b)
  90. if err != nil || !node.Valid() {
  91. return rw.ResponseWriter.Write(b)
  92. }
  93. // Process the chunk
  94. respMap, err := node.Map()
  95. if err != nil {
  96. return rw.ResponseWriter.Write(b)
  97. }
  98. // Check if this is a streaming response chunk
  99. if rw.isStream || utils.IsStreamResponseWithHeader(rw.Header()) {
  100. rw.isStream = true
  101. rw.done = StreamSplitThink(respMap, rw.getThinkSplitter())
  102. jsonData, err := sonic.Marshal(respMap)
  103. if err != nil {
  104. return rw.ResponseWriter.Write(b)
  105. }
  106. return rw.ResponseWriter.Write(jsonData)
  107. }
  108. rw.done = true
  109. SplitThink(respMap, rw.getThinkSplitter())
  110. jsonData, err := sonic.Marshal(respMap)
  111. if err != nil {
  112. return rw.ResponseWriter.Write(b)
  113. }
  114. if rw.ResponseWriter.Header().Get("Content-Length") != "" {
  115. rw.ResponseWriter.Header().Set("Content-Length", strconv.Itoa(len(jsonData)))
  116. }
  117. return rw.ResponseWriter.Write(jsonData)
  118. }
  119. func (rw *thinkResponseWriter) WriteString(s string) (int, error) {
  120. return rw.Write(conv.StringToBytes(s))
  121. }
  122. // renderCallback maybe reuse data, so don't modify data
  123. func StreamSplitThink(data map[string]any, thinkSplitter *splitter.Splitter) (done bool) {
  124. choices, ok := data["choices"].([]any)
  125. // only support one choice
  126. if !ok || len(choices) != 1 {
  127. return false
  128. }
  129. choice := choices[0]
  130. choiceMap, ok := choice.(map[string]any)
  131. if !ok {
  132. return false
  133. }
  134. delta, ok := choiceMap["delta"].(map[string]any)
  135. if !ok {
  136. return false
  137. }
  138. content, ok := delta["content"].(string)
  139. if !ok {
  140. return false
  141. }
  142. if _, ok := delta["reasoning_content"].(string); ok {
  143. return true
  144. }
  145. think, remaining := thinkSplitter.Process(conv.StringToBytes(content))
  146. if len(think) == 0 && len(remaining) == 0 {
  147. delta["content"] = ""
  148. delete(delta, "reasoning_content")
  149. return false
  150. }
  151. if len(think) != 0 && len(remaining) != 0 {
  152. delta["content"] = conv.BytesToString(remaining)
  153. delta["reasoning_content"] = conv.BytesToString(think)
  154. return false
  155. }
  156. if len(think) > 0 {
  157. delta["content"] = ""
  158. delta["reasoning_content"] = conv.BytesToString(think)
  159. return false
  160. }
  161. if len(remaining) > 0 {
  162. delta["content"] = conv.BytesToString(remaining)
  163. delete(delta, "reasoning_content")
  164. return true
  165. }
  166. return false
  167. }
  168. func SplitThink(data map[string]any, thinkSplitter *splitter.Splitter) {
  169. choices, ok := data["choices"].([]any)
  170. if !ok {
  171. return
  172. }
  173. for _, choice := range choices {
  174. choiceMap, ok := choice.(map[string]any)
  175. if !ok {
  176. continue
  177. }
  178. message, ok := choiceMap["message"].(map[string]any)
  179. if !ok {
  180. continue
  181. }
  182. content, ok := message["content"].(string)
  183. if !ok {
  184. continue
  185. }
  186. if _, ok := message["reasoning_content"].(string); ok {
  187. continue
  188. }
  189. think, remaining := thinkSplitter.Process(conv.StringToBytes(content))
  190. message["reasoning_content"] = conv.BytesToString(think)
  191. message["content"] = conv.BytesToString(remaining)
  192. }
  193. }