split.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. package thinksplit
  2. import (
  3. "net/http"
  4. "strconv"
  5. "github.com/bytedance/sonic"
  6. "github.com/gin-gonic/gin"
  7. "github.com/labring/aiproxy/core/common/conv"
  8. "github.com/labring/aiproxy/core/model"
  9. "github.com/labring/aiproxy/core/relay/adaptor"
  10. "github.com/labring/aiproxy/core/relay/meta"
  11. "github.com/labring/aiproxy/core/relay/mode"
  12. "github.com/labring/aiproxy/core/relay/plugin"
  13. "github.com/labring/aiproxy/core/relay/plugin/noop"
  14. "github.com/labring/aiproxy/core/relay/plugin/thinksplit/splitter"
  15. "github.com/labring/aiproxy/core/relay/utils"
  16. )
  17. var _ plugin.Plugin = (*ThinkPlugin)(nil)
  18. // ThinkPlugin implements the think content splitting functionality
  19. type ThinkPlugin struct {
  20. noop.Noop
  21. }
  22. // NewThinkPlugin creates a new think plugin instance
  23. func NewThinkPlugin() plugin.Plugin {
  24. return &ThinkPlugin{}
  25. }
  26. // getConfig retrieves the plugin configuration
  27. func (p *ThinkPlugin) getConfig(meta *meta.Meta) (*Config, error) {
  28. pluginConfig := &Config{}
  29. if err := meta.ModelConfig.LoadPluginConfig("think-split", pluginConfig); err != nil {
  30. return nil, err
  31. }
  32. return pluginConfig, nil
  33. }
  34. // DoResponse handles the response processing to split think content
  35. func (p *ThinkPlugin) DoResponse(
  36. meta *meta.Meta,
  37. store adaptor.Store,
  38. c *gin.Context,
  39. resp *http.Response,
  40. do adaptor.DoResponse,
  41. ) (model.Usage, adaptor.Error) {
  42. // Only process chat completions
  43. if meta.Mode != mode.ChatCompletions {
  44. return do.DoResponse(meta, store, c, resp)
  45. }
  46. // Check if think splitting is enabled
  47. pluginConfig, err := p.getConfig(meta)
  48. if err != nil || !pluginConfig.Enable {
  49. return do.DoResponse(meta, store, c, resp)
  50. }
  51. return p.handleResponse(meta, store, c, resp, do)
  52. }
  53. // handleResponse processes streaming responses
  54. func (p *ThinkPlugin) handleResponse(
  55. meta *meta.Meta,
  56. store adaptor.Store,
  57. c *gin.Context,
  58. resp *http.Response,
  59. do adaptor.DoResponse,
  60. ) (model.Usage, adaptor.Error) {
  61. // Create a custom response writer
  62. rw := &thinkResponseWriter{
  63. ResponseWriter: c.Writer,
  64. }
  65. c.Writer = rw
  66. defer func() {
  67. c.Writer = rw.ResponseWriter
  68. }()
  69. return do.DoResponse(meta, store, c, resp)
  70. }
  71. // thinkResponseWriter wraps the response writer for streaming responses
  72. type thinkResponseWriter struct {
  73. gin.ResponseWriter
  74. thinkSplitter *splitter.Splitter
  75. isStream bool
  76. done bool
  77. }
  78. func (rw *thinkResponseWriter) getThinkSplitter() *splitter.Splitter {
  79. if rw.thinkSplitter == nil {
  80. rw.thinkSplitter = splitter.NewThinkSplitter()
  81. }
  82. return rw.thinkSplitter
  83. }
  84. // ignore WriteHeaderNow
  85. func (rw *thinkResponseWriter) WriteHeaderNow() {}
  86. func (rw *thinkResponseWriter) Write(b []byte) (int, error) {
  87. if rw.done {
  88. return rw.ResponseWriter.Write(b)
  89. }
  90. // For streaming responses, process each chunk
  91. node, err := sonic.Get(b)
  92. if err != nil || !node.Valid() {
  93. return rw.ResponseWriter.Write(b)
  94. }
  95. // Process the chunk
  96. respMap, err := node.Map()
  97. if err != nil {
  98. return rw.ResponseWriter.Write(b)
  99. }
  100. // Check if this is a streaming response chunk
  101. if rw.isStream || utils.IsStreamResponseWithHeader(rw.Header()) {
  102. rw.isStream = true
  103. rw.done = StreamSplitThink(respMap, rw.getThinkSplitter())
  104. jsonData, err := sonic.Marshal(respMap)
  105. if err != nil {
  106. return rw.ResponseWriter.Write(b)
  107. }
  108. return rw.ResponseWriter.Write(jsonData)
  109. }
  110. rw.done = true
  111. SplitThink(respMap, rw.getThinkSplitter())
  112. jsonData, err := sonic.Marshal(respMap)
  113. if err != nil {
  114. return rw.ResponseWriter.Write(b)
  115. }
  116. if rw.ResponseWriter.Header().Get("Content-Length") != "" {
  117. rw.ResponseWriter.Header().Set("Content-Length", strconv.Itoa(len(jsonData)))
  118. }
  119. return rw.ResponseWriter.Write(jsonData)
  120. }
  121. func (rw *thinkResponseWriter) WriteString(s string) (int, error) {
  122. return rw.Write(conv.StringToBytes(s))
  123. }
  124. // renderCallback maybe reuse data, so don't modify data
  125. func StreamSplitThink(data map[string]any, thinkSplitter *splitter.Splitter) (done bool) {
  126. choices, ok := data["choices"].([]any)
  127. // only support one choice
  128. if !ok || len(choices) != 1 {
  129. return false
  130. }
  131. choice := choices[0]
  132. choiceMap, ok := choice.(map[string]any)
  133. if !ok {
  134. return false
  135. }
  136. delta, ok := choiceMap["delta"].(map[string]any)
  137. if !ok {
  138. return false
  139. }
  140. content, ok := delta["content"].(string)
  141. if !ok {
  142. return false
  143. }
  144. if _, ok := delta["reasoning_content"].(string); ok {
  145. return true
  146. }
  147. think, remaining := thinkSplitter.Process(conv.StringToBytes(content))
  148. if len(think) == 0 && len(remaining) == 0 {
  149. delta["content"] = ""
  150. delete(delta, "reasoning_content")
  151. return false
  152. }
  153. if len(think) != 0 && len(remaining) != 0 {
  154. delta["content"] = conv.BytesToString(remaining)
  155. delta["reasoning_content"] = conv.BytesToString(think)
  156. return false
  157. }
  158. if len(think) > 0 {
  159. delta["content"] = ""
  160. delta["reasoning_content"] = conv.BytesToString(think)
  161. return false
  162. }
  163. if len(remaining) > 0 {
  164. delta["content"] = conv.BytesToString(remaining)
  165. delete(delta, "reasoning_content")
  166. return true
  167. }
  168. return false
  169. }
  170. func SplitThink(data map[string]any, thinkSplitter *splitter.Splitter) {
  171. choices, ok := data["choices"].([]any)
  172. if !ok {
  173. return
  174. }
  175. for _, choice := range choices {
  176. choiceMap, ok := choice.(map[string]any)
  177. if !ok {
  178. continue
  179. }
  180. message, ok := choiceMap["message"].(map[string]any)
  181. if !ok {
  182. continue
  183. }
  184. content, ok := message["content"].(string)
  185. if !ok {
  186. continue
  187. }
  188. if _, ok := message["reasoning_content"].(string); ok {
  189. continue
  190. }
  191. think, remaining := thinkSplitter.Process(conv.StringToBytes(content))
  192. message["reasoning_content"] = conv.BytesToString(think)
  193. message["content"] = conv.BytesToString(remaining)
  194. }
  195. }