publicmcp-server.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. package controller
  2. import (
  3. "fmt"
  4. "maps"
  5. "net/http"
  6. "net/url"
  7. "github.com/gin-gonic/gin"
  8. "github.com/labring/aiproxy/core/mcpproxy"
  9. "github.com/labring/aiproxy/core/middleware"
  10. "github.com/labring/aiproxy/core/model"
  11. mcpservers "github.com/labring/aiproxy/mcp-servers"
  12. "github.com/mark3labs/mcp-go/client/transport"
  13. "github.com/mark3labs/mcp-go/mcp"
  14. )
  15. // PublicMCPSSEServer godoc
  16. //
  17. // @Summary Public MCP SSE Server
  18. // @Security ApiKeyAuth
  19. // @Router /mcp/public/{id}/sse [get]
  20. func PublicMCPSSEServer(c *gin.Context) {
  21. mcpID := c.Param("id")
  22. if mcpID == "" {
  23. http.Error(c.Writer, "mcp id is required", http.StatusBadRequest)
  24. return
  25. }
  26. publicMcp, err := model.CacheGetPublicMCP(mcpID)
  27. if err != nil {
  28. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  29. return
  30. }
  31. if publicMcp.Status != model.PublicMCPStatusEnabled {
  32. http.Error(c.Writer, "mcp is not enabled", http.StatusBadRequest)
  33. return
  34. }
  35. group := middleware.GetGroup(c)
  36. paramsFunc := newGroupParams(publicMcp.ID, group.ID)
  37. handlePublicSSEMCP(c, publicMcp, paramsFunc, sseEndpoint)
  38. }
  39. func handlePublicSSEMCP(
  40. c *gin.Context,
  41. publicMcp *model.PublicMCPCache,
  42. paramsFunc ParamsFunc,
  43. endpoint EndpointProvider,
  44. ) {
  45. switch publicMcp.Type {
  46. case model.PublicMCPTypeProxySSE:
  47. if err := handlePublicProxySSE(c, publicMcp, paramsFunc, endpoint); err != nil {
  48. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  49. return
  50. }
  51. case model.PublicMCPTypeProxyStreamable:
  52. if err := handlePublicProxyStreamableSSE(c, publicMcp, paramsFunc, endpoint); err != nil {
  53. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  54. return
  55. }
  56. case model.PublicMCPTypeOpenAPI:
  57. server, err := newOpenAPIMCPServer(publicMcp.OpenAPIConfig)
  58. if err != nil {
  59. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  60. return
  61. }
  62. handleSSEMCPServer(c, server, string(model.PublicMCPTypeOpenAPI), endpoint)
  63. case model.PublicMCPTypeEmbed:
  64. handleEmbedSSEMCP(c, publicMcp.ID, publicMcp.EmbedConfig, paramsFunc, endpoint)
  65. default:
  66. http.Error(c.Writer, "unknown mcp type", http.StatusBadRequest)
  67. }
  68. }
  69. // handlePublicProxySSE 处理公共代理SSE
  70. func handlePublicProxySSE(
  71. c *gin.Context,
  72. publicMcp *model.PublicMCPCache,
  73. paramsFunc ParamsFunc,
  74. endpoint EndpointProvider,
  75. ) error {
  76. client, err := createProxySSEClient(c, publicMcp, paramsFunc)
  77. if err != nil {
  78. return err
  79. }
  80. defer client.Close()
  81. handleSSEMCPServer(
  82. c,
  83. mcpservers.WrapMCPClient2Server(client),
  84. string(model.PublicMCPTypeProxySSE),
  85. endpoint,
  86. )
  87. return nil
  88. }
  89. // handlePublicProxyStreamableSSE 处理公共代理Streamable SSE
  90. func handlePublicProxyStreamableSSE(
  91. c *gin.Context,
  92. publicMcp *model.PublicMCPCache,
  93. paramsFunc ParamsFunc,
  94. endpoint EndpointProvider,
  95. ) error {
  96. client, err := createProxyStreamableClient(c, publicMcp, paramsFunc)
  97. if err != nil {
  98. return err
  99. }
  100. defer client.Close()
  101. handleSSEMCPServer(
  102. c,
  103. mcpservers.WrapMCPClient2Server(client),
  104. string(model.PublicMCPTypeProxyStreamable),
  105. endpoint,
  106. )
  107. return nil
  108. }
  109. // createProxySSEClient 创建代理SSE客户端
  110. func createProxySSEClient(
  111. c *gin.Context,
  112. publicMcp *model.PublicMCPCache,
  113. paramsFunc ParamsFunc,
  114. ) (transport.Interface, error) {
  115. url, headers, err := prepareProxyConfig(publicMcp, paramsFunc)
  116. if err != nil {
  117. return nil, err
  118. }
  119. client, err := transport.NewSSE(url, transport.WithHeaders(headers))
  120. if err != nil {
  121. return nil, err
  122. }
  123. if err := client.Start(c.Request.Context()); err != nil {
  124. return nil, err
  125. }
  126. return client, nil
  127. }
  128. // createProxyStreamableClient 创建代理Streamable客户端
  129. func createProxyStreamableClient(
  130. c *gin.Context,
  131. publicMcp *model.PublicMCPCache,
  132. paramsFunc ParamsFunc,
  133. ) (transport.Interface, error) {
  134. url, headers, err := prepareProxyConfig(publicMcp, paramsFunc)
  135. if err != nil {
  136. return nil, err
  137. }
  138. client, err := transport.NewStreamableHTTP(url, transport.WithHTTPHeaders(headers))
  139. if err != nil {
  140. return nil, err
  141. }
  142. if err := client.Start(c.Request.Context()); err != nil {
  143. return nil, err
  144. }
  145. return client, nil
  146. }
  147. // prepareProxyConfig 准备代理配置
  148. func prepareProxyConfig(
  149. publicMcp *model.PublicMCPCache,
  150. paramsFunc ParamsFunc,
  151. ) (string, map[string]string, error) {
  152. url, err := url.Parse(publicMcp.ProxyConfig.URL)
  153. if err != nil {
  154. return "", nil, fmt.Errorf("invalid proxy URL: %w", err)
  155. }
  156. headers := make(map[string]string)
  157. backendQuery := url.Query()
  158. if len(publicMcp.ProxyConfig.Reusing) > 0 {
  159. processor := NewReusingParamProcessor(publicMcp.ID, paramsFunc)
  160. if err := processor.ProcessProxyReusingParams(
  161. publicMcp.ProxyConfig.Reusing,
  162. headers,
  163. &backendQuery,
  164. ); err != nil {
  165. return "", nil, err
  166. }
  167. }
  168. maps.Copy(headers, publicMcp.ProxyConfig.Headers)
  169. url.RawQuery = backendQuery.Encode()
  170. return url.String(), headers, nil
  171. }
  172. // processProxyReusingParams handles the reusing parameters for MCP proxy
  173. func processProxyReusingParams(
  174. reusingParams map[string]model.PublicMCPProxyReusingParam,
  175. paramsFunc ParamsFunc,
  176. headers map[string]string,
  177. backendQuery *url.Values,
  178. ) error {
  179. if len(reusingParams) == 0 {
  180. return nil
  181. }
  182. params, err := paramsFunc.GetParams()
  183. if err != nil {
  184. return err
  185. }
  186. for k, v := range reusingParams {
  187. paramValue, ok := params[k]
  188. if !ok {
  189. if v.Required {
  190. return fmt.Errorf("required reusing parameter %s is missing", k)
  191. }
  192. continue
  193. }
  194. switch v.Type {
  195. case model.ParamTypeHeader:
  196. headers[k] = paramValue
  197. case model.ParamTypeQuery:
  198. backendQuery.Set(k, paramValue)
  199. case model.ParamTypeURL:
  200. return fmt.Errorf("URL parameter %s cannot be set via reusing", k)
  201. default:
  202. return fmt.Errorf("unknown param type: %s", v.Type)
  203. }
  204. }
  205. return nil
  206. }
  207. // PublicMCPStreamable godoc
  208. //
  209. // @Summary Public MCP Streamable Server
  210. // @Security ApiKeyAuth
  211. // @Router /mcp/public/{id} [get]
  212. // @Router /mcp/public/{id} [post]
  213. // @Router /mcp/public/{id} [delete]
  214. func PublicMCPStreamable(c *gin.Context) {
  215. mcpID := c.Param("id")
  216. publicMcp, err := model.CacheGetPublicMCP(mcpID)
  217. if err != nil {
  218. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  219. mcp.NewRequestId(nil),
  220. mcp.INVALID_REQUEST,
  221. err.Error(),
  222. ))
  223. return
  224. }
  225. if publicMcp.Status != model.PublicMCPStatusEnabled {
  226. c.JSON(http.StatusNotFound, mcpservers.CreateMCPErrorResponse(
  227. mcp.NewRequestId(nil),
  228. mcp.INVALID_REQUEST,
  229. "mcp is not enabled",
  230. ))
  231. return
  232. }
  233. group := middleware.GetGroup(c)
  234. paramsFunc := newGroupParams(publicMcp.ID, group.ID)
  235. handlePublicStreamable(c, publicMcp, paramsFunc)
  236. }
  237. func handlePublicStreamable(
  238. c *gin.Context,
  239. publicMcp *model.PublicMCPCache,
  240. paramsFunc ParamsFunc,
  241. ) {
  242. switch publicMcp.Type {
  243. case model.PublicMCPTypeProxySSE:
  244. client, err := createProxySSEClient(c, publicMcp, paramsFunc)
  245. if err != nil {
  246. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  247. return
  248. }
  249. defer client.Close()
  250. mcpproxy.NewStatelessStreamableHTTPServer(
  251. mcpservers.WrapMCPClient2Server(client),
  252. ).ServeHTTP(c.Writer, c.Request)
  253. case model.PublicMCPTypeProxyStreamable:
  254. handlePublicProxyStreamable(c, paramsFunc, publicMcp.ProxyConfig)
  255. case model.PublicMCPTypeOpenAPI:
  256. server, err := newOpenAPIMCPServer(publicMcp.OpenAPIConfig)
  257. if err != nil {
  258. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  259. mcp.NewRequestId(nil),
  260. mcp.INVALID_REQUEST,
  261. err.Error(),
  262. ))
  263. return
  264. }
  265. handleStreamableMCPServer(c, server)
  266. case model.PublicMCPTypeEmbed:
  267. handlePublicEmbedStreamable(c, publicMcp.ID, paramsFunc, publicMcp.EmbedConfig)
  268. default:
  269. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  270. mcp.NewRequestId(nil),
  271. mcp.INVALID_REQUEST,
  272. "unknown mcp type",
  273. ))
  274. }
  275. }
  276. func handlePublicEmbedStreamable(
  277. c *gin.Context,
  278. mcpID string,
  279. paramsFunc ParamsFunc,
  280. config *model.MCPEmbeddingConfig,
  281. ) {
  282. var reusingConfig map[string]string
  283. if len(config.Reusing) != 0 {
  284. params, err := paramsFunc.GetParams()
  285. if err != nil {
  286. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  287. mcp.NewRequestId(nil),
  288. mcp.INVALID_REQUEST,
  289. err.Error(),
  290. ))
  291. return
  292. }
  293. reusingConfig = params
  294. }
  295. server, err := mcpservers.GetMCPServer(mcpID, config.Init, reusingConfig)
  296. if err != nil {
  297. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  298. mcp.NewRequestId(nil),
  299. mcp.INVALID_REQUEST,
  300. err.Error(),
  301. ))
  302. return
  303. }
  304. handleStreamableMCPServer(c, server)
  305. }
  306. // handlePublicProxyStreamable processes Streamable proxy requests
  307. func handlePublicProxyStreamable(
  308. c *gin.Context,
  309. paramsFunc ParamsFunc,
  310. config *model.PublicMCPProxyConfig,
  311. ) {
  312. if config == nil || config.URL == "" {
  313. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  314. mcp.NewRequestId(nil),
  315. mcp.INVALID_REQUEST,
  316. "invalid proxy configuration",
  317. ))
  318. return
  319. }
  320. backendURL, err := url.Parse(config.URL)
  321. if err != nil {
  322. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  323. mcp.NewRequestId(nil),
  324. mcp.INVALID_REQUEST,
  325. err.Error(),
  326. ))
  327. return
  328. }
  329. headers := make(map[string]string)
  330. backendQuery := backendURL.Query()
  331. // Process reusing parameters if any
  332. if err := processProxyReusingParams(config.Reusing, paramsFunc, headers, &backendQuery); err != nil {
  333. c.JSON(http.StatusBadRequest, mcpservers.CreateMCPErrorResponse(
  334. mcp.NewRequestId(nil),
  335. mcp.INVALID_REQUEST,
  336. err.Error(),
  337. ))
  338. return
  339. }
  340. maps.Copy(headers, config.Headers)
  341. for k, v := range config.Querys {
  342. backendQuery.Set(k, v)
  343. }
  344. backendURL.RawQuery = backendQuery.Encode()
  345. mcpproxy.NewStreamableProxy(backendURL.String(), headers, getStore()).
  346. ServeHTTP(c.Writer, c.Request)
  347. }
  348. // TestPublicMCPSSEServer godoc
  349. //
  350. // @Summary Test Public MCP SSE Server
  351. // @Security ApiKeyAuth
  352. // @Param group path string true "Group ID"
  353. // @Param id path string true "MCP ID"
  354. // @Router /api/test-publicmcp/{group}/{id}/sse [get]
  355. func TestPublicMCPSSEServer(c *gin.Context) {
  356. mcpID := c.Param("id")
  357. if mcpID == "" {
  358. http.Error(c.Writer, "mcp id is required", http.StatusBadRequest)
  359. return
  360. }
  361. groupID := c.Param("group")
  362. if groupID == "" {
  363. http.Error(c.Writer, "group id is required", http.StatusBadRequest)
  364. return
  365. }
  366. publicMcp, err := model.CacheGetPublicMCP(mcpID)
  367. if err != nil {
  368. http.Error(c.Writer, err.Error(), http.StatusBadRequest)
  369. return
  370. }
  371. if publicMcp.Status != model.PublicMCPStatusEnabled {
  372. http.Error(c.Writer, "mcp is not enabled", http.StatusBadRequest)
  373. return
  374. }
  375. paramsFunc := newGroupParams(publicMcp.ID, groupID)
  376. handlePublicSSEMCP(c, publicMcp, paramsFunc, sseEndpoint)
  377. }