publicmcp-server.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/url"
  6. "github.com/gin-gonic/gin"
  7. "github.com/labring/aiproxy/core/mcpproxy"
  8. "github.com/labring/aiproxy/core/middleware"
  9. "github.com/labring/aiproxy/core/model"
  10. mcpservers "github.com/labring/aiproxy/mcp-servers"
  11. "github.com/mark3labs/mcp-go/client/transport"
  12. "github.com/mark3labs/mcp-go/mcp"
  13. )
  14. // PublicMCPSSEServer godoc
  15. //
  16. // @Summary Public MCP SSE Server
  17. // @Security ApiKeyAuth
  18. // @Router /mcp/public/{id}/sse [get]
  19. func PublicMCPSSEServer(c *gin.Context) {
  20. mcpID := c.Param("id")
  21. if mcpID == "" {
  22. http.Error(c.Writer, "mcp id is required", http.StatusBadRequest)
  23. return
  24. }
  25. publicMcp, err := model.CacheGetPublicMCP(mcpID)
  26. if err != nil {
  27. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  28. return
  29. }
  30. if publicMcp.Status != model.PublicMCPStatusEnabled {
  31. http.Error(c.Writer, "mcp is not enabled", http.StatusBadRequest)
  32. return
  33. }
  34. group := middleware.GetGroup(c)
  35. paramsFunc := newGroupParams(publicMcp.ID, group.ID)
  36. handlePublicSSEMCP(c, publicMcp, paramsFunc, sseEndpoint)
  37. }
  38. func handlePublicSSEMCP(
  39. c *gin.Context,
  40. publicMcp *model.PublicMCPCache,
  41. paramsFunc ParamsFunc,
  42. endpoint EndpointProvider,
  43. ) {
  44. switch publicMcp.Type {
  45. case model.PublicMCPTypeProxySSE:
  46. if err := handlePublicProxySSE(c, publicMcp, paramsFunc, endpoint); err != nil {
  47. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  48. return
  49. }
  50. case model.PublicMCPTypeProxyStreamable:
  51. if err := handlePublicProxyStreamableSSE(c, publicMcp, paramsFunc, endpoint); err != nil {
  52. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  53. return
  54. }
  55. case model.PublicMCPTypeOpenAPI:
  56. server, err := newOpenAPIMCPServer(publicMcp.OpenAPIConfig)
  57. if err != nil {
  58. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  59. return
  60. }
  61. handleSSEMCPServer(c, server, string(model.PublicMCPTypeOpenAPI), endpoint)
  62. case model.PublicMCPTypeEmbed:
  63. handleEmbedSSEMCP(c, publicMcp.ID, publicMcp.EmbedConfig, paramsFunc, endpoint)
  64. default:
  65. http.Error(c.Writer, "unknown mcp type", http.StatusBadRequest)
  66. }
  67. }
  68. // handlePublicProxySSE 处理公共代理SSE
  69. func handlePublicProxySSE(
  70. c *gin.Context,
  71. publicMcp *model.PublicMCPCache,
  72. paramsFunc ParamsFunc,
  73. endpoint EndpointProvider,
  74. ) error {
  75. client, err := createProxySSEClient(c, publicMcp, paramsFunc)
  76. if err != nil {
  77. return err
  78. }
  79. defer client.Close()
  80. handleSSEMCPServer(
  81. c,
  82. mcpservers.WrapMCPClient2Server(client),
  83. string(model.PublicMCPTypeProxySSE),
  84. endpoint,
  85. )
  86. return nil
  87. }
  88. // handlePublicProxyStreamableSSE 处理公共代理Streamable SSE
  89. func handlePublicProxyStreamableSSE(
  90. c *gin.Context,
  91. publicMcp *model.PublicMCPCache,
  92. paramsFunc ParamsFunc,
  93. endpoint EndpointProvider,
  94. ) error {
  95. client, err := createProxyStreamableClient(c, publicMcp, paramsFunc)
  96. if err != nil {
  97. return err
  98. }
  99. defer client.Close()
  100. handleSSEMCPServer(
  101. c,
  102. mcpservers.WrapMCPClient2Server(client),
  103. string(model.PublicMCPTypeProxyStreamable),
  104. endpoint,
  105. )
  106. return nil
  107. }
  108. // createProxySSEClient 创建代理SSE客户端
  109. func createProxySSEClient(
  110. c *gin.Context,
  111. publicMcp *model.PublicMCPCache,
  112. paramsFunc ParamsFunc,
  113. ) (transport.Interface, error) {
  114. url, headers, err := prepareProxyConfig(publicMcp, paramsFunc)
  115. if err != nil {
  116. return nil, err
  117. }
  118. client, err := transport.NewSSE(url, transport.WithHeaders(headers))
  119. if err != nil {
  120. return nil, err
  121. }
  122. if err := client.Start(c.Request.Context()); err != nil {
  123. return nil, err
  124. }
  125. return client, nil
  126. }
  127. // createProxyStreamableClient 创建代理Streamable客户端
  128. func createProxyStreamableClient(
  129. c *gin.Context,
  130. publicMcp *model.PublicMCPCache,
  131. paramsFunc ParamsFunc,
  132. ) (transport.Interface, error) {
  133. url, headers, err := prepareProxyConfig(publicMcp, paramsFunc)
  134. if err != nil {
  135. return nil, err
  136. }
  137. client, err := transport.NewStreamableHTTP(url, transport.WithHTTPHeaders(headers))
  138. if err != nil {
  139. return nil, err
  140. }
  141. if err := client.Start(c.Request.Context()); err != nil {
  142. return nil, err
  143. }
  144. return client, nil
  145. }
  146. // prepareProxyConfig 准备代理配置
  147. func prepareProxyConfig(
  148. publicMcp *model.PublicMCPCache,
  149. paramsFunc ParamsFunc,
  150. ) (string, map[string]string, error) {
  151. url, err := url.Parse(publicMcp.ProxyConfig.URL)
  152. if err != nil {
  153. return "", nil, fmt.Errorf("invalid proxy URL: %w", err)
  154. }
  155. headers := make(map[string]string)
  156. backendQuery := url.Query()
  157. if len(publicMcp.ProxyConfig.Reusing) > 0 {
  158. processor := NewReusingParamProcessor(publicMcp.ID, paramsFunc)
  159. if err := processor.ProcessProxyReusingParams(
  160. publicMcp.ProxyConfig.Reusing,
  161. headers,
  162. &backendQuery,
  163. ); err != nil {
  164. return "", nil, err
  165. }
  166. }
  167. for k, v := range publicMcp.ProxyConfig.Headers {
  168. headers[k] = v
  169. }
  170. url.RawQuery = backendQuery.Encode()
  171. return url.String(), headers, nil
  172. }
  173. // processProxyReusingParams handles the reusing parameters for MCP proxy
  174. func processProxyReusingParams(
  175. reusingParams map[string]model.PublicMCPProxyReusingParam,
  176. paramsFunc ParamsFunc,
  177. headers map[string]string,
  178. backendQuery *url.Values,
  179. ) error {
  180. if len(reusingParams) == 0 {
  181. return nil
  182. }
  183. params, err := paramsFunc.GetParams()
  184. if err != nil {
  185. return err
  186. }
  187. for k, v := range reusingParams {
  188. paramValue, ok := params[k]
  189. if !ok {
  190. if v.Required {
  191. return fmt.Errorf("required reusing parameter %s is missing", k)
  192. }
  193. continue
  194. }
  195. switch v.Type {
  196. case model.ParamTypeHeader:
  197. headers[k] = paramValue
  198. case model.ParamTypeQuery:
  199. backendQuery.Set(k, paramValue)
  200. case model.ParamTypeURL:
  201. return fmt.Errorf("URL parameter %s cannot be set via reusing", k)
  202. default:
  203. return fmt.Errorf("unknown param type: %s", v.Type)
  204. }
  205. }
  206. return nil
  207. }
  208. // PublicMCPStreamable godoc
  209. //
  210. // @Summary Public MCP Streamable Server
  211. // @Security ApiKeyAuth
  212. // @Router /mcp/public/{id} [get]
  213. // @Router /mcp/public/{id} [post]
  214. // @Router /mcp/public/{id} [delete]
  215. func PublicMCPStreamable(c *gin.Context) {
  216. mcpID := c.Param("id")
  217. publicMcp, err := model.CacheGetPublicMCP(mcpID)
  218. if err != nil {
  219. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  220. mcp.NewRequestId(nil),
  221. mcp.INVALID_REQUEST,
  222. err.Error(),
  223. ))
  224. return
  225. }
  226. if publicMcp.Status != model.PublicMCPStatusEnabled {
  227. c.JSON(http.StatusNotFound, mcpservers.CreateMCPErrorResponse(
  228. mcp.NewRequestId(nil),
  229. mcp.INVALID_REQUEST,
  230. "mcp is not enabled",
  231. ))
  232. return
  233. }
  234. group := middleware.GetGroup(c)
  235. paramsFunc := newGroupParams(publicMcp.ID, group.ID)
  236. handlePublicStreamable(c, publicMcp, paramsFunc)
  237. }
  238. func handlePublicStreamable(
  239. c *gin.Context,
  240. publicMcp *model.PublicMCPCache,
  241. paramsFunc ParamsFunc,
  242. ) {
  243. switch publicMcp.Type {
  244. case model.PublicMCPTypeProxySSE:
  245. client, err := createProxySSEClient(c, publicMcp, paramsFunc)
  246. if err != nil {
  247. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  248. return
  249. }
  250. defer client.Close()
  251. mcpproxy.NewStatelessStreamableHTTPServer(
  252. mcpservers.WrapMCPClient2Server(client),
  253. ).ServeHTTP(c.Writer, c.Request)
  254. case model.PublicMCPTypeProxyStreamable:
  255. handlePublicProxyStreamable(c, paramsFunc, publicMcp.ProxyConfig)
  256. case model.PublicMCPTypeOpenAPI:
  257. server, err := newOpenAPIMCPServer(publicMcp.OpenAPIConfig)
  258. if err != nil {
  259. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  260. mcp.NewRequestId(nil),
  261. mcp.INVALID_REQUEST,
  262. err.Error(),
  263. ))
  264. return
  265. }
  266. handleStreamableMCPServer(c, server)
  267. case model.PublicMCPTypeEmbed:
  268. handlePublicEmbedStreamable(c, publicMcp.ID, paramsFunc, publicMcp.EmbedConfig)
  269. default:
  270. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  271. mcp.NewRequestId(nil),
  272. mcp.INVALID_REQUEST,
  273. "unknown mcp type",
  274. ))
  275. }
  276. }
  277. func handlePublicEmbedStreamable(
  278. c *gin.Context,
  279. mcpID string,
  280. paramsFunc ParamsFunc,
  281. config *model.MCPEmbeddingConfig,
  282. ) {
  283. var reusingConfig map[string]string
  284. if len(config.Reusing) != 0 {
  285. params, err := paramsFunc.GetParams()
  286. if err != nil {
  287. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  288. mcp.NewRequestId(nil),
  289. mcp.INVALID_REQUEST,
  290. err.Error(),
  291. ))
  292. return
  293. }
  294. reusingConfig = params
  295. }
  296. server, err := mcpservers.GetMCPServer(mcpID, config.Init, reusingConfig)
  297. if err != nil {
  298. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  299. mcp.NewRequestId(nil),
  300. mcp.INVALID_REQUEST,
  301. err.Error(),
  302. ))
  303. return
  304. }
  305. handleStreamableMCPServer(c, server)
  306. }
  307. // handlePublicProxyStreamable processes Streamable proxy requests
  308. func handlePublicProxyStreamable(
  309. c *gin.Context,
  310. paramsFunc ParamsFunc,
  311. config *model.PublicMCPProxyConfig,
  312. ) {
  313. if config == nil || config.URL == "" {
  314. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  315. mcp.NewRequestId(nil),
  316. mcp.INVALID_REQUEST,
  317. "invalid proxy configuration",
  318. ))
  319. return
  320. }
  321. backendURL, err := url.Parse(config.URL)
  322. if err != nil {
  323. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  324. mcp.NewRequestId(nil),
  325. mcp.INVALID_REQUEST,
  326. err.Error(),
  327. ))
  328. return
  329. }
  330. headers := make(map[string]string)
  331. backendQuery := backendURL.Query()
  332. // Process reusing parameters if any
  333. if err := processProxyReusingParams(config.Reusing, paramsFunc, headers, &backendQuery); err != nil {
  334. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  335. mcp.NewRequestId(nil),
  336. mcp.INVALID_REQUEST,
  337. err.Error(),
  338. ))
  339. return
  340. }
  341. for k, v := range config.Headers {
  342. headers[k] = v
  343. }
  344. for k, v := range config.Querys {
  345. backendQuery.Set(k, v)
  346. }
  347. backendURL.RawQuery = backendQuery.Encode()
  348. mcpproxy.NewStreamableProxy(backendURL.String(), headers, getStore()).
  349. ServeHTTP(c.Writer, c.Request)
  350. }
  351. // TestPublicMCPSSEServer godoc
  352. //
  353. // @Summary Test Public MCP SSE Server
  354. // @Security ApiKeyAuth
  355. // @Param group path string true "Group ID"
  356. // @Param id path string true "MCP ID"
  357. // @Router /api/test-publicmcp/{group}/{id}/sse [get]
  358. func TestPublicMCPSSEServer(c *gin.Context) {
  359. mcpID := c.Param("id")
  360. if mcpID == "" {
  361. http.Error(c.Writer, "mcp id is required", http.StatusBadRequest)
  362. return
  363. }
  364. groupID := c.Param("group")
  365. if groupID == "" {
  366. http.Error(c.Writer, "group id is required", http.StatusBadRequest)
  367. return
  368. }
  369. publicMcp, err := model.CacheGetPublicMCP(mcpID)
  370. if err != nil {
  371. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  372. return
  373. }
  374. if publicMcp.Status != model.PublicMCPStatusEnabled {
  375. http.Error(c.Writer, "mcp is not enabled", http.StatusBadRequest)
  376. return
  377. }
  378. paramsFunc := newGroupParams(publicMcp.ID, groupID)
  379. handlePublicSSEMCP(c, publicMcp, paramsFunc, sseEndpoint)
  380. }