sniproxy.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // The sniproxy is an outbound SNI proxy. It receives TLS connections over
  4. // Tailscale on one or more TCP ports and sends them out to the same SNI
  5. // hostname & port on the internet. It only does TCP.
  6. package main
  7. import (
  8. "context"
  9. "flag"
  10. "log"
  11. "net"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "golang.org/x/net/dns/dnsmessage"
  16. "inet.af/tcpproxy"
  17. "tailscale.com/client/tailscale"
  18. "tailscale.com/hostinfo"
  19. "tailscale.com/net/netutil"
  20. "tailscale.com/tsnet"
  21. "tailscale.com/types/nettype"
  22. )
  23. var (
  24. ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
  25. wgPort = flag.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
  26. promoteHTTPS = flag.Bool("promote-https", true, "promote HTTP to HTTPS")
  27. )
  28. var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
  29. func main() {
  30. flag.Parse()
  31. if *ports == "" {
  32. log.Fatal("no ports")
  33. }
  34. hostinfo.SetApp("sniproxy")
  35. var s server
  36. s.ts.Port = uint16(*wgPort)
  37. defer s.ts.Close()
  38. lc, err := s.ts.LocalClient()
  39. if err != nil {
  40. log.Fatal(err)
  41. }
  42. s.lc = lc
  43. for _, portStr := range strings.Split(*ports, ",") {
  44. ln, err := s.ts.Listen("tcp", ":"+portStr)
  45. if err != nil {
  46. log.Fatal(err)
  47. }
  48. log.Printf("Serving on port %v ...", portStr)
  49. go s.serve(ln)
  50. }
  51. ln, err := s.ts.Listen("udp", ":53")
  52. if err != nil {
  53. log.Fatal(err)
  54. }
  55. go s.serveDNS(ln)
  56. if *promoteHTTPS {
  57. ln, err := s.ts.Listen("tcp", ":80")
  58. if err != nil {
  59. log.Fatal(err)
  60. }
  61. log.Printf("Promoting HTTP to HTTPS ...")
  62. go s.promoteHTTPS(ln)
  63. }
  64. select {}
  65. }
  66. type server struct {
  67. ts tsnet.Server
  68. lc *tailscale.LocalClient
  69. }
  70. func (s *server) serve(ln net.Listener) {
  71. for {
  72. c, err := ln.Accept()
  73. if err != nil {
  74. log.Fatal(err)
  75. }
  76. go s.serveConn(c)
  77. }
  78. }
  79. func (s *server) serveDNS(ln net.Listener) {
  80. for {
  81. c, err := ln.Accept()
  82. if err != nil {
  83. log.Fatal(err)
  84. }
  85. go s.serveDNSConn(c.(nettype.ConnPacketConn))
  86. }
  87. }
  88. func (s *server) serveDNSConn(c nettype.ConnPacketConn) {
  89. defer c.Close()
  90. c.SetReadDeadline(time.Now().Add(5 * time.Second))
  91. buf := make([]byte, 1500)
  92. n, err := c.Read(buf)
  93. if err != nil {
  94. log.Printf("c.Read failed: %v\n ", err)
  95. return
  96. }
  97. var msg dnsmessage.Message
  98. err = msg.Unpack(buf[:n])
  99. if err != nil {
  100. log.Printf("dnsmessage unpack failed: %v\n ", err)
  101. return
  102. }
  103. buf, err = s.dnsResponse(&msg)
  104. if err != nil {
  105. log.Printf("s.dnsResponse failed: %v\n", err)
  106. return
  107. }
  108. _, err = c.Write(buf)
  109. if err != nil {
  110. log.Printf("c.Write failed: %v\n", err)
  111. return
  112. }
  113. }
  114. func (s *server) serveConn(c net.Conn) {
  115. addrPortStr := c.LocalAddr().String()
  116. _, port, err := net.SplitHostPort(addrPortStr)
  117. if err != nil {
  118. log.Printf("bogus addrPort %q", addrPortStr)
  119. c.Close()
  120. return
  121. }
  122. var dialer net.Dialer
  123. dialer.Timeout = 5 * time.Second
  124. var p tcpproxy.Proxy
  125. p.ListenFunc = func(net, laddr string) (net.Listener, error) {
  126. return netutil.NewOneConnListener(c, nil), nil
  127. }
  128. p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
  129. return &tcpproxy.DialProxy{
  130. Addr: net.JoinHostPort(sniName, port),
  131. DialContext: dialer.DialContext,
  132. }, true
  133. })
  134. p.Start()
  135. }
  136. func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
  137. resp := dnsmessage.NewBuilder(buf,
  138. dnsmessage.Header{
  139. ID: req.Header.ID,
  140. Response: true,
  141. Authoritative: true,
  142. })
  143. resp.EnableCompression()
  144. if len(req.Questions) == 0 {
  145. buf, _ = resp.Finish()
  146. return
  147. }
  148. q := req.Questions[0]
  149. err = resp.StartQuestions()
  150. if err != nil {
  151. return
  152. }
  153. resp.Question(q)
  154. ip4, ip6 := s.ts.TailscaleIPs()
  155. err = resp.StartAnswers()
  156. if err != nil {
  157. return
  158. }
  159. switch q.Type {
  160. case dnsmessage.TypeAAAA:
  161. err = resp.AAAAResource(
  162. dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
  163. dnsmessage.AAAAResource{AAAA: ip6.As16()},
  164. )
  165. case dnsmessage.TypeA:
  166. err = resp.AResource(
  167. dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
  168. dnsmessage.AResource{A: ip4.As4()},
  169. )
  170. case dnsmessage.TypeSOA:
  171. err = resp.SOAResource(
  172. dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
  173. dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
  174. Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
  175. )
  176. case dnsmessage.TypeNS:
  177. err = resp.NSResource(
  178. dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
  179. dnsmessage.NSResource{NS: tsMBox},
  180. )
  181. }
  182. if err != nil {
  183. return
  184. }
  185. return resp.Finish()
  186. }
  187. func (s *server) promoteHTTPS(ln net.Listener) {
  188. err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  189. http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
  190. }))
  191. log.Fatalf("promoteHTTPS http.Serve: %v", err)
  192. }