host.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/url"
  6. "strings"
  7. "github.com/gin-gonic/gin"
  8. "github.com/labring/aiproxy/core/common"
  9. "github.com/labring/aiproxy/core/common/config"
  10. "github.com/labring/aiproxy/core/middleware"
  11. "github.com/labring/aiproxy/core/model"
  12. mcpservers "github.com/labring/aiproxy/mcp-servers"
  13. "github.com/mark3labs/mcp-go/mcp"
  14. )
  15. // hostMcpEndpointProvider implements the EndpointProvider interface for MCP
  16. type hostMcpEndpointProvider struct {
  17. key string
  18. t string
  19. }
  20. func newHostMcpEndpoint(key, t string) EndpointProvider {
  21. return &hostMcpEndpointProvider{
  22. key: key,
  23. t: t,
  24. }
  25. }
  26. func (m *hostMcpEndpointProvider) NewEndpoint(session string) (newEndpoint string) {
  27. endpoint := fmt.Sprintf("/message?sessionId=%s&key=%s&type=%s", session, m.key, m.t)
  28. return endpoint
  29. }
  30. func (m *hostMcpEndpointProvider) LoadEndpoint(endpoint string) (session string) {
  31. parsedURL, err := url.Parse(endpoint)
  32. if err != nil {
  33. return ""
  34. }
  35. return parsedURL.Query().Get("sessionId")
  36. }
  37. func routeHostMCP(
  38. c *gin.Context,
  39. publicHandler, groupHandler func(c *gin.Context, mcpID string),
  40. ) {
  41. log := common.GetLogger(c)
  42. host := c.Request.Host
  43. log.Debugf("route host mcp: %s", host)
  44. publicMCPHost := config.GetPublicMCPHost()
  45. groupMCPHost := config.GetGroupMCPHost()
  46. switch {
  47. case publicMCPHost != "" && strings.HasSuffix(host, publicMCPHost):
  48. mcpID := strings.TrimSuffix(host, "."+publicMCPHost)
  49. publicHandler(c, mcpID)
  50. case groupMCPHost != "" && strings.HasSuffix(host, groupMCPHost):
  51. mcpID := strings.TrimSuffix(host, "."+groupMCPHost)
  52. groupHandler(c, mcpID)
  53. default:
  54. http.Error(c.Writer, "invalid host", http.StatusNotFound)
  55. }
  56. }
  57. // HostMCPSSEServer godoc
  58. //
  59. // @Summary Public MCP SSE Server
  60. // @Security ApiKeyAuth
  61. // @Router /sse [get]
  62. func HostMCPSSEServer(c *gin.Context) {
  63. routeHostMCP(c, func(c *gin.Context, mcpID string) {
  64. publicMcp, err := model.CacheGetPublicMCP(mcpID)
  65. if err != nil {
  66. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  67. return
  68. }
  69. if publicMcp.Status != model.PublicMCPStatusEnabled {
  70. http.Error(c.Writer, "mcp is not enabled", http.StatusBadRequest)
  71. return
  72. }
  73. token := middleware.GetToken(c)
  74. endpoint := newHostMcpEndpoint(token.Key, string(publicMcp.Type))
  75. handlePublicSSEMCP(c, publicMcp, endpoint)
  76. }, func(c *gin.Context, mcpID string) {
  77. group := middleware.GetGroup(c)
  78. groupMcp, err := model.CacheGetGroupMCP(group.ID, mcpID)
  79. if err != nil {
  80. http.Error(c.Writer, err.Error(), http.StatusNotFound)
  81. return
  82. }
  83. if groupMcp.Status != model.GroupMCPStatusEnabled {
  84. http.Error(c.Writer, "mcp is not enabled", http.StatusNotFound)
  85. return
  86. }
  87. token := middleware.GetToken(c)
  88. endpoint := newHostMcpEndpoint(token.Key, string(groupMcp.Type))
  89. handleGroupSSEMCPServer(c, groupMcp, endpoint)
  90. })
  91. }
  92. // HostMCPMessage godoc
  93. //
  94. // @Summary Public MCP SSE Server
  95. // @Security ApiKeyAuth
  96. // @Router /message [post]
  97. func HostMCPMessage(c *gin.Context) {
  98. routeHostMCP(c, func(c *gin.Context, _ string) {
  99. mcpTypeStr, _ := c.GetQuery("type")
  100. if mcpTypeStr == "" {
  101. http.Error(c.Writer, "missing mcp type", http.StatusBadRequest)
  102. return
  103. }
  104. mcpType := model.PublicMCPType(mcpTypeStr)
  105. sessionID, _ := c.GetQuery("sessionId")
  106. if sessionID == "" {
  107. http.Error(c.Writer, "missing sessionId", http.StatusBadRequest)
  108. return
  109. }
  110. handlePublicSSEMessage(c, mcpType, sessionID)
  111. }, func(c *gin.Context, _ string) {
  112. mcpTypeStr, _ := c.GetQuery("type")
  113. if mcpTypeStr == "" {
  114. http.Error(c.Writer, "missing mcp type", http.StatusBadRequest)
  115. return
  116. }
  117. mcpType := model.GroupMCPType(mcpTypeStr)
  118. sessionID, _ := c.GetQuery("sessionId")
  119. if sessionID == "" {
  120. http.Error(c.Writer, "missing sessionId", http.StatusBadRequest)
  121. return
  122. }
  123. handleGroupSSEMessage(c, mcpType, sessionID)
  124. })
  125. }
  126. // HostMCPStreamable godoc
  127. //
  128. // @Summary Host MCP Streamable Server
  129. // @Security ApiKeyAuth
  130. // @Router /mcp [get]
  131. // @Router /mcp [post]
  132. // @Router /mcp [delete]
  133. func HostMCPStreamable(c *gin.Context) {
  134. routeHostMCP(c, func(c *gin.Context, mcpID string) {
  135. publicMcp, err := model.CacheGetPublicMCP(mcpID)
  136. if err != nil {
  137. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  138. mcp.NewRequestId(nil),
  139. mcp.INVALID_REQUEST,
  140. err.Error(),
  141. ))
  142. return
  143. }
  144. if publicMcp.Status != model.PublicMCPStatusEnabled {
  145. c.JSON(http.StatusNotFound, mcpservers.CreateMCPErrorResponse(
  146. mcp.NewRequestId(nil),
  147. mcp.INVALID_REQUEST,
  148. "mcp is not enabled",
  149. ))
  150. return
  151. }
  152. handlePublicSSEStreamable(c, publicMcp)
  153. }, func(c *gin.Context, mcpID string) {
  154. group := middleware.GetGroup(c)
  155. groupMcp, err := model.CacheGetGroupMCP(group.ID, mcpID)
  156. if err != nil {
  157. c.JSON(http.StatusNotFound, mcpservers.CreateMCPErrorResponse(
  158. mcp.NewRequestId(nil),
  159. mcp.INVALID_REQUEST,
  160. err.Error(),
  161. ))
  162. return
  163. }
  164. if groupMcp.Status != model.GroupMCPStatusEnabled {
  165. c.JSON(http.StatusNotFound, mcpservers.CreateMCPErrorResponse(
  166. mcp.NewRequestId(nil),
  167. mcp.INVALID_REQUEST,
  168. "mcp is not enabled",
  169. ))
  170. return
  171. }
  172. handleGroupSSEStreamable(c, groupMcp)
  173. })
  174. }