| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- package controller
- import (
- "context"
- "errors"
- "net/http"
- "net/url"
- "github.com/gin-gonic/gin"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/mcpproxy"
- "github.com/labring/aiproxy/core/model"
- mcpservers "github.com/labring/aiproxy/mcp-servers"
- "github.com/labring/aiproxy/openapi-mcp/convert"
- "github.com/mark3labs/mcp-go/mcp"
- )
- type EndpointProvider interface {
- NewEndpoint(newSession string) (newEndpoint string)
- LoadEndpoint(endpoint string) (session string)
- }
- // handleSSEMCPServer handles the SSE connection for an MCP server
- func handleSSEMCPServer(
- c *gin.Context,
- s mcpservers.Server,
- mcpType string,
- endpoint EndpointProvider,
- ) {
- // Store the session
- store := getStore()
- newSession := store.New()
- newEndpoint := endpoint.NewEndpoint(newSession)
- server := mcpproxy.NewSSEServer(
- s,
- mcpproxy.WithMessageEndpoint(newEndpoint),
- )
- store.Set(newSession, mcpType)
- defer func() {
- store.Delete(newSession)
- }()
- ctx, cancel := context.WithCancel(c.Request.Context())
- defer cancel()
- // Start message processing goroutine
- go processMCPSSEMpscMessages(ctx, newSession, server)
- // Handle SSE connection
- server.ServeHTTP(c.Writer, c.Request)
- }
- // processMCPSSEMpscMessages handles message processing for OpenAPI
- func processMCPSSEMpscMessages(
- ctx context.Context,
- sessionID string,
- server *mcpproxy.SSEServer,
- ) {
- mpscInstance := getMCPMpsc()
- for {
- select {
- case <-ctx.Done():
- return
- default:
- data, err := mpscInstance.recv(ctx, sessionID)
- if err != nil {
- return
- }
- if err := server.HandleMessage(ctx, data); err != nil {
- continue
- }
- }
- }
- }
- func handleEmbedSSEMCP(
- c *gin.Context,
- mcpID string,
- config *model.MCPEmbeddingConfig,
- paramsFunc ParamsFunc,
- endpoint EndpointProvider,
- ) {
- reusingConfig, err := prepareEmbedReusingConfig(mcpID, paramsFunc, config.Reusing)
- if err != nil {
- http.Error(c.Writer, err.Error(), http.StatusBadRequest)
- return
- }
- server, err := mcpservers.GetMCPServer(mcpID, config.Init, reusingConfig)
- if err != nil {
- http.Error(c.Writer, err.Error(), http.StatusBadRequest)
- return
- }
- handleSSEMCPServer(c, server, string(model.PublicMCPTypeEmbed), endpoint)
- }
- // prepareEmbedReusingConfig 准备嵌入MCP的reusing配置
- func prepareEmbedReusingConfig(
- mcpID string,
- paramsFunc ParamsFunc,
- reusingParams map[string]model.ReusingParam,
- ) (map[string]string, error) {
- if len(reusingParams) == 0 {
- return nil, nil
- }
- return NewReusingParamProcessor(mcpID, paramsFunc).
- ProcessEmbedReusingParams(reusingParams)
- }
- func sendMCPSSEMessage(c *gin.Context, sessionID string) {
- _, ok := getStore().Get(sessionID)
- if !ok {
- http.Error(c.Writer, "invalid session", http.StatusBadRequest)
- return
- }
- mpscInstance := getMCPMpsc()
- body, err := common.GetRequestBody(c.Request)
- if err != nil {
- http.Error(c.Writer, err.Error(), http.StatusInternalServerError)
- return
- }
- err = mpscInstance.send(c.Request.Context(), sessionID, body)
- if err != nil {
- http.Error(c.Writer, err.Error(), http.StatusInternalServerError)
- return
- }
- c.Writer.WriteHeader(http.StatusAccepted)
- }
- // handleStreamableMCPServer handles the streamable connection for an MCP server
- func handleStreamableMCPServer(c *gin.Context, s mcpservers.Server) {
- if c.Request.Method != http.MethodPost {
- c.JSON(http.StatusMethodNotAllowed, mcpservers.CreateMCPErrorResponse(
- mcp.NewRequestId(nil),
- mcp.METHOD_NOT_FOUND,
- "method not allowed",
- ))
- return
- }
- reqBody, err := common.GetRequestBody(c.Request)
- if err != nil {
- c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
- mcp.NewRequestId(nil),
- mcp.PARSE_ERROR,
- err.Error(),
- ))
- return
- }
- respMessage := s.HandleMessage(c.Request.Context(), reqBody)
- if respMessage == nil {
- // For notifications, just send 202 Accepted with no body
- c.Status(http.StatusAccepted)
- return
- }
- c.JSON(http.StatusOK, respMessage)
- }
- func handleGroupStreamable(c *gin.Context, groupMcp *model.GroupMCPCache) {
- switch groupMcp.Type {
- case model.GroupMCPTypeProxyStreamable:
- handleGroupProxyStreamable(c, groupMcp.ProxyConfig)
- case model.GroupMCPTypeOpenAPI:
- server, err := newOpenAPIMCPServer(groupMcp.OpenAPIConfig)
- if err != nil {
- c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
- mcp.NewRequestId(nil),
- mcp.INVALID_REQUEST,
- err.Error(),
- ))
- return
- }
- handleStreamableMCPServer(c, server)
- default:
- c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
- mcp.NewRequestId(nil),
- mcp.INVALID_REQUEST,
- "unsupported mcp type",
- ))
- }
- }
- // newOpenAPIMCPServer creates a new MCP server from OpenAPI configuration
- func newOpenAPIMCPServer(config *model.MCPOpenAPIConfig) (mcpservers.Server, error) {
- if config == nil || (config.OpenAPISpec == "" && config.OpenAPIContent == "") {
- return nil, errors.New("invalid OpenAPI configuration")
- }
- // Parse OpenAPI specification
- parser := convert.NewParser()
- var (
- err error
- openAPIFrom string
- )
- if config.OpenAPISpec != "" {
- openAPIFrom, err = parseOpenAPIFromURL(config, parser)
- } else {
- err = parseOpenAPIFromContent(config, parser)
- }
- if err != nil {
- return nil, err
- }
- // Convert to MCP server
- converter := convert.NewConverter(parser, convert.Options{
- OpenAPIFrom: openAPIFrom,
- ServerAddr: config.ServerAddr,
- Authorization: config.Authorization,
- })
- s, err := converter.Convert()
- if err != nil {
- return nil, err
- }
- return s, nil
- }
- // parseOpenAPIFromURL parses OpenAPI spec from a URL
- func parseOpenAPIFromURL(config *model.MCPOpenAPIConfig, parser *convert.Parser) (string, error) {
- spec, err := url.Parse(config.OpenAPISpec)
- if err != nil || (spec.Scheme != "http" && spec.Scheme != "https") {
- return "", errors.New("invalid OpenAPI spec URL")
- }
- openAPIFrom := spec.String()
- if config.V2 {
- err = parser.ParseFileV2(openAPIFrom)
- } else {
- err = parser.ParseFile(openAPIFrom)
- }
- return openAPIFrom, err
- }
- // parseOpenAPIFromContent parses OpenAPI spec from content string
- func parseOpenAPIFromContent(config *model.MCPOpenAPIConfig, parser *convert.Parser) error {
- if config.V2 {
- return parser.ParseV2([]byte(config.OpenAPIContent))
- }
- return parser.Parse([]byte(config.OpenAPIContent))
- }
- // sseEndpointProvider implements the EndpointProvider interface for MCP
- type sseEndpointProvider struct{}
- var sseEndpoint = &sseEndpointProvider{}
- func (m *sseEndpointProvider) NewEndpoint(session string) (newEndpoint string) {
- endpoint := "/message?sessionId=" + session
- return endpoint
- }
- func (m *sseEndpointProvider) LoadEndpoint(endpoint string) (session string) {
- parsedURL, err := url.Parse(endpoint)
- if err != nil {
- return ""
- }
- return parsedURL.Query().Get("sessionId")
- }
- // MCPMessage godoc
- //
- // @Summary MCP SSE Message
- // @Router /message [post]
- func MCPMessage(c *gin.Context) {
- sessionID, _ := c.GetQuery("sessionId")
- if sessionID == "" {
- http.Error(c.Writer, "missing sessionId", http.StatusBadRequest)
- return
- }
- sendMCPSSEMessage(c, sessionID)
- }
|