mcpproxy.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/url"
  9. "runtime"
  10. "sync"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/uuid"
  14. "github.com/labring/aiproxy/core/common"
  15. "github.com/labring/aiproxy/core/common/mcpproxy"
  16. "github.com/labring/aiproxy/core/middleware"
  17. "github.com/labring/aiproxy/core/model"
  18. "github.com/labring/aiproxy/openapi-mcp/convert"
  19. "github.com/redis/go-redis/v9"
  20. )
  21. // mcpEndpointProvider implements the EndpointProvider interface for MCP
  22. type mcpEndpointProvider struct {
  23. key string
  24. t model.PublicMCPType
  25. }
  26. func newEndpoint(key string, t model.PublicMCPType) mcpproxy.EndpointProvider {
  27. return &mcpEndpointProvider{
  28. key: key,
  29. t: t,
  30. }
  31. }
  32. func (m *mcpEndpointProvider) NewEndpoint() (newSession string, newEndpoint string) {
  33. session := uuid.NewString()
  34. endpoint := fmt.Sprintf("/mcp/message?sessionId=%s&key=%s&type=%s", session, m.key, m.t)
  35. return session, endpoint
  36. }
  37. func (m *mcpEndpointProvider) LoadEndpoint(endpoint string) (session string) {
  38. parsedURL, err := url.Parse(endpoint)
  39. if err != nil {
  40. return ""
  41. }
  42. return parsedURL.Query().Get("sessionId")
  43. }
  44. // Global variables for session management
  45. var (
  46. memStore mcpproxy.SessionManager = mcpproxy.NewMemStore()
  47. redisStore mcpproxy.SessionManager
  48. redisStoreOnce = &sync.Once{}
  49. )
  50. func getStore() mcpproxy.SessionManager {
  51. if common.RedisEnabled {
  52. redisStoreOnce.Do(func() {
  53. redisStore = newRedisStoreManager(common.RDB)
  54. })
  55. return redisStore
  56. }
  57. return memStore
  58. }
  59. // Redis-based session manager
  60. type redisStoreManager struct {
  61. rdb *redis.Client
  62. }
  63. func newRedisStoreManager(rdb *redis.Client) mcpproxy.SessionManager {
  64. return &redisStoreManager{
  65. rdb: rdb,
  66. }
  67. }
  68. var redisStoreManagerScript = redis.NewScript(`
  69. local key = KEYS[1]
  70. local value = redis.call('GET', key)
  71. if not value then
  72. return nil
  73. end
  74. redis.call('EXPIRE', key, 300)
  75. return value
  76. `)
  77. func (r *redisStoreManager) Get(sessionID string) (string, bool) {
  78. ctx := context.Background()
  79. result, err := redisStoreManagerScript.Run(ctx, r.rdb, []string{"mcp:session:" + sessionID}).Result()
  80. if err != nil || result == nil {
  81. return "", false
  82. }
  83. return result.(string), true
  84. }
  85. func (r *redisStoreManager) Set(sessionID, endpoint string) {
  86. ctx := context.Background()
  87. r.rdb.Set(ctx, "mcp:session:"+sessionID, endpoint, time.Minute*5)
  88. }
  89. func (r *redisStoreManager) Delete(session string) {
  90. ctx := context.Background()
  91. r.rdb.Del(ctx, "mcp:session:"+session)
  92. }
  93. // MCPSseProxy godoc
  94. //
  95. // @Summary MCP SSE Proxy
  96. // @Router /mcp/public/{id}/sse [get]
  97. func MCPSseProxy(c *gin.Context) {
  98. mcpID := c.Param("id")
  99. publicMcp, err := model.GetPublicMCPByID(mcpID)
  100. if err != nil {
  101. middleware.AbortLogWithMessage(c, http.StatusBadRequest, err.Error())
  102. return
  103. }
  104. switch publicMcp.Type {
  105. case model.PublicMCPTypeProxySSE:
  106. handleProxySSE(c, publicMcp)
  107. case model.PublicMCPTypeOpenAPI:
  108. handleOpenAPI(c, publicMcp)
  109. default:
  110. middleware.AbortLogWithMessage(c, http.StatusBadRequest, "unknow mcp type")
  111. return
  112. }
  113. }
  114. // handleProxySSE processes SSE proxy requests
  115. func handleProxySSE(c *gin.Context, publicMcp *model.PublicMCP) {
  116. config := publicMcp.ProxySSEConfig
  117. if config == nil || config.URL == "" {
  118. return
  119. }
  120. backendURL, err := url.Parse(config.URL)
  121. if err != nil {
  122. middleware.AbortLogWithMessage(c, http.StatusBadRequest, err.Error())
  123. return
  124. }
  125. headers := make(map[string]string)
  126. backendQuery := &url.Values{}
  127. group := middleware.GetGroup(c)
  128. token := middleware.GetToken(c)
  129. // Process reusing parameters if any
  130. if err := processReusingParams(config.ReusingParams, publicMcp.ID, group.ID, headers, backendQuery); err != nil {
  131. middleware.AbortLogWithMessage(c, http.StatusBadRequest, err.Error())
  132. return
  133. }
  134. backendURL.RawQuery = backendQuery.Encode()
  135. mcpproxy.SSEHandler(
  136. c.Writer,
  137. c.Request,
  138. getStore(),
  139. newEndpoint(token.Key, publicMcp.Type),
  140. backendURL.String(),
  141. headers,
  142. )
  143. }
  144. // handleOpenAPI processes OpenAPI requests
  145. func handleOpenAPI(c *gin.Context, publicMcp *model.PublicMCP) {
  146. config := publicMcp.OpenAPIConfig
  147. if config == nil || (config.OpenAPISpec == "" && config.OpenAPIContent == "") {
  148. return
  149. }
  150. // Parse OpenAPI specification
  151. parser := convert.NewParser()
  152. var err error
  153. var openAPIFrom string
  154. if config.OpenAPISpec != "" {
  155. openAPIFrom, err = parseOpenAPIFromURL(config, parser)
  156. } else {
  157. err = parseOpenAPIFromContent(config, parser)
  158. }
  159. if err != nil {
  160. return
  161. }
  162. // Convert to MCP server
  163. converter := convert.NewConverter(parser, convert.Options{
  164. OpenAPIFrom: openAPIFrom,
  165. })
  166. s, err := converter.Convert()
  167. if err != nil {
  168. return
  169. }
  170. token := middleware.GetToken(c)
  171. // Setup SSE server
  172. newSession, newEndpoint := newEndpoint(token.Key, publicMcp.Type).NewEndpoint()
  173. store := getStore()
  174. store.Set(newSession, "openapi")
  175. defer func() {
  176. store.Delete(newSession)
  177. }()
  178. server := NewSSEServer(
  179. s,
  180. WithMessageEndpoint(newEndpoint),
  181. )
  182. ctx, cancel := context.WithCancel(c.Request.Context())
  183. defer cancel()
  184. // Start message processing goroutine
  185. go processOpenAPIMessages(ctx, newSession, server)
  186. // Handle SSE connection
  187. server.HandleSSE(c.Writer, c.Request)
  188. }
  189. // parseOpenAPIFromURL parses OpenAPI spec from a URL
  190. func parseOpenAPIFromURL(config *model.MCPOpenAPIConfig, parser *convert.Parser) (string, error) {
  191. spec, err := url.Parse(config.OpenAPISpec)
  192. if err != nil || (spec.Scheme != "http" && spec.Scheme != "https") {
  193. return "", errors.New("invalid OpenAPI spec URL")
  194. }
  195. openAPIFrom := spec.String()
  196. if config.V2 {
  197. err = parser.ParseFileV2(openAPIFrom)
  198. } else {
  199. err = parser.ParseFile(openAPIFrom)
  200. }
  201. return openAPIFrom, err
  202. }
  203. // parseOpenAPIFromContent parses OpenAPI spec from content string
  204. func parseOpenAPIFromContent(config *model.MCPOpenAPIConfig, parser *convert.Parser) error {
  205. if config.V2 {
  206. return parser.ParseV2([]byte(config.OpenAPIContent))
  207. }
  208. return parser.Parse([]byte(config.OpenAPIContent))
  209. }
  210. // processOpenAPIMessages handles message processing for OpenAPI
  211. func processOpenAPIMessages(ctx context.Context, sessionID string, server *SSEServer) {
  212. mpscInstance := getMpsc()
  213. for {
  214. select {
  215. case <-ctx.Done():
  216. return
  217. default:
  218. data, err := mpscInstance.recv(ctx, sessionID)
  219. if err != nil {
  220. return
  221. }
  222. if err := server.HandleMessage(data); err != nil {
  223. return
  224. }
  225. }
  226. }
  227. }
  228. // processReusingParams handles the reusing parameters for MCP proxy
  229. func processReusingParams(reusingParams map[string]model.ReusingParam, mcpID string, groupID string, headers map[string]string, backendQuery *url.Values) error {
  230. if len(reusingParams) == 0 {
  231. return nil
  232. }
  233. param, err := model.GetGroupPublicMCPReusingParam(mcpID, groupID)
  234. if err != nil {
  235. return err
  236. }
  237. for k, v := range reusingParams {
  238. paramValue, ok := param.ReusingParams[k]
  239. if !ok {
  240. if v.Required {
  241. return fmt.Errorf("%s required", k)
  242. }
  243. continue
  244. }
  245. switch v.Type {
  246. case model.ParamTypeHeader:
  247. headers[k] = paramValue
  248. case model.ParamTypeQuery:
  249. backendQuery.Set(k, paramValue)
  250. default:
  251. return errors.New("unknow param type")
  252. }
  253. }
  254. return nil
  255. }
  256. // MCPMessage godoc
  257. //
  258. // @Summary MCP SSE Proxy
  259. // @Router /mcp/message [post]
  260. func MCPMessage(c *gin.Context) {
  261. token := middleware.GetToken(c)
  262. mcpTypeStr, _ := c.GetQuery("type")
  263. if mcpTypeStr == "" {
  264. return
  265. }
  266. mcpType := model.PublicMCPType(mcpTypeStr)
  267. sessionID, _ := c.GetQuery("sessionId")
  268. if sessionID == "" {
  269. return
  270. }
  271. switch mcpType {
  272. case model.PublicMCPTypeProxySSE:
  273. mcpproxy.ProxyHandler(
  274. c.Writer,
  275. c.Request,
  276. getStore(),
  277. newEndpoint(token.Key, mcpType),
  278. )
  279. case model.PublicMCPTypeOpenAPI:
  280. backend, ok := getStore().Get(sessionID)
  281. if !ok || backend != "openapi" {
  282. return
  283. }
  284. mpscInstance := getMpsc()
  285. body, err := io.ReadAll(c.Request.Body)
  286. if err != nil {
  287. _ = c.AbortWithError(http.StatusInternalServerError, err)
  288. return
  289. }
  290. err = mpscInstance.send(c.Request.Context(), sessionID, body)
  291. if err != nil {
  292. _ = c.AbortWithError(http.StatusInternalServerError, err)
  293. return
  294. }
  295. c.Writer.WriteHeader(http.StatusAccepted)
  296. }
  297. }
  298. // Interface for multi-producer, single-consumer message passing
  299. type mpsc interface {
  300. recv(ctx context.Context, id string) ([]byte, error)
  301. send(ctx context.Context, id string, data []byte) error
  302. }
  303. // Global MPSC instances
  304. var (
  305. memMpsc mpsc = newChannelMpsc()
  306. redisMpsc mpsc
  307. redisMpscOnce = &sync.Once{}
  308. )
  309. func getMpsc() mpsc {
  310. if common.RedisEnabled {
  311. redisMpscOnce.Do(func() {
  312. redisMpsc = newRedisMPSC(common.RDB)
  313. })
  314. return redisMpsc
  315. }
  316. return memMpsc
  317. }
  318. // In-memory channel-based MPSC implementation
  319. type channelMpsc struct {
  320. channels map[string]chan []byte
  321. lastAccess map[string]time.Time
  322. channelMutex sync.RWMutex
  323. }
  324. // newChannelMpsc creates a new channel-based mpsc implementation
  325. func newChannelMpsc() *channelMpsc {
  326. c := &channelMpsc{
  327. channels: make(map[string]chan []byte),
  328. lastAccess: make(map[string]time.Time),
  329. }
  330. // Start a goroutine to clean up expired channels
  331. go c.cleanupExpiredChannels()
  332. return c
  333. }
  334. // cleanupExpiredChannels periodically checks for and removes channels that haven't been accessed in 5 minutes
  335. func (c *channelMpsc) cleanupExpiredChannels() {
  336. ticker := time.NewTicker(1 * time.Minute)
  337. defer ticker.Stop()
  338. for range ticker.C {
  339. c.channelMutex.Lock()
  340. now := time.Now()
  341. for id, lastAccess := range c.lastAccess {
  342. if now.Sub(lastAccess) > 5*time.Minute {
  343. // Close and delete the channel
  344. if ch, exists := c.channels[id]; exists {
  345. close(ch)
  346. delete(c.channels, id)
  347. }
  348. delete(c.lastAccess, id)
  349. }
  350. }
  351. c.channelMutex.Unlock()
  352. }
  353. }
  354. // getOrCreateChannel gets an existing channel or creates a new one for the session
  355. func (c *channelMpsc) getOrCreateChannel(id string) chan []byte {
  356. c.channelMutex.RLock()
  357. ch, exists := c.channels[id]
  358. c.channelMutex.RUnlock()
  359. if !exists {
  360. c.channelMutex.Lock()
  361. if ch, exists = c.channels[id]; !exists {
  362. ch = make(chan []byte, 10)
  363. c.channels[id] = ch
  364. }
  365. c.lastAccess[id] = time.Now()
  366. c.channelMutex.Unlock()
  367. } else {
  368. c.channelMutex.Lock()
  369. c.lastAccess[id] = time.Now()
  370. c.channelMutex.Unlock()
  371. }
  372. return ch
  373. }
  374. // recv receives data for the specified session
  375. func (c *channelMpsc) recv(ctx context.Context, id string) ([]byte, error) {
  376. ch := c.getOrCreateChannel(id)
  377. select {
  378. case data, ok := <-ch:
  379. if !ok {
  380. return nil, fmt.Errorf("channel closed for session %s", id)
  381. }
  382. return data, nil
  383. case <-ctx.Done():
  384. return nil, ctx.Err()
  385. }
  386. }
  387. // send sends data to the specified session
  388. func (c *channelMpsc) send(ctx context.Context, id string, data []byte) error {
  389. ch := c.getOrCreateChannel(id)
  390. select {
  391. case ch <- data:
  392. return nil
  393. case <-ctx.Done():
  394. return ctx.Err()
  395. default:
  396. return fmt.Errorf("channel buffer full for session %s", id)
  397. }
  398. }
  399. // Redis-based MPSC implementation
  400. type redisMPSC struct {
  401. rdb *redis.Client
  402. }
  403. // newRedisMPSC creates a new Redis MPSC instance
  404. func newRedisMPSC(rdb *redis.Client) *redisMPSC {
  405. return &redisMPSC{rdb: rdb}
  406. }
  407. func (r *redisMPSC) send(ctx context.Context, id string, data []byte) error {
  408. // Set expiration to 5 minutes when sending data
  409. pipe := r.rdb.Pipeline()
  410. pipe.LPush(ctx, id, data)
  411. pipe.Expire(ctx, id, 5*time.Minute)
  412. _, err := pipe.Exec(ctx)
  413. return err
  414. }
  415. func (r *redisMPSC) recv(ctx context.Context, id string) ([]byte, error) {
  416. for {
  417. select {
  418. case <-ctx.Done():
  419. return nil, ctx.Err()
  420. default:
  421. result, err := r.rdb.BRPop(ctx, time.Second, id).Result()
  422. if err != nil {
  423. if errors.Is(err, redis.Nil) {
  424. runtime.Gosched()
  425. continue
  426. }
  427. return nil, err
  428. }
  429. if len(result) != 2 {
  430. return nil, errors.New("invalid BRPop result")
  431. }
  432. return []byte(result[1]), nil
  433. }
  434. }
  435. }