| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502 |
- package controller
- import (
- "context"
- "errors"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "runtime"
- "sync"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/google/uuid"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/common/mcpproxy"
- "github.com/labring/aiproxy/core/middleware"
- "github.com/labring/aiproxy/core/model"
- "github.com/labring/aiproxy/openapi-mcp/convert"
- "github.com/redis/go-redis/v9"
- )
- // mcpEndpointProvider implements the EndpointProvider interface for MCP
- type mcpEndpointProvider struct {
- key string
- t model.PublicMCPType
- }
- func newEndpoint(key string, t model.PublicMCPType) mcpproxy.EndpointProvider {
- return &mcpEndpointProvider{
- key: key,
- t: t,
- }
- }
- func (m *mcpEndpointProvider) NewEndpoint() (newSession string, newEndpoint string) {
- session := uuid.NewString()
- endpoint := fmt.Sprintf("/mcp/message?sessionId=%s&key=%s&type=%s", session, m.key, m.t)
- return session, endpoint
- }
- func (m *mcpEndpointProvider) LoadEndpoint(endpoint string) (session string) {
- parsedURL, err := url.Parse(endpoint)
- if err != nil {
- return ""
- }
- return parsedURL.Query().Get("sessionId")
- }
- // Global variables for session management
- var (
- memStore mcpproxy.SessionManager = mcpproxy.NewMemStore()
- redisStore mcpproxy.SessionManager
- redisStoreOnce = &sync.Once{}
- )
- func getStore() mcpproxy.SessionManager {
- if common.RedisEnabled {
- redisStoreOnce.Do(func() {
- redisStore = newRedisStoreManager(common.RDB)
- })
- return redisStore
- }
- return memStore
- }
- // Redis-based session manager
- type redisStoreManager struct {
- rdb *redis.Client
- }
- func newRedisStoreManager(rdb *redis.Client) mcpproxy.SessionManager {
- return &redisStoreManager{
- rdb: rdb,
- }
- }
- var redisStoreManagerScript = redis.NewScript(`
- local key = KEYS[1]
- local value = redis.call('GET', key)
- if not value then
- return nil
- end
- redis.call('EXPIRE', key, 300)
- return value
- `)
- func (r *redisStoreManager) Get(sessionID string) (string, bool) {
- ctx := context.Background()
- result, err := redisStoreManagerScript.Run(ctx, r.rdb, []string{"mcp:session:" + sessionID}).Result()
- if err != nil || result == nil {
- return "", false
- }
- return result.(string), true
- }
- func (r *redisStoreManager) Set(sessionID, endpoint string) {
- ctx := context.Background()
- r.rdb.Set(ctx, "mcp:session:"+sessionID, endpoint, time.Minute*5)
- }
- func (r *redisStoreManager) Delete(session string) {
- ctx := context.Background()
- r.rdb.Del(ctx, "mcp:session:"+session)
- }
- // MCPSseProxy godoc
- //
- // @Summary MCP SSE Proxy
- // @Router /mcp/public/{id}/sse [get]
- func MCPSseProxy(c *gin.Context) {
- mcpID := c.Param("id")
- publicMcp, err := model.GetPublicMCPByID(mcpID)
- if err != nil {
- middleware.AbortLogWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
- switch publicMcp.Type {
- case model.PublicMCPTypeProxySSE:
- handleProxySSE(c, publicMcp)
- case model.PublicMCPTypeOpenAPI:
- handleOpenAPI(c, publicMcp)
- default:
- middleware.AbortLogWithMessage(c, http.StatusBadRequest, "unknow mcp type")
- return
- }
- }
- // handleProxySSE processes SSE proxy requests
- func handleProxySSE(c *gin.Context, publicMcp *model.PublicMCP) {
- config := publicMcp.ProxySSEConfig
- if config == nil || config.URL == "" {
- return
- }
- backendURL, err := url.Parse(config.URL)
- if err != nil {
- middleware.AbortLogWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
- headers := make(map[string]string)
- backendQuery := &url.Values{}
- group := middleware.GetGroup(c)
- token := middleware.GetToken(c)
- // Process reusing parameters if any
- if err := processReusingParams(config.ReusingParams, publicMcp.ID, group.ID, headers, backendQuery); err != nil {
- middleware.AbortLogWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
- backendURL.RawQuery = backendQuery.Encode()
- mcpproxy.SSEHandler(
- c.Writer,
- c.Request,
- getStore(),
- newEndpoint(token.Key, publicMcp.Type),
- backendURL.String(),
- headers,
- )
- }
- // handleOpenAPI processes OpenAPI requests
- func handleOpenAPI(c *gin.Context, publicMcp *model.PublicMCP) {
- config := publicMcp.OpenAPIConfig
- if config == nil || (config.OpenAPISpec == "" && config.OpenAPIContent == "") {
- return
- }
- // Parse OpenAPI specification
- parser := convert.NewParser()
- var err error
- var openAPIFrom string
- if config.OpenAPISpec != "" {
- openAPIFrom, err = parseOpenAPIFromURL(config, parser)
- } else {
- err = parseOpenAPIFromContent(config, parser)
- }
- if err != nil {
- return
- }
- // Convert to MCP server
- converter := convert.NewConverter(parser, convert.Options{
- OpenAPIFrom: openAPIFrom,
- })
- s, err := converter.Convert()
- if err != nil {
- return
- }
- token := middleware.GetToken(c)
- // Setup SSE server
- newSession, newEndpoint := newEndpoint(token.Key, publicMcp.Type).NewEndpoint()
- store := getStore()
- store.Set(newSession, "openapi")
- defer func() {
- store.Delete(newSession)
- }()
- server := NewSSEServer(
- s,
- WithMessageEndpoint(newEndpoint),
- )
- ctx, cancel := context.WithCancel(c.Request.Context())
- defer cancel()
- // Start message processing goroutine
- go processOpenAPIMessages(ctx, newSession, server)
- // Handle SSE connection
- server.HandleSSE(c.Writer, c.Request)
- }
- // parseOpenAPIFromURL parses OpenAPI spec from a URL
- func parseOpenAPIFromURL(config *model.MCPOpenAPIConfig, parser *convert.Parser) (string, error) {
- spec, err := url.Parse(config.OpenAPISpec)
- if err != nil || (spec.Scheme != "http" && spec.Scheme != "https") {
- return "", errors.New("invalid OpenAPI spec URL")
- }
- openAPIFrom := spec.String()
- if config.V2 {
- err = parser.ParseFileV2(openAPIFrom)
- } else {
- err = parser.ParseFile(openAPIFrom)
- }
- return openAPIFrom, err
- }
- // parseOpenAPIFromContent parses OpenAPI spec from content string
- func parseOpenAPIFromContent(config *model.MCPOpenAPIConfig, parser *convert.Parser) error {
- if config.V2 {
- return parser.ParseV2([]byte(config.OpenAPIContent))
- }
- return parser.Parse([]byte(config.OpenAPIContent))
- }
- // processOpenAPIMessages handles message processing for OpenAPI
- func processOpenAPIMessages(ctx context.Context, sessionID string, server *SSEServer) {
- mpscInstance := getMpsc()
- for {
- select {
- case <-ctx.Done():
- return
- default:
- data, err := mpscInstance.recv(ctx, sessionID)
- if err != nil {
- return
- }
- if err := server.HandleMessage(data); err != nil {
- return
- }
- }
- }
- }
- // processReusingParams handles the reusing parameters for MCP proxy
- func processReusingParams(reusingParams map[string]model.ReusingParam, mcpID string, groupID string, headers map[string]string, backendQuery *url.Values) error {
- if len(reusingParams) == 0 {
- return nil
- }
- param, err := model.GetGroupPublicMCPReusingParam(mcpID, groupID)
- if err != nil {
- return err
- }
- for k, v := range reusingParams {
- paramValue, ok := param.ReusingParams[k]
- if !ok {
- if v.Required {
- return fmt.Errorf("%s required", k)
- }
- continue
- }
- switch v.Type {
- case model.ParamTypeHeader:
- headers[k] = paramValue
- case model.ParamTypeQuery:
- backendQuery.Set(k, paramValue)
- default:
- return errors.New("unknow param type")
- }
- }
- return nil
- }
- // MCPMessage godoc
- //
- // @Summary MCP SSE Proxy
- // @Router /mcp/message [post]
- func MCPMessage(c *gin.Context) {
- token := middleware.GetToken(c)
- mcpTypeStr, _ := c.GetQuery("type")
- if mcpTypeStr == "" {
- return
- }
- mcpType := model.PublicMCPType(mcpTypeStr)
- sessionID, _ := c.GetQuery("sessionId")
- if sessionID == "" {
- return
- }
- switch mcpType {
- case model.PublicMCPTypeProxySSE:
- mcpproxy.ProxyHandler(
- c.Writer,
- c.Request,
- getStore(),
- newEndpoint(token.Key, mcpType),
- )
- case model.PublicMCPTypeOpenAPI:
- backend, ok := getStore().Get(sessionID)
- if !ok || backend != "openapi" {
- return
- }
- mpscInstance := getMpsc()
- body, err := io.ReadAll(c.Request.Body)
- if err != nil {
- _ = c.AbortWithError(http.StatusInternalServerError, err)
- return
- }
- err = mpscInstance.send(c.Request.Context(), sessionID, body)
- if err != nil {
- _ = c.AbortWithError(http.StatusInternalServerError, err)
- return
- }
- c.Writer.WriteHeader(http.StatusAccepted)
- }
- }
- // Interface for multi-producer, single-consumer message passing
- type mpsc interface {
- recv(ctx context.Context, id string) ([]byte, error)
- send(ctx context.Context, id string, data []byte) error
- }
- // Global MPSC instances
- var (
- memMpsc mpsc = newChannelMpsc()
- redisMpsc mpsc
- redisMpscOnce = &sync.Once{}
- )
- func getMpsc() mpsc {
- if common.RedisEnabled {
- redisMpscOnce.Do(func() {
- redisMpsc = newRedisMPSC(common.RDB)
- })
- return redisMpsc
- }
- return memMpsc
- }
- // In-memory channel-based MPSC implementation
- type channelMpsc struct {
- channels map[string]chan []byte
- lastAccess map[string]time.Time
- channelMutex sync.RWMutex
- }
- // newChannelMpsc creates a new channel-based mpsc implementation
- func newChannelMpsc() *channelMpsc {
- c := &channelMpsc{
- channels: make(map[string]chan []byte),
- lastAccess: make(map[string]time.Time),
- }
- // Start a goroutine to clean up expired channels
- go c.cleanupExpiredChannels()
- return c
- }
- // cleanupExpiredChannels periodically checks for and removes channels that haven't been accessed in 5 minutes
- func (c *channelMpsc) cleanupExpiredChannels() {
- ticker := time.NewTicker(1 * time.Minute)
- defer ticker.Stop()
- for range ticker.C {
- c.channelMutex.Lock()
- now := time.Now()
- for id, lastAccess := range c.lastAccess {
- if now.Sub(lastAccess) > 5*time.Minute {
- // Close and delete the channel
- if ch, exists := c.channels[id]; exists {
- close(ch)
- delete(c.channels, id)
- }
- delete(c.lastAccess, id)
- }
- }
- c.channelMutex.Unlock()
- }
- }
- // getOrCreateChannel gets an existing channel or creates a new one for the session
- func (c *channelMpsc) getOrCreateChannel(id string) chan []byte {
- c.channelMutex.RLock()
- ch, exists := c.channels[id]
- c.channelMutex.RUnlock()
- if !exists {
- c.channelMutex.Lock()
- if ch, exists = c.channels[id]; !exists {
- ch = make(chan []byte, 10)
- c.channels[id] = ch
- }
- c.lastAccess[id] = time.Now()
- c.channelMutex.Unlock()
- } else {
- c.channelMutex.Lock()
- c.lastAccess[id] = time.Now()
- c.channelMutex.Unlock()
- }
- return ch
- }
- // recv receives data for the specified session
- func (c *channelMpsc) recv(ctx context.Context, id string) ([]byte, error) {
- ch := c.getOrCreateChannel(id)
- select {
- case data, ok := <-ch:
- if !ok {
- return nil, fmt.Errorf("channel closed for session %s", id)
- }
- return data, nil
- case <-ctx.Done():
- return nil, ctx.Err()
- }
- }
- // send sends data to the specified session
- func (c *channelMpsc) send(ctx context.Context, id string, data []byte) error {
- ch := c.getOrCreateChannel(id)
- select {
- case ch <- data:
- return nil
- case <-ctx.Done():
- return ctx.Err()
- default:
- return fmt.Errorf("channel buffer full for session %s", id)
- }
- }
- // Redis-based MPSC implementation
- type redisMPSC struct {
- rdb *redis.Client
- }
- // newRedisMPSC creates a new Redis MPSC instance
- func newRedisMPSC(rdb *redis.Client) *redisMPSC {
- return &redisMPSC{rdb: rdb}
- }
- func (r *redisMPSC) send(ctx context.Context, id string, data []byte) error {
- // Set expiration to 5 minutes when sending data
- pipe := r.rdb.Pipeline()
- pipe.LPush(ctx, id, data)
- pipe.Expire(ctx, id, 5*time.Minute)
- _, err := pipe.Exec(ctx)
- return err
- }
- func (r *redisMPSC) recv(ctx context.Context, id string) ([]byte, error) {
- for {
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- default:
- result, err := r.rdb.BRPop(ctx, time.Second, id).Result()
- if err != nil {
- if errors.Is(err, redis.Nil) {
- runtime.Gosched()
- continue
- }
- return nil, err
- }
- if len(result) != 2 {
- return nil, errors.New("invalid BRPop result")
- }
- return []byte(result[1]), nil
- }
- }
- }
|