pgproxy.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. // Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // The pgproxy server is a proxy for the Postgres wire protocol.
  5. package main
  6. import (
  7. "context"
  8. "crypto/ecdsa"
  9. "crypto/elliptic"
  10. crand "crypto/rand"
  11. "crypto/tls"
  12. "crypto/x509"
  13. "crypto/x509/pkix"
  14. "expvar"
  15. "flag"
  16. "fmt"
  17. "io"
  18. "log"
  19. "math/big"
  20. "net"
  21. "net/http"
  22. "os"
  23. "strings"
  24. "time"
  25. "tailscale.com/client/tailscale"
  26. "tailscale.com/metrics"
  27. "tailscale.com/tsnet"
  28. "tailscale.com/tsweb"
  29. "tailscale.com/types/logger"
  30. )
  31. var (
  32. hostname = flag.String("hostname", "", "Tailscale hostname to serve on")
  33. port = flag.Int("port", 5432, "Listening port for client connections")
  34. debugPort = flag.Int("debug-port", 80, "Listening port for debug/metrics endpoint")
  35. upstreamAddr = flag.String("upstream-addr", "", "Address of the upstream Postgres server, in host:port format")
  36. upstreamCA = flag.String("upstream-ca-file", "", "File containing the PEM-encoded CA certificate for the upstream server")
  37. tailscaleDir = flag.String("state-dir", "", "Directory in which to store the Tailscale auth state")
  38. )
  39. func main() {
  40. flag.Parse()
  41. if *hostname == "" {
  42. log.Fatal("missing --hostname")
  43. }
  44. if *upstreamAddr == "" {
  45. log.Fatal("missing --upstream-addr")
  46. }
  47. if *upstreamCA == "" {
  48. log.Fatal("missing --upstream-ca-file")
  49. }
  50. if *tailscaleDir == "" {
  51. log.Fatal("missing --state-dir")
  52. }
  53. ts := &tsnet.Server{
  54. Dir: *tailscaleDir,
  55. Hostname: *hostname,
  56. // Make the stdout logs a clean audit log of connections.
  57. Logf: logger.Discard,
  58. }
  59. if os.Getenv("TS_AUTHKEY") == "" {
  60. log.Print("Note: you need to run this with TS_AUTHKEY=... the first time, to join your tailnet of choice.")
  61. }
  62. tsclient, err := ts.LocalClient()
  63. if err != nil {
  64. log.Fatalf("getting tsnet API client: %v", err)
  65. }
  66. p, err := newProxy(*upstreamAddr, *upstreamCA, tsclient)
  67. if err != nil {
  68. log.Fatal(err)
  69. }
  70. expvar.Publish("pgproxy", p.Expvar())
  71. if *debugPort != 0 {
  72. mux := http.NewServeMux()
  73. tsweb.Debugger(mux)
  74. srv := &http.Server{
  75. Handler: mux,
  76. }
  77. dln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort))
  78. if err != nil {
  79. log.Fatal(err)
  80. }
  81. go func() {
  82. log.Fatal(srv.Serve(dln))
  83. }()
  84. }
  85. ln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *port))
  86. if err != nil {
  87. log.Fatal(err)
  88. }
  89. log.Printf("serving access to %s on port %d", *upstreamAddr, *port)
  90. log.Fatal(p.Serve(ln))
  91. }
  92. // proxy is a postgres wire protocol proxy, which strictly enforces
  93. // the security of the TLS connection to its upstream regardless of
  94. // what the client's TLS configuration is.
  95. type proxy struct {
  96. upstreamAddr string // "my.database.com:5432"
  97. upstreamHost string // "my.database.com"
  98. upstreamCertPool *x509.CertPool
  99. downstreamCert []tls.Certificate
  100. client *tailscale.LocalClient
  101. activeSessions expvar.Int
  102. startedSessions expvar.Int
  103. errors metrics.LabelMap
  104. }
  105. // newProxy returns a proxy that forwards connections to
  106. // upstreamAddr. The upstream's TLS session is verified using the CA
  107. // cert(s) in upstreamCAPath.
  108. func newProxy(upstreamAddr, upstreamCAPath string, client *tailscale.LocalClient) (*proxy, error) {
  109. bs, err := os.ReadFile(upstreamCAPath)
  110. if err != nil {
  111. return nil, err
  112. }
  113. upstreamCertPool := x509.NewCertPool()
  114. if !upstreamCertPool.AppendCertsFromPEM(bs) {
  115. return nil, fmt.Errorf("invalid CA cert in %q", upstreamCAPath)
  116. }
  117. h, _, err := net.SplitHostPort(upstreamAddr)
  118. if err != nil {
  119. return nil, err
  120. }
  121. downstreamCert, err := mkSelfSigned(h)
  122. if err != nil {
  123. return nil, err
  124. }
  125. return &proxy{
  126. upstreamAddr: upstreamAddr,
  127. upstreamHost: h,
  128. upstreamCertPool: upstreamCertPool,
  129. downstreamCert: []tls.Certificate{downstreamCert},
  130. client: client,
  131. errors: metrics.LabelMap{Label: "kind"},
  132. }, nil
  133. }
  134. // Expvar returns p's monitoring metrics.
  135. func (p *proxy) Expvar() expvar.Var {
  136. ret := &metrics.Set{}
  137. ret.Set("sessions_active", &p.activeSessions)
  138. ret.Set("sessions_started", &p.startedSessions)
  139. ret.Set("session_errors", &p.errors)
  140. return ret
  141. }
  142. // Serve accepts postgres client connections on ln and proxies them to
  143. // the configured upstream. ln can be any net.Listener, but all client
  144. // connections must originate from tailscale IPs that can be verified
  145. // with WhoIs.
  146. func (p *proxy) Serve(ln net.Listener) error {
  147. var lastSessionID int64
  148. for {
  149. c, err := ln.Accept()
  150. if err != nil {
  151. return err
  152. }
  153. id := time.Now().UnixNano()
  154. if id == lastSessionID {
  155. // Bluntly enforce SID uniqueness, even if collisions are
  156. // fantastically unlikely (but OSes vary in how much timer
  157. // precision they expose to the OS, so id might be rounded
  158. // e.g. to the same millisecond)
  159. id++
  160. }
  161. lastSessionID = id
  162. go func(sessionID int64) {
  163. if err := p.serve(sessionID, c); err != nil {
  164. log.Printf("%d: session ended with error: %v", sessionID, err)
  165. }
  166. }(id)
  167. }
  168. }
  169. var (
  170. // sslStart is the magic bytes that postgres clients use to indicate
  171. // that they want to do a TLS handshake. Servers should respond with
  172. // the single byte "S" before starting a normal TLS handshake.
  173. sslStart = [8]byte{0, 0, 0, 8, 0x04, 0xd2, 0x16, 0x2f}
  174. // plaintextStart is the magic bytes that postgres clients use to
  175. // indicate that they're starting a plaintext authentication
  176. // handshake.
  177. plaintextStart = [8]byte{0, 0, 0, 86, 0, 3, 0, 0}
  178. )
  179. // serve proxies the postgres client on c to the proxy's upstream,
  180. // enforcing strict TLS to the upstream.
  181. func (p *proxy) serve(sessionID int64, c net.Conn) error {
  182. defer c.Close()
  183. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
  184. defer cancel()
  185. whois, err := p.client.WhoIs(ctx, c.RemoteAddr().String())
  186. if err != nil {
  187. p.errors.Add("whois-failed", 1)
  188. return fmt.Errorf("getting client identity: %v", err)
  189. }
  190. // Before anything else, log the connection attempt.
  191. user, machine := "", ""
  192. if whois.Node != nil {
  193. if whois.Node.Hostinfo.ShareeNode() {
  194. machine = "external-device"
  195. } else {
  196. machine = strings.TrimSuffix(whois.Node.Name, ".")
  197. }
  198. }
  199. if whois.UserProfile != nil {
  200. user = whois.UserProfile.LoginName
  201. if user == "tagged-devices" && whois.Node != nil {
  202. user = strings.Join(whois.Node.Tags, ",")
  203. }
  204. }
  205. if user == "" || machine == "" {
  206. p.errors.Add("no-ts-identity", 1)
  207. return fmt.Errorf("couldn't identify source user and machine (user %q, machine %q)", user, machine)
  208. }
  209. log.Printf("%d: session start, from %s (machine %s, user %s)", sessionID, c.RemoteAddr(), machine, user)
  210. start := time.Now()
  211. defer func() {
  212. elapsed := time.Since(start)
  213. log.Printf("%d: session end, from %s (machine %s, user %s), lasted %s", sessionID, c.RemoteAddr(), machine, user, elapsed.Round(time.Millisecond))
  214. }()
  215. // Read the client's opening message, to figure out if it's trying
  216. // to TLS or not.
  217. var buf [8]byte
  218. if _, err := io.ReadFull(c, buf[:len(sslStart)]); err != nil {
  219. p.errors.Add("network-error", 1)
  220. return fmt.Errorf("initial magic read: %v", err)
  221. }
  222. var clientIsTLS bool
  223. switch {
  224. case buf == sslStart:
  225. clientIsTLS = true
  226. case buf == plaintextStart:
  227. clientIsTLS = false
  228. default:
  229. p.errors.Add("client-bad-protocol", 1)
  230. return fmt.Errorf("unrecognized initial packet = % 02x", buf)
  231. }
  232. // Dial & verify upstream connection.
  233. var d net.Dialer
  234. d.Timeout = 10 * time.Second
  235. upc, err := d.Dial("tcp", p.upstreamAddr)
  236. if err != nil {
  237. p.errors.Add("network-error", 1)
  238. return fmt.Errorf("upstream dial: %v", err)
  239. }
  240. defer upc.Close()
  241. if _, err := upc.Write(sslStart[:]); err != nil {
  242. p.errors.Add("network-error", 1)
  243. return fmt.Errorf("upstream write of start-ssl magic: %v", err)
  244. }
  245. if _, err := io.ReadFull(upc, buf[:1]); err != nil {
  246. p.errors.Add("network-error", 1)
  247. return fmt.Errorf("reading upstream start-ssl response: %v", err)
  248. }
  249. if buf[0] != 'S' {
  250. p.errors.Add("upstream-bad-protocol", 1)
  251. return fmt.Errorf("upstream didn't acknowldge start-ssl, said %q", buf[0])
  252. }
  253. tlsConf := &tls.Config{
  254. ServerName: p.upstreamHost,
  255. RootCAs: p.upstreamCertPool,
  256. MinVersion: tls.VersionTLS12,
  257. }
  258. uptc := tls.Client(upc, tlsConf)
  259. if err = uptc.HandshakeContext(ctx); err != nil {
  260. p.errors.Add("upstream-tls", 1)
  261. return fmt.Errorf("upstream TLS handshake: %v", err)
  262. }
  263. // Accept the client conn and set it up the way the client wants.
  264. var clientConn net.Conn
  265. if clientIsTLS {
  266. io.WriteString(c, "S") // yeah, we're good to speak TLS
  267. s := tls.Server(c, &tls.Config{
  268. ServerName: p.upstreamHost,
  269. Certificates: p.downstreamCert,
  270. MinVersion: tls.VersionTLS12,
  271. })
  272. if err = uptc.HandshakeContext(ctx); err != nil {
  273. p.errors.Add("client-tls", 1)
  274. return fmt.Errorf("client TLS handshake: %v", err)
  275. }
  276. clientConn = s
  277. } else {
  278. // Repeat the header we read earlier up to the server.
  279. if _, err := uptc.Write(plaintextStart[:]); err != nil {
  280. p.errors.Add("network-error", 1)
  281. return fmt.Errorf("sending initial client bytes to upstream: %v", err)
  282. }
  283. clientConn = c
  284. }
  285. // Finally, proxy the client to the upstream.
  286. errc := make(chan error, 1)
  287. go func() {
  288. _, err := io.Copy(uptc, clientConn)
  289. errc <- err
  290. }()
  291. go func() {
  292. _, err := io.Copy(clientConn, uptc)
  293. errc <- err
  294. }()
  295. if err := <-errc; err != nil {
  296. // Don't increment error counts here, because the most common
  297. // cause of termination is client or server closing the
  298. // connection normally, and it'll obscure "interesting"
  299. // handshake errors.
  300. return fmt.Errorf("session terminated with error: %v", err)
  301. }
  302. return nil
  303. }
  304. // mkSelfSigned creates and returns a self-signed TLS certificate for
  305. // hostname.
  306. func mkSelfSigned(hostname string) (tls.Certificate, error) {
  307. priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
  308. if err != nil {
  309. return tls.Certificate{}, err
  310. }
  311. pub := priv.Public()
  312. template := x509.Certificate{
  313. SerialNumber: big.NewInt(1),
  314. Subject: pkix.Name{
  315. Organization: []string{"pgproxy"},
  316. },
  317. DNSNames: []string{hostname},
  318. NotBefore: time.Now(),
  319. NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
  320. KeyUsage: x509.KeyUsageDigitalSignature,
  321. ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
  322. BasicConstraintsValid: true,
  323. }
  324. derBytes, err := x509.CreateCertificate(crand.Reader, &template, &template, pub, priv)
  325. if err != nil {
  326. return tls.Certificate{}, err
  327. }
  328. cert, err := x509.ParseCertificate(derBytes)
  329. if err != nil {
  330. return tls.Certificate{}, err
  331. }
  332. return tls.Certificate{
  333. Certificate: [][]byte{derBytes},
  334. PrivateKey: priv,
  335. Leaf: cert,
  336. }, nil
  337. }