search.go 25 KB


  1. package websearch
  2. import (
  3. "bytes"
  4. "context"
  5. _ "embed"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/url"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/bytedance/sonic"
  17. "github.com/bytedance/sonic/ast"
  18. "github.com/gin-gonic/gin"
  19. "github.com/labring/aiproxy/core/common"
  20. "github.com/labring/aiproxy/core/common/conv"
  21. "github.com/labring/aiproxy/core/middleware"
  22. "github.com/labring/aiproxy/core/model"
  23. "github.com/labring/aiproxy/core/relay/adaptor"
  24. "github.com/labring/aiproxy/core/relay/adaptors"
  25. "github.com/labring/aiproxy/core/relay/controller"
  26. "github.com/labring/aiproxy/core/relay/meta"
  27. "github.com/labring/aiproxy/core/relay/mode"
  28. "github.com/labring/aiproxy/core/relay/plugin"
  29. "github.com/labring/aiproxy/core/relay/plugin/noop"
  30. "github.com/labring/aiproxy/core/relay/plugin/patch"
  31. "github.com/labring/aiproxy/core/relay/utils"
  32. "github.com/labring/aiproxy/mcp-servers/hosted/web-search/engine"
  33. "github.com/sirupsen/logrus"
  34. "golang.org/x/sync/errgroup"
  35. )
  36. var _ plugin.Plugin = (*WebSearch)(nil)
  37. type GetChannel func(modelName string) (*model.Channel, error)
  38. // WebSearch implements web search functionality
  39. type WebSearch struct {
  40. noop.Noop
  41. GetChannel GetChannel
  42. }
  43. // NewWebSearchPlugin creates a new web search plugin
  44. func NewWebSearchPlugin(getChannel GetChannel) plugin.Plugin {
  45. return &WebSearch{
  46. GetChannel: getChannel,
  47. }
  48. }
  49. //go:embed prompts/arxiv.md
  50. var arxivSearchPrompts string
  51. //go:embed prompts/internet.md
  52. var internetSearchPrompts string
  53. // Constants for metadata keys
  54. const (
  55. searchCount = "web-search-count"
  56. rewriteUsage = "web-search-rewrite-usage"
  57. )
  58. // Metadata helper functions
  59. func setSearchCount(m *meta.Meta, count int) {
  60. m.Set(searchCount, count)
  61. }
  62. func getSearchCount(m *meta.Meta) int {
  63. return m.GetInt(searchCount)
  64. }
  65. func setRewriteUsage(m *meta.Meta, usage model.Usage) {
  66. m.Set(rewriteUsage, usage)
  67. }
  68. func getRewriteUsage(m *meta.Meta) *model.Usage {
  69. usage, ok := m.Get(rewriteUsage)
  70. if !ok {
  71. return nil
  72. }
  73. u, ok := usage.(model.Usage)
  74. if !ok {
  75. panic(fmt.Sprintf("rewrite usage type %T is not a model.Usage", usage))
  76. }
  77. return &u
  78. }
  79. func (p *WebSearch) getConfig(meta *meta.Meta) (Config, error) {
  80. pluginConfig := Config{}
  81. if err := meta.ModelConfig.LoadPluginConfig("web-search", &pluginConfig); err != nil {
  82. return Config{}, err
  83. }
  84. return pluginConfig, nil
  85. }
  86. func lazyRemoveSearchOption(meta *meta.Meta) {
  87. patch.AddLazyPatch(meta, patch.PatchOperation{
  88. Op: patch.OpFunction,
  89. Function: func(root *ast.Node) (bool, error) {
  90. ok, err := root.Unset("web_search_options")
  91. if err != nil {
  92. return false, err
  93. }
  94. return ok, nil
  95. },
  96. })
  97. }
  98. func fallback(
  99. meta *meta.Meta,
  100. store adaptor.Store,
  101. req *http.Request,
  102. do adaptor.ConvertRequest,
  103. ) (adaptor.ConvertResult, error) {
  104. lazyRemoveSearchOption(meta)
  105. return do.ConvertRequest(meta, store, req)
  106. }
  107. // ConvertRequest intercepts and modifies requests to add web search capabilities
  108. func (p *WebSearch) ConvertRequest(
  109. meta *meta.Meta,
  110. store adaptor.Store,
  111. req *http.Request,
  112. do adaptor.ConvertRequest,
  113. ) (adaptor.ConvertResult, error) {
  114. // Skip if not chat completions mode
  115. if meta.Mode != mode.ChatCompletions {
  116. return do.ConvertRequest(meta, store, req)
  117. }
  118. // Load plugin configuration
  119. pluginConfig, err := p.getConfig(meta)
  120. if err != nil {
  121. return do.ConvertRequest(meta, store, req)
  122. }
  123. // Skip if plugin is disabled
  124. if !pluginConfig.Enable {
  125. return do.ConvertRequest(meta, store, req)
  126. }
  127. // Apply default configuration values if needed
  128. if err := p.validateAndApplyDefaults(&pluginConfig); err != nil {
  129. return fallback(meta, store, req, do)
  130. }
  131. // Initialize search engines
  132. engines, arxivExists, err := p.initializeSearchEngines(pluginConfig.SearchFrom)
  133. if err != nil || len(engines) == 0 {
  134. return fallback(meta, store, req, do)
  135. }
  136. // Read and parse request body
  137. body, err := common.GetRequestBodyReusable(req)
  138. if err != nil {
  139. return adaptor.ConvertResult{}, fmt.Errorf("failed to read request body: %w", err)
  140. }
  141. var chatRequest map[string]any
  142. if err := sonic.Unmarshal(body, &chatRequest); err != nil {
  143. return fallback(meta, store, req, do)
  144. }
  145. // Check if web search should be enabled for this request
  146. webSearchOptions, hasWebSearchOptions := chatRequest["web_search_options"].(map[string]any)
  147. if !pluginConfig.ForceSearch && !hasWebSearchOptions {
  148. return fallback(meta, store, req, do)
  149. }
  150. webSearchEnable, ok := webSearchOptions["enable"].(bool)
  151. if ok && !webSearchEnable {
  152. return fallback(meta, store, req, do)
  153. }
  154. // Extract user query from messages
  155. messages, ok := chatRequest["messages"].([]any)
  156. if !ok || len(messages) == 0 {
  157. return fallback(meta, store, req, do)
  158. }
  159. queryIndex, query := p.extractUserQuery(messages)
  160. if query == "" {
  161. return fallback(meta, store, req, do)
  162. }
  163. // Prepare search rewrite prompt if configured
  164. searchRewritePrompt := p.prepareSearchRewritePrompt(
  165. pluginConfig.SearchRewrite,
  166. arxivExists,
  167. webSearchOptions,
  168. )
  169. // Generate search contexts
  170. searchContexts, err := p.generateSearchContexts(
  171. meta,
  172. store,
  173. pluginConfig,
  174. query,
  175. searchRewritePrompt,
  176. )
  177. if err != nil {
  178. return fallback(meta, store, req, do)
  179. }
  180. if len(searchContexts) == 0 {
  181. return fallback(meta, store, req, do)
  182. }
  183. // Execute searches
  184. searchResult := p.executeSearches(context.Background(), engines, searchContexts)
  185. if searchResult.Count == 0 || len(searchResult.Results) == 0 {
  186. return fallback(meta, store, req, do)
  187. }
  188. setSearchCount(meta, searchResult.Count)
  189. // Format search results and modify request
  190. p.formatSearchResults(messages, queryIndex, query, searchResult.Results, pluginConfig)
  191. delete(chatRequest, "web_search_options")
  192. // Create new request body
  193. modifiedBody, err := sonic.Marshal(chatRequest)
  194. if err != nil {
  195. return fallback(meta, store, req, do)
  196. }
  197. // Update the request
  198. common.SetRequestBody(req, modifiedBody)
  199. defer common.SetRequestBody(req, body)
  200. // Store references in context if needed
  201. if pluginConfig.NeedReference {
  202. meta.Set("references", searchResult.Results)
  203. }
  204. return do.ConvertRequest(meta, store, req)
  205. }
  206. // validateAndApplyDefaults validates configuration and applies default values
  207. func (p *WebSearch) validateAndApplyDefaults(config *Config) error {
  208. // Set default max results
  209. if config.MaxResults == 0 {
  210. config.MaxResults = 10
  211. }
  212. // Configure reference settings
  213. if config.NeedReference {
  214. if config.ReferenceFormat != "" && !strings.Contains(config.ReferenceFormat, "%s") {
  215. return errors.New("invalid reference format")
  216. }
  217. }
  218. // Set default prompt template if not provided
  219. if config.PromptTemplate == "" {
  220. config.PromptTemplate = p.getDefaultPromptTemplate(config.NeedReference)
  221. }
  222. // Validate prompt template
  223. if !strings.Contains(config.PromptTemplate, "{search_results}") ||
  224. !strings.Contains(config.PromptTemplate, "{question}") {
  225. return errors.New("invalid prompt template")
  226. }
  227. return nil
  228. }
  229. // getDefaultPromptTemplate returns the appropriate default prompt template based on configuration
  230. func (p *WebSearch) getDefaultPromptTemplate(needReference bool) string {
  231. if needReference {
  232. return `# 以下内容是基于用户发送的消息的搜索结果:
  233. {search_results}
  234. 在我给你的搜索结果中,每个结果都是[webpage X begin]...[webpage X end]格式的,X代表每篇文章的数字索引。请在适当的情况下在句子末尾引用上下文。请按照引用编号[X]的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如[3][5],切记不要将引用集中在最后返回引用编号,而是在答案对应部分列出。
  235. 在回答时,请注意以下几点:
  236. - 今天是北京时间:{cur_date}。
  237. - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
  238. - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内,并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项;如非必要,不要主动告诉用户搜索结果未提供的内容。
  239. - 对于创作类的问题(如写论文),请务必在正文的段落中引用对应的参考编号,例如[3][5],不能只在文章末尾引用。你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
  240. - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
  241. - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
  242. - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
  243. - 你的回答应该综合多个相关网页来回答,不能重复引用一个网页。
  244. - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
  245. # 用户消息为:
  246. {question}`
  247. }
  248. return `# 以下内容是基于用户发送的消息的搜索结果:
  249. {search_results}
  250. 在我给你的搜索结果中,每个结果都是[webpage begin]...[webpage end]格式的。
  251. 在回答时,请注意以下几点:
  252. - 今天是北京时间:{cur_date}。
  253. - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
  254. - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内。如非必要,不要主动告诉用户搜索结果未提供的内容。
  255. - 对于创作类的问题(如写论文),你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
  256. - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
  257. - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
  258. - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
  259. - 你的回答应该综合多个相关网页来回答,但回答中不要给出网页的引用来源。
  260. - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
  261. # 用户消息为:
  262. {question}`
  263. }
  264. // initializeSearchEngines creates search engine instances based on configuration
  265. func (p *WebSearch) initializeSearchEngines(configs []EngineConfig) ([]engine.Engine, bool, error) {
  266. var (
  267. engines []engine.Engine
  268. arxivExists bool
  269. )
  270. for _, e := range configs {
  271. switch e.Type {
  272. case "bing":
  273. var spec BingSpec
  274. if err := e.LoadSpec(&spec); err != nil {
  275. return nil, false, err
  276. }
  277. engines = append(engines, engine.NewBingEngine(spec.APIKey))
  278. case "bingcn":
  279. engines = append(engines, engine.NewBingCNEngine())
  280. case "google":
  281. var spec GoogleSpec
  282. if err := e.LoadSpec(&spec); err != nil {
  283. return nil, false, err
  284. }
  285. engines = append(engines, engine.NewGoogleEngine(spec.APIKey, spec.CX))
  286. case "arxiv":
  287. engines = append(engines, engine.NewArxivEngine())
  288. arxivExists = true
  289. case "searchxng":
  290. var spec SearchXNGSpec
  291. if err := e.LoadSpec(&spec); err != nil {
  292. return nil, false, err
  293. }
  294. engines = append(engines, engine.NewSearchXNGEngine(spec.BaseURL))
  295. default:
  296. return nil, false, fmt.Errorf("unsupported engine type: %s", e.Type)
  297. }
  298. }
  299. return engines, arxivExists, nil
  300. }
  301. // extractUserQuery finds the last user message in the conversation
  302. func (p *WebSearch) extractUserQuery(messages []any) (int, string) {
  303. for i := len(messages) - 1; i >= 0; i-- {
  304. msg, ok := messages[i].(map[string]any)
  305. if !ok {
  306. continue
  307. }
  308. if role, ok := msg["role"].(string); ok && role == "user" {
  309. if content, ok := msg["content"].(string); ok {
  310. return i, content
  311. }
  312. return i, ""
  313. }
  314. }
  315. return -1, ""
  316. }
  317. // prepareSearchRewritePrompt prepares the prompt for search query rewriting
  318. func (p *WebSearch) prepareSearchRewritePrompt(
  319. searchRewrite SearchRewrite,
  320. arxivExists bool,
  321. webSearchOptions map[string]any,
  322. ) string {
  323. if !searchRewrite.Enable {
  324. return ""
  325. }
  326. // Select appropriate prompt template
  327. var searchRewritePromptTemplate string
  328. if arxivExists {
  329. searchRewritePromptTemplate = arxivSearchPrompts
  330. } else {
  331. searchRewritePromptTemplate = internetSearchPrompts
  332. }
  333. // Adjust max count based on search context size if specified
  334. maxCount := searchRewrite.MaxCount
  335. if webSearchOptions != nil {
  336. if searchContextSize, ok := webSearchOptions["search_context_size"].(string); ok {
  337. switch searchContextSize {
  338. case "low":
  339. maxCount = 1
  340. case "medium":
  341. maxCount = 3
  342. case "high":
  343. maxCount = 5
  344. }
  345. }
  346. }
  347. // Replace placeholder with actual max count
  348. return strings.ReplaceAll(searchRewritePromptTemplate, "{max_count}", strconv.Itoa(maxCount))
  349. }
  350. // generateSearchContexts creates search contexts based on the user query
  351. func (p *WebSearch) generateSearchContexts(
  352. m *meta.Meta,
  353. store adaptor.Store,
  354. config Config,
  355. query, searchRewritePrompt string,
  356. ) ([]engine.SearchQuery, error) {
  357. if searchRewritePrompt == "" {
  358. return []engine.SearchQuery{{
  359. Queries: []string{query},
  360. Language: config.DefaultLanguage,
  361. }}, nil
  362. }
  363. // Prepare request for query rewriting
  364. rewriteBody, err := sonic.Marshal(map[string]any{
  365. "stream": false,
  366. "max_tokens": 4096,
  367. "model": config.SearchRewrite.ModelName,
  368. "messages": []map[string]any{
  369. {
  370. "role": "user",
  371. "content": strings.ReplaceAll(searchRewritePrompt, "{question}", query),
  372. },
  373. },
  374. })
  375. if err != nil {
  376. return nil, err
  377. }
  378. // Set up test context for rewrite request
  379. w := httptest.NewRecorder()
  380. newc, _ := gin.CreateTestContext(w)
  381. newc.Request = &http.Request{
  382. URL: &url.URL{},
  383. Body: io.NopCloser(bytes.NewReader(rewriteBody)),
  384. Header: make(http.Header),
  385. }
  386. middleware.SetRequestID(newc, "web-search-rewrite")
  387. // Set up metadata for rewrite request
  388. modelName := config.SearchRewrite.ModelName
  389. if modelName == "" {
  390. modelName = m.OriginModel
  391. }
  392. newMeta := meta.NewMeta(
  393. nil,
  394. mode.ChatCompletions,
  395. modelName,
  396. model.ModelConfig{
  397. Model: modelName,
  398. Type: mode.ChatCompletions,
  399. },
  400. meta.WithRequestID("web-search-rewrite"),
  401. )
  402. // Set appropriate channel
  403. if config.SearchRewrite.ModelName == "" {
  404. newMeta.CopyChannelFromMeta(m)
  405. } else {
  406. channel, err := p.GetChannel(config.SearchRewrite.ModelName)
  407. if err != nil {
  408. return nil, err
  409. }
  410. newMeta.SetChannel(channel)
  411. }
  412. // Get adaptor and handle request
  413. adaptor, ok := adaptors.GetAdaptor(newMeta.Channel.Type)
  414. if !ok {
  415. return nil, errors.New("adaptor not found")
  416. }
  417. result := controller.Handle(adaptor, newc, newMeta, store)
  418. if result.Error != nil {
  419. return nil, result.Error
  420. }
  421. setRewriteUsage(m, result.Usage)
  422. // Extract content from response
  423. contentNode, err := sonic.Get(w.Body.Bytes(), "choices", 0, "message", "content")
  424. if err != nil {
  425. return nil, err
  426. }
  427. content, err := contentNode.String()
  428. if err != nil || content == "" {
  429. return nil, err
  430. }
  431. if strings.Contains(content, "none") {
  432. return nil, nil
  433. }
  434. // Parse search queries from LLM response
  435. return p.parseSearchContexts(config.DefaultLanguage, content), nil
  436. }
  437. // parseSearchContexts extracts search queries from LLM response
  438. func (p *WebSearch) parseSearchContexts(defaultLanguage, content string) []engine.SearchQuery {
  439. var searchContexts []engine.SearchQuery
  440. for line := range strings.SplitSeq(content, "\n") {
  441. line = strings.TrimSpace(line)
  442. if line == "" {
  443. continue
  444. }
  445. parts := strings.SplitN(line, ":", 2)
  446. if len(parts) != 2 {
  447. continue
  448. }
  449. engineType := strings.TrimSpace(parts[0])
  450. queryStr := strings.TrimSpace(parts[1])
  451. var ctx engine.SearchQuery
  452. ctx.Language = defaultLanguage
  453. switch engineType {
  454. case "internet":
  455. ctx.Queries = []string{queryStr}
  456. default:
  457. // Arxiv category
  458. ctx.ArxivCategory = engineType
  459. ctx.Queries = strings.Split(queryStr, ",")
  460. for i := range ctx.Queries {
  461. ctx.Queries[i] = strings.TrimSpace(ctx.Queries[i])
  462. }
  463. }
  464. if len(ctx.Queries) > 0 {
  465. searchContexts = append(searchContexts, ctx)
  466. if ctx.ArxivCategory != "" {
  467. // Conduct inquiries in all areas to increase recall.
  468. backupCtx := ctx
  469. backupCtx.ArxivCategory = ""
  470. searchContexts = append(searchContexts, backupCtx)
  471. }
  472. }
  473. }
  474. return searchContexts
  475. }
  476. // Search result structure
  477. type searchResult struct {
  478. Results []engine.SearchResult
  479. Count int
  480. }
  481. // executeSearches performs searches using all configured engines
  482. func (p *WebSearch) executeSearches(
  483. ctx context.Context,
  484. engines []engine.Engine,
  485. searchContexts []engine.SearchQuery,
  486. ) *searchResult {
  487. var (
  488. allResults []engine.SearchResult
  489. mu sync.Mutex
  490. )
  491. ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
  492. defer cancel()
  493. g, ctx := errgroup.WithContext(ctx)
  494. for _, eng := range engines {
  495. for _, searchCtx := range searchContexts {
  496. g.Go(func() error {
  497. results, err := eng.Search(ctx, engine.SearchQuery{
  498. Queries: searchCtx.Queries,
  499. MaxResults: 10,
  500. Language: searchCtx.Language,
  501. ArxivCategory: searchCtx.ArxivCategory,
  502. })
  503. if err != nil {
  504. logrus.Errorf("search error: %v", err)
  505. return err
  506. }
  507. mu.Lock()
  508. allResults = append(allResults, results...)
  509. mu.Unlock()
  510. return nil
  511. })
  512. }
  513. }
  514. _ = g.Wait()
  515. seen := make(map[string]bool)
  516. var uniqueResults []engine.SearchResult
  517. for _, result := range allResults {
  518. if !seen[result.Link] {
  519. seen[result.Link] = true
  520. uniqueResults = append(uniqueResults, result)
  521. }
  522. }
  523. return &searchResult{
  524. Results: uniqueResults,
  525. Count: len(engines) * len(searchContexts),
  526. }
  527. }
  528. // formatSearchResults formats search results for the prompt
  529. func (p *WebSearch) formatSearchResults(
  530. messages []any,
  531. queryIndex int,
  532. query string,
  533. searchResults []engine.SearchResult,
  534. config Config,
  535. ) {
  536. message, ok := messages[queryIndex].(map[string]any)
  537. if !ok {
  538. return
  539. }
  540. var formattedResults []string
  541. for i, result := range searchResults {
  542. if config.NeedReference {
  543. formattedResults = append(formattedResults,
  544. fmt.Sprintf("[webpage %d begin]\n%s\n[webpage %d end]", i+1, result.Content, i+1))
  545. } else {
  546. formattedResults = append(formattedResults,
  547. fmt.Sprintf("[webpage begin]\n%s\n[webpage end]", result.Content))
  548. }
  549. }
  550. // Fill template
  551. curDate := time.Now().In(time.FixedZone("CST", 8*3600)).Format("2006年1月2日")
  552. searchResultsStr := strings.Join(formattedResults, "\n")
  553. prompt := strings.Replace(config.PromptTemplate, "{search_results}", searchResultsStr, 1)
  554. prompt = strings.Replace(prompt, "{question}", query, 1)
  555. prompt = strings.Replace(prompt, "{cur_date}", curDate, 1)
  556. // Update message
  557. message["content"] = prompt
  558. }
  559. // Custom response writer to handle metadata and references
  560. type responseWriter struct {
  561. gin.ResponseWriter
  562. refWritten bool
  563. referenceFormat string
  564. references []engine.SearchResult
  565. referencesLocation string
  566. webSearchCount int
  567. rewriteUsage *model.Usage
  568. rewriteUsageWritten bool
  569. rewriteUsageField string
  570. isStream bool
  571. }
  572. // Write overrides the standard Write method to inject metadata
  573. func (rw *responseWriter) Write(b []byte) (int, error) {
  574. if rw.isStream || utils.IsStreamResponseWithHeader(rw.Header()) {
  575. rw.isStream = true
  576. }
  577. node, err := sonic.Get(b)
  578. if err != nil || !node.Valid() {
  579. return rw.ResponseWriter.Write(b)
  580. }
  581. // Process the response node
  582. rw.processRewriteUsage(&node)
  583. rw.processWebSearchCount(&node)
  584. rw.processReferences(&node)
  585. // Marshal the modified node
  586. nb, err := sonic.Marshal(&node)
  587. if err != nil {
  588. return rw.ResponseWriter.Write(b)
  589. }
  590. if !rw.isStream {
  591. if rw.ResponseWriter.Header().Get("Content-Length") != "" {
  592. rw.ResponseWriter.Header().Set("Content-Length", strconv.Itoa(len(nb)))
  593. }
  594. }
  595. return rw.ResponseWriter.Write(nb)
  596. }
  597. // processRewriteUsage adds rewrite usage information to the response
  598. func (rw *responseWriter) processRewriteUsage(node *ast.Node) {
  599. if rw.rewriteUsage != nil && !rw.rewriteUsageWritten {
  600. field := rw.rewriteUsageField
  601. if field == "" {
  602. field = "rewrite_usage"
  603. }
  604. _, _ = node.SetAny(field, rw.rewriteUsage)
  605. rw.rewriteUsageWritten = true
  606. }
  607. }
  608. // processWebSearchCount adds or increments web search count in usage statistics
  609. func (rw *responseWriter) processWebSearchCount(node *ast.Node) {
  610. if rw.webSearchCount <= 0 {
  611. return
  612. }
  613. usageNode := node.Get("usage")
  614. if usageNode == nil || !usageNode.Valid() {
  615. return
  616. }
  617. // Check if web_search_count already exists
  618. existingCount := usageNode.Get("web_search_count")
  619. if existingCount != nil && existingCount.Valid() {
  620. // If exists, add to the existing value
  621. currentCount, _ := existingCount.Int64()
  622. _, _ = usageNode.Set(
  623. "web_search_count",
  624. ast.NewNumber(strconv.FormatInt(currentCount+int64(rw.webSearchCount), 10)),
  625. )
  626. } else {
  627. // If not exists, set the value
  628. _, _ = usageNode.Set("web_search_count", ast.NewNumber(strconv.FormatInt(int64(rw.webSearchCount), 10)))
  629. }
  630. }
  631. func buildReferenceContent(searchResults []engine.SearchResult) string {
  632. formattedReferences := make([]string, len(searchResults))
  633. for i, result := range searchResults {
  634. formattedReferences[i] = fmt.Sprintf("[%d] [%s](%s)", i+1, result.Title, result.Link)
  635. }
  636. return strings.Join(formattedReferences, "\n\n")
  637. }
  638. // processReferences adds reference information to the content
  639. func (rw *responseWriter) processReferences(node *ast.Node) {
  640. if rw.refWritten || len(rw.references) == 0 {
  641. return
  642. }
  643. rw.refWritten = true
  644. if rw.referencesLocation == "" || rw.referencesLocation == "content" {
  645. var contentNode *ast.Node
  646. if rw.isStream {
  647. contentNode = node.GetByPath("choices", 0, "delta", "content")
  648. } else {
  649. contentNode = node.GetByPath("choices", 0, "message", "content")
  650. }
  651. if contentNode != nil && contentNode.Valid() {
  652. content, err := contentNode.String()
  653. if err == nil {
  654. format := rw.referenceFormat
  655. if format == "" {
  656. format = "**References:**\n%s"
  657. }
  658. ref := fmt.Sprintf(format, buildReferenceContent(rw.references))
  659. refContent := fmt.Sprintf("%s\n\n%s", ref, content)
  660. *contentNode = ast.NewString(refContent)
  661. }
  662. }
  663. } else {
  664. var outterLocation *ast.Node
  665. if rw.isStream {
  666. outterLocation = node.GetByPath("choices", 0, "delta")
  667. } else {
  668. outterLocation = node.GetByPath("choices", 0, "message")
  669. }
  670. if outterLocation != nil && outterLocation.Valid() {
  671. _, _ = outterLocation.SetAny(rw.referencesLocation, rw.references)
  672. }
  673. }
  674. }
  675. // WriteString implements the WriteString method for the custom response writer
  676. func (rw *responseWriter) WriteString(s string) (int, error) {
  677. return rw.Write(conv.StringToBytes(s))
  678. }
  679. // DoResponse handles response modification for references
  680. func (p *WebSearch) DoResponse(
  681. meta *meta.Meta,
  682. store adaptor.Store,
  683. c *gin.Context,
  684. resp *http.Response,
  685. do adaptor.DoResponse,
  686. ) (model.Usage, adaptor.Error) {
  687. if meta.Mode != mode.ChatCompletions {
  688. return do.DoResponse(meta, store, c, resp)
  689. }
  690. var references []engine.SearchResult
  691. referencesI, ok := meta.Get("references")
  692. if ok {
  693. references, ok = referencesI.([]engine.SearchResult)
  694. if !ok {
  695. panic(fmt.Sprintf("references type %T is not a []engine.SearchResult", referencesI))
  696. }
  697. }
  698. count := getSearchCount(meta)
  699. var rewriteUsage *model.Usage
  700. rewriteUsageField := ""
  701. pluginConfig, _ := p.getConfig(meta)
  702. if pluginConfig.SearchRewrite.AddRewriteUsage {
  703. rewriteUsage = getRewriteUsage(meta)
  704. rewriteUsageField = pluginConfig.SearchRewrite.RewriteUsageField
  705. }
  706. // Check if we need to wrap the response writer
  707. if len(references) > 0 || rewriteUsage != nil || count > 0 {
  708. rw := &responseWriter{
  709. ResponseWriter: c.Writer,
  710. webSearchCount: count,
  711. rewriteUsageWritten: false,
  712. rewriteUsage: rewriteUsage,
  713. rewriteUsageField: rewriteUsageField,
  714. references: references,
  715. referencesLocation: pluginConfig.ReferenceLocation,
  716. referenceFormat: pluginConfig.ReferenceFormat,
  717. }
  718. c.Writer = rw
  719. defer func() {
  720. c.Writer = rw.ResponseWriter
  721. }()
  722. }
  723. return p.doResponseWithCount(meta, store, c, resp, do, count)
  724. }
  725. // doResponseWithCount adds search count to usage statistics
  726. func (p *WebSearch) doResponseWithCount(
  727. meta *meta.Meta,
  728. store adaptor.Store,
  729. c *gin.Context,
  730. resp *http.Response,
  731. do adaptor.DoResponse,
  732. count int,
  733. ) (model.Usage, adaptor.Error) {
  734. u, err := do.DoResponse(meta, store, c, resp)
  735. if err != nil {
  736. return model.Usage{}, err
  737. }
  738. u.WebSearchCount += model.ZeroNullInt64(int64(count))
  739. return u, nil
  740. }