search.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. package websearch
  2. import (
  3. "bytes"
  4. "context"
  5. _ "embed"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/url"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/bytedance/sonic"
  17. "github.com/bytedance/sonic/ast"
  18. "github.com/gin-gonic/gin"
  19. "github.com/labring/aiproxy/core/common"
  20. "github.com/labring/aiproxy/core/common/conv"
  21. "github.com/labring/aiproxy/core/middleware"
  22. "github.com/labring/aiproxy/core/model"
  23. "github.com/labring/aiproxy/core/relay/adaptor"
  24. "github.com/labring/aiproxy/core/relay/adaptors"
  25. "github.com/labring/aiproxy/core/relay/controller"
  26. "github.com/labring/aiproxy/core/relay/meta"
  27. "github.com/labring/aiproxy/core/relay/mode"
  28. "github.com/labring/aiproxy/core/relay/plugin"
  29. "github.com/labring/aiproxy/core/relay/plugin/noop"
  30. "github.com/labring/aiproxy/core/relay/utils"
  31. "github.com/labring/aiproxy/mcp-servers/hosted/web-search/engine"
  32. "golang.org/x/sync/errgroup"
  33. )
  34. var _ plugin.Plugin = (*WebSearch)(nil)
  35. type GetChannel func(modelName string) (*model.Channel, error)
  36. // WebSearch implements web search functionality
  37. type WebSearch struct {
  38. noop.Noop
  39. GetChannel GetChannel
  40. }
  41. // NewWebSearchPlugin creates a new web search plugin
  42. func NewWebSearchPlugin(getChannel GetChannel) plugin.Plugin {
  43. return &WebSearch{
  44. GetChannel: getChannel,
  45. }
  46. }
  47. //go:embed prompts/arxiv.md
  48. var arxivSearchPrompts string
  49. //go:embed prompts/internet.md
  50. var internetSearchPrompts string
  51. // Constants for metadata keys
  52. const (
  53. searchCount = "web-search-count"
  54. rewriteUsage = "web-search-rewrite-usage"
  55. )
  56. // Metadata helper functions
  57. func setSearchCount(m *meta.Meta, count int) {
  58. m.Set(searchCount, count)
  59. }
  60. func getSearchCount(m *meta.Meta) int {
  61. return m.GetInt(searchCount)
  62. }
  63. func setRewriteUsage(m *meta.Meta, usage model.Usage) {
  64. m.Set(rewriteUsage, usage)
  65. }
  66. func getRewriteUsage(m *meta.Meta) *model.Usage {
  67. usage, ok := m.Get(rewriteUsage)
  68. if !ok {
  69. return nil
  70. }
  71. u, ok := usage.(model.Usage)
  72. if !ok {
  73. panic(fmt.Sprintf("rewrite usage type %T is not a model.Usage", usage))
  74. }
  75. return &u
  76. }
  77. func (p *WebSearch) getConfig(meta *meta.Meta) (Config, error) {
  78. pluginConfig := Config{}
  79. if err := meta.ModelConfig.LoadPluginConfig("web-search", &pluginConfig); err != nil {
  80. return Config{}, err
  81. }
  82. return pluginConfig, nil
  83. }
  84. // ConvertRequest intercepts and modifies requests to add web search capabilities
  85. func (p *WebSearch) ConvertRequest(
  86. meta *meta.Meta,
  87. store adaptor.Store,
  88. req *http.Request,
  89. do adaptor.ConvertRequest,
  90. ) (adaptor.ConvertResult, error) {
  91. // Skip if not chat completions mode
  92. if meta.Mode != mode.ChatCompletions {
  93. return do.ConvertRequest(meta, store, req)
  94. }
  95. // Load plugin configuration
  96. pluginConfig, err := p.getConfig(meta)
  97. if err != nil {
  98. return do.ConvertRequest(meta, store, req)
  99. }
  100. // Skip if plugin is disabled
  101. if !pluginConfig.Enable {
  102. return do.ConvertRequest(meta, store, req)
  103. }
  104. // Apply default configuration values if needed
  105. if err := p.validateAndApplyDefaults(&pluginConfig); err != nil {
  106. return do.ConvertRequest(meta, store, req)
  107. }
  108. // Initialize search engines
  109. engines, arxivExists, err := p.initializeSearchEngines(pluginConfig.SearchFrom)
  110. if err != nil || len(engines) == 0 {
  111. return do.ConvertRequest(meta, store, req)
  112. }
  113. // Read and parse request body
  114. body, err := common.GetRequestBodyReusable(req)
  115. if err != nil {
  116. return adaptor.ConvertResult{}, fmt.Errorf("failed to read request body: %w", err)
  117. }
  118. var chatRequest map[string]any
  119. if err := sonic.Unmarshal(body, &chatRequest); err != nil {
  120. return do.ConvertRequest(meta, store, req)
  121. }
  122. // Check if web search should be enabled for this request
  123. webSearchOptions, hasWebSearchOptions := chatRequest["web_search_options"].(map[string]any)
  124. if !pluginConfig.ForceSearch && !hasWebSearchOptions {
  125. return do.ConvertRequest(meta, store, req)
  126. }
  127. webSearchEnable, ok := webSearchOptions["enable"].(bool)
  128. if ok && !webSearchEnable {
  129. return do.ConvertRequest(meta, store, req)
  130. }
  131. // Extract user query from messages
  132. messages, ok := chatRequest["messages"].([]any)
  133. if !ok || len(messages) == 0 {
  134. return do.ConvertRequest(meta, store, req)
  135. }
  136. queryIndex, query := p.extractUserQuery(messages)
  137. if query == "" {
  138. return do.ConvertRequest(meta, store, req)
  139. }
  140. // Prepare search rewrite prompt if configured
  141. searchRewritePrompt := p.prepareSearchRewritePrompt(
  142. pluginConfig.SearchRewrite,
  143. arxivExists,
  144. webSearchOptions,
  145. )
  146. // Generate search contexts
  147. searchContexts, err := p.generateSearchContexts(
  148. meta,
  149. store,
  150. pluginConfig,
  151. query,
  152. searchRewritePrompt,
  153. )
  154. if err != nil {
  155. return do.ConvertRequest(meta, store, req)
  156. }
  157. if len(searchContexts) == 0 {
  158. return do.ConvertRequest(meta, store, req)
  159. }
  160. // Execute searches
  161. searchResult := p.executeSearches(context.Background(), engines, searchContexts)
  162. if searchResult.Count == 0 || len(searchResult.Results) == 0 {
  163. return do.ConvertRequest(meta, store, req)
  164. }
  165. setSearchCount(meta, searchResult.Count)
  166. // Format search results and modify request
  167. p.formatSearchResults(messages, queryIndex, query, searchResult.Results, pluginConfig)
  168. delete(chatRequest, "web_search_options")
  169. // Create new request body
  170. modifiedBody, err := sonic.Marshal(chatRequest)
  171. if err != nil {
  172. return do.ConvertRequest(meta, store, req)
  173. }
  174. // Update the request
  175. common.SetRequestBody(req, modifiedBody)
  176. defer common.SetRequestBody(req, body)
  177. // Store references in context if needed
  178. if pluginConfig.NeedReference {
  179. meta.Set("references", searchResult.Results)
  180. }
  181. return do.ConvertRequest(meta, store, req)
  182. }
  183. // validateAndApplyDefaults validates configuration and applies default values
  184. func (p *WebSearch) validateAndApplyDefaults(config *Config) error {
  185. // Set default max results
  186. if config.MaxResults == 0 {
  187. config.MaxResults = 10
  188. }
  189. // Configure reference settings
  190. if config.NeedReference {
  191. if config.ReferenceFormat != "" && !strings.Contains(config.ReferenceFormat, "%s") {
  192. return errors.New("invalid reference format")
  193. }
  194. }
  195. // Set default prompt template if not provided
  196. if config.PromptTemplate == "" {
  197. config.PromptTemplate = p.getDefaultPromptTemplate(config.NeedReference)
  198. }
  199. // Validate prompt template
  200. if !strings.Contains(config.PromptTemplate, "{search_results}") ||
  201. !strings.Contains(config.PromptTemplate, "{question}") {
  202. return errors.New("invalid prompt template")
  203. }
  204. return nil
  205. }
  206. // getDefaultPromptTemplate returns the appropriate default prompt template based on configuration
  207. func (p *WebSearch) getDefaultPromptTemplate(needReference bool) string {
  208. if needReference {
  209. return `# 以下内容是基于用户发送的消息的搜索结果:
  210. {search_results}
  211. 在我给你的搜索结果中,每个结果都是[webpage X begin]...[webpage X end]格式的,X代表每篇文章的数字索引。请在适当的情况下在句子末尾引用上下文。请按照引用编号[X]的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如[3][5],切记不要将引用集中在最后返回引用编号,而是在答案对应部分列出。
  212. 在回答时,请注意以下几点:
  213. - 今天是北京时间:{cur_date}。
  214. - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
  215. - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内,并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项;如非必要,不要主动告诉用户搜索结果未提供的内容。
  216. - 对于创作类的问题(如写论文),请务必在正文的段落中引用对应的参考编号,例如[3][5],不能只在文章末尾引用。你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
  217. - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
  218. - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
  219. - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
  220. - 你的回答应该综合多个相关网页来回答,不能重复引用一个网页。
  221. - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
  222. # 用户消息为:
  223. {question}`
  224. }
  225. return `# 以下内容是基于用户发送的消息的搜索结果:
  226. {search_results}
  227. 在我给你的搜索结果中,每个结果都是[webpage begin]...[webpage end]格式的。
  228. 在回答时,请注意以下几点:
  229. - 今天是北京时间:{cur_date}。
  230. - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
  231. - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内。如非必要,不要主动告诉用户搜索结果未提供的内容。
  232. - 对于创作类的问题(如写论文),你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
  233. - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
  234. - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
  235. - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
  236. - 你的回答应该综合多个相关网页来回答,但回答中不要给出网页的引用来源。
  237. - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
  238. # 用户消息为:
  239. {question}`
  240. }
  241. // initializeSearchEngines creates search engine instances based on configuration
  242. func (p *WebSearch) initializeSearchEngines(configs []EngineConfig) ([]engine.Engine, bool, error) {
  243. var (
  244. engines []engine.Engine
  245. arxivExists bool
  246. )
  247. for _, e := range configs {
  248. switch e.Type {
  249. case "bing":
  250. var spec BingSpec
  251. if err := e.LoadSpec(&spec); err != nil {
  252. return nil, false, err
  253. }
  254. engines = append(engines, engine.NewBingEngine(spec.APIKey))
  255. case "bingcn":
  256. engines = append(engines, engine.NewBingCNEngine())
  257. case "google":
  258. var spec GoogleSpec
  259. if err := e.LoadSpec(&spec); err != nil {
  260. return nil, false, err
  261. }
  262. engines = append(engines, engine.NewGoogleEngine(spec.APIKey, spec.CX))
  263. case "arxiv":
  264. engines = append(engines, engine.NewArxivEngine())
  265. arxivExists = true
  266. case "searchxng":
  267. var spec SearchXNGSpec
  268. if err := e.LoadSpec(&spec); err != nil {
  269. return nil, false, err
  270. }
  271. engines = append(engines, engine.NewSearchXNGEngine(spec.BaseURL))
  272. default:
  273. return nil, false, fmt.Errorf("unsupported engine type: %s", e.Type)
  274. }
  275. }
  276. return engines, arxivExists, nil
  277. }
  278. // extractUserQuery finds the last user message in the conversation
  279. func (p *WebSearch) extractUserQuery(messages []any) (int, string) {
  280. for i := len(messages) - 1; i >= 0; i-- {
  281. msg, ok := messages[i].(map[string]any)
  282. if !ok {
  283. continue
  284. }
  285. if role, ok := msg["role"].(string); ok && role == "user" {
  286. if content, ok := msg["content"].(string); ok {
  287. return i, content
  288. }
  289. return i, ""
  290. }
  291. }
  292. return -1, ""
  293. }
  294. // prepareSearchRewritePrompt prepares the prompt for search query rewriting
  295. func (p *WebSearch) prepareSearchRewritePrompt(
  296. searchRewrite SearchRewrite,
  297. arxivExists bool,
  298. webSearchOptions map[string]any,
  299. ) string {
  300. if !searchRewrite.Enable {
  301. return ""
  302. }
  303. // Select appropriate prompt template
  304. var searchRewritePromptTemplate string
  305. if arxivExists {
  306. searchRewritePromptTemplate = arxivSearchPrompts
  307. } else {
  308. searchRewritePromptTemplate = internetSearchPrompts
  309. }
  310. // Adjust max count based on search context size if specified
  311. maxCount := searchRewrite.MaxCount
  312. if webSearchOptions != nil {
  313. if searchContextSize, ok := webSearchOptions["search_context_size"].(string); ok {
  314. switch searchContextSize {
  315. case "low":
  316. maxCount = 1
  317. case "medium":
  318. maxCount = 3
  319. case "high":
  320. maxCount = 5
  321. }
  322. }
  323. }
  324. // Replace placeholder with actual max count
  325. return strings.ReplaceAll(searchRewritePromptTemplate, "{max_count}", strconv.Itoa(maxCount))
  326. }
  327. // generateSearchContexts creates search contexts based on the user query
  328. func (p *WebSearch) generateSearchContexts(
  329. m *meta.Meta,
  330. store adaptor.Store,
  331. config Config,
  332. query, searchRewritePrompt string,
  333. ) ([]engine.SearchQuery, error) {
  334. if searchRewritePrompt == "" {
  335. return []engine.SearchQuery{{
  336. Queries: []string{query},
  337. Language: config.DefaultLanguage,
  338. }}, nil
  339. }
  340. // Prepare request for query rewriting
  341. rewriteBody, err := sonic.Marshal(map[string]any{
  342. "stream": false,
  343. "max_tokens": 4096,
  344. "model": config.SearchRewrite.ModelName,
  345. "messages": []map[string]any{
  346. {
  347. "role": "user",
  348. "content": strings.ReplaceAll(searchRewritePrompt, "{question}", query),
  349. },
  350. },
  351. })
  352. if err != nil {
  353. return nil, err
  354. }
  355. // Set up test context for rewrite request
  356. w := httptest.NewRecorder()
  357. newc, _ := gin.CreateTestContext(w)
  358. newc.Request = &http.Request{
  359. URL: &url.URL{},
  360. Body: io.NopCloser(bytes.NewReader(rewriteBody)),
  361. Header: make(http.Header),
  362. }
  363. middleware.SetRequestID(newc, "web-search-rewrite")
  364. // Set up metadata for rewrite request
  365. modelName := config.SearchRewrite.ModelName
  366. if modelName == "" {
  367. modelName = m.OriginModel
  368. }
  369. newMeta := meta.NewMeta(
  370. nil,
  371. mode.ChatCompletions,
  372. modelName,
  373. model.ModelConfig{
  374. Model: modelName,
  375. Type: mode.ChatCompletions,
  376. },
  377. meta.WithRequestID("web-search-rewrite"),
  378. )
  379. // Set appropriate channel
  380. if config.SearchRewrite.ModelName == "" {
  381. newMeta.CopyChannelFromMeta(m)
  382. } else {
  383. channel, err := p.GetChannel(config.SearchRewrite.ModelName)
  384. if err != nil {
  385. return nil, err
  386. }
  387. newMeta.SetChannel(channel)
  388. }
  389. // Get adaptor and handle request
  390. adaptor, ok := adaptors.GetAdaptor(newMeta.Channel.Type)
  391. if !ok {
  392. return nil, errors.New("adaptor not found")
  393. }
  394. result := controller.Handle(adaptor, newc, newMeta, store)
  395. if result.Error != nil {
  396. return nil, result.Error
  397. }
  398. setRewriteUsage(m, result.Usage)
  399. // Extract content from response
  400. contentNode, err := sonic.Get(w.Body.Bytes(), "choices", 0, "message", "content")
  401. if err != nil {
  402. return nil, err
  403. }
  404. content, err := contentNode.String()
  405. if err != nil || content == "" {
  406. return nil, err
  407. }
  408. if strings.Contains(content, "none") {
  409. return nil, nil
  410. }
  411. // Parse search queries from LLM response
  412. return p.parseSearchContexts(config.DefaultLanguage, content), nil
  413. }
  414. // parseSearchContexts extracts search queries from LLM response
  415. func (p *WebSearch) parseSearchContexts(defaultLanguage, content string) []engine.SearchQuery {
  416. var searchContexts []engine.SearchQuery
  417. for _, line := range strings.Split(content, "\n") {
  418. line = strings.TrimSpace(line)
  419. if line == "" {
  420. continue
  421. }
  422. parts := strings.SplitN(line, ":", 2)
  423. if len(parts) != 2 {
  424. continue
  425. }
  426. engineType := strings.TrimSpace(parts[0])
  427. queryStr := strings.TrimSpace(parts[1])
  428. var ctx engine.SearchQuery
  429. ctx.Language = defaultLanguage
  430. switch engineType {
  431. case "internet":
  432. ctx.Queries = []string{queryStr}
  433. default:
  434. // Arxiv category
  435. ctx.ArxivCategory = engineType
  436. ctx.Queries = strings.Split(queryStr, ",")
  437. for i := range ctx.Queries {
  438. ctx.Queries[i] = strings.TrimSpace(ctx.Queries[i])
  439. }
  440. }
  441. if len(ctx.Queries) > 0 {
  442. searchContexts = append(searchContexts, ctx)
  443. if ctx.ArxivCategory != "" {
  444. // Conduct inquiries in all areas to increase recall.
  445. backupCtx := ctx
  446. backupCtx.ArxivCategory = ""
  447. searchContexts = append(searchContexts, backupCtx)
  448. }
  449. }
  450. }
  451. return searchContexts
  452. }
  453. // Search result structure
  454. type searchResult struct {
  455. Results []engine.SearchResult
  456. Count int
  457. }
  458. // executeSearches performs searches using all configured engines
  459. func (p *WebSearch) executeSearches(
  460. ctx context.Context,
  461. engines []engine.Engine,
  462. searchContexts []engine.SearchQuery,
  463. ) *searchResult {
  464. var (
  465. allResults []engine.SearchResult
  466. mu sync.Mutex
  467. )
  468. ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
  469. defer cancel()
  470. g, ctx := errgroup.WithContext(ctx)
  471. for _, eng := range engines {
  472. for _, searchCtx := range searchContexts {
  473. g.Go(func() error {
  474. results, err := eng.Search(ctx, engine.SearchQuery{
  475. Queries: searchCtx.Queries,
  476. MaxResults: 10,
  477. Language: searchCtx.Language,
  478. ArxivCategory: searchCtx.ArxivCategory,
  479. })
  480. if err != nil {
  481. return err
  482. }
  483. mu.Lock()
  484. allResults = append(allResults, results...)
  485. mu.Unlock()
  486. return nil
  487. })
  488. }
  489. }
  490. _ = g.Wait()
  491. seen := make(map[string]bool)
  492. var uniqueResults []engine.SearchResult
  493. for _, result := range allResults {
  494. if !seen[result.Link] {
  495. seen[result.Link] = true
  496. uniqueResults = append(uniqueResults, result)
  497. }
  498. }
  499. return &searchResult{
  500. Results: uniqueResults,
  501. Count: len(engines) * len(searchContexts),
  502. }
  503. }
  504. // formatSearchResults formats search results for the prompt
  505. func (p *WebSearch) formatSearchResults(
  506. messages []any,
  507. queryIndex int,
  508. query string,
  509. searchResults []engine.SearchResult,
  510. config Config,
  511. ) {
  512. message, ok := messages[queryIndex].(map[string]any)
  513. if !ok {
  514. return
  515. }
  516. var formattedResults []string
  517. for i, result := range searchResults {
  518. if config.NeedReference {
  519. formattedResults = append(formattedResults,
  520. fmt.Sprintf("[webpage %d begin]\n%s\n[webpage %d end]", i+1, result.Content, i+1))
  521. } else {
  522. formattedResults = append(formattedResults,
  523. fmt.Sprintf("[webpage begin]\n%s\n[webpage end]", result.Content))
  524. }
  525. }
  526. // Fill template
  527. curDate := time.Now().In(time.FixedZone("CST", 8*3600)).Format("2006年1月2日")
  528. searchResultsStr := strings.Join(formattedResults, "\n")
  529. prompt := strings.Replace(config.PromptTemplate, "{search_results}", searchResultsStr, 1)
  530. prompt = strings.Replace(prompt, "{question}", query, 1)
  531. prompt = strings.Replace(prompt, "{cur_date}", curDate, 1)
  532. // Update message
  533. message["content"] = prompt
  534. }
  535. // Custom response writer to handle metadata and references
  536. type responseWriter struct {
  537. gin.ResponseWriter
  538. refWritten bool
  539. referenceFormat string
  540. references []engine.SearchResult
  541. referencesLocation string
  542. webSearchCount int
  543. rewriteUsage *model.Usage
  544. rewriteUsageWritten bool
  545. rewriteUsageField string
  546. isStream bool
  547. }
  548. // Write overrides the standard Write method to inject metadata
  549. func (rw *responseWriter) Write(b []byte) (int, error) {
  550. if rw.isStream || utils.IsStreamResponseWithHeader(rw.Header()) {
  551. rw.isStream = true
  552. }
  553. node, err := sonic.Get(b)
  554. if err != nil || !node.Valid() {
  555. return rw.ResponseWriter.Write(b)
  556. }
  557. // Process the response node
  558. rw.processRewriteUsage(&node)
  559. rw.processWebSearchCount(&node)
  560. rw.processReferences(&node)
  561. // Marshal the modified node
  562. nb, err := sonic.Marshal(&node)
  563. if err != nil {
  564. return rw.ResponseWriter.Write(b)
  565. }
  566. if !rw.isStream {
  567. if rw.ResponseWriter.Header().Get("Content-Length") != "" {
  568. rw.ResponseWriter.Header().Set("Content-Length", strconv.Itoa(len(nb)))
  569. }
  570. }
  571. return rw.ResponseWriter.Write(nb)
  572. }
  573. // processRewriteUsage adds rewrite usage information to the response
  574. func (rw *responseWriter) processRewriteUsage(node *ast.Node) {
  575. if rw.rewriteUsage != nil && !rw.rewriteUsageWritten {
  576. field := rw.rewriteUsageField
  577. if field == "" {
  578. field = "rewrite_usage"
  579. }
  580. _, _ = node.SetAny(field, rw.rewriteUsage)
  581. rw.rewriteUsageWritten = true
  582. }
  583. }
  584. // processWebSearchCount adds or increments web search count in usage statistics
  585. func (rw *responseWriter) processWebSearchCount(node *ast.Node) {
  586. if rw.webSearchCount <= 0 {
  587. return
  588. }
  589. usageNode := node.Get("usage")
  590. if usageNode == nil || !usageNode.Valid() {
  591. return
  592. }
  593. // Check if web_search_count already exists
  594. existingCount := usageNode.Get("web_search_count")
  595. if existingCount != nil && existingCount.Valid() {
  596. // If exists, add to the existing value
  597. currentCount, _ := existingCount.Int64()
  598. _, _ = usageNode.Set(
  599. "web_search_count",
  600. ast.NewNumber(strconv.FormatInt(currentCount+int64(rw.webSearchCount), 10)),
  601. )
  602. } else {
  603. // If not exists, set the value
  604. _, _ = usageNode.Set("web_search_count", ast.NewNumber(strconv.FormatInt(int64(rw.webSearchCount), 10)))
  605. }
  606. }
  607. func buildReferenceContent(searchResults []engine.SearchResult) string {
  608. formattedReferences := make([]string, len(searchResults))
  609. for i, result := range searchResults {
  610. formattedReferences[i] = fmt.Sprintf("[%d] [%s](%s)", i+1, result.Title, result.Link)
  611. }
  612. return strings.Join(formattedReferences, "\n\n")
  613. }
  614. // processReferences adds reference information to the content
  615. func (rw *responseWriter) processReferences(node *ast.Node) {
  616. if rw.refWritten || len(rw.references) == 0 {
  617. return
  618. }
  619. rw.refWritten = true
  620. if rw.referencesLocation == "" || rw.referencesLocation == "content" {
  621. var contentNode *ast.Node
  622. if rw.isStream {
  623. contentNode = node.GetByPath("choices", 0, "delta", "content")
  624. } else {
  625. contentNode = node.GetByPath("choices", 0, "message", "content")
  626. }
  627. if contentNode != nil && contentNode.Valid() {
  628. content, err := contentNode.String()
  629. if err == nil {
  630. format := rw.referenceFormat
  631. if format == "" {
  632. format = "**References:**\n%s"
  633. }
  634. ref := fmt.Sprintf(format, buildReferenceContent(rw.references))
  635. refContent := fmt.Sprintf("%s\n\n%s", ref, content)
  636. *contentNode = ast.NewString(refContent)
  637. }
  638. }
  639. } else {
  640. var outterLocation *ast.Node
  641. if rw.isStream {
  642. outterLocation = node.GetByPath("choices", 0, "delta")
  643. } else {
  644. outterLocation = node.GetByPath("choices", 0, "message")
  645. }
  646. if outterLocation != nil && outterLocation.Valid() {
  647. _, _ = outterLocation.SetAny(rw.referencesLocation, rw.references)
  648. }
  649. }
  650. }
  651. // WriteString implements the WriteString method for the custom response writer
  652. func (rw *responseWriter) WriteString(s string) (int, error) {
  653. return rw.Write(conv.StringToBytes(s))
  654. }
  655. // DoResponse handles response modification for references
  656. func (p *WebSearch) DoResponse(
  657. meta *meta.Meta,
  658. store adaptor.Store,
  659. c *gin.Context,
  660. resp *http.Response,
  661. do adaptor.DoResponse,
  662. ) (model.Usage, adaptor.Error) {
  663. if meta.Mode != mode.ChatCompletions {
  664. return do.DoResponse(meta, store, c, resp)
  665. }
  666. var references []engine.SearchResult
  667. referencesI, ok := meta.Get("references")
  668. if ok {
  669. references, ok = referencesI.([]engine.SearchResult)
  670. if !ok {
  671. panic(fmt.Sprintf("references type %T is not a []engine.SearchResult", referencesI))
  672. }
  673. }
  674. count := getSearchCount(meta)
  675. var rewriteUsage *model.Usage
  676. rewriteUsageField := ""
  677. pluginConfig, _ := p.getConfig(meta)
  678. if pluginConfig.SearchRewrite.AddRewriteUsage {
  679. rewriteUsage = getRewriteUsage(meta)
  680. rewriteUsageField = pluginConfig.SearchRewrite.RewriteUsageField
  681. }
  682. // Check if we need to wrap the response writer
  683. if len(references) > 0 || rewriteUsage != nil || count > 0 {
  684. rw := &responseWriter{
  685. ResponseWriter: c.Writer,
  686. webSearchCount: count,
  687. rewriteUsageWritten: false,
  688. rewriteUsage: rewriteUsage,
  689. rewriteUsageField: rewriteUsageField,
  690. references: references,
  691. referencesLocation: pluginConfig.ReferenceLocation,
  692. referenceFormat: pluginConfig.ReferenceFormat,
  693. }
  694. c.Writer = rw
  695. defer func() {
  696. c.Writer = rw.ResponseWriter
  697. }()
  698. }
  699. return p.doResponseWithCount(meta, store, c, resp, do, count)
  700. }
  701. // doResponseWithCount adds search count to usage statistics
  702. func (p *WebSearch) doResponseWithCount(
  703. meta *meta.Meta,
  704. store adaptor.Store,
  705. c *gin.Context,
  706. resp *http.Response,
  707. do adaptor.DoResponse,
  708. count int,
  709. ) (model.Usage, adaptor.Error) {
  710. u, err := do.DoResponse(meta, store, c, resp)
  711. if err != nil {
  712. return model.Usage{}, err
  713. }
  714. u.WebSearchCount += model.ZeroNullInt64(int64(count))
  715. return u, nil
  716. }