| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- // The sniproxy is an outbound SNI proxy. It receives TLS connections over
- // Tailscale on one or more TCP ports and sends them out to the same SNI
- // hostname & port on the internet. It only does TCP.
- package main
- import (
- "context"
- "flag"
- "log"
- "net"
- "net/http"
- "strings"
- "time"
- "golang.org/x/net/dns/dnsmessage"
- "inet.af/tcpproxy"
- "tailscale.com/client/tailscale"
- "tailscale.com/hostinfo"
- "tailscale.com/net/netutil"
- "tailscale.com/tsnet"
- "tailscale.com/types/nettype"
- )
- var (
- ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
- wgPort = flag.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
- promoteHTTPS = flag.Bool("promote-https", true, "promote HTTP to HTTPS")
- )
- var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
- func main() {
- flag.Parse()
- if *ports == "" {
- log.Fatal("no ports")
- }
- hostinfo.SetApp("sniproxy")
- var s server
- s.ts.Port = uint16(*wgPort)
- defer s.ts.Close()
- lc, err := s.ts.LocalClient()
- if err != nil {
- log.Fatal(err)
- }
- s.lc = lc
- for _, portStr := range strings.Split(*ports, ",") {
- ln, err := s.ts.Listen("tcp", ":"+portStr)
- if err != nil {
- log.Fatal(err)
- }
- log.Printf("Serving on port %v ...", portStr)
- go s.serve(ln)
- }
- ln, err := s.ts.Listen("udp", ":53")
- if err != nil {
- log.Fatal(err)
- }
- go s.serveDNS(ln)
- if *promoteHTTPS {
- ln, err := s.ts.Listen("tcp", ":80")
- if err != nil {
- log.Fatal(err)
- }
- log.Printf("Promoting HTTP to HTTPS ...")
- go s.promoteHTTPS(ln)
- }
- select {}
- }
- type server struct {
- ts tsnet.Server
- lc *tailscale.LocalClient
- }
- func (s *server) serve(ln net.Listener) {
- for {
- c, err := ln.Accept()
- if err != nil {
- log.Fatal(err)
- }
- go s.serveConn(c)
- }
- }
- func (s *server) serveDNS(ln net.Listener) {
- for {
- c, err := ln.Accept()
- if err != nil {
- log.Fatal(err)
- }
- go s.serveDNSConn(c.(nettype.ConnPacketConn))
- }
- }
- func (s *server) serveDNSConn(c nettype.ConnPacketConn) {
- defer c.Close()
- c.SetReadDeadline(time.Now().Add(5 * time.Second))
- buf := make([]byte, 1500)
- n, err := c.Read(buf)
- if err != nil {
- log.Printf("c.Read failed: %v\n ", err)
- return
- }
- var msg dnsmessage.Message
- err = msg.Unpack(buf[:n])
- if err != nil {
- log.Printf("dnsmessage unpack failed: %v\n ", err)
- return
- }
- buf, err = s.dnsResponse(&msg)
- if err != nil {
- log.Printf("s.dnsResponse failed: %v\n", err)
- return
- }
- _, err = c.Write(buf)
- if err != nil {
- log.Printf("c.Write failed: %v\n", err)
- return
- }
- }
- func (s *server) serveConn(c net.Conn) {
- addrPortStr := c.LocalAddr().String()
- _, port, err := net.SplitHostPort(addrPortStr)
- if err != nil {
- log.Printf("bogus addrPort %q", addrPortStr)
- c.Close()
- return
- }
- var dialer net.Dialer
- dialer.Timeout = 5 * time.Second
- var p tcpproxy.Proxy
- p.ListenFunc = func(net, laddr string) (net.Listener, error) {
- return netutil.NewOneConnListener(c, nil), nil
- }
- p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
- return &tcpproxy.DialProxy{
- Addr: net.JoinHostPort(sniName, port),
- DialContext: dialer.DialContext,
- }, true
- })
- p.Start()
- }
- func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
- resp := dnsmessage.NewBuilder(buf,
- dnsmessage.Header{
- ID: req.Header.ID,
- Response: true,
- Authoritative: true,
- })
- resp.EnableCompression()
- if len(req.Questions) == 0 {
- buf, _ = resp.Finish()
- return
- }
- q := req.Questions[0]
- err = resp.StartQuestions()
- if err != nil {
- return
- }
- resp.Question(q)
- ip4, ip6 := s.ts.TailscaleIPs()
- err = resp.StartAnswers()
- if err != nil {
- return
- }
- switch q.Type {
- case dnsmessage.TypeAAAA:
- err = resp.AAAAResource(
- dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
- dnsmessage.AAAAResource{AAAA: ip6.As16()},
- )
- case dnsmessage.TypeA:
- err = resp.AResource(
- dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
- dnsmessage.AResource{A: ip4.As4()},
- )
- case dnsmessage.TypeSOA:
- err = resp.SOAResource(
- dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
- dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
- Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
- )
- case dnsmessage.TypeNS:
- err = resp.NSResource(
- dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
- dnsmessage.NSResource{NS: tsMBox},
- )
- }
- if err != nil {
- return
- }
- return resp.Finish()
- }
- func (s *server) promoteHTTPS(ln net.Listener) {
- err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
- }))
- log.Fatalf("promoteHTTPS http.Serve: %v", err)
- }
|