|
|
@@ -12,35 +12,35 @@ import (
|
|
|
"time"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
- "github.com/google/uuid"
|
|
|
"github.com/labring/aiproxy/core/common"
|
|
|
"github.com/labring/aiproxy/core/common/mcpproxy"
|
|
|
"github.com/labring/aiproxy/core/middleware"
|
|
|
"github.com/labring/aiproxy/core/model"
|
|
|
"github.com/labring/aiproxy/openapi-mcp/convert"
|
|
|
+ "github.com/mark3labs/mcp-go/server"
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
)
|
|
|
|
|
|
-// mcpEndpointProvider implements the EndpointProvider interface for MCP
|
|
|
-type mcpEndpointProvider struct {
|
|
|
+// publicMcpEndpointProvider implements the EndpointProvider interface for MCP
|
|
|
+type publicMcpEndpointProvider struct {
|
|
|
key string
|
|
|
t model.PublicMCPType
|
|
|
}
|
|
|
|
|
|
-func newEndpoint(key string, t model.PublicMCPType) mcpproxy.EndpointProvider {
|
|
|
- return &mcpEndpointProvider{
|
|
|
+func newPublicMcpEndpoint(key string, t model.PublicMCPType) mcpproxy.EndpointProvider {
|
|
|
+ return &publicMcpEndpointProvider{
|
|
|
key: key,
|
|
|
t: t,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (m *mcpEndpointProvider) NewEndpoint() (newSession string, newEndpoint string) {
|
|
|
- session := uuid.NewString()
|
|
|
- endpoint := fmt.Sprintf("/mcp/message?sessionId=%s&key=%s&type=%s", session, m.key, m.t)
|
|
|
+func (m *publicMcpEndpointProvider) NewEndpoint() (newSession string, newEndpoint string) {
|
|
|
+ session := common.ShortUUID()
|
|
|
+ endpoint := fmt.Sprintf("/mcp/public/message?sessionId=%s&key=%s&type=%s", session, m.key, m.t)
|
|
|
return session, endpoint
|
|
|
}
|
|
|
|
|
|
-func (m *mcpEndpointProvider) LoadEndpoint(endpoint string) (session string) {
|
|
|
+func (m *publicMcpEndpointProvider) LoadEndpoint(endpoint string) (session string) {
|
|
|
parsedURL, err := url.Parse(endpoint)
|
|
|
if err != nil {
|
|
|
return ""
|
|
|
@@ -107,11 +107,11 @@ func (r *redisStoreManager) Delete(session string) {
|
|
|
r.rdb.Del(ctx, "mcp:session:"+session)
|
|
|
}
|
|
|
|
|
|
-// MCPSseProxy godoc
|
|
|
+// PublicMCPSseServer godoc
|
|
|
//
|
|
|
-// @Summary MCP SSE Proxy
|
|
|
+// @Summary Public MCP SSE Server
|
|
|
// @Router /mcp/public/{id}/sse [get]
|
|
|
-func MCPSseProxy(c *gin.Context) {
|
|
|
+func PublicMCPSseServer(c *gin.Context) {
|
|
|
mcpID := c.Param("id")
|
|
|
|
|
|
publicMcp, err := model.GetPublicMCPByID(mcpID)
|
|
|
@@ -122,18 +122,22 @@ func MCPSseProxy(c *gin.Context) {
|
|
|
|
|
|
switch publicMcp.Type {
|
|
|
case model.PublicMCPTypeProxySSE:
|
|
|
- handleProxySSE(c, publicMcp)
|
|
|
+ handlePublicProxySSE(c, publicMcp.ID, publicMcp.ProxySSEConfig)
|
|
|
case model.PublicMCPTypeOpenAPI:
|
|
|
- handleOpenAPI(c, publicMcp)
|
|
|
+ server, err := newOpenAPIMCPServer(publicMcp.OpenAPIConfig)
|
|
|
+ if err != nil {
|
|
|
+ middleware.AbortLogWithMessage(c, http.StatusBadRequest, err.Error())
|
|
|
+ return
|
|
|
+ }
|
|
|
+ handleMCPServer(c, server, model.PublicMCPTypeOpenAPI)
|
|
|
default:
|
|
|
middleware.AbortLogWithMessage(c, http.StatusBadRequest, "unknow mcp type")
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// handleProxySSE processes SSE proxy requests
|
|
|
-func handleProxySSE(c *gin.Context, publicMcp *model.PublicMCP) {
|
|
|
- config := publicMcp.ProxySSEConfig
|
|
|
+// handlePublicProxySSE processes SSE proxy requests
|
|
|
+func handlePublicProxySSE(c *gin.Context, mcpID string, config *model.PublicMCPProxySSEConfig) {
|
|
|
if config == nil || config.URL == "" {
|
|
|
return
|
|
|
}
|
|
|
@@ -150,27 +154,33 @@ func handleProxySSE(c *gin.Context, publicMcp *model.PublicMCP) {
|
|
|
token := middleware.GetToken(c)
|
|
|
|
|
|
// Process reusing parameters if any
|
|
|
- if err := processReusingParams(config.ReusingParams, publicMcp.ID, group.ID, headers, backendQuery); err != nil {
|
|
|
+ if err := processReusingParams(config.ReusingParams, mcpID, group.ID, headers, backendQuery); err != nil {
|
|
|
middleware.AbortLogWithMessage(c, http.StatusBadRequest, err.Error())
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ for k, v := range config.Headers {
|
|
|
+ headers[k] = v
|
|
|
+ }
|
|
|
+ for k, v := range config.Querys {
|
|
|
+ backendQuery.Set(k, v)
|
|
|
+ }
|
|
|
+
|
|
|
backendURL.RawQuery = backendQuery.Encode()
|
|
|
mcpproxy.SSEHandler(
|
|
|
c.Writer,
|
|
|
c.Request,
|
|
|
getStore(),
|
|
|
- newEndpoint(token.Key, publicMcp.Type),
|
|
|
+ newPublicMcpEndpoint(token.Key, model.PublicMCPTypeProxySSE),
|
|
|
backendURL.String(),
|
|
|
headers,
|
|
|
)
|
|
|
}
|
|
|
|
|
|
-// handleOpenAPI processes OpenAPI requests
|
|
|
-func handleOpenAPI(c *gin.Context, publicMcp *model.PublicMCP) {
|
|
|
- config := publicMcp.OpenAPIConfig
|
|
|
+// newOpenAPIMCPServer creates a new MCP server from OpenAPI configuration
|
|
|
+func newOpenAPIMCPServer(config *model.MCPOpenAPIConfig) (*server.MCPServer, error) {
|
|
|
if config == nil || (config.OpenAPISpec == "" && config.OpenAPIContent == "") {
|
|
|
- return
|
|
|
+ return nil, errors.New("invalid OpenAPI configuration")
|
|
|
}
|
|
|
|
|
|
// Parse OpenAPI specification
|
|
|
@@ -185,38 +195,45 @@ func handleOpenAPI(c *gin.Context, publicMcp *model.PublicMCP) {
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
- return
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
// Convert to MCP server
|
|
|
converter := convert.NewConverter(parser, convert.Options{
|
|
|
- OpenAPIFrom: openAPIFrom,
|
|
|
+ OpenAPIFrom: openAPIFrom,
|
|
|
+ ServerAddr: config.ServerAddr,
|
|
|
+ Authorization: config.Authorization,
|
|
|
})
|
|
|
s, err := converter.Convert()
|
|
|
if err != nil {
|
|
|
- return
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- token := middleware.GetToken(c)
|
|
|
+ return s, nil
|
|
|
+}
|
|
|
|
|
|
- // Setup SSE server
|
|
|
- newSession, newEndpoint := newEndpoint(token.Key, publicMcp.Type).NewEndpoint()
|
|
|
- store := getStore()
|
|
|
- store.Set(newSession, "openapi")
|
|
|
- defer func() {
|
|
|
- store.Delete(newSession)
|
|
|
- }()
|
|
|
+// handleMCPServer handles the SSE connection for an MCP server
|
|
|
+func handleMCPServer(c *gin.Context, s *server.MCPServer, mcpType model.PublicMCPType) {
|
|
|
+ token := middleware.GetToken(c)
|
|
|
|
|
|
+ newSession, newEndpoint := newPublicMcpEndpoint(token.Key, mcpType).NewEndpoint()
|
|
|
server := NewSSEServer(
|
|
|
s,
|
|
|
WithMessageEndpoint(newEndpoint),
|
|
|
)
|
|
|
|
|
|
+ // Store the session
|
|
|
+ store := getStore()
|
|
|
+ store.Set(newSession, string(mcpType))
|
|
|
+ defer func() {
|
|
|
+ store.Delete(newSession)
|
|
|
+ }()
|
|
|
+
|
|
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
|
|
defer cancel()
|
|
|
|
|
|
// Start message processing goroutine
|
|
|
- go processOpenAPIMessages(ctx, newSession, server)
|
|
|
+ go processMCPSseMpscMessages(ctx, newSession, server)
|
|
|
|
|
|
// Handle SSE connection
|
|
|
server.HandleSSE(c.Writer, c.Request)
|
|
|
@@ -247,9 +264,9 @@ func parseOpenAPIFromContent(config *model.MCPOpenAPIConfig, parser *convert.Par
|
|
|
return parser.Parse([]byte(config.OpenAPIContent))
|
|
|
}
|
|
|
|
|
|
-// processOpenAPIMessages handles message processing for OpenAPI
|
|
|
-func processOpenAPIMessages(ctx context.Context, sessionID string, server *SSEServer) {
|
|
|
- mpscInstance := getMpsc()
|
|
|
+// processMCPSseMpscMessages handles message processing for OpenAPI
|
|
|
+func processMCPSseMpscMessages(ctx context.Context, sessionID string, server *SSEServer) {
|
|
|
+ mpscInstance := getMCPMpsc()
|
|
|
for {
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
@@ -299,11 +316,11 @@ func processReusingParams(reusingParams map[string]model.ReusingParam, mcpID str
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-// MCPMessage godoc
|
|
|
+// PublicMCPMessage godoc
|
|
|
//
|
|
|
// @Summary MCP SSE Proxy
|
|
|
-// @Router /mcp/message [post]
|
|
|
-func MCPMessage(c *gin.Context) {
|
|
|
+// @Router /mcp/public/message [post]
|
|
|
+func PublicMCPMessage(c *gin.Context) {
|
|
|
token := middleware.GetToken(c)
|
|
|
mcpTypeStr, _ := c.GetQuery("type")
|
|
|
if mcpTypeStr == "" {
|
|
|
@@ -321,26 +338,30 @@ func MCPMessage(c *gin.Context) {
|
|
|
c.Writer,
|
|
|
c.Request,
|
|
|
getStore(),
|
|
|
- newEndpoint(token.Key, mcpType),
|
|
|
+ newPublicMcpEndpoint(token.Key, mcpType),
|
|
|
)
|
|
|
- case model.PublicMCPTypeOpenAPI:
|
|
|
- backend, ok := getStore().Get(sessionID)
|
|
|
- if !ok || backend != "openapi" {
|
|
|
- return
|
|
|
- }
|
|
|
- mpscInstance := getMpsc()
|
|
|
- body, err := io.ReadAll(c.Request.Body)
|
|
|
- if err != nil {
|
|
|
- _ = c.AbortWithError(http.StatusInternalServerError, err)
|
|
|
- return
|
|
|
- }
|
|
|
- err = mpscInstance.send(c.Request.Context(), sessionID, body)
|
|
|
- if err != nil {
|
|
|
- _ = c.AbortWithError(http.StatusInternalServerError, err)
|
|
|
- return
|
|
|
- }
|
|
|
- c.Writer.WriteHeader(http.StatusAccepted)
|
|
|
+ default:
|
|
|
+ sendMCPSSEMessage(c, mcpTypeStr, sessionID)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func sendMCPSSEMessage(c *gin.Context, mcpType, sessionID string) {
|
|
|
+ backend, ok := getStore().Get(sessionID)
|
|
|
+ if !ok || backend != mcpType {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ mpscInstance := getMCPMpsc()
|
|
|
+ body, err := io.ReadAll(c.Request.Body)
|
|
|
+ if err != nil {
|
|
|
+ _ = c.AbortWithError(http.StatusInternalServerError, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ err = mpscInstance.send(c.Request.Context(), sessionID, body)
|
|
|
+ if err != nil {
|
|
|
+ _ = c.AbortWithError(http.StatusInternalServerError, err)
|
|
|
+ return
|
|
|
}
|
|
|
+ c.Writer.WriteHeader(http.StatusAccepted)
|
|
|
}
|
|
|
|
|
|
// Interface for multi-producer, single-consumer message passing
|
|
|
@@ -351,31 +372,31 @@ type mpsc interface {
|
|
|
|
|
|
// Global MPSC instances
|
|
|
var (
|
|
|
- memMpsc mpsc = newChannelMpsc()
|
|
|
- redisMpsc mpsc
|
|
|
- redisMpscOnce = &sync.Once{}
|
|
|
+ memMCPMpsc mpsc = newChannelMCPMpsc()
|
|
|
+ redisMCPMpsc mpsc
|
|
|
+ redisMCPMpscOnce = &sync.Once{}
|
|
|
)
|
|
|
|
|
|
-func getMpsc() mpsc {
|
|
|
+func getMCPMpsc() mpsc {
|
|
|
if common.RedisEnabled {
|
|
|
- redisMpscOnce.Do(func() {
|
|
|
- redisMpsc = newRedisMPSC(common.RDB)
|
|
|
+ redisMCPMpscOnce.Do(func() {
|
|
|
+ redisMCPMpsc = newRedisMCPMPSC(common.RDB)
|
|
|
})
|
|
|
- return redisMpsc
|
|
|
+ return redisMCPMpsc
|
|
|
}
|
|
|
- return memMpsc
|
|
|
+ return memMCPMpsc
|
|
|
}
|
|
|
|
|
|
// In-memory channel-based MPSC implementation
|
|
|
-type channelMpsc struct {
|
|
|
+type channelMCPMpsc struct {
|
|
|
channels map[string]chan []byte
|
|
|
lastAccess map[string]time.Time
|
|
|
channelMutex sync.RWMutex
|
|
|
}
|
|
|
|
|
|
-// newChannelMpsc creates a new channel-based mpsc implementation
|
|
|
-func newChannelMpsc() *channelMpsc {
|
|
|
- c := &channelMpsc{
|
|
|
+// newChannelMCPMpsc creates a new channel-based mpsc implementation
|
|
|
+func newChannelMCPMpsc() *channelMCPMpsc {
|
|
|
+ c := &channelMCPMpsc{
|
|
|
channels: make(map[string]chan []byte),
|
|
|
lastAccess: make(map[string]time.Time),
|
|
|
}
|
|
|
@@ -387,7 +408,7 @@ func newChannelMpsc() *channelMpsc {
|
|
|
}
|
|
|
|
|
|
// cleanupExpiredChannels periodically checks for and removes channels that haven't been accessed in 5 minutes
|
|
|
-func (c *channelMpsc) cleanupExpiredChannels() {
|
|
|
+func (c *channelMCPMpsc) cleanupExpiredChannels() {
|
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
@@ -409,7 +430,7 @@ func (c *channelMpsc) cleanupExpiredChannels() {
|
|
|
}
|
|
|
|
|
|
// getOrCreateChannel gets an existing channel or creates a new one for the session
|
|
|
-func (c *channelMpsc) getOrCreateChannel(id string) chan []byte {
|
|
|
+func (c *channelMCPMpsc) getOrCreateChannel(id string) chan []byte {
|
|
|
c.channelMutex.RLock()
|
|
|
ch, exists := c.channels[id]
|
|
|
c.channelMutex.RUnlock()
|
|
|
@@ -432,7 +453,7 @@ func (c *channelMpsc) getOrCreateChannel(id string) chan []byte {
|
|
|
}
|
|
|
|
|
|
// recv receives data for the specified session
|
|
|
-func (c *channelMpsc) recv(ctx context.Context, id string) ([]byte, error) {
|
|
|
+func (c *channelMCPMpsc) recv(ctx context.Context, id string) ([]byte, error) {
|
|
|
ch := c.getOrCreateChannel(id)
|
|
|
|
|
|
select {
|
|
|
@@ -447,7 +468,7 @@ func (c *channelMpsc) recv(ctx context.Context, id string) ([]byte, error) {
|
|
|
}
|
|
|
|
|
|
// send sends data to the specified session
|
|
|
-func (c *channelMpsc) send(ctx context.Context, id string, data []byte) error {
|
|
|
+func (c *channelMCPMpsc) send(ctx context.Context, id string, data []byte) error {
|
|
|
ch := c.getOrCreateChannel(id)
|
|
|
|
|
|
select {
|
|
|
@@ -461,17 +482,18 @@ func (c *channelMpsc) send(ctx context.Context, id string, data []byte) error {
|
|
|
}
|
|
|
|
|
|
// Redis-based MPSC implementation
|
|
|
-type redisMPSC struct {
|
|
|
+type redisMCPMPSC struct {
|
|
|
rdb *redis.Client
|
|
|
}
|
|
|
|
|
|
-// newRedisMPSC creates a new Redis MPSC instance
|
|
|
-func newRedisMPSC(rdb *redis.Client) *redisMPSC {
|
|
|
- return &redisMPSC{rdb: rdb}
|
|
|
+// newRedisMCPMPSC creates a new Redis MPSC instance
|
|
|
+func newRedisMCPMPSC(rdb *redis.Client) *redisMCPMPSC {
|
|
|
+ return &redisMCPMPSC{rdb: rdb}
|
|
|
}
|
|
|
|
|
|
-func (r *redisMPSC) send(ctx context.Context, id string, data []byte) error {
|
|
|
+func (r *redisMCPMPSC) send(ctx context.Context, id string, data []byte) error {
|
|
|
// Set expiration to 5 minutes when sending data
|
|
|
+ id = "mcp:mpsc:" + id
|
|
|
pipe := r.rdb.Pipeline()
|
|
|
pipe.LPush(ctx, id, data)
|
|
|
pipe.Expire(ctx, id, 5*time.Minute)
|
|
|
@@ -479,7 +501,8 @@ func (r *redisMPSC) send(ctx context.Context, id string, data []byte) error {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
-func (r *redisMPSC) recv(ctx context.Context, id string) ([]byte, error) {
|
|
|
+func (r *redisMCPMPSC) recv(ctx context.Context, id string) ([]byte, error) {
|
|
|
+ id = "mcp:mpsc:" + id
|
|
|
for {
|
|
|
select {
|
|
|
case <-ctx.Done():
|