connectproxy.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // Package connectproxy contains some CONNECT proxy code.
  4. package connectproxy
  5. import (
  6. "context"
  7. "io"
  8. "log"
  9. "net"
  10. "net/http"
  11. "time"
  12. "tailscale.com/net/netx"
  13. "tailscale.com/types/logger"
  14. )
  15. // Handler is an HTTP CONNECT proxy handler.
  16. type Handler struct {
  17. // Dial, if non-nil, is an alternate dialer to use
  18. // instead of the default dialer.
  19. Dial netx.DialFunc
  20. // Logf, if non-nil, is an alterate logger to
  21. // use instead of log.Printf.
  22. Logf logger.Logf
  23. // Check, if non-nil, validates the CONNECT target.
  24. Check func(hostPort string) error
  25. }
  26. func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  27. ctx := r.Context()
  28. if r.Method != "CONNECT" {
  29. http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
  30. return
  31. }
  32. dial := h.Dial
  33. if dial == nil {
  34. var d net.Dialer
  35. dial = d.DialContext
  36. }
  37. logf := h.Logf
  38. if logf == nil {
  39. logf = log.Printf
  40. }
  41. hostPort := r.RequestURI
  42. if h.Check != nil {
  43. if err := h.Check(hostPort); err != nil {
  44. logf("CONNECT target %q not allowed: %v", hostPort, err)
  45. http.Error(w, "Invalid CONNECT target", http.StatusForbidden)
  46. return
  47. }
  48. }
  49. ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
  50. defer cancel()
  51. back, err := dial(ctx, "tcp", hostPort)
  52. if err != nil {
  53. logf("error CONNECT dialing %v: %v", hostPort, err)
  54. http.Error(w, "Connect failure", http.StatusBadGateway)
  55. return
  56. }
  57. defer back.Close()
  58. hj, ok := w.(http.Hijacker)
  59. if !ok {
  60. http.Error(w, "CONNECT hijack unavailable", http.StatusInternalServerError)
  61. return
  62. }
  63. c, br, err := hj.Hijack()
  64. if err != nil {
  65. logf("CONNECT hijack: %v", err)
  66. return
  67. }
  68. defer c.Close()
  69. io.WriteString(c, "HTTP/1.1 200 OK\r\n\r\n")
  70. errc := make(chan error, 2)
  71. go func() {
  72. _, err := io.Copy(c, back)
  73. errc <- err
  74. }()
  75. go func() {
  76. _, err := io.Copy(back, br)
  77. errc <- err
  78. }()
  79. <-errc
  80. }