http.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package tsconsensus
  4. import (
  5. "bytes"
  6. "context"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log"
  12. "net/http"
  13. "time"
  14. "tailscale.com/util/httpm"
  15. )
  16. type joinRequest struct {
  17. RemoteHost string
  18. RemoteID string
  19. }
  20. type commandClient struct {
  21. port uint16
  22. httpClient *http.Client
  23. }
  24. func (rac *commandClient) url(host string, path string) string {
  25. return fmt.Sprintf("http://%s:%d%s", host, rac.port, path)
  26. }
  27. const maxBodyBytes = 1024 * 1024
  28. func readAllMaxBytes(r io.Reader) ([]byte, error) {
  29. return io.ReadAll(io.LimitReader(r, maxBodyBytes+1))
  30. }
  31. func (rac *commandClient) join(host string, jr joinRequest) error {
  32. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
  33. defer cancel()
  34. rBs, err := json.Marshal(jr)
  35. if err != nil {
  36. return err
  37. }
  38. url := rac.url(host, "/join")
  39. req, err := http.NewRequestWithContext(ctx, httpm.POST, url, bytes.NewReader(rBs))
  40. if err != nil {
  41. return err
  42. }
  43. resp, err := rac.httpClient.Do(req)
  44. if err != nil {
  45. return err
  46. }
  47. defer resp.Body.Close()
  48. if resp.StatusCode != 200 {
  49. respBs, err := readAllMaxBytes(resp.Body)
  50. if err != nil {
  51. return err
  52. }
  53. return fmt.Errorf("remote responded %d: %s", resp.StatusCode, string(respBs))
  54. }
  55. return nil
  56. }
  57. func (rac *commandClient) executeCommand(host string, bs []byte) (CommandResult, error) {
  58. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
  59. defer cancel()
  60. url := rac.url(host, "/executeCommand")
  61. req, err := http.NewRequestWithContext(ctx, httpm.POST, url, bytes.NewReader(bs))
  62. if err != nil {
  63. return CommandResult{}, err
  64. }
  65. resp, err := rac.httpClient.Do(req)
  66. if err != nil {
  67. return CommandResult{}, err
  68. }
  69. defer resp.Body.Close()
  70. respBs, err := readAllMaxBytes(resp.Body)
  71. if err != nil {
  72. return CommandResult{}, err
  73. }
  74. if resp.StatusCode != 200 {
  75. return CommandResult{}, fmt.Errorf("remote responded %d: %s", resp.StatusCode, string(respBs))
  76. }
  77. var cr CommandResult
  78. if err = json.Unmarshal(respBs, &cr); err != nil {
  79. return CommandResult{}, err
  80. }
  81. return cr, nil
  82. }
  83. type authedHandler struct {
  84. auth *authorization
  85. handler http.Handler
  86. }
  87. func (h authedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  88. err := h.auth.Refresh(r.Context())
  89. if err != nil {
  90. log.Printf("error authedHandler ServeHTTP refresh auth: %v", err)
  91. http.Error(w, "", http.StatusInternalServerError)
  92. return
  93. }
  94. a, err := addrFromServerAddress(r.RemoteAddr)
  95. if err != nil {
  96. log.Printf("error authedHandler ServeHTTP refresh auth: %v", err)
  97. http.Error(w, "", http.StatusInternalServerError)
  98. return
  99. }
  100. allowed := h.auth.AllowsHost(a)
  101. if !allowed {
  102. http.Error(w, "peer not allowed", http.StatusForbidden)
  103. return
  104. }
  105. h.handler.ServeHTTP(w, r)
  106. }
  107. func (c *Consensus) handleJoinHTTP(w http.ResponseWriter, r *http.Request) {
  108. defer r.Body.Close()
  109. decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxBodyBytes+1))
  110. var jr joinRequest
  111. err := decoder.Decode(&jr)
  112. if err != nil {
  113. http.Error(w, err.Error(), http.StatusBadRequest)
  114. return
  115. }
  116. _, err = decoder.Token()
  117. if !errors.Is(err, io.EOF) {
  118. http.Error(w, "Request body must only contain a single JSON object", http.StatusBadRequest)
  119. return
  120. }
  121. if jr.RemoteHost == "" {
  122. http.Error(w, "Required: remoteAddr", http.StatusBadRequest)
  123. return
  124. }
  125. if jr.RemoteID == "" {
  126. http.Error(w, "Required: remoteID", http.StatusBadRequest)
  127. return
  128. }
  129. err = c.handleJoin(jr)
  130. if err != nil {
  131. log.Printf("join handler error: %v", err)
  132. http.Error(w, "", http.StatusInternalServerError)
  133. return
  134. }
  135. }
  136. func (c *Consensus) handleExecuteCommandHTTP(w http.ResponseWriter, r *http.Request) {
  137. defer r.Body.Close()
  138. decoder := json.NewDecoder(r.Body)
  139. var cmd Command
  140. err := decoder.Decode(&cmd)
  141. if err != nil {
  142. http.Error(w, err.Error(), http.StatusInternalServerError)
  143. return
  144. }
  145. result, err := c.executeCommandLocally(cmd)
  146. if err != nil {
  147. http.Error(w, err.Error(), http.StatusInternalServerError)
  148. return
  149. }
  150. if err := json.NewEncoder(w).Encode(result); err != nil {
  151. log.Printf("error encoding execute command result: %v", err)
  152. return
  153. }
  154. }
  155. func (c *Consensus) makeCommandMux() *http.ServeMux {
  156. mux := http.NewServeMux()
  157. mux.HandleFunc("POST /join", c.handleJoinHTTP)
  158. mux.HandleFunc("POST /executeCommand", c.handleExecuteCommandHTTP)
  159. return mux
  160. }
  161. func (c *Consensus) makeCommandHandler(auth *authorization) http.Handler {
  162. return authedHandler{
  163. handler: c.makeCommandMux(),
  164. auth: auth,
  165. }
  166. }