client.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package client
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net"
  8. "net/http"
  9. "net/url"
  10. stdpath "path"
  11. "path/filepath"
  12. "time"
  13. "github.com/charmbracelet/crush/internal/config"
  14. "github.com/charmbracelet/crush/internal/proto"
  15. "github.com/charmbracelet/crush/internal/server"
  16. )
  17. // DummyHost is used to satisfy the http.Client's requirement for a URL.
  18. const DummyHost = "api.crush.localhost"
  19. // Client represents an RPC client connected to a Crush server.
  20. type Client struct {
  21. h *http.Client
  22. path string
  23. network string
  24. addr string
  25. }
  26. // DefaultClient creates a new [Client] connected to the default server address.
  27. func DefaultClient(path string) (*Client, error) {
  28. host, err := server.ParseHostURL(server.DefaultHost())
  29. if err != nil {
  30. return nil, err
  31. }
  32. return NewClient(path, host.Scheme, host.Host)
  33. }
  34. // NewClient creates a new [Client] connected to the server at the given
  35. // network and address.
  36. func NewClient(path, network, address string) (*Client, error) {
  37. c := new(Client)
  38. c.path = filepath.Clean(path)
  39. c.network = network
  40. c.addr = address
  41. p := &http.Protocols{}
  42. p.SetHTTP1(true)
  43. p.SetUnencryptedHTTP2(true)
  44. tr := http.DefaultTransport.(*http.Transport).Clone()
  45. tr.Protocols = p
  46. tr.DialContext = c.dialer
  47. if c.network == "npipe" || c.network == "unix" {
  48. tr.DisableCompression = true
  49. }
  50. c.h = &http.Client{
  51. Transport: tr,
  52. Timeout: 0,
  53. }
  54. return c, nil
  55. }
  56. // Path returns the client's workspace filesystem path.
  57. func (c *Client) Path() string {
  58. return c.path
  59. }
  60. // GetGlobalConfig retrieves the server's configuration.
  61. func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) {
  62. var cfg config.Config
  63. rsp, err := c.get(ctx, "/config", nil, nil)
  64. if err != nil {
  65. return nil, err
  66. }
  67. defer rsp.Body.Close()
  68. if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
  69. return nil, err
  70. }
  71. return &cfg, nil
  72. }
  73. // Health checks the server's health status.
  74. func (c *Client) Health(ctx context.Context) error {
  75. rsp, err := c.get(ctx, "/health", nil, nil)
  76. if err != nil {
  77. return err
  78. }
  79. defer rsp.Body.Close()
  80. if rsp.StatusCode != http.StatusOK {
  81. return fmt.Errorf("server health check failed: %s", rsp.Status)
  82. }
  83. return nil
  84. }
  85. // VersionInfo retrieves the server's version information.
  86. func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) {
  87. var vi proto.VersionInfo
  88. rsp, err := c.get(ctx, "version", nil, nil)
  89. if err != nil {
  90. return nil, err
  91. }
  92. defer rsp.Body.Close()
  93. if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil {
  94. return nil, err
  95. }
  96. return &vi, nil
  97. }
  98. // ShutdownServer sends a shutdown request to the server.
  99. func (c *Client) ShutdownServer(ctx context.Context) error {
  100. rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{
  101. Command: "shutdown",
  102. }), nil)
  103. if err != nil {
  104. return err
  105. }
  106. defer rsp.Body.Close()
  107. if rsp.StatusCode != http.StatusOK {
  108. return fmt.Errorf("server shutdown failed: %s", rsp.Status)
  109. }
  110. return nil
  111. }
  112. func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
  113. d := net.Dialer{
  114. Timeout: 30 * time.Second,
  115. KeepAlive: 30 * time.Second,
  116. }
  117. switch c.network {
  118. case "npipe":
  119. ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
  120. defer cancel()
  121. return dialPipeContext(ctx, c.addr)
  122. case "unix":
  123. return d.DialContext(ctx, "unix", c.addr)
  124. default:
  125. return d.DialContext(ctx, network, address)
  126. }
  127. }
  128. func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
  129. return c.sendReq(ctx, http.MethodGet, path, query, nil, headers)
  130. }
  131. func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
  132. return c.sendReq(ctx, http.MethodPost, path, query, body, headers)
  133. }
  134. func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
  135. return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers)
  136. }
  137. func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
  138. return c.sendReq(ctx, http.MethodPut, path, query, body, headers)
  139. }
  140. func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
  141. url := (&url.URL{
  142. Path: stdpath.Join("/v1", path),
  143. RawQuery: query.Encode(),
  144. }).String()
  145. req, err := c.buildReq(ctx, method, url, body, headers)
  146. if err != nil {
  147. return nil, err
  148. }
  149. rsp, err := c.h.Do(req)
  150. if err != nil {
  151. return nil, err
  152. }
  153. return rsp, nil
  154. }
  155. func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) {
  156. r, err := http.NewRequestWithContext(ctx, method, url, body)
  157. if err != nil {
  158. return nil, err
  159. }
  160. for k, v := range headers {
  161. r.Header[http.CanonicalHeaderKey(k)] = v
  162. }
  163. r.URL.Scheme = "http"
  164. r.URL.Host = c.addr
  165. if c.network == "npipe" || c.network == "unix" {
  166. r.Host = DummyHost
  167. }
  168. if body != nil && r.Header.Get("Content-Type") == "" {
  169. r.Header.Set("Content-Type", "text/plain")
  170. }
  171. return r, nil
  172. }