| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- package controller
- import (
- "context"
- "encoding"
- "sync"
- "time"
- "github.com/bytedance/sonic"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/common/conv"
- "github.com/labring/aiproxy/core/model"
- mcpservers "github.com/labring/aiproxy/mcp-servers"
- "github.com/mark3labs/mcp-go/client/transport"
- "github.com/mark3labs/mcp-go/mcp"
- "github.com/redis/go-redis/v9"
- log "github.com/sirupsen/logrus"
- )
- const (
- ToolCacheKey = "tool:%s:%d" // mcp_id:updated_at
- )
- type redisToolSlice []mcp.Tool
- var (
- _ redis.Scanner = (*redisToolSlice)(nil)
- _ encoding.BinaryMarshaler = (*redisToolSlice)(nil)
- )
- func (r *redisToolSlice) ScanRedis(value string) error {
- return sonic.Unmarshal(conv.StringToBytes(value), r)
- }
- func (r redisToolSlice) MarshalBinary() ([]byte, error) {
- return sonic.Marshal(r)
- }
- type toolCacheMemItem struct {
- tools []mcp.Tool
- lastUsedAt time.Time
- expiresAt time.Time
- }
- type toolMemoryCache struct {
- mu sync.Mutex
- items map[string]toolCacheMemItem
- cleanStartOnce sync.Once
- }
- var toolMemCache = &toolMemoryCache{
- items: make(map[string]toolCacheMemItem),
- }
- func getToolCacheKey(mcpID string, updatedAt int64) string {
- return common.RedisKeyf(ToolCacheKey, mcpID, updatedAt)
- }
- func (c *toolMemoryCache) set(key string, tools []mcp.Tool) {
- c.startCleanupOnStart()
- c.mu.Lock()
- defer c.mu.Unlock()
- now := time.Now()
- c.items[key] = toolCacheMemItem{
- tools: tools,
- lastUsedAt: now,
- expiresAt: now.Add(time.Hour),
- }
- }
- func (c *toolMemoryCache) get(key string) ([]mcp.Tool, bool) {
- c.mu.Lock()
- defer c.mu.Unlock()
- item, exists := c.items[key]
- if !exists {
- return nil, false
- }
- now := time.Now()
- if now.After(item.lastUsedAt.Add(time.Minute*10)) ||
- now.After(item.expiresAt) {
- delete(c.items, key)
- return nil, false
- }
- item.lastUsedAt = now
- return item.tools, true
- }
- func (c *toolMemoryCache) cleanup() {
- c.mu.Lock()
- defer c.mu.Unlock()
- now := time.Now()
- for key, item := range c.items {
- if now.After(item.lastUsedAt.Add(time.Minute*10)) ||
- now.After(item.expiresAt) {
- delete(c.items, key)
- }
- }
- }
- func (c *toolMemoryCache) startCleanupOnStart() {
- c.cleanStartOnce.Do(func() {
- go func() {
- ticker := time.NewTicker(time.Minute * 10)
- defer ticker.Stop()
- for range ticker.C {
- c.cleanup()
- }
- }()
- })
- }
- func CacheSetTools(mcpID string, updatedAt int64, tools []mcp.Tool) error {
- key := getToolCacheKey(mcpID, updatedAt)
- if common.RedisEnabled {
- redisKey := common.RedisKeyf(ToolCacheKey, mcpID, updatedAt)
- pipe := common.RDB.Pipeline()
- pipe.HSet(context.Background(), redisKey, tools)
- pipe.Expire(context.Background(), redisKey, time.Hour)
- _, err := pipe.Exec(context.Background())
- return err
- }
- toolMemCache.set(key, tools)
- return nil
- }
- func CacheGetTools(mcpID string, updatedAt int64) ([]mcp.Tool, bool) {
- key := getToolCacheKey(mcpID, updatedAt)
- if common.RedisEnabled {
- tools := redisToolSlice{}
- err := common.RDB.HGetAll(context.Background(), key).Scan(&tools)
- if err != nil {
- log.Errorf("failed to get tools cache from redis (%s): %v", key, err)
- } else {
- return tools, true
- }
- }
- item, exists := toolMemCache.get(key)
- if exists {
- return item, true
- }
- return nil, false
- }
- func checkParamsIsFull(params model.Params, reusing map[string]model.ReusingParam) bool {
- for key, r := range reusing {
- if !r.Required {
- continue
- }
- if params == nil {
- return false
- }
- if v, ok := params[key]; !ok || v == "" {
- return false
- }
- }
- return true
- }
- func getPublicMCPTools(
- ctx context.Context,
- publicMcp model.PublicMCP,
- testConfig model.TestConfig,
- params map[string]string,
- reusing map[string]model.ReusingParam,
- ) (tools []mcp.Tool, err error) {
- tools, exists := CacheGetTools(publicMcp.ID, publicMcp.UpdateAt.Unix())
- if exists {
- return tools, nil
- }
- defer func() {
- if err != nil {
- return
- }
- if err := CacheSetTools(publicMcp.ID, publicMcp.UpdateAt.Unix(), tools); err != nil {
- log.Errorf("failed to set tools cache in redis: %v", err)
- }
- }()
- ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
- switch publicMcp.Type {
- case model.PublicMCPTypeEmbed:
- return getEmbedMCPTools(ctx, publicMcp, testConfig, params, reusing)
- case model.PublicMCPTypeOpenAPI:
- return getOpenAPIMCPTools(ctx, publicMcp)
- case model.PublicMCPTypeProxySSE:
- return getProxySSEMCPTools(ctx, publicMcp, testConfig, params, reusing)
- case model.PublicMCPTypeProxyStreamable:
- return getProxyStreamableMCPTools(ctx, publicMcp, testConfig, params, reusing)
- default:
- return nil, nil
- }
- }
- func getEmbedMCPTools(
- ctx context.Context,
- publicMcp model.PublicMCP,
- testConfig model.TestConfig,
- params map[string]string,
- reusing map[string]model.ReusingParam,
- ) ([]mcp.Tool, error) {
- tools, err := mcpservers.ListTools(ctx, publicMcp.ID)
- if err == nil {
- return tools, nil
- }
- if publicMcp.EmbedConfig == nil {
- return nil, nil
- }
- var effectiveParams map[string]string
- switch {
- case testConfig.Enabled && checkParamsIsFull(testConfig.Params, reusing):
- effectiveParams = testConfig.Params
- case checkParamsIsFull(params, reusing):
- effectiveParams = params
- default:
- return nil, nil
- }
- server, err := mcpservers.GetMCPServer(
- publicMcp.ID,
- publicMcp.EmbedConfig.Init,
- effectiveParams,
- )
- if err != nil {
- return nil, err
- }
- return mcpservers.ListServerTools(ctx, server)
- }
- func getOpenAPIMCPTools(ctx context.Context, publicMcp model.PublicMCP) ([]mcp.Tool, error) {
- if publicMcp.OpenAPIConfig == nil {
- return nil, nil
- }
- server, err := newOpenAPIMCPServer(publicMcp.OpenAPIConfig)
- if err != nil {
- return nil, err
- }
- return mcpservers.ListServerTools(ctx, server)
- }
- func getProxySSEMCPTools(
- ctx context.Context,
- publicMcp model.PublicMCP,
- testConfig model.TestConfig,
- params map[string]string,
- reusing map[string]model.ReusingParam,
- ) ([]mcp.Tool, error) {
- if publicMcp.ProxyConfig == nil {
- return nil, nil
- }
- var effectiveParams map[string]string
- switch {
- case testConfig.Enabled && checkParamsIsFull(testConfig.Params, reusing):
- effectiveParams = testConfig.Params
- case checkParamsIsFull(params, reusing):
- effectiveParams = params
- default:
- return nil, nil
- }
- url, headers, err := prepareProxyConfig(
- publicMcp.ToPublicMCPCache(),
- staticParams(effectiveParams),
- )
- if err != nil {
- return nil, err
- }
- client, err := transport.NewSSE(url, transport.WithHeaders(headers))
- if err != nil {
- return nil, err
- }
- defer client.Close()
- if err := client.Start(ctx); err != nil {
- return nil, err
- }
- return mcpservers.ListServerTools(ctx, mcpservers.WrapMCPClient2Server(client))
- }
- func getProxyStreamableMCPTools(
- ctx context.Context,
- publicMcp model.PublicMCP,
- testConfig model.TestConfig,
- params map[string]string,
- reusing map[string]model.ReusingParam,
- ) ([]mcp.Tool, error) {
- if publicMcp.ProxyConfig == nil {
- return nil, nil
- }
- var effectiveParams map[string]string
- switch {
- case testConfig.Enabled && checkParamsIsFull(testConfig.Params, reusing):
- effectiveParams = testConfig.Params
- case checkParamsIsFull(params, reusing):
- effectiveParams = params
- default:
- return nil, nil
- }
- url, headers, err := prepareProxyConfig(
- publicMcp.ToPublicMCPCache(),
- staticParams(effectiveParams),
- )
- if err != nil {
- return nil, err
- }
- client, err := transport.NewStreamableHTTP(
- url,
- transport.WithHTTPHeaders(headers),
- )
- if err != nil {
- return nil, err
- }
- defer client.Close()
- if err := client.Start(ctx); err != nil {
- return nil, err
- }
- return mcpservers.ListServerTools(ctx, mcpservers.WrapMCPClient2Server(client))
- }
|