server.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlhttp
  4. import (
  5. "context"
  6. "encoding/base64"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "net"
  11. "net/http"
  12. "strings"
  13. "time"
  14. "nhooyr.io/websocket"
  15. "tailscale.com/control/controlbase"
  16. "tailscale.com/net/netutil"
  17. "tailscale.com/net/wsconn"
  18. "tailscale.com/types/key"
  19. )
  20. // AcceptHTTP upgrades the HTTP request given by w and r into a Tailscale
  21. // control protocol base transport connection.
  22. //
  23. // AcceptHTTP always writes an HTTP response to w. The caller must not attempt
  24. // their own response after calling AcceptHTTP.
  25. //
  26. // earlyWrite optionally specifies a func to write to the noise connection
  27. // (encrypted). It receives the negotiated version and a writer to write to, if
  28. // desired.
  29. func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (*controlbase.Conn, error) {
  30. return acceptHTTP(ctx, w, r, private, earlyWrite)
  31. }
  32. func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (_ *controlbase.Conn, retErr error) {
  33. next := strings.ToLower(r.Header.Get("Upgrade"))
  34. if next == "" {
  35. http.Error(w, "missing next protocol", http.StatusBadRequest)
  36. return nil, errors.New("no next protocol in HTTP request")
  37. }
  38. if next == "websocket" {
  39. return acceptWebsocket(ctx, w, r, private)
  40. }
  41. if next != upgradeHeaderValue {
  42. http.Error(w, "unknown next protocol", http.StatusBadRequest)
  43. return nil, fmt.Errorf("client requested unhandled next protocol %q", next)
  44. }
  45. initB64 := r.Header.Get(handshakeHeaderName)
  46. if initB64 == "" {
  47. http.Error(w, "missing Tailscale handshake header", http.StatusBadRequest)
  48. return nil, errors.New("no tailscale handshake header in HTTP request")
  49. }
  50. init, err := base64.StdEncoding.DecodeString(initB64)
  51. if err != nil {
  52. http.Error(w, "invalid tailscale handshake header", http.StatusBadRequest)
  53. return nil, fmt.Errorf("decoding base64 handshake header: %v", err)
  54. }
  55. hijacker, ok := w.(http.Hijacker)
  56. if !ok {
  57. http.Error(w, "make request over HTTP/1", http.StatusBadRequest)
  58. return nil, errors.New("can't hijack client connection")
  59. }
  60. w.Header().Set("Upgrade", upgradeHeaderValue)
  61. w.Header().Set("Connection", "upgrade")
  62. w.WriteHeader(http.StatusSwitchingProtocols)
  63. conn, brw, err := hijacker.Hijack()
  64. if err != nil {
  65. return nil, fmt.Errorf("hijacking client connection: %w", err)
  66. }
  67. defer func() {
  68. if retErr != nil {
  69. conn.Close()
  70. }
  71. }()
  72. if err := brw.Flush(); err != nil {
  73. return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err)
  74. }
  75. conn = netutil.NewDrainBufConn(conn, brw.Reader)
  76. cwc := newWriteCorkingConn(conn)
  77. nc, err := controlbase.Server(ctx, cwc, private, init)
  78. if err != nil {
  79. return nil, fmt.Errorf("noise handshake failed: %w", err)
  80. }
  81. if earlyWrite != nil {
  82. if deadline, ok := ctx.Deadline(); ok {
  83. if err := conn.SetDeadline(deadline); err != nil {
  84. return nil, fmt.Errorf("setting conn deadline: %w", err)
  85. }
  86. defer conn.SetDeadline(time.Time{})
  87. }
  88. if err := earlyWrite(nc.ProtocolVersion(), nc); err != nil {
  89. return nil, err
  90. }
  91. }
  92. if err := cwc.uncork(); err != nil {
  93. return nil, err
  94. }
  95. return nc, nil
  96. }
  97. // acceptWebsocket upgrades a WebSocket connection (from a client that cannot
  98. // speak HTTP) to a Tailscale control protocol base transport connection.
  99. func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) {
  100. c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
  101. Subprotocols: []string{upgradeHeaderValue},
  102. OriginPatterns: []string{"*"},
  103. // Disable compression because we transmit Noise messages that are not
  104. // compressible.
  105. // Additionally, Safari has a broken implementation of compression
  106. // (see https://github.com/nhooyr/websocket/issues/218) that makes
  107. // enabling it actively harmful.
  108. CompressionMode: websocket.CompressionDisabled,
  109. })
  110. if err != nil {
  111. return nil, fmt.Errorf("Could not accept WebSocket connection %v", err)
  112. }
  113. if c.Subprotocol() != upgradeHeaderValue {
  114. c.Close(websocket.StatusPolicyViolation, "client must speak the control subprotocol")
  115. return nil, fmt.Errorf("Unexpected subprotocol %q", c.Subprotocol())
  116. }
  117. if err := r.ParseForm(); err != nil {
  118. c.Close(websocket.StatusPolicyViolation, "Could not parse parameters")
  119. return nil, fmt.Errorf("parse query parameters: %v", err)
  120. }
  121. initB64 := r.Form.Get(handshakeHeaderName)
  122. if initB64 == "" {
  123. c.Close(websocket.StatusPolicyViolation, "missing Tailscale handshake parameter")
  124. return nil, errors.New("no tailscale handshake parameter in HTTP request")
  125. }
  126. init, err := base64.StdEncoding.DecodeString(initB64)
  127. if err != nil {
  128. c.Close(websocket.StatusPolicyViolation, "invalid tailscale handshake parameter")
  129. return nil, fmt.Errorf("decoding base64 handshake parameter: %v", err)
  130. }
  131. conn := wsconn.NetConn(ctx, c, websocket.MessageBinary, r.RemoteAddr)
  132. nc, err := controlbase.Server(ctx, conn, private, init)
  133. if err != nil {
  134. conn.Close()
  135. return nil, fmt.Errorf("noise handshake failed: %w", err)
  136. }
  137. return nc, nil
  138. }
  139. // corkConn is a net.Conn wrapper that initially buffers all writes until uncork
  140. // is called. If the conn is corked and a Read occurs, the Read will flush any
  141. // buffered (corked) write.
  142. //
  143. // Until uncorked, Read/Write/uncork may be not called concurrently.
  144. //
  145. // Deadlines still work, but a corked write ignores deadlines until a Read or
  146. // uncork goes to do that Write.
  147. //
  148. // Use newWriteCorkingConn to create one.
  149. type corkConn struct {
  150. net.Conn
  151. corked bool
  152. buf []byte // corked data
  153. }
  154. func newWriteCorkingConn(c net.Conn) *corkConn {
  155. return &corkConn{Conn: c, corked: true}
  156. }
  157. func (c *corkConn) Write(b []byte) (int, error) {
  158. if c.corked {
  159. c.buf = append(c.buf, b...)
  160. return len(b), nil
  161. }
  162. return c.Conn.Write(b)
  163. }
  164. func (c *corkConn) Read(b []byte) (int, error) {
  165. if c.corked {
  166. if err := c.flush(); err != nil {
  167. return 0, err
  168. }
  169. }
  170. return c.Conn.Read(b)
  171. }
  172. // uncork flushes any buffered data and uncorks the connection so future Writes
  173. // don't buffer. It may not be called concurrently with reads or writes and
  174. // may only be called once.
  175. func (c *corkConn) uncork() error {
  176. if !c.corked {
  177. panic("usage error; uncork called twice") // worth panicking to catch misuse
  178. }
  179. err := c.flush()
  180. c.corked = false
  181. return err
  182. }
  183. func (c *corkConn) flush() error {
  184. if len(c.buf) == 0 {
  185. return nil
  186. }
  187. _, err := c.Conn.Write(c.buf)
  188. c.buf = nil
  189. return err
  190. }