server.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "log/slog"
  6. "net"
  7. "net/http"
  8. "net/url"
  9. "os/user"
  10. "runtime"
  11. "strings"
  12. "github.com/charmbracelet/crush/internal/backend"
  13. "github.com/charmbracelet/crush/internal/config"
  14. _ "github.com/charmbracelet/crush/internal/swagger"
  15. httpswagger "github.com/swaggo/http-swagger/v2"
  16. )
  17. // ErrServerClosed is returned when the server is closed.
  18. var ErrServerClosed = http.ErrServerClosed
  19. // ParseHostURL parses a host URL into a [url.URL].
  20. func ParseHostURL(host string) (*url.URL, error) {
  21. proto, addr, ok := strings.Cut(host, "://")
  22. if !ok {
  23. return nil, fmt.Errorf("invalid host format: %s", host)
  24. }
  25. var basePath string
  26. if proto == "tcp" {
  27. parsed, err := url.Parse("tcp://" + addr)
  28. if err != nil {
  29. return nil, fmt.Errorf("invalid tcp address: %v", err)
  30. }
  31. addr = parsed.Host
  32. basePath = parsed.Path
  33. }
  34. return &url.URL{
  35. Scheme: proto,
  36. Host: addr,
  37. Path: basePath,
  38. }, nil
  39. }
  40. // DefaultHost returns the default server host.
  41. func DefaultHost() string {
  42. sock := "crush.sock"
  43. usr, err := user.Current()
  44. if err == nil && usr.Uid != "" {
  45. sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
  46. }
  47. if runtime.GOOS == "windows" {
  48. return fmt.Sprintf("npipe:////./pipe/%s", sock)
  49. }
  50. return fmt.Sprintf("unix:///tmp/%s", sock)
  51. }
  52. // Server represents a Crush server bound to a specific address.
  53. type Server struct {
  54. // Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
  55. Addr string
  56. network string
  57. h *http.Server
  58. ln net.Listener
  59. backend *backend.Backend
  60. logger *slog.Logger
  61. }
  62. // SetLogger sets the logger for the server.
  63. func (s *Server) SetLogger(logger *slog.Logger) {
  64. s.logger = logger
  65. }
  66. // DefaultServer returns a new [Server] with the default address.
  67. func DefaultServer(cfg *config.ConfigStore) *Server {
  68. hostURL, err := ParseHostURL(DefaultHost())
  69. if err != nil {
  70. panic("invalid default host")
  71. }
  72. return NewServer(cfg, hostURL.Scheme, hostURL.Host)
  73. }
  74. // NewServer creates a new [Server] with the given network and address.
  75. func NewServer(cfg *config.ConfigStore, network, address string) *Server {
  76. s := new(Server)
  77. s.Addr = address
  78. s.network = network
  79. // The backend is created with a shutdown callback that triggers
  80. // a graceful server shutdown (e.g. when the last workspace is
  81. // removed).
  82. s.backend = backend.New(context.Background(), cfg, func() {
  83. go func() {
  84. slog.Info("Shutting down server...")
  85. if err := s.Shutdown(context.Background()); err != nil {
  86. slog.Error("Failed to shutdown server", "error", err)
  87. }
  88. }()
  89. })
  90. var p http.Protocols
  91. p.SetHTTP1(true)
  92. p.SetUnencryptedHTTP2(true)
  93. c := &controllerV1{backend: s.backend, server: s}
  94. mux := http.NewServeMux()
  95. mux.HandleFunc("GET /v1/health", c.handleGetHealth)
  96. mux.HandleFunc("GET /v1/version", c.handleGetVersion)
  97. mux.HandleFunc("GET /v1/config", c.handleGetConfig)
  98. mux.HandleFunc("POST /v1/control", c.handlePostControl)
  99. mux.HandleFunc("GET /v1/workspaces", c.handleGetWorkspaces)
  100. mux.HandleFunc("POST /v1/workspaces", c.handlePostWorkspaces)
  101. mux.HandleFunc("DELETE /v1/workspaces/{id}", c.handleDeleteWorkspaces)
  102. mux.HandleFunc("GET /v1/workspaces/{id}", c.handleGetWorkspace)
  103. mux.HandleFunc("GET /v1/workspaces/{id}/config", c.handleGetWorkspaceConfig)
  104. mux.HandleFunc("GET /v1/workspaces/{id}/events", c.handleGetWorkspaceEvents)
  105. mux.HandleFunc("GET /v1/workspaces/{id}/providers", c.handleGetWorkspaceProviders)
  106. mux.HandleFunc("GET /v1/workspaces/{id}/sessions", c.handleGetWorkspaceSessions)
  107. mux.HandleFunc("POST /v1/workspaces/{id}/sessions", c.handlePostWorkspaceSessions)
  108. mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession)
  109. mux.HandleFunc("PUT /v1/workspaces/{id}/sessions/{sid}", c.handlePutWorkspaceSession)
  110. mux.HandleFunc("DELETE /v1/workspaces/{id}/sessions/{sid}", c.handleDeleteWorkspaceSession)
  111. mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory)
  112. mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages)
  113. mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages/user", c.handleGetWorkspaceSessionUserMessages)
  114. mux.HandleFunc("GET /v1/workspaces/{id}/messages/user", c.handleGetWorkspaceAllUserMessages)
  115. mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/filetracker/files", c.handleGetWorkspaceSessionFileTrackerFiles)
  116. mux.HandleFunc("POST /v1/workspaces/{id}/filetracker/read", c.handlePostWorkspaceFileTrackerRead)
  117. mux.HandleFunc("GET /v1/workspaces/{id}/filetracker/lastread", c.handleGetWorkspaceFileTrackerLastRead)
  118. mux.HandleFunc("GET /v1/workspaces/{id}/lsps", c.handleGetWorkspaceLSPs)
  119. mux.HandleFunc("GET /v1/workspaces/{id}/lsps/{lsp}/diagnostics", c.handleGetWorkspaceLSPDiagnostics)
  120. mux.HandleFunc("POST /v1/workspaces/{id}/lsps/start", c.handlePostWorkspaceLSPStart)
  121. mux.HandleFunc("POST /v1/workspaces/{id}/lsps/stop", c.handlePostWorkspaceLSPStopAll)
  122. mux.HandleFunc("GET /v1/workspaces/{id}/permissions/skip", c.handleGetWorkspacePermissionsSkip)
  123. mux.HandleFunc("POST /v1/workspaces/{id}/permissions/skip", c.handlePostWorkspacePermissionsSkip)
  124. mux.HandleFunc("POST /v1/workspaces/{id}/permissions/grant", c.handlePostWorkspacePermissionsGrant)
  125. mux.HandleFunc("GET /v1/workspaces/{id}/agent", c.handleGetWorkspaceAgent)
  126. mux.HandleFunc("POST /v1/workspaces/{id}/agent", c.handlePostWorkspaceAgent)
  127. mux.HandleFunc("POST /v1/workspaces/{id}/agent/init", c.handlePostWorkspaceAgentInit)
  128. mux.HandleFunc("POST /v1/workspaces/{id}/agent/update", c.handlePostWorkspaceAgentUpdate)
  129. mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}", c.handleGetWorkspaceAgentSession)
  130. mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/cancel", c.handlePostWorkspaceAgentSessionCancel)
  131. mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetWorkspaceAgentSessionPromptQueued)
  132. mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/list", c.handleGetWorkspaceAgentSessionPromptList)
  133. mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostWorkspaceAgentSessionPromptClear)
  134. mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/summarize", c.handlePostWorkspaceAgentSessionSummarize)
  135. mux.HandleFunc("GET /v1/workspaces/{id}/agent/default-small-model", c.handleGetWorkspaceAgentDefaultSmallModel)
  136. mux.HandleFunc("POST /v1/workspaces/{id}/config/set", c.handlePostWorkspaceConfigSet)
  137. mux.HandleFunc("POST /v1/workspaces/{id}/config/remove", c.handlePostWorkspaceConfigRemove)
  138. mux.HandleFunc("POST /v1/workspaces/{id}/config/model", c.handlePostWorkspaceConfigModel)
  139. mux.HandleFunc("POST /v1/workspaces/{id}/config/compact", c.handlePostWorkspaceConfigCompact)
  140. mux.HandleFunc("POST /v1/workspaces/{id}/config/provider-key", c.handlePostWorkspaceConfigProviderKey)
  141. mux.HandleFunc("POST /v1/workspaces/{id}/config/import-copilot", c.handlePostWorkspaceConfigImportCopilot)
  142. mux.HandleFunc("POST /v1/workspaces/{id}/config/refresh-oauth", c.handlePostWorkspaceConfigRefreshOAuth)
  143. mux.HandleFunc("GET /v1/workspaces/{id}/project/needs-init", c.handleGetWorkspaceProjectNeedsInit)
  144. mux.HandleFunc("POST /v1/workspaces/{id}/project/init", c.handlePostWorkspaceProjectInit)
  145. mux.HandleFunc("GET /v1/workspaces/{id}/project/init-prompt", c.handleGetWorkspaceProjectInitPrompt)
  146. mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-tools", c.handlePostWorkspaceMCPRefreshTools)
  147. mux.HandleFunc("POST /v1/workspaces/{id}/mcp/read-resource", c.handlePostWorkspaceMCPReadResource)
  148. mux.HandleFunc("POST /v1/workspaces/{id}/mcp/get-prompt", c.handlePostWorkspaceMCPGetPrompt)
  149. mux.HandleFunc("GET /v1/workspaces/{id}/mcp/states", c.handleGetWorkspaceMCPStates)
  150. mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-prompts", c.handlePostWorkspaceMCPRefreshPrompts)
  151. mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-resources", c.handlePostWorkspaceMCPRefreshResources)
  152. mux.HandleFunc("POST /v1/workspaces/{id}/mcp/docker/enable", c.handlePostWorkspaceMCPEnableDocker)
  153. mux.HandleFunc("POST /v1/workspaces/{id}/mcp/docker/disable", c.handlePostWorkspaceMCPDisableDocker)
  154. mux.Handle("/v1/docs/", httpswagger.WrapHandler)
  155. s.h = &http.Server{
  156. Protocols: &p,
  157. Handler: s.loggingHandler(mux),
  158. }
  159. if network == "tcp" {
  160. s.h.Addr = address
  161. }
  162. return s
  163. }
  164. // Serve accepts incoming connections on the listener.
  165. func (s *Server) Serve(ln net.Listener) error {
  166. return s.h.Serve(ln)
  167. }
  168. // ListenAndServe starts the server and begins accepting connections.
  169. func (s *Server) ListenAndServe() error {
  170. if s.ln != nil {
  171. return fmt.Errorf("server already started")
  172. }
  173. ln, err := listen(s.network, s.Addr)
  174. if err != nil {
  175. return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
  176. }
  177. return s.Serve(ln)
  178. }
  179. func (s *Server) closeListener() {
  180. if s.ln != nil {
  181. s.ln.Close()
  182. s.ln = nil
  183. }
  184. }
  185. // Close force closes all listeners and connections.
  186. func (s *Server) Close() error {
  187. defer func() { s.closeListener() }()
  188. return s.h.Close()
  189. }
  190. // Shutdown gracefully shuts down the server without interrupting active
  191. // connections.
  192. func (s *Server) Shutdown(ctx context.Context) error {
  193. defer func() { s.closeListener() }()
  194. return s.h.Shutdown(ctx)
  195. }
  196. func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
  197. if s.logger != nil {
  198. s.logger.With(
  199. slog.String("method", r.Method),
  200. slog.String("url", r.URL.String()),
  201. slog.String("remote_addr", r.RemoteAddr),
  202. ).Debug(msg, args...)
  203. }
  204. }
  205. func (s *Server) logError(r *http.Request, msg string, args ...any) {
  206. if s.logger != nil {
  207. s.logger.With(
  208. slog.String("method", r.Method),
  209. slog.String("url", r.URL.String()),
  210. slog.String("remote_addr", r.RemoteAddr),
  211. ).Error(msg, args...)
  212. }
  213. }