|
|
@@ -16,23 +16,28 @@ import (
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"github.com/labring/aiproxy/core/common"
|
|
|
"github.com/labring/aiproxy/core/common/mcpproxy"
|
|
|
- statelessmcp "github.com/labring/aiproxy/core/common/stateless-mcp"
|
|
|
"github.com/labring/aiproxy/core/middleware"
|
|
|
"github.com/labring/aiproxy/core/model"
|
|
|
mcpservers "github.com/labring/aiproxy/mcp-servers"
|
|
|
"github.com/labring/aiproxy/openapi-mcp/convert"
|
|
|
+ "github.com/mark3labs/mcp-go/client/transport"
|
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
|
"github.com/mark3labs/mcp-go/server"
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
)
|
|
|
|
|
|
+type EndpointProvider interface {
|
|
|
+ NewEndpoint(newSession string) (newEndpoint string)
|
|
|
+ LoadEndpoint(endpoint string) (session string)
|
|
|
+}
|
|
|
+
|
|
|
// publicMcpEndpointProvider implements the EndpointProvider interface for MCP
|
|
|
type publicMcpEndpointProvider struct {
|
|
|
key string
|
|
|
t model.PublicMCPType
|
|
|
}
|
|
|
|
|
|
-func newPublicMcpEndpoint(key string, t model.PublicMCPType) mcpproxy.EndpointProvider {
|
|
|
+func newPublicMcpEndpoint(key string, t model.PublicMCPType) EndpointProvider {
|
|
|
return &publicMcpEndpointProvider{
|
|
|
key: key,
|
|
|
t: t,
|
|
|
@@ -117,6 +122,92 @@ func (r *redisStoreManager) Delete(session string) {
|
|
|
r.rdb.Del(ctx, "mcp:session:"+session)
|
|
|
}
|
|
|
|
|
|
+type mcpClient2Server struct {
|
|
|
+ client transport.Interface
|
|
|
+}
|
|
|
+
|
|
|
+type JSONRPCNoErrorResponse struct {
|
|
|
+ JSONRPC string `json:"jsonrpc"`
|
|
|
+ ID mcp.RequestId `json:"id"`
|
|
|
+ Result json.RawMessage `json:"result"`
|
|
|
+}
|
|
|
+
|
|
|
+func handleError(err error) mcp.JSONRPCMessage {
|
|
|
+ return mcp.JSONRPCError{
|
|
|
+ JSONRPC: mcp.JSONRPC_VERSION,
|
|
|
+ ID: mcp.NewRequestId(nil),
|
|
|
+ Error: struct {
|
|
|
+ Code int `json:"code"`
|
|
|
+ Message string `json:"message"`
|
|
|
+ Data any `json:"data,omitempty"`
|
|
|
+ }{
|
|
|
+ Code: mcp.INTERNAL_ERROR,
|
|
|
+ Message: err.Error(),
|
|
|
+ },
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (s *mcpClient2Server) HandleMessage(
|
|
|
+ ctx context.Context,
|
|
|
+ message json.RawMessage,
|
|
|
+) mcp.JSONRPCMessage {
|
|
|
+ methodNode, err := sonic.Get(message, "method")
|
|
|
+ if err != nil {
|
|
|
+ return handleError(err)
|
|
|
+ }
|
|
|
+ method, err := methodNode.String()
|
|
|
+ if err != nil {
|
|
|
+ return handleError(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ switch method {
|
|
|
+ case "notifications/initialized":
|
|
|
+ req := mcp.JSONRPCNotification{}
|
|
|
+ err := sonic.Unmarshal(message, &req)
|
|
|
+ if err != nil {
|
|
|
+ return handleError(err)
|
|
|
+ }
|
|
|
+ err = s.client.SendNotification(ctx, req)
|
|
|
+ if err != nil {
|
|
|
+ return handleError(err)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ default:
|
|
|
+ req := transport.JSONRPCRequest{}
|
|
|
+ err := sonic.Unmarshal(message, &req)
|
|
|
+ if err != nil {
|
|
|
+ return handleError(err)
|
|
|
+ }
|
|
|
+ resp, err := s.client.SendRequest(ctx, req)
|
|
|
+ if err != nil {
|
|
|
+ return mcp.JSONRPCError{
|
|
|
+ JSONRPC: mcp.JSONRPC_VERSION,
|
|
|
+ ID: mcp.NewRequestId(nil),
|
|
|
+ Error: struct {
|
|
|
+ Code int `json:"code"`
|
|
|
+ Message string `json:"message"`
|
|
|
+ Data any `json:"data,omitempty"`
|
|
|
+ }{
|
|
|
+ Code: mcp.INTERNAL_ERROR,
|
|
|
+ Message: err.Error(),
|
|
|
+ },
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if resp.Error != nil {
|
|
|
+ return resp
|
|
|
+ }
|
|
|
+ return &JSONRPCNoErrorResponse{
|
|
|
+ JSONRPC: resp.JSONRPC,
|
|
|
+ ID: resp.ID,
|
|
|
+ Result: resp.Result,
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func wrapMCPClient2Server(client transport.Interface) mcpproxy.MCPServer {
|
|
|
+ return &mcpClient2Server{client: client}
|
|
|
+}
|
|
|
+
|
|
|
// PublicMCPSseServer godoc
|
|
|
//
|
|
|
// @Summary Public MCP SSE Server
|
|
|
@@ -124,28 +215,54 @@ func (r *redisStoreManager) Delete(session string) {
|
|
|
// @Router /mcp/public/{id}/sse [get]
|
|
|
func PublicMCPSseServer(c *gin.Context) {
|
|
|
mcpID := c.Param("id")
|
|
|
+ if mcpID == "" {
|
|
|
+ http.Error(c.Writer, "mcp id is required", http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
publicMcp, err := model.CacheGetPublicMCP(mcpID)
|
|
|
if err != nil {
|
|
|
- c.JSON(http.StatusBadRequest, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INVALID_REQUEST,
|
|
|
- err.Error(),
|
|
|
- ))
|
|
|
+ http.Error(c.Writer, err.Error(), http.StatusBadRequest)
|
|
|
return
|
|
|
}
|
|
|
if publicMcp.Status != model.PublicMCPStatusEnabled {
|
|
|
- c.JSON(http.StatusNotFound, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INVALID_REQUEST,
|
|
|
- "mcp is not enabled",
|
|
|
- ))
|
|
|
+ http.Error(c.Writer, "mcp is not enabled", http.StatusBadRequest)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
switch publicMcp.Type {
|
|
|
case model.PublicMCPTypeProxySSE:
|
|
|
- handlePublicProxySSE(c, publicMcp.ID, publicMcp.ProxyConfig)
|
|
|
+ client, err := transport.NewSSE(
|
|
|
+ publicMcp.ProxyConfig.URL,
|
|
|
+ transport.WithHeaders(publicMcp.ProxyConfig.Headers),
|
|
|
+ )
|
|
|
+ if err != nil {
|
|
|
+ http.Error(c.Writer, err.Error(), http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ err = client.Start(c.Request.Context())
|
|
|
+ if err != nil {
|
|
|
+ http.Error(c.Writer, err.Error(), http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer client.Close()
|
|
|
+ handleSSEMCPServer(c, wrapMCPClient2Server(client), model.PublicMCPTypeProxySSE)
|
|
|
+ case model.PublicMCPTypeProxyStreamable:
|
|
|
+ client, err := transport.NewStreamableHTTP(
|
|
|
+ publicMcp.ProxyConfig.URL,
|
|
|
+ transport.WithHTTPHeaders(publicMcp.ProxyConfig.Headers),
|
|
|
+ )
|
|
|
+ if err != nil {
|
|
|
+ http.Error(c.Writer, err.Error(), http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ err = client.Start(c.Request.Context())
|
|
|
+ if err != nil {
|
|
|
+ http.Error(c.Writer, err.Error(), http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer client.Close()
|
|
|
+ handleSSEMCPServer(c, wrapMCPClient2Server(client), model.PublicMCPTypeProxyStreamable)
|
|
|
case model.PublicMCPTypeOpenAPI:
|
|
|
server, err := newOpenAPIMCPServer(publicMcp.OpenAPIConfig)
|
|
|
if err != nil {
|
|
|
@@ -196,55 +313,6 @@ func handlePublicEmbedMCP(c *gin.Context, mcpID string, config *model.MCPEmbeddi
|
|
|
handleSSEMCPServer(c, server, model.PublicMCPTypeEmbed)
|
|
|
}
|
|
|
|
|
|
-// handlePublicProxySSE processes SSE proxy requests
|
|
|
-func handlePublicProxySSE(c *gin.Context, mcpID string, config *model.PublicMCPProxyConfig) {
|
|
|
- if config == nil || config.URL == "" {
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- backendURL, err := url.Parse(config.URL)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusBadRequest, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INVALID_REQUEST,
|
|
|
- err.Error(),
|
|
|
- ))
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- headers := make(map[string]string)
|
|
|
- backendQuery := &url.Values{}
|
|
|
- group := middleware.GetGroup(c)
|
|
|
- token := middleware.GetToken(c)
|
|
|
-
|
|
|
- // Process reusing parameters if any
|
|
|
- if err := processReusingParams(config.ReusingParams, mcpID, group.ID, headers, backendQuery); err != nil {
|
|
|
- c.JSON(http.StatusBadRequest, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INVALID_REQUEST,
|
|
|
- err.Error(),
|
|
|
- ))
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- for k, v := range config.Headers {
|
|
|
- headers[k] = v
|
|
|
- }
|
|
|
- for k, v := range config.Querys {
|
|
|
- backendQuery.Set(k, v)
|
|
|
- }
|
|
|
-
|
|
|
- backendURL.RawQuery = backendQuery.Encode()
|
|
|
- mcpproxy.SSEHandler(
|
|
|
- c.Writer,
|
|
|
- c.Request,
|
|
|
- getStore(),
|
|
|
- newPublicMcpEndpoint(token.Key, model.PublicMCPTypeProxySSE),
|
|
|
- backendURL.String(),
|
|
|
- headers,
|
|
|
- )
|
|
|
-}
|
|
|
-
|
|
|
// newOpenAPIMCPServer creates a new MCP server from OpenAPI configuration
|
|
|
func newOpenAPIMCPServer(config *model.MCPOpenAPIConfig) (*server.MCPServer, error) {
|
|
|
if config == nil || (config.OpenAPISpec == "" && config.OpenAPIContent == "") {
|
|
|
@@ -281,7 +349,7 @@ func newOpenAPIMCPServer(config *model.MCPOpenAPIConfig) (*server.MCPServer, err
|
|
|
}
|
|
|
|
|
|
// handleSSEMCPServer handles the SSE connection for an MCP server
|
|
|
-func handleSSEMCPServer(c *gin.Context, s *server.MCPServer, mcpType model.PublicMCPType) {
|
|
|
+func handleSSEMCPServer(c *gin.Context, s mcpproxy.MCPServer, mcpType model.PublicMCPType) {
|
|
|
token := middleware.GetToken(c)
|
|
|
|
|
|
// Store the session
|
|
|
@@ -289,9 +357,9 @@ func handleSSEMCPServer(c *gin.Context, s *server.MCPServer, mcpType model.Publi
|
|
|
newSession := store.New()
|
|
|
|
|
|
newEndpoint := newPublicMcpEndpoint(token.Key, mcpType).NewEndpoint(newSession)
|
|
|
- server := statelessmcp.NewSSEServer(
|
|
|
+ server := mcpproxy.NewSSEServer(
|
|
|
s,
|
|
|
- statelessmcp.WithMessageEndpoint(newEndpoint),
|
|
|
+ mcpproxy.WithMessageEndpoint(newEndpoint),
|
|
|
)
|
|
|
|
|
|
store.Set(newSession, string(mcpType))
|
|
|
@@ -306,7 +374,7 @@ func handleSSEMCPServer(c *gin.Context, s *server.MCPServer, mcpType model.Publi
|
|
|
go processMCPSseMpscMessages(ctx, newSession, server)
|
|
|
|
|
|
// Handle SSE connection
|
|
|
- server.HandleSSE(c.Writer, c.Request)
|
|
|
+ server.ServeHTTP(c.Writer, c.Request)
|
|
|
}
|
|
|
|
|
|
// parseOpenAPIFromURL parses OpenAPI spec from a URL
|
|
|
@@ -338,7 +406,7 @@ func parseOpenAPIFromContent(config *model.MCPOpenAPIConfig, parser *convert.Par
|
|
|
func processMCPSseMpscMessages(
|
|
|
ctx context.Context,
|
|
|
sessionID string,
|
|
|
- server *statelessmcp.SSEServer,
|
|
|
+ server *mcpproxy.SSEServer,
|
|
|
) {
|
|
|
mpscInstance := getMCPMpsc()
|
|
|
for {
|
|
|
@@ -401,75 +469,47 @@ func processReusingParams(
|
|
|
// @Security ApiKeyAuth
|
|
|
// @Router /mcp/public/message [post]
|
|
|
func PublicMCPMessage(c *gin.Context) {
|
|
|
- token := middleware.GetToken(c)
|
|
|
mcpTypeStr, _ := c.GetQuery("type")
|
|
|
if mcpTypeStr == "" {
|
|
|
- c.JSON(http.StatusBadRequest, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INVALID_REQUEST,
|
|
|
- "missing mcp type",
|
|
|
- ))
|
|
|
+ http.Error(c.Writer, "missing mcp type", http.StatusBadRequest)
|
|
|
return
|
|
|
}
|
|
|
mcpType := model.PublicMCPType(mcpTypeStr)
|
|
|
sessionID, _ := c.GetQuery("sessionId")
|
|
|
if sessionID == "" {
|
|
|
- c.JSON(http.StatusBadRequest, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INVALID_REQUEST,
|
|
|
- "missing sessionId",
|
|
|
- ))
|
|
|
+ http.Error(c.Writer, "missing sessionId", http.StatusBadRequest)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
switch mcpType {
|
|
|
case model.PublicMCPTypeProxySSE:
|
|
|
- mcpproxy.SSEProxyHandler(
|
|
|
- c.Writer,
|
|
|
- c.Request,
|
|
|
- getStore(),
|
|
|
- newPublicMcpEndpoint(token.Key, mcpType),
|
|
|
- )
|
|
|
+ sendMCPSSEMessage(c, mcpTypeStr, sessionID)
|
|
|
+ case model.PublicMCPTypeProxyStreamable:
|
|
|
+ sendMCPSSEMessage(c, mcpTypeStr, sessionID)
|
|
|
case model.PublicMCPTypeOpenAPI:
|
|
|
sendMCPSSEMessage(c, mcpTypeStr, sessionID)
|
|
|
case model.PublicMCPTypeEmbed:
|
|
|
sendMCPSSEMessage(c, mcpTypeStr, sessionID)
|
|
|
default:
|
|
|
- c.JSON(http.StatusBadRequest, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INVALID_REQUEST,
|
|
|
- "unknown mcp type",
|
|
|
- ))
|
|
|
+ http.Error(c.Writer, "unknown mcp type", http.StatusBadRequest)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func sendMCPSSEMessage(c *gin.Context, mcpType, sessionID string) {
|
|
|
backend, ok := getStore().Get(sessionID)
|
|
|
if !ok || backend != mcpType {
|
|
|
- c.JSON(http.StatusBadRequest, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INVALID_REQUEST,
|
|
|
- "invalid session",
|
|
|
- ))
|
|
|
+ http.Error(c.Writer, "invalid session", http.StatusBadRequest)
|
|
|
return
|
|
|
}
|
|
|
mpscInstance := getMCPMpsc()
|
|
|
body, err := io.ReadAll(c.Request.Body)
|
|
|
if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INTERNAL_ERROR,
|
|
|
- err.Error(),
|
|
|
- ))
|
|
|
+ http.Error(c.Writer, err.Error(), http.StatusInternalServerError)
|
|
|
return
|
|
|
}
|
|
|
err = mpscInstance.send(c.Request.Context(), sessionID, body)
|
|
|
if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, CreateMCPErrorResponse(
|
|
|
- mcp.NewRequestId(nil),
|
|
|
- mcp.INTERNAL_ERROR,
|
|
|
- err.Error(),
|
|
|
- ))
|
|
|
+ http.Error(c.Writer, err.Error(), http.StatusInternalServerError)
|
|
|
return
|
|
|
}
|
|
|
c.Writer.WriteHeader(http.StatusAccepted)
|
|
|
@@ -578,12 +618,11 @@ func handlePublicProxyStreamable(c *gin.Context, mcpID string, config *model.Pub
|
|
|
}
|
|
|
|
|
|
headers := make(map[string]string)
|
|
|
- backendQuery := &url.Values{}
|
|
|
+ backendQuery := backendURL.Query()
|
|
|
group := middleware.GetGroup(c)
|
|
|
- token := middleware.GetToken(c)
|
|
|
|
|
|
// Process reusing parameters if any
|
|
|
- if err := processReusingParams(config.ReusingParams, mcpID, group.ID, headers, backendQuery); err != nil {
|
|
|
+ if err := processReusingParams(config.ReusingParams, mcpID, group.ID, headers, &backendQuery); err != nil {
|
|
|
c.JSON(http.StatusBadRequest, CreateMCPErrorResponse(
|
|
|
mcp.NewRequestId(nil),
|
|
|
mcp.INVALID_REQUEST,
|
|
|
@@ -600,15 +639,6 @@ func handlePublicProxyStreamable(c *gin.Context, mcpID string, config *model.Pub
|
|
|
}
|
|
|
|
|
|
backendURL.RawQuery = backendQuery.Encode()
|
|
|
- mcpproxy.SSEHandler(
|
|
|
- c.Writer,
|
|
|
- c.Request,
|
|
|
- getStore(),
|
|
|
- newPublicMcpEndpoint(token.Key, model.PublicMCPTypeProxySSE),
|
|
|
- backendURL.String(),
|
|
|
- headers,
|
|
|
- )
|
|
|
-
|
|
|
mcpproxy.NewStreamableProxy(backendURL.String(), headers, getStore()).
|
|
|
ServeHTTP(c.Writer, c.Request)
|
|
|
}
|
|
|
@@ -623,8 +653,8 @@ func handleStreamableMCPServer(c *gin.Context, s *server.MCPServer) {
|
|
|
))
|
|
|
return
|
|
|
}
|
|
|
- var rawMessage json.RawMessage
|
|
|
- if err := sonic.ConfigDefault.NewDecoder(c.Request.Body).Decode(&rawMessage); err != nil {
|
|
|
+ reqBody, err := io.ReadAll(c.Request.Body)
|
|
|
+ if err != nil {
|
|
|
c.JSON(http.StatusBadRequest, CreateMCPErrorResponse(
|
|
|
mcp.NewRequestId(nil),
|
|
|
mcp.PARSE_ERROR,
|
|
|
@@ -632,7 +662,12 @@ func handleStreamableMCPServer(c *gin.Context, s *server.MCPServer) {
|
|
|
))
|
|
|
return
|
|
|
}
|
|
|
- respMessage := s.HandleMessage(c.Request.Context(), rawMessage)
|
|
|
+ respMessage := s.HandleMessage(c.Request.Context(), reqBody)
|
|
|
+ if respMessage == nil {
|
|
|
+ // For notifications, just send 202 Accepted with no body
|
|
|
+ c.Status(http.StatusAccepted)
|
|
|
+ return
|
|
|
+ }
|
|
|
c.JSON(http.StatusOK, respMessage)
|
|
|
}
|
|
|
|
|
|
@@ -680,16 +715,16 @@ func newChannelMCPMpsc() *channelMCPMpsc {
|
|
|
}
|
|
|
|
|
|
// cleanupExpiredChannels periodically checks for and removes channels that haven't been accessed in
|
|
|
-// 5 minutes
|
|
|
+// 15 seconds
|
|
|
func (c *channelMCPMpsc) cleanupExpiredChannels() {
|
|
|
- ticker := time.NewTicker(1 * time.Minute)
|
|
|
+ ticker := time.NewTicker(15 * time.Second)
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
for range ticker.C {
|
|
|
c.channelMutex.Lock()
|
|
|
now := time.Now()
|
|
|
for id, lastAccess := range c.lastAccess {
|
|
|
- if now.Sub(lastAccess) > 5*time.Minute {
|
|
|
+ if now.Sub(lastAccess) > 15*time.Second {
|
|
|
// Close and delete the channel
|
|
|
if ch, exists := c.channels[id]; exists {
|
|
|
close(ch)
|
|
|
@@ -765,11 +800,11 @@ func newRedisMCPMPSC(rdb *redis.Client) *redisMCPMPSC {
|
|
|
}
|
|
|
|
|
|
func (r *redisMCPMPSC) send(ctx context.Context, id string, data []byte) error {
|
|
|
- // Set expiration to 5 minutes when sending data
|
|
|
+ // Set expiration to 15 seconds when sending data
|
|
|
id = "mcp:mpsc:" + id
|
|
|
pipe := r.rdb.Pipeline()
|
|
|
pipe.LPush(ctx, id, data)
|
|
|
- pipe.Expire(ctx, id, 5*time.Minute)
|
|
|
+ pipe.Expire(ctx, id, 15*time.Second)
|
|
|
_, err := pipe.Exec(ctx)
|
|
|
return err
|
|
|
}
|