| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- // The pgproxy server is a proxy for the Postgres wire protocol.
- package main
- import (
- "context"
- "crypto/ecdsa"
- "crypto/elliptic"
- crand "crypto/rand"
- "crypto/tls"
- "crypto/x509"
- "crypto/x509/pkix"
- "expvar"
- "flag"
- "fmt"
- "io"
- "log"
- "math/big"
- "net"
- "net/http"
- "os"
- "strings"
- "time"
- "tailscale.com/client/local"
- "tailscale.com/metrics"
- "tailscale.com/tsnet"
- "tailscale.com/tsweb"
- )
- var (
- hostname = flag.String("hostname", "", "Tailscale hostname to serve on")
- port = flag.Int("port", 5432, "Listening port for client connections")
- debugPort = flag.Int("debug-port", 80, "Listening port for debug/metrics endpoint")
- upstreamAddr = flag.String("upstream-addr", "", "Address of the upstream Postgres server, in host:port format")
- upstreamCA = flag.String("upstream-ca-file", "", "File containing the PEM-encoded CA certificate for the upstream server")
- tailscaleDir = flag.String("state-dir", "", "Directory in which to store the Tailscale auth state")
- )
- func main() {
- flag.Parse()
- if *hostname == "" {
- log.Fatal("missing --hostname")
- }
- if *upstreamAddr == "" {
- log.Fatal("missing --upstream-addr")
- }
- if *upstreamCA == "" {
- log.Fatal("missing --upstream-ca-file")
- }
- if *tailscaleDir == "" {
- log.Fatal("missing --state-dir")
- }
- ts := &tsnet.Server{
- Dir: *tailscaleDir,
- Hostname: *hostname,
- }
- if os.Getenv("TS_AUTHKEY") == "" {
- log.Print("Note: you need to run this with TS_AUTHKEY=... the first time, to join your tailnet of choice.")
- }
- tsclient, err := ts.LocalClient()
- if err != nil {
- log.Fatalf("getting tsnet API client: %v", err)
- }
- p, err := newProxy(*upstreamAddr, *upstreamCA, tsclient)
- if err != nil {
- log.Fatal(err)
- }
- expvar.Publish("pgproxy", p.Expvar())
- if *debugPort != 0 {
- mux := http.NewServeMux()
- tsweb.Debugger(mux)
- srv := &http.Server{
- Handler: mux,
- }
- dln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort))
- if err != nil {
- log.Fatal(err)
- }
- go func() {
- log.Fatal(srv.Serve(dln))
- }()
- }
- ln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *port))
- if err != nil {
- log.Fatal(err)
- }
- log.Printf("serving access to %s on port %d", *upstreamAddr, *port)
- log.Fatal(p.Serve(ln))
- }
- // proxy is a postgres wire protocol proxy, which strictly enforces
- // the security of the TLS connection to its upstream regardless of
- // what the client's TLS configuration is.
- type proxy struct {
- upstreamAddr string // "my.database.com:5432"
- upstreamHost string // "my.database.com"
- upstreamCertPool *x509.CertPool
- downstreamCert []tls.Certificate
- client *local.Client
- activeSessions expvar.Int
- startedSessions expvar.Int
- errors metrics.LabelMap
- }
- // newProxy returns a proxy that forwards connections to
- // upstreamAddr. The upstream's TLS session is verified using the CA
- // cert(s) in upstreamCAPath.
- func newProxy(upstreamAddr, upstreamCAPath string, client *local.Client) (*proxy, error) {
- bs, err := os.ReadFile(upstreamCAPath)
- if err != nil {
- return nil, err
- }
- upstreamCertPool := x509.NewCertPool()
- if !upstreamCertPool.AppendCertsFromPEM(bs) {
- return nil, fmt.Errorf("invalid CA cert in %q", upstreamCAPath)
- }
- h, _, err := net.SplitHostPort(upstreamAddr)
- if err != nil {
- return nil, err
- }
- downstreamCert, err := mkSelfSigned(h)
- if err != nil {
- return nil, err
- }
- return &proxy{
- upstreamAddr: upstreamAddr,
- upstreamHost: h,
- upstreamCertPool: upstreamCertPool,
- downstreamCert: []tls.Certificate{downstreamCert},
- client: client,
- errors: metrics.LabelMap{Label: "kind"},
- }, nil
- }
- // Expvar returns p's monitoring metrics.
- func (p *proxy) Expvar() expvar.Var {
- ret := &metrics.Set{}
- ret.Set("sessions_active", &p.activeSessions)
- ret.Set("sessions_started", &p.startedSessions)
- ret.Set("session_errors", &p.errors)
- return ret
- }
- // Serve accepts postgres client connections on ln and proxies them to
- // the configured upstream. ln can be any net.Listener, but all client
- // connections must originate from tailscale IPs that can be verified
- // with WhoIs.
- func (p *proxy) Serve(ln net.Listener) error {
- var lastSessionID int64
- for {
- c, err := ln.Accept()
- if err != nil {
- return err
- }
- id := time.Now().UnixNano()
- if id == lastSessionID {
- // Bluntly enforce SID uniqueness, even if collisions are
- // fantastically unlikely (but OSes vary in how much timer
- // precision they expose to the OS, so id might be rounded
- // e.g. to the same millisecond)
- id++
- }
- lastSessionID = id
- go func(sessionID int64) {
- if err := p.serve(sessionID, c); err != nil {
- log.Printf("%d: session ended with error: %v", sessionID, err)
- }
- }(id)
- }
- }
- var (
- // sslStart is the magic bytes that postgres clients use to indicate
- // that they want to do a TLS handshake. Servers should respond with
- // the single byte "S" before starting a normal TLS handshake.
- sslStart = [8]byte{0, 0, 0, 8, 0x04, 0xd2, 0x16, 0x2f}
- // plaintextStart is the magic bytes that postgres clients use to
- // indicate that they're starting a plaintext authentication
- // handshake.
- plaintextStart = [8]byte{0, 0, 0, 86, 0, 3, 0, 0}
- )
- // serve proxies the postgres client on c to the proxy's upstream,
- // enforcing strict TLS to the upstream.
- func (p *proxy) serve(sessionID int64, c net.Conn) error {
- defer c.Close()
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
- whois, err := p.client.WhoIs(ctx, c.RemoteAddr().String())
- if err != nil {
- p.errors.Add("whois-failed", 1)
- return fmt.Errorf("getting client identity: %v", err)
- }
- // Before anything else, log the connection attempt.
- user, machine := "", ""
- if whois.Node != nil {
- if whois.Node.Hostinfo.ShareeNode() {
- machine = "external-device"
- } else {
- machine = strings.TrimSuffix(whois.Node.Name, ".")
- }
- }
- if whois.UserProfile != nil {
- user = whois.UserProfile.LoginName
- if user == "tagged-devices" && whois.Node != nil {
- user = strings.Join(whois.Node.Tags, ",")
- }
- }
- if user == "" || machine == "" {
- p.errors.Add("no-ts-identity", 1)
- return fmt.Errorf("couldn't identify source user and machine (user %q, machine %q)", user, machine)
- }
- log.Printf("%d: session start, from %s (machine %s, user %s)", sessionID, c.RemoteAddr(), machine, user)
- start := time.Now()
- defer func() {
- elapsed := time.Since(start)
- log.Printf("%d: session end, from %s (machine %s, user %s), lasted %s", sessionID, c.RemoteAddr(), machine, user, elapsed.Round(time.Millisecond))
- }()
- // Read the client's opening message, to figure out if it's trying
- // to TLS or not.
- var buf [8]byte
- if _, err := io.ReadFull(c, buf[:len(sslStart)]); err != nil {
- p.errors.Add("network-error", 1)
- return fmt.Errorf("initial magic read: %v", err)
- }
- var clientIsTLS bool
- switch {
- case buf == sslStart:
- clientIsTLS = true
- case buf == plaintextStart:
- clientIsTLS = false
- default:
- p.errors.Add("client-bad-protocol", 1)
- return fmt.Errorf("unrecognized initial packet = % 02x", buf)
- }
- // Dial & verify upstream connection.
- var d net.Dialer
- d.Timeout = 10 * time.Second
- upc, err := d.Dial("tcp", p.upstreamAddr)
- if err != nil {
- p.errors.Add("network-error", 1)
- return fmt.Errorf("upstream dial: %v", err)
- }
- defer upc.Close()
- if _, err := upc.Write(sslStart[:]); err != nil {
- p.errors.Add("network-error", 1)
- return fmt.Errorf("upstream write of start-ssl magic: %v", err)
- }
- if _, err := io.ReadFull(upc, buf[:1]); err != nil {
- p.errors.Add("network-error", 1)
- return fmt.Errorf("reading upstream start-ssl response: %v", err)
- }
- if buf[0] != 'S' {
- p.errors.Add("upstream-bad-protocol", 1)
- return fmt.Errorf("upstream didn't acknowledge start-ssl, said %q", buf[0])
- }
- tlsConf := &tls.Config{
- ServerName: p.upstreamHost,
- RootCAs: p.upstreamCertPool,
- MinVersion: tls.VersionTLS12,
- }
- uptc := tls.Client(upc, tlsConf)
- if err = uptc.HandshakeContext(ctx); err != nil {
- p.errors.Add("upstream-tls", 1)
- return fmt.Errorf("upstream TLS handshake: %v", err)
- }
- // Accept the client conn and set it up the way the client wants.
- var clientConn net.Conn
- if clientIsTLS {
- io.WriteString(c, "S") // yeah, we're good to speak TLS
- s := tls.Server(c, &tls.Config{
- ServerName: p.upstreamHost,
- Certificates: p.downstreamCert,
- MinVersion: tls.VersionTLS12,
- })
- if err = uptc.HandshakeContext(ctx); err != nil {
- p.errors.Add("client-tls", 1)
- return fmt.Errorf("client TLS handshake: %v", err)
- }
- clientConn = s
- } else {
- // Repeat the header we read earlier up to the server.
- if _, err := uptc.Write(plaintextStart[:]); err != nil {
- p.errors.Add("network-error", 1)
- return fmt.Errorf("sending initial client bytes to upstream: %v", err)
- }
- clientConn = c
- }
- // Finally, proxy the client to the upstream.
- errc := make(chan error, 1)
- go func() {
- _, err := io.Copy(uptc, clientConn)
- errc <- err
- }()
- go func() {
- _, err := io.Copy(clientConn, uptc)
- errc <- err
- }()
- if err := <-errc; err != nil {
- // Don't increment error counts here, because the most common
- // cause of termination is client or server closing the
- // connection normally, and it'll obscure "interesting"
- // handshake errors.
- return fmt.Errorf("session terminated with error: %v", err)
- }
- return nil
- }
- // mkSelfSigned creates and returns a self-signed TLS certificate for
- // hostname.
- func mkSelfSigned(hostname string) (tls.Certificate, error) {
- priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
- if err != nil {
- return tls.Certificate{}, err
- }
- pub := priv.Public()
- template := x509.Certificate{
- SerialNumber: big.NewInt(1),
- Subject: pkix.Name{
- Organization: []string{"pgproxy"},
- },
- DNSNames: []string{hostname},
- NotBefore: time.Now(),
- NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
- KeyUsage: x509.KeyUsageDigitalSignature,
- ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
- BasicConstraintsValid: true,
- }
- derBytes, err := x509.CreateCertificate(crand.Reader, &template, &template, pub, priv)
- if err != nil {
- return tls.Certificate{}, err
- }
- cert, err := x509.ParseCertificate(derBytes)
- if err != nil {
- return tls.Certificate{}, err
- }
- return tls.Certificate{
- Certificate: [][]byte{derBytes},
- PrivateKey: priv,
- Leaf: cert,
- }, nil
- }
|