package websearch import ( "bytes" "context" _ "embed" "errors" "fmt" "io" "net/http" "net/http/httptest" "net/url" "strconv" "strings" "sync" "time" "github.com/bytedance/sonic" "github.com/bytedance/sonic/ast" "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/common" "github.com/labring/aiproxy/core/common/conv" "github.com/labring/aiproxy/core/middleware" "github.com/labring/aiproxy/core/model" "github.com/labring/aiproxy/core/relay/adaptor" "github.com/labring/aiproxy/core/relay/adaptors" "github.com/labring/aiproxy/core/relay/controller" "github.com/labring/aiproxy/core/relay/meta" "github.com/labring/aiproxy/core/relay/mode" "github.com/labring/aiproxy/core/relay/plugin" "github.com/labring/aiproxy/core/relay/plugin/noop" "github.com/labring/aiproxy/core/relay/plugin/patch" "github.com/labring/aiproxy/core/relay/utils" "github.com/labring/aiproxy/mcp-servers/hosted/web-search/engine" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" ) var _ plugin.Plugin = (*WebSearch)(nil) type GetChannel func(modelName string) (*model.Channel, error) // WebSearch implements web search functionality type WebSearch struct { noop.Noop GetChannel GetChannel } // NewWebSearchPlugin creates a new web search plugin func NewWebSearchPlugin(getChannel GetChannel) plugin.Plugin { return &WebSearch{ GetChannel: getChannel, } } //go:embed prompts/arxiv.md var arxivSearchPrompts string //go:embed prompts/internet.md var internetSearchPrompts string // Constants for metadata keys const ( searchCount = "web-search-count" rewriteUsage = "web-search-rewrite-usage" ) // Metadata helper functions func setSearchCount(m *meta.Meta, count int) { m.Set(searchCount, count) } func getSearchCount(m *meta.Meta) int { return m.GetInt(searchCount) } func setRewriteUsage(m *meta.Meta, usage model.Usage) { m.Set(rewriteUsage, usage) } func getRewriteUsage(m *meta.Meta) *model.Usage { usage, ok := m.Get(rewriteUsage) if !ok { return nil } u, ok := usage.(model.Usage) if !ok { panic(fmt.Sprintf("rewrite usage type %T is not a model.Usage", usage)) } return &u } func (p *WebSearch) getConfig(meta *meta.Meta) (Config, error) { pluginConfig := Config{} if err := meta.ModelConfig.LoadPluginConfig("web-search", &pluginConfig); err != nil { return Config{}, err } return pluginConfig, nil } func lazyRemoveSearchOption(meta *meta.Meta) { patch.AddLazyPatch(meta, patch.PatchOperation{ Op: patch.OpFunction, Function: func(root *ast.Node) (bool, error) { ok, err := root.Unset("web_search_options") if err != nil { return false, err } return ok, nil }, }) } func fallback( meta *meta.Meta, store adaptor.Store, req *http.Request, do adaptor.ConvertRequest, ) (adaptor.ConvertResult, error) { lazyRemoveSearchOption(meta) return do.ConvertRequest(meta, store, req) } // ConvertRequest intercepts and modifies requests to add web search capabilities func (p *WebSearch) ConvertRequest( meta *meta.Meta, store adaptor.Store, req *http.Request, do adaptor.ConvertRequest, ) (adaptor.ConvertResult, error) { // Skip if not chat completions mode if meta.Mode != mode.ChatCompletions { return do.ConvertRequest(meta, store, req) } // Load plugin configuration pluginConfig, err := p.getConfig(meta) if err != nil { return do.ConvertRequest(meta, store, req) } // Skip if plugin is disabled if !pluginConfig.Enable { return do.ConvertRequest(meta, store, req) } // Apply default configuration values if needed if err := p.validateAndApplyDefaults(&pluginConfig); err != nil { return fallback(meta, store, req, do) } // Initialize search engines engines, arxivExists, err := p.initializeSearchEngines(pluginConfig.SearchFrom) if err != nil || len(engines) == 0 { return fallback(meta, store, req, do) } // Read and parse request body body, err := common.GetRequestBodyReusable(req) if err != nil { return adaptor.ConvertResult{}, fmt.Errorf("failed to read request body: %w", err) } var chatRequest map[string]any if err := sonic.Unmarshal(body, &chatRequest); err != nil { return fallback(meta, store, req, do) } // Check if web search should be enabled for this request webSearchOptions, hasWebSearchOptions := chatRequest["web_search_options"].(map[string]any) if !pluginConfig.ForceSearch && !hasWebSearchOptions { return fallback(meta, store, req, do) } webSearchEnable, ok := webSearchOptions["enable"].(bool) if ok && !webSearchEnable { return fallback(meta, store, req, do) } // Extract user query from messages messages, ok := chatRequest["messages"].([]any) if !ok || len(messages) == 0 { return fallback(meta, store, req, do) } queryIndex, query := p.extractUserQuery(messages) if query == "" { return fallback(meta, store, req, do) } // Prepare search rewrite prompt if configured searchRewritePrompt := p.prepareSearchRewritePrompt( pluginConfig.SearchRewrite, arxivExists, webSearchOptions, ) // Generate search contexts searchContexts, err := p.generateSearchContexts( meta, store, pluginConfig, query, searchRewritePrompt, ) if err != nil { return fallback(meta, store, req, do) } if len(searchContexts) == 0 { return fallback(meta, store, req, do) } // Execute searches searchResult := p.executeSearches(context.Background(), engines, searchContexts) if searchResult.Count == 0 || len(searchResult.Results) == 0 { return fallback(meta, store, req, do) } setSearchCount(meta, searchResult.Count) // Format search results and modify request p.formatSearchResults(messages, queryIndex, query, searchResult.Results, pluginConfig) delete(chatRequest, "web_search_options") // Create new request body modifiedBody, err := sonic.Marshal(chatRequest) if err != nil { return fallback(meta, store, req, do) } // Update the request common.SetRequestBody(req, modifiedBody) defer common.SetRequestBody(req, body) // Store references in context if needed if pluginConfig.NeedReference { meta.Set("references", searchResult.Results) } return do.ConvertRequest(meta, store, req) } // validateAndApplyDefaults validates configuration and applies default values func (p *WebSearch) validateAndApplyDefaults(config *Config) error { // Set default max results if config.MaxResults == 0 { config.MaxResults = 10 } // Configure reference settings if config.NeedReference { if config.ReferenceFormat != "" && !strings.Contains(config.ReferenceFormat, "%s") { return errors.New("invalid reference format") } } // Set default prompt template if not provided if config.PromptTemplate == "" { config.PromptTemplate = p.getDefaultPromptTemplate(config.NeedReference) } // Validate prompt template if !strings.Contains(config.PromptTemplate, "{search_results}") || !strings.Contains(config.PromptTemplate, "{question}") { return errors.New("invalid prompt template") } return nil } // getDefaultPromptTemplate returns the appropriate default prompt template based on configuration func (p *WebSearch) getDefaultPromptTemplate(needReference bool) string { if needReference { return `# 以下内容是基于用户发送的消息的搜索结果: {search_results} 在我给你的搜索结果中,每个结果都是[webpage X begin]...[webpage X end]格式的,X代表每篇文章的数字索引。请在适当的情况下在句子末尾引用上下文。请按照引用编号[X]的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如[3][5],切记不要将引用集中在最后返回引用编号,而是在答案对应部分列出。 在回答时,请注意以下几点: - 今天是北京时间:{cur_date}。 - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。 - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内,并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项;如非必要,不要主动告诉用户搜索结果未提供的内容。 - 对于创作类的问题(如写论文),请务必在正文的段落中引用对应的参考编号,例如[3][5],不能只在文章末尾引用。你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。 - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。 - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。 - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。 - 你的回答应该综合多个相关网页来回答,不能重复引用一个网页。 - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。 # 用户消息为: {question}` } return `# 以下内容是基于用户发送的消息的搜索结果: {search_results} 在我给你的搜索结果中,每个结果都是[webpage begin]...[webpage end]格式的。 在回答时,请注意以下几点: - 今天是北京时间:{cur_date}。 - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。 - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内。如非必要,不要主动告诉用户搜索结果未提供的内容。 - 对于创作类的问题(如写论文),你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。 - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。 - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。 - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。 - 你的回答应该综合多个相关网页来回答,但回答中不要给出网页的引用来源。 - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。 # 用户消息为: {question}` } // initializeSearchEngines creates search engine instances based on configuration func (p *WebSearch) initializeSearchEngines(configs []EngineConfig) ([]engine.Engine, bool, error) { var ( engines []engine.Engine arxivExists bool ) for _, e := range configs { switch e.Type { case "bing": var spec BingSpec if err := e.LoadSpec(&spec); err != nil { return nil, false, err } engines = append(engines, engine.NewBingEngine(spec.APIKey)) case "bingcn": engines = append(engines, engine.NewBingCNEngine()) case "google": var spec GoogleSpec if err := e.LoadSpec(&spec); err != nil { return nil, false, err } engines = append(engines, engine.NewGoogleEngine(spec.APIKey, spec.CX)) case "arxiv": engines = append(engines, engine.NewArxivEngine()) arxivExists = true case "searchxng": var spec SearchXNGSpec if err := e.LoadSpec(&spec); err != nil { return nil, false, err } engines = append(engines, engine.NewSearchXNGEngine(spec.BaseURL)) default: return nil, false, fmt.Errorf("unsupported engine type: %s", e.Type) } } return engines, arxivExists, nil } // extractUserQuery finds the last user message in the conversation func (p *WebSearch) extractUserQuery(messages []any) (int, string) { for i := len(messages) - 1; i >= 0; i-- { msg, ok := messages[i].(map[string]any) if !ok { continue } if role, ok := msg["role"].(string); ok && role == "user" { if content, ok := msg["content"].(string); ok { return i, content } return i, "" } } return -1, "" } // prepareSearchRewritePrompt prepares the prompt for search query rewriting func (p *WebSearch) prepareSearchRewritePrompt( searchRewrite SearchRewrite, arxivExists bool, webSearchOptions map[string]any, ) string { if !searchRewrite.Enable { return "" } // Select appropriate prompt template var searchRewritePromptTemplate string if arxivExists { searchRewritePromptTemplate = arxivSearchPrompts } else { searchRewritePromptTemplate = internetSearchPrompts } // Adjust max count based on search context size if specified maxCount := searchRewrite.MaxCount if webSearchOptions != nil { if searchContextSize, ok := webSearchOptions["search_context_size"].(string); ok { switch searchContextSize { case "low": maxCount = 1 case "medium": maxCount = 3 case "high": maxCount = 5 } } } // Replace placeholder with actual max count return strings.ReplaceAll(searchRewritePromptTemplate, "{max_count}", strconv.Itoa(maxCount)) } // generateSearchContexts creates search contexts based on the user query func (p *WebSearch) generateSearchContexts( m *meta.Meta, store adaptor.Store, config Config, query, searchRewritePrompt string, ) ([]engine.SearchQuery, error) { if searchRewritePrompt == "" { return []engine.SearchQuery{{ Queries: []string{query}, Language: config.DefaultLanguage, }}, nil } // Prepare request for query rewriting rewriteBody, err := sonic.Marshal(map[string]any{ "stream": false, "max_tokens": 4096, "model": config.SearchRewrite.ModelName, "messages": []map[string]any{ { "role": "user", "content": strings.ReplaceAll(searchRewritePrompt, "{question}", query), }, }, }) if err != nil { return nil, err } // Set up test context for rewrite request w := httptest.NewRecorder() newc, _ := gin.CreateTestContext(w) newc.Request = &http.Request{ URL: &url.URL{}, Body: io.NopCloser(bytes.NewReader(rewriteBody)), Header: make(http.Header), } middleware.SetRequestID(newc, "web-search-rewrite") // Set up metadata for rewrite request modelName := config.SearchRewrite.ModelName if modelName == "" { modelName = m.OriginModel } newMeta := meta.NewMeta( nil, mode.ChatCompletions, modelName, model.ModelConfig{ Model: modelName, Type: mode.ChatCompletions, }, meta.WithRequestID("web-search-rewrite"), ) // Set appropriate channel if config.SearchRewrite.ModelName == "" { newMeta.CopyChannelFromMeta(m) } else { channel, err := p.GetChannel(config.SearchRewrite.ModelName) if err != nil { return nil, err } newMeta.SetChannel(channel) } // Get adaptor and handle request adaptor, ok := adaptors.GetAdaptor(newMeta.Channel.Type) if !ok { return nil, errors.New("adaptor not found") } result := controller.Handle(adaptor, newc, newMeta, store) if result.Error != nil { return nil, result.Error } setRewriteUsage(m, result.Usage) // Extract content from response contentNode, err := sonic.Get(w.Body.Bytes(), "choices", 0, "message", "content") if err != nil { return nil, err } content, err := contentNode.String() if err != nil || content == "" { return nil, err } if strings.Contains(content, "none") { return nil, nil } // Parse search queries from LLM response return p.parseSearchContexts(config.DefaultLanguage, content), nil } // parseSearchContexts extracts search queries from LLM response func (p *WebSearch) parseSearchContexts(defaultLanguage, content string) []engine.SearchQuery { var searchContexts []engine.SearchQuery for line := range strings.SplitSeq(content, "\n") { line = strings.TrimSpace(line) if line == "" { continue } parts := strings.SplitN(line, ":", 2) if len(parts) != 2 { continue } engineType := strings.TrimSpace(parts[0]) queryStr := strings.TrimSpace(parts[1]) var ctx engine.SearchQuery ctx.Language = defaultLanguage switch engineType { case "internet": ctx.Queries = []string{queryStr} default: // Arxiv category ctx.ArxivCategory = engineType ctx.Queries = strings.Split(queryStr, ",") for i := range ctx.Queries { ctx.Queries[i] = strings.TrimSpace(ctx.Queries[i]) } } if len(ctx.Queries) > 0 { searchContexts = append(searchContexts, ctx) if ctx.ArxivCategory != "" { // Conduct inquiries in all areas to increase recall. backupCtx := ctx backupCtx.ArxivCategory = "" searchContexts = append(searchContexts, backupCtx) } } } return searchContexts } // Search result structure type searchResult struct { Results []engine.SearchResult Count int } // executeSearches performs searches using all configured engines func (p *WebSearch) executeSearches( ctx context.Context, engines []engine.Engine, searchContexts []engine.SearchQuery, ) *searchResult { var ( allResults []engine.SearchResult mu sync.Mutex ) ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() g, ctx := errgroup.WithContext(ctx) for _, eng := range engines { for _, searchCtx := range searchContexts { g.Go(func() error { results, err := eng.Search(ctx, engine.SearchQuery{ Queries: searchCtx.Queries, MaxResults: 10, Language: searchCtx.Language, ArxivCategory: searchCtx.ArxivCategory, }) if err != nil { logrus.Errorf("search error: %v", err) return err } mu.Lock() allResults = append(allResults, results...) mu.Unlock() return nil }) } } _ = g.Wait() seen := make(map[string]bool) var uniqueResults []engine.SearchResult for _, result := range allResults { if !seen[result.Link] { seen[result.Link] = true uniqueResults = append(uniqueResults, result) } } return &searchResult{ Results: uniqueResults, Count: len(engines) * len(searchContexts), } } // formatSearchResults formats search results for the prompt func (p *WebSearch) formatSearchResults( messages []any, queryIndex int, query string, searchResults []engine.SearchResult, config Config, ) { message, ok := messages[queryIndex].(map[string]any) if !ok { return } var formattedResults []string for i, result := range searchResults { if config.NeedReference { formattedResults = append(formattedResults, fmt.Sprintf("[webpage %d begin]\n%s\n[webpage %d end]", i+1, result.Content, i+1)) } else { formattedResults = append(formattedResults, fmt.Sprintf("[webpage begin]\n%s\n[webpage end]", result.Content)) } } // Fill template curDate := time.Now().In(time.FixedZone("CST", 8*3600)).Format("2006年1月2日") searchResultsStr := strings.Join(formattedResults, "\n") prompt := strings.Replace(config.PromptTemplate, "{search_results}", searchResultsStr, 1) prompt = strings.Replace(prompt, "{question}", query, 1) prompt = strings.Replace(prompt, "{cur_date}", curDate, 1) // Update message message["content"] = prompt } // Custom response writer to handle metadata and references type responseWriter struct { gin.ResponseWriter refWritten bool referenceFormat string references []engine.SearchResult referencesLocation string webSearchCount int rewriteUsage *model.Usage rewriteUsageWritten bool rewriteUsageField string isStream bool } // Write overrides the standard Write method to inject metadata func (rw *responseWriter) Write(b []byte) (int, error) { if rw.isStream || utils.IsStreamResponseWithHeader(rw.Header()) { rw.isStream = true } node, err := sonic.Get(b) if err != nil || !node.Valid() { return rw.ResponseWriter.Write(b) } // Process the response node rw.processRewriteUsage(&node) rw.processWebSearchCount(&node) rw.processReferences(&node) // Marshal the modified node nb, err := sonic.Marshal(&node) if err != nil { return rw.ResponseWriter.Write(b) } if !rw.isStream { if rw.ResponseWriter.Header().Get("Content-Length") != "" { rw.ResponseWriter.Header().Set("Content-Length", strconv.Itoa(len(nb))) } } return rw.ResponseWriter.Write(nb) } // processRewriteUsage adds rewrite usage information to the response func (rw *responseWriter) processRewriteUsage(node *ast.Node) { if rw.rewriteUsage != nil && !rw.rewriteUsageWritten { field := rw.rewriteUsageField if field == "" { field = "rewrite_usage" } _, _ = node.SetAny(field, rw.rewriteUsage) rw.rewriteUsageWritten = true } } // processWebSearchCount adds or increments web search count in usage statistics func (rw *responseWriter) processWebSearchCount(node *ast.Node) { if rw.webSearchCount <= 0 { return } usageNode := node.Get("usage") if usageNode == nil || !usageNode.Valid() { return } // Check if web_search_count already exists existingCount := usageNode.Get("web_search_count") if existingCount != nil && existingCount.Valid() { // If exists, add to the existing value currentCount, _ := existingCount.Int64() _, _ = usageNode.Set( "web_search_count", ast.NewNumber(strconv.FormatInt(currentCount+int64(rw.webSearchCount), 10)), ) } else { // If not exists, set the value _, _ = usageNode.Set("web_search_count", ast.NewNumber(strconv.FormatInt(int64(rw.webSearchCount), 10))) } } func buildReferenceContent(searchResults []engine.SearchResult) string { formattedReferences := make([]string, len(searchResults)) for i, result := range searchResults { formattedReferences[i] = fmt.Sprintf("[%d] [%s](%s)", i+1, result.Title, result.Link) } return strings.Join(formattedReferences, "\n\n") } // processReferences adds reference information to the content func (rw *responseWriter) processReferences(node *ast.Node) { if rw.refWritten || len(rw.references) == 0 { return } rw.refWritten = true if rw.referencesLocation == "" || rw.referencesLocation == "content" { var contentNode *ast.Node if rw.isStream { contentNode = node.GetByPath("choices", 0, "delta", "content") } else { contentNode = node.GetByPath("choices", 0, "message", "content") } if contentNode != nil && contentNode.Valid() { content, err := contentNode.String() if err == nil { format := rw.referenceFormat if format == "" { format = "**References:**\n%s" } ref := fmt.Sprintf(format, buildReferenceContent(rw.references)) refContent := fmt.Sprintf("%s\n\n%s", ref, content) *contentNode = ast.NewString(refContent) } } } else { var outterLocation *ast.Node if rw.isStream { outterLocation = node.GetByPath("choices", 0, "delta") } else { outterLocation = node.GetByPath("choices", 0, "message") } if outterLocation != nil && outterLocation.Valid() { _, _ = outterLocation.SetAny(rw.referencesLocation, rw.references) } } } // WriteString implements the WriteString method for the custom response writer func (rw *responseWriter) WriteString(s string) (int, error) { return rw.Write(conv.StringToBytes(s)) } // DoResponse handles response modification for references func (p *WebSearch) DoResponse( meta *meta.Meta, store adaptor.Store, c *gin.Context, resp *http.Response, do adaptor.DoResponse, ) (model.Usage, adaptor.Error) { if meta.Mode != mode.ChatCompletions { return do.DoResponse(meta, store, c, resp) } var references []engine.SearchResult referencesI, ok := meta.Get("references") if ok { references, ok = referencesI.([]engine.SearchResult) if !ok { panic(fmt.Sprintf("references type %T is not a []engine.SearchResult", referencesI)) } } count := getSearchCount(meta) var rewriteUsage *model.Usage rewriteUsageField := "" pluginConfig, _ := p.getConfig(meta) if pluginConfig.SearchRewrite.AddRewriteUsage { rewriteUsage = getRewriteUsage(meta) rewriteUsageField = pluginConfig.SearchRewrite.RewriteUsageField } // Check if we need to wrap the response writer if len(references) > 0 || rewriteUsage != nil || count > 0 { rw := &responseWriter{ ResponseWriter: c.Writer, webSearchCount: count, rewriteUsageWritten: false, rewriteUsage: rewriteUsage, rewriteUsageField: rewriteUsageField, references: references, referencesLocation: pluginConfig.ReferenceLocation, referenceFormat: pluginConfig.ReferenceFormat, } c.Writer = rw defer func() { c.Writer = rw.ResponseWriter }() } return p.doResponseWithCount(meta, store, c, resp, do, count) } // doResponseWithCount adds search count to usage statistics func (p *WebSearch) doResponseWithCount( meta *meta.Meta, store adaptor.Store, c *gin.Context, resp *http.Response, do adaptor.DoResponse, count int, ) (model.Usage, adaptor.Error) { u, err := do.DoResponse(meta, store, c, resp) if err != nil { return model.Usage{}, err } u.WebSearchCount += model.ZeroNullInt64(int64(count)) return u, nil }