search.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855
  1. package websearch
  2. import (
  3. "bytes"
  4. "context"
  5. _ "embed"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/url"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/bytedance/sonic"
  17. "github.com/bytedance/sonic/ast"
  18. "github.com/gin-gonic/gin"
  19. "github.com/labring/aiproxy/core/common"
  20. "github.com/labring/aiproxy/core/common/conv"
  21. "github.com/labring/aiproxy/core/middleware"
  22. "github.com/labring/aiproxy/core/model"
  23. "github.com/labring/aiproxy/core/relay/adaptor"
  24. "github.com/labring/aiproxy/core/relay/adaptors"
  25. "github.com/labring/aiproxy/core/relay/controller"
  26. "github.com/labring/aiproxy/core/relay/meta"
  27. "github.com/labring/aiproxy/core/relay/mode"
  28. "github.com/labring/aiproxy/core/relay/plugin"
  29. "github.com/labring/aiproxy/core/relay/plugin/noop"
  30. "github.com/labring/aiproxy/core/relay/utils"
  31. "github.com/labring/aiproxy/mcp-servers/hosted/web-search/engine"
  32. "github.com/sirupsen/logrus"
  33. "golang.org/x/sync/errgroup"
  34. )
  35. var _ plugin.Plugin = (*WebSearch)(nil)
  36. type GetChannel func(modelName string) (*model.Channel, error)
  37. // WebSearch implements web search functionality
  38. type WebSearch struct {
  39. noop.Noop
  40. GetChannel GetChannel
  41. }
  42. // NewWebSearchPlugin creates a new web search plugin
  43. func NewWebSearchPlugin(getChannel GetChannel) plugin.Plugin {
  44. return &WebSearch{
  45. GetChannel: getChannel,
  46. }
  47. }
  48. //go:embed prompts/arxiv.md
  49. var arxivSearchPrompts string
  50. //go:embed prompts/internet.md
  51. var internetSearchPrompts string
  52. // Constants for metadata keys
  53. const (
  54. searchCount = "web-search-count"
  55. rewriteUsage = "web-search-rewrite-usage"
  56. )
  57. // Metadata helper functions
  58. func setSearchCount(m *meta.Meta, count int) {
  59. m.Set(searchCount, count)
  60. }
  61. func getSearchCount(m *meta.Meta) int {
  62. return m.GetInt(searchCount)
  63. }
  64. func setRewriteUsage(m *meta.Meta, usage model.Usage) {
  65. m.Set(rewriteUsage, usage)
  66. }
  67. func getRewriteUsage(m *meta.Meta) *model.Usage {
  68. usage, ok := m.Get(rewriteUsage)
  69. if !ok {
  70. return nil
  71. }
  72. u, ok := usage.(model.Usage)
  73. if !ok {
  74. panic(fmt.Sprintf("rewrite usage type %T is not a model.Usage", usage))
  75. }
  76. return &u
  77. }
  78. func (p *WebSearch) getConfig(meta *meta.Meta) (Config, error) {
  79. pluginConfig := Config{}
  80. if err := meta.ModelConfig.LoadPluginConfig("web-search", &pluginConfig); err != nil {
  81. return Config{}, err
  82. }
  83. return pluginConfig, nil
  84. }
  85. // ConvertRequest intercepts and modifies requests to add web search capabilities
  86. func (p *WebSearch) ConvertRequest(
  87. meta *meta.Meta,
  88. store adaptor.Store,
  89. req *http.Request,
  90. do adaptor.ConvertRequest,
  91. ) (adaptor.ConvertResult, error) {
  92. // Skip if not chat completions mode
  93. if meta.Mode != mode.ChatCompletions {
  94. return do.ConvertRequest(meta, store, req)
  95. }
  96. // Load plugin configuration
  97. pluginConfig, err := p.getConfig(meta)
  98. if err != nil {
  99. return do.ConvertRequest(meta, store, req)
  100. }
  101. // Skip if plugin is disabled
  102. if !pluginConfig.Enable {
  103. return do.ConvertRequest(meta, store, req)
  104. }
  105. // Apply default configuration values if needed
  106. if err := p.validateAndApplyDefaults(&pluginConfig); err != nil {
  107. return do.ConvertRequest(meta, store, req)
  108. }
  109. // Initialize search engines
  110. engines, arxivExists, err := p.initializeSearchEngines(pluginConfig.SearchFrom)
  111. if err != nil || len(engines) == 0 {
  112. return do.ConvertRequest(meta, store, req)
  113. }
  114. // Read and parse request body
  115. body, err := common.GetRequestBodyReusable(req)
  116. if err != nil {
  117. return adaptor.ConvertResult{}, fmt.Errorf("failed to read request body: %w", err)
  118. }
  119. var chatRequest map[string]any
  120. if err := sonic.Unmarshal(body, &chatRequest); err != nil {
  121. return do.ConvertRequest(meta, store, req)
  122. }
  123. // Check if web search should be enabled for this request
  124. webSearchOptions, hasWebSearchOptions := chatRequest["web_search_options"].(map[string]any)
  125. if !pluginConfig.ForceSearch && !hasWebSearchOptions {
  126. return do.ConvertRequest(meta, store, req)
  127. }
  128. webSearchEnable, ok := webSearchOptions["enable"].(bool)
  129. if ok && !webSearchEnable {
  130. return do.ConvertRequest(meta, store, req)
  131. }
  132. // Extract user query from messages
  133. messages, ok := chatRequest["messages"].([]any)
  134. if !ok || len(messages) == 0 {
  135. return do.ConvertRequest(meta, store, req)
  136. }
  137. queryIndex, query := p.extractUserQuery(messages)
  138. if query == "" {
  139. return do.ConvertRequest(meta, store, req)
  140. }
  141. // Prepare search rewrite prompt if configured
  142. searchRewritePrompt := p.prepareSearchRewritePrompt(
  143. pluginConfig.SearchRewrite,
  144. arxivExists,
  145. webSearchOptions,
  146. )
  147. // Generate search contexts
  148. searchContexts, err := p.generateSearchContexts(
  149. meta,
  150. store,
  151. pluginConfig,
  152. query,
  153. searchRewritePrompt,
  154. )
  155. if err != nil {
  156. return do.ConvertRequest(meta, store, req)
  157. }
  158. if len(searchContexts) == 0 {
  159. return do.ConvertRequest(meta, store, req)
  160. }
  161. // Execute searches
  162. searchResult := p.executeSearches(context.Background(), engines, searchContexts)
  163. if searchResult.Count == 0 || len(searchResult.Results) == 0 {
  164. return do.ConvertRequest(meta, store, req)
  165. }
  166. setSearchCount(meta, searchResult.Count)
  167. // Format search results and modify request
  168. p.formatSearchResults(messages, queryIndex, query, searchResult.Results, pluginConfig)
  169. delete(chatRequest, "web_search_options")
  170. // Create new request body
  171. modifiedBody, err := sonic.Marshal(chatRequest)
  172. if err != nil {
  173. return do.ConvertRequest(meta, store, req)
  174. }
  175. // Update the request
  176. common.SetRequestBody(req, modifiedBody)
  177. defer common.SetRequestBody(req, body)
  178. // Store references in context if needed
  179. if pluginConfig.NeedReference {
  180. meta.Set("references", searchResult.Results)
  181. }
  182. return do.ConvertRequest(meta, store, req)
  183. }
  184. // validateAndApplyDefaults validates configuration and applies default values
  185. func (p *WebSearch) validateAndApplyDefaults(config *Config) error {
  186. // Set default max results
  187. if config.MaxResults == 0 {
  188. config.MaxResults = 10
  189. }
  190. // Configure reference settings
  191. if config.NeedReference {
  192. if config.ReferenceFormat != "" && !strings.Contains(config.ReferenceFormat, "%s") {
  193. return errors.New("invalid reference format")
  194. }
  195. }
  196. // Set default prompt template if not provided
  197. if config.PromptTemplate == "" {
  198. config.PromptTemplate = p.getDefaultPromptTemplate(config.NeedReference)
  199. }
  200. // Validate prompt template
  201. if !strings.Contains(config.PromptTemplate, "{search_results}") ||
  202. !strings.Contains(config.PromptTemplate, "{question}") {
  203. return errors.New("invalid prompt template")
  204. }
  205. return nil
  206. }
  207. // getDefaultPromptTemplate returns the appropriate default prompt template based on configuration
  208. func (p *WebSearch) getDefaultPromptTemplate(needReference bool) string {
  209. if needReference {
  210. return `# 以下内容是基于用户发送的消息的搜索结果:
  211. {search_results}
  212. 在我给你的搜索结果中,每个结果都是[webpage X begin]...[webpage X end]格式的,X代表每篇文章的数字索引。请在适当的情况下在句子末尾引用上下文。请按照引用编号[X]的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如[3][5],切记不要将引用集中在最后返回引用编号,而是在答案对应部分列出。
  213. 在回答时,请注意以下几点:
  214. - 今天是北京时间:{cur_date}。
  215. - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
  216. - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内,并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项;如非必要,不要主动告诉用户搜索结果未提供的内容。
  217. - 对于创作类的问题(如写论文),请务必在正文的段落中引用对应的参考编号,例如[3][5],不能只在文章末尾引用。你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
  218. - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
  219. - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
  220. - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
  221. - 你的回答应该综合多个相关网页来回答,不能重复引用一个网页。
  222. - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
  223. # 用户消息为:
  224. {question}`
  225. }
  226. return `# 以下内容是基于用户发送的消息的搜索结果:
  227. {search_results}
  228. 在我给你的搜索结果中,每个结果都是[webpage begin]...[webpage end]格式的。
  229. 在回答时,请注意以下几点:
  230. - 今天是北京时间:{cur_date}。
  231. - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
  232. - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内。如非必要,不要主动告诉用户搜索结果未提供的内容。
  233. - 对于创作类的问题(如写论文),你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
  234. - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
  235. - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
  236. - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
  237. - 你的回答应该综合多个相关网页来回答,但回答中不要给出网页的引用来源。
  238. - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
  239. # 用户消息为:
  240. {question}`
  241. }
  242. // initializeSearchEngines creates search engine instances based on configuration
  243. func (p *WebSearch) initializeSearchEngines(configs []EngineConfig) ([]engine.Engine, bool, error) {
  244. var (
  245. engines []engine.Engine
  246. arxivExists bool
  247. )
  248. for _, e := range configs {
  249. switch e.Type {
  250. case "bing":
  251. var spec BingSpec
  252. if err := e.LoadSpec(&spec); err != nil {
  253. return nil, false, err
  254. }
  255. engines = append(engines, engine.NewBingEngine(spec.APIKey))
  256. case "bingcn":
  257. engines = append(engines, engine.NewBingCNEngine())
  258. case "google":
  259. var spec GoogleSpec
  260. if err := e.LoadSpec(&spec); err != nil {
  261. return nil, false, err
  262. }
  263. engines = append(engines, engine.NewGoogleEngine(spec.APIKey, spec.CX))
  264. case "arxiv":
  265. engines = append(engines, engine.NewArxivEngine())
  266. arxivExists = true
  267. case "searchxng":
  268. var spec SearchXNGSpec
  269. if err := e.LoadSpec(&spec); err != nil {
  270. return nil, false, err
  271. }
  272. engines = append(engines, engine.NewSearchXNGEngine(spec.BaseURL))
  273. default:
  274. return nil, false, fmt.Errorf("unsupported engine type: %s", e.Type)
  275. }
  276. }
  277. return engines, arxivExists, nil
  278. }
  279. // extractUserQuery finds the last user message in the conversation
  280. func (p *WebSearch) extractUserQuery(messages []any) (int, string) {
  281. for i := len(messages) - 1; i >= 0; i-- {
  282. msg, ok := messages[i].(map[string]any)
  283. if !ok {
  284. continue
  285. }
  286. if role, ok := msg["role"].(string); ok && role == "user" {
  287. if content, ok := msg["content"].(string); ok {
  288. return i, content
  289. }
  290. return i, ""
  291. }
  292. }
  293. return -1, ""
  294. }
  295. // prepareSearchRewritePrompt prepares the prompt for search query rewriting
  296. func (p *WebSearch) prepareSearchRewritePrompt(
  297. searchRewrite SearchRewrite,
  298. arxivExists bool,
  299. webSearchOptions map[string]any,
  300. ) string {
  301. if !searchRewrite.Enable {
  302. return ""
  303. }
  304. // Select appropriate prompt template
  305. var searchRewritePromptTemplate string
  306. if arxivExists {
  307. searchRewritePromptTemplate = arxivSearchPrompts
  308. } else {
  309. searchRewritePromptTemplate = internetSearchPrompts
  310. }
  311. // Adjust max count based on search context size if specified
  312. maxCount := searchRewrite.MaxCount
  313. if webSearchOptions != nil {
  314. if searchContextSize, ok := webSearchOptions["search_context_size"].(string); ok {
  315. switch searchContextSize {
  316. case "low":
  317. maxCount = 1
  318. case "medium":
  319. maxCount = 3
  320. case "high":
  321. maxCount = 5
  322. }
  323. }
  324. }
  325. // Replace placeholder with actual max count
  326. return strings.ReplaceAll(searchRewritePromptTemplate, "{max_count}", strconv.Itoa(maxCount))
  327. }
  328. // generateSearchContexts creates search contexts based on the user query
  329. func (p *WebSearch) generateSearchContexts(
  330. m *meta.Meta,
  331. store adaptor.Store,
  332. config Config,
  333. query, searchRewritePrompt string,
  334. ) ([]engine.SearchQuery, error) {
  335. if searchRewritePrompt == "" {
  336. return []engine.SearchQuery{{
  337. Queries: []string{query},
  338. Language: config.DefaultLanguage,
  339. }}, nil
  340. }
  341. // Prepare request for query rewriting
  342. rewriteBody, err := sonic.Marshal(map[string]any{
  343. "stream": false,
  344. "max_tokens": 4096,
  345. "model": config.SearchRewrite.ModelName,
  346. "messages": []map[string]any{
  347. {
  348. "role": "user",
  349. "content": strings.ReplaceAll(searchRewritePrompt, "{question}", query),
  350. },
  351. },
  352. })
  353. if err != nil {
  354. return nil, err
  355. }
  356. // Set up test context for rewrite request
  357. w := httptest.NewRecorder()
  358. newc, _ := gin.CreateTestContext(w)
  359. newc.Request = &http.Request{
  360. URL: &url.URL{},
  361. Body: io.NopCloser(bytes.NewReader(rewriteBody)),
  362. Header: make(http.Header),
  363. }
  364. middleware.SetRequestID(newc, "web-search-rewrite")
  365. // Set up metadata for rewrite request
  366. modelName := config.SearchRewrite.ModelName
  367. if modelName == "" {
  368. modelName = m.OriginModel
  369. }
  370. newMeta := meta.NewMeta(
  371. nil,
  372. mode.ChatCompletions,
  373. modelName,
  374. model.ModelConfig{
  375. Model: modelName,
  376. Type: mode.ChatCompletions,
  377. },
  378. meta.WithRequestID("web-search-rewrite"),
  379. )
  380. // Set appropriate channel
  381. if config.SearchRewrite.ModelName == "" {
  382. newMeta.CopyChannelFromMeta(m)
  383. } else {
  384. channel, err := p.GetChannel(config.SearchRewrite.ModelName)
  385. if err != nil {
  386. return nil, err
  387. }
  388. newMeta.SetChannel(channel)
  389. }
  390. // Get adaptor and handle request
  391. adaptor, ok := adaptors.GetAdaptor(newMeta.Channel.Type)
  392. if !ok {
  393. return nil, errors.New("adaptor not found")
  394. }
  395. result := controller.Handle(adaptor, newc, newMeta, store)
  396. if result.Error != nil {
  397. return nil, result.Error
  398. }
  399. setRewriteUsage(m, result.Usage)
  400. // Extract content from response
  401. contentNode, err := sonic.Get(w.Body.Bytes(), "choices", 0, "message", "content")
  402. if err != nil {
  403. return nil, err
  404. }
  405. content, err := contentNode.String()
  406. if err != nil || content == "" {
  407. return nil, err
  408. }
  409. if strings.Contains(content, "none") {
  410. return nil, nil
  411. }
  412. // Parse search queries from LLM response
  413. return p.parseSearchContexts(config.DefaultLanguage, content), nil
  414. }
  415. // parseSearchContexts extracts search queries from LLM response
  416. func (p *WebSearch) parseSearchContexts(defaultLanguage, content string) []engine.SearchQuery {
  417. var searchContexts []engine.SearchQuery
  418. for line := range strings.SplitSeq(content, "\n") {
  419. line = strings.TrimSpace(line)
  420. if line == "" {
  421. continue
  422. }
  423. parts := strings.SplitN(line, ":", 2)
  424. if len(parts) != 2 {
  425. continue
  426. }
  427. engineType := strings.TrimSpace(parts[0])
  428. queryStr := strings.TrimSpace(parts[1])
  429. var ctx engine.SearchQuery
  430. ctx.Language = defaultLanguage
  431. switch engineType {
  432. case "internet":
  433. ctx.Queries = []string{queryStr}
  434. default:
  435. // Arxiv category
  436. ctx.ArxivCategory = engineType
  437. ctx.Queries = strings.Split(queryStr, ",")
  438. for i := range ctx.Queries {
  439. ctx.Queries[i] = strings.TrimSpace(ctx.Queries[i])
  440. }
  441. }
  442. if len(ctx.Queries) > 0 {
  443. searchContexts = append(searchContexts, ctx)
  444. if ctx.ArxivCategory != "" {
  445. // Conduct inquiries in all areas to increase recall.
  446. backupCtx := ctx
  447. backupCtx.ArxivCategory = ""
  448. searchContexts = append(searchContexts, backupCtx)
  449. }
  450. }
  451. }
  452. return searchContexts
  453. }
  454. // Search result structure
  455. type searchResult struct {
  456. Results []engine.SearchResult
  457. Count int
  458. }
  459. // executeSearches performs searches using all configured engines
  460. func (p *WebSearch) executeSearches(
  461. ctx context.Context,
  462. engines []engine.Engine,
  463. searchContexts []engine.SearchQuery,
  464. ) *searchResult {
  465. var (
  466. allResults []engine.SearchResult
  467. mu sync.Mutex
  468. )
  469. ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
  470. defer cancel()
  471. g, ctx := errgroup.WithContext(ctx)
  472. for _, eng := range engines {
  473. for _, searchCtx := range searchContexts {
  474. g.Go(func() error {
  475. results, err := eng.Search(ctx, engine.SearchQuery{
  476. Queries: searchCtx.Queries,
  477. MaxResults: 10,
  478. Language: searchCtx.Language,
  479. ArxivCategory: searchCtx.ArxivCategory,
  480. })
  481. if err != nil {
  482. logrus.Errorf("search error: %v", err)
  483. return err
  484. }
  485. mu.Lock()
  486. allResults = append(allResults, results...)
  487. mu.Unlock()
  488. return nil
  489. })
  490. }
  491. }
  492. _ = g.Wait()
  493. seen := make(map[string]bool)
  494. var uniqueResults []engine.SearchResult
  495. for _, result := range allResults {
  496. if !seen[result.Link] {
  497. seen[result.Link] = true
  498. uniqueResults = append(uniqueResults, result)
  499. }
  500. }
  501. return &searchResult{
  502. Results: uniqueResults,
  503. Count: len(engines) * len(searchContexts),
  504. }
  505. }
  506. // formatSearchResults formats search results for the prompt
  507. func (p *WebSearch) formatSearchResults(
  508. messages []any,
  509. queryIndex int,
  510. query string,
  511. searchResults []engine.SearchResult,
  512. config Config,
  513. ) {
  514. message, ok := messages[queryIndex].(map[string]any)
  515. if !ok {
  516. return
  517. }
  518. var formattedResults []string
  519. for i, result := range searchResults {
  520. if config.NeedReference {
  521. formattedResults = append(formattedResults,
  522. fmt.Sprintf("[webpage %d begin]\n%s\n[webpage %d end]", i+1, result.Content, i+1))
  523. } else {
  524. formattedResults = append(formattedResults,
  525. fmt.Sprintf("[webpage begin]\n%s\n[webpage end]", result.Content))
  526. }
  527. }
  528. // Fill template
  529. curDate := time.Now().In(time.FixedZone("CST", 8*3600)).Format("2006年1月2日")
  530. searchResultsStr := strings.Join(formattedResults, "\n")
  531. prompt := strings.Replace(config.PromptTemplate, "{search_results}", searchResultsStr, 1)
  532. prompt = strings.Replace(prompt, "{question}", query, 1)
  533. prompt = strings.Replace(prompt, "{cur_date}", curDate, 1)
  534. // Update message
  535. message["content"] = prompt
  536. }
  537. // Custom response writer to handle metadata and references
  538. type responseWriter struct {
  539. gin.ResponseWriter
  540. refWritten bool
  541. referenceFormat string
  542. references []engine.SearchResult
  543. referencesLocation string
  544. webSearchCount int
  545. rewriteUsage *model.Usage
  546. rewriteUsageWritten bool
  547. rewriteUsageField string
  548. isStream bool
  549. }
  550. // Write overrides the standard Write method to inject metadata
  551. func (rw *responseWriter) Write(b []byte) (int, error) {
  552. if rw.isStream || utils.IsStreamResponseWithHeader(rw.Header()) {
  553. rw.isStream = true
  554. }
  555. node, err := sonic.Get(b)
  556. if err != nil || !node.Valid() {
  557. return rw.ResponseWriter.Write(b)
  558. }
  559. // Process the response node
  560. rw.processRewriteUsage(&node)
  561. rw.processWebSearchCount(&node)
  562. rw.processReferences(&node)
  563. // Marshal the modified node
  564. nb, err := sonic.Marshal(&node)
  565. if err != nil {
  566. return rw.ResponseWriter.Write(b)
  567. }
  568. if !rw.isStream {
  569. if rw.ResponseWriter.Header().Get("Content-Length") != "" {
  570. rw.ResponseWriter.Header().Set("Content-Length", strconv.Itoa(len(nb)))
  571. }
  572. }
  573. return rw.ResponseWriter.Write(nb)
  574. }
  575. // processRewriteUsage adds rewrite usage information to the response
  576. func (rw *responseWriter) processRewriteUsage(node *ast.Node) {
  577. if rw.rewriteUsage != nil && !rw.rewriteUsageWritten {
  578. field := rw.rewriteUsageField
  579. if field == "" {
  580. field = "rewrite_usage"
  581. }
  582. _, _ = node.SetAny(field, rw.rewriteUsage)
  583. rw.rewriteUsageWritten = true
  584. }
  585. }
  586. // processWebSearchCount adds or increments web search count in usage statistics
  587. func (rw *responseWriter) processWebSearchCount(node *ast.Node) {
  588. if rw.webSearchCount <= 0 {
  589. return
  590. }
  591. usageNode := node.Get("usage")
  592. if usageNode == nil || !usageNode.Valid() {
  593. return
  594. }
  595. // Check if web_search_count already exists
  596. existingCount := usageNode.Get("web_search_count")
  597. if existingCount != nil && existingCount.Valid() {
  598. // If exists, add to the existing value
  599. currentCount, _ := existingCount.Int64()
  600. _, _ = usageNode.Set(
  601. "web_search_count",
  602. ast.NewNumber(strconv.FormatInt(currentCount+int64(rw.webSearchCount), 10)),
  603. )
  604. } else {
  605. // If not exists, set the value
  606. _, _ = usageNode.Set("web_search_count", ast.NewNumber(strconv.FormatInt(int64(rw.webSearchCount), 10)))
  607. }
  608. }
  609. func buildReferenceContent(searchResults []engine.SearchResult) string {
  610. formattedReferences := make([]string, len(searchResults))
  611. for i, result := range searchResults {
  612. formattedReferences[i] = fmt.Sprintf("[%d] [%s](%s)", i+1, result.Title, result.Link)
  613. }
  614. return strings.Join(formattedReferences, "\n\n")
  615. }
  616. // processReferences adds reference information to the content
  617. func (rw *responseWriter) processReferences(node *ast.Node) {
  618. if rw.refWritten || len(rw.references) == 0 {
  619. return
  620. }
  621. rw.refWritten = true
  622. if rw.referencesLocation == "" || rw.referencesLocation == "content" {
  623. var contentNode *ast.Node
  624. if rw.isStream {
  625. contentNode = node.GetByPath("choices", 0, "delta", "content")
  626. } else {
  627. contentNode = node.GetByPath("choices", 0, "message", "content")
  628. }
  629. if contentNode != nil && contentNode.Valid() {
  630. content, err := contentNode.String()
  631. if err == nil {
  632. format := rw.referenceFormat
  633. if format == "" {
  634. format = "**References:**\n%s"
  635. }
  636. ref := fmt.Sprintf(format, buildReferenceContent(rw.references))
  637. refContent := fmt.Sprintf("%s\n\n%s", ref, content)
  638. *contentNode = ast.NewString(refContent)
  639. }
  640. }
  641. } else {
  642. var outterLocation *ast.Node
  643. if rw.isStream {
  644. outterLocation = node.GetByPath("choices", 0, "delta")
  645. } else {
  646. outterLocation = node.GetByPath("choices", 0, "message")
  647. }
  648. if outterLocation != nil && outterLocation.Valid() {
  649. _, _ = outterLocation.SetAny(rw.referencesLocation, rw.references)
  650. }
  651. }
  652. }
  653. // WriteString implements the WriteString method for the custom response writer
  654. func (rw *responseWriter) WriteString(s string) (int, error) {
  655. return rw.Write(conv.StringToBytes(s))
  656. }
  657. // DoResponse handles response modification for references
  658. func (p *WebSearch) DoResponse(
  659. meta *meta.Meta,
  660. store adaptor.Store,
  661. c *gin.Context,
  662. resp *http.Response,
  663. do adaptor.DoResponse,
  664. ) (model.Usage, adaptor.Error) {
  665. if meta.Mode != mode.ChatCompletions {
  666. return do.DoResponse(meta, store, c, resp)
  667. }
  668. var references []engine.SearchResult
  669. referencesI, ok := meta.Get("references")
  670. if ok {
  671. references, ok = referencesI.([]engine.SearchResult)
  672. if !ok {
  673. panic(fmt.Sprintf("references type %T is not a []engine.SearchResult", referencesI))
  674. }
  675. }
  676. count := getSearchCount(meta)
  677. var rewriteUsage *model.Usage
  678. rewriteUsageField := ""
  679. pluginConfig, _ := p.getConfig(meta)
  680. if pluginConfig.SearchRewrite.AddRewriteUsage {
  681. rewriteUsage = getRewriteUsage(meta)
  682. rewriteUsageField = pluginConfig.SearchRewrite.RewriteUsageField
  683. }
  684. // Check if we need to wrap the response writer
  685. if len(references) > 0 || rewriteUsage != nil || count > 0 {
  686. rw := &responseWriter{
  687. ResponseWriter: c.Writer,
  688. webSearchCount: count,
  689. rewriteUsageWritten: false,
  690. rewriteUsage: rewriteUsage,
  691. rewriteUsageField: rewriteUsageField,
  692. references: references,
  693. referencesLocation: pluginConfig.ReferenceLocation,
  694. referenceFormat: pluginConfig.ReferenceFormat,
  695. }
  696. c.Writer = rw
  697. defer func() {
  698. c.Writer = rw.ResponseWriter
  699. }()
  700. }
  701. return p.doResponseWithCount(meta, store, c, resp, do, count)
  702. }
  703. // doResponseWithCount adds search count to usage statistics
  704. func (p *WebSearch) doResponseWithCount(
  705. meta *meta.Meta,
  706. store adaptor.Store,
  707. c *gin.Context,
  708. resp *http.Response,
  709. do adaptor.DoResponse,
  710. count int,
  711. ) (model.Usage, adaptor.Error) {
  712. u, err := do.DoResponse(meta, store, c, resp)
  713. if err != nil {
  714. return model.Usage{}, err
  715. }
  716. u.WebSearchCount += model.ZeroNullInt64(int64(count))
  717. return u, nil
  718. }