stateless-streamable.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package mcpproxy
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "strconv"
  7. "github.com/bytedance/sonic"
  8. mcpservers "github.com/labring/aiproxy/mcp-servers"
  9. "github.com/mark3labs/mcp-go/mcp"
  10. )
  11. type StreamableHTTPOption func(*StreamableHTTPServer)
  12. type StreamableHTTPServer struct {
  13. server mcpservers.Server
  14. }
  15. // NewStatelessStreamableHTTPServer creates a new streamable-http server instance
  16. func NewStatelessStreamableHTTPServer(
  17. server mcpservers.Server,
  18. opts ...StreamableHTTPOption,
  19. ) *StreamableHTTPServer {
  20. s := &StreamableHTTPServer{
  21. server: server,
  22. }
  23. for _, opt := range opts {
  24. opt(s)
  25. }
  26. return s
  27. }
  28. // ServeHTTP implements the http.Handler interface.
  29. func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  30. switch r.Method {
  31. case http.MethodPost:
  32. s.handlePost(w, r)
  33. case http.MethodGet:
  34. s.handleGet(w, r)
  35. case http.MethodDelete:
  36. s.handleDelete(w, r)
  37. default:
  38. http.NotFound(w, r)
  39. }
  40. }
  41. func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
  42. // post request carry request/notification message
  43. // Check content type
  44. contentType := r.Header.Get("Content-Type")
  45. if contentType != "application/json" {
  46. http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest)
  47. return
  48. }
  49. // Check the request body is valid json, meanwhile, get the request Method
  50. rawData, err := io.ReadAll(r.Body)
  51. if err != nil {
  52. s.writeJSONRPCError(
  53. w,
  54. nil,
  55. mcp.PARSE_ERROR,
  56. fmt.Sprintf("read request body error: %v", err),
  57. )
  58. return
  59. }
  60. var baseMessage struct {
  61. Method mcp.MCPMethod `json:"method"`
  62. }
  63. if err := sonic.Unmarshal(rawData, &baseMessage); err != nil {
  64. s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json")
  65. return
  66. }
  67. // Process message through MCPServer
  68. response := s.server.HandleMessage(r.Context(), rawData)
  69. if response == nil {
  70. // For notifications, just send 202 Accepted with no body
  71. w.WriteHeader(http.StatusAccepted)
  72. return
  73. }
  74. w.Header().Set("Content-Type", "application/json")
  75. w.WriteHeader(http.StatusOK)
  76. _ = sonic.ConfigDefault.NewEncoder(w).Encode(response)
  77. }
  78. func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, _ *http.Request) {
  79. http.Error(w, "get request is not supported", http.StatusMethodNotAllowed)
  80. }
  81. func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, _ *http.Request) {
  82. http.Error(w, "delete request is not supported", http.StatusMethodNotAllowed)
  83. }
  84. func (s *StreamableHTTPServer) writeJSONRPCError(
  85. w http.ResponseWriter,
  86. id any,
  87. code int,
  88. message string,
  89. ) {
  90. response := mcpservers.CreateMCPErrorResponse(id, code, message)
  91. jsonBody, err := sonic.Marshal(response)
  92. if err != nil {
  93. http.Error(w, err.Error(), http.StatusInternalServerError)
  94. return
  95. }
  96. w.Header().Set("Content-Type", "application/json")
  97. w.Header().Set("Content-Length", strconv.Itoa(len(jsonBody)))
  98. w.WriteHeader(http.StatusBadRequest)
  99. _, _ = w.Write(jsonBody)
  100. }