stateless-streamable.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package mcpproxy
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "github.com/bytedance/sonic"
  7. "github.com/labring/aiproxy/core/common"
  8. mcpservers "github.com/labring/aiproxy/mcp-servers"
  9. "github.com/mark3labs/mcp-go/mcp"
  10. )
  11. type StreamableHTTPOption func(*StreamableHTTPServer)
  12. type StreamableHTTPServer struct {
  13. server mcpservers.Server
  14. }
  15. // NewStatelessStreamableHTTPServer creates a new streamable-http server instance
  16. func NewStatelessStreamableHTTPServer(
  17. server mcpservers.Server,
  18. opts ...StreamableHTTPOption,
  19. ) *StreamableHTTPServer {
  20. s := &StreamableHTTPServer{
  21. server: server,
  22. }
  23. for _, opt := range opts {
  24. opt(s)
  25. }
  26. return s
  27. }
  28. // ServeHTTP implements the http.Handler interface.
  29. func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  30. switch r.Method {
  31. case http.MethodPost:
  32. s.handlePost(w, r)
  33. case http.MethodGet:
  34. s.handleGet(w, r)
  35. case http.MethodDelete:
  36. s.handleDelete(w, r)
  37. default:
  38. http.NotFound(w, r)
  39. }
  40. }
  41. func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
  42. // post request carry request/notification message
  43. // Check content type
  44. contentType := r.Header.Get("Content-Type")
  45. if !common.IsJSONContentType(contentType) {
  46. http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest)
  47. return
  48. }
  49. // Check the request body is valid json, meanwhile, get the request Method
  50. rawData, err := common.GetRequestBody(r)
  51. if err != nil {
  52. s.writeJSONRPCError(
  53. w,
  54. nil,
  55. mcp.PARSE_ERROR,
  56. fmt.Sprintf("read request body error: %v", err),
  57. )
  58. return
  59. }
  60. var baseMessage struct {
  61. Method mcp.MCPMethod `json:"method"`
  62. }
  63. if err := sonic.Unmarshal(rawData, &baseMessage); err != nil {
  64. s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json")
  65. return
  66. }
  67. // Process message through MCPServer
  68. response := s.server.HandleMessage(r.Context(), rawData)
  69. if response == nil {
  70. // For notifications, just send 202 Accepted with no body
  71. w.WriteHeader(http.StatusAccepted)
  72. return
  73. }
  74. jsonBody, err := sonic.Marshal(response)
  75. if err != nil {
  76. s.writeJSONRPCError(
  77. w,
  78. nil,
  79. mcp.INTERNAL_ERROR,
  80. fmt.Sprintf("marshal response body error: %v", err),
  81. )
  82. return
  83. }
  84. w.Header().Set("Content-Type", "application/json")
  85. w.Header().Set("Content-Length", strconv.Itoa(len(jsonBody)))
  86. w.WriteHeader(http.StatusOK)
  87. _, _ = w.Write(jsonBody)
  88. }
  89. func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, _ *http.Request) {
  90. http.Error(w, "get request is not supported", http.StatusMethodNotAllowed)
  91. }
  92. func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, _ *http.Request) {
  93. http.Error(w, "delete request is not supported", http.StatusMethodNotAllowed)
  94. }
  95. func (s *StreamableHTTPServer) writeJSONRPCError(
  96. w http.ResponseWriter,
  97. id any,
  98. code int,
  99. message string,
  100. ) {
  101. response := mcpservers.CreateMCPErrorResponse(id, code, message)
  102. jsonBody, err := sonic.Marshal(response)
  103. if err != nil {
  104. http.Error(w, err.Error(), http.StatusInternalServerError)
  105. return
  106. }
  107. w.Header().Set("Content-Type", "application/json")
  108. w.Header().Set("Content-Length", strconv.Itoa(len(jsonBody)))
  109. w.WriteHeader(http.StatusBadRequest)
  110. _, _ = w.Write(jsonBody)
  111. }