Bläddra i källkod

net/wsconn: add back custom wrapper for turning a websocket.Conn into a net.Conn

We removed it in #4806 in favor of the built-in functionality from the
nhooyr.io/websocket package. However, it has an issue with deadlines
that has not been fixed yet (see nhooyr/websocket#350). Temporarily
go back to using a custom wrapper (using the fix from our fork) so that
derpers will stop closing connections too aggressively.

Updates #5921

Signed-off-by: Mihai Parparita <[email protected]>
Mihai Parparita 3 år sedan
förälder
incheckning
9d04ffc782

+ 1 - 0
cmd/derper/depaware.txt

@@ -47,6 +47,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         tailscale.com/net/tlsdial                                    from tailscale.com/derp/derphttp
         tailscale.com/net/tsaddr                                     from tailscale.com/ipn+
      💣 tailscale.com/net/tshttpproxy                                from tailscale.com/derp/derphttp+
+        tailscale.com/net/wsconn                                     from tailscale.com/cmd/derper+
         tailscale.com/paths                                          from tailscale.com/client/tailscale
         tailscale.com/safesocket                                     from tailscale.com/client/tailscale
         tailscale.com/syncs                                          from tailscale.com/cmd/derper+

+ 2 - 1
cmd/derper/websocket.go

@@ -13,6 +13,7 @@ import (
 
 	"nhooyr.io/websocket"
 	"tailscale.com/derp"
+	"tailscale.com/net/wsconn"
 )
 
 var counterWebSocketAccepts = expvar.NewInt("derp_websocket_accepts")
@@ -50,7 +51,7 @@ func addWebSocketSupport(s *derp.Server, base http.Handler) http.Handler {
 			return
 		}
 		counterWebSocketAccepts.Add(1)
-		wc := websocket.NetConn(r.Context(), c, websocket.MessageBinary)
+		wc := wsconn.NetConn(r.Context(), c, websocket.MessageBinary)
 		brw := bufio.NewReadWriter(bufio.NewReader(wc), bufio.NewWriter(wc))
 		s.Accept(r.Context(), wc, brw, r.RemoteAddr)
 	})

+ 1 - 0
cmd/tailscale/depaware.txt

@@ -70,6 +70,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/net/tlsdial                                    from tailscale.com/derp/derphttp+
         tailscale.com/net/tsaddr                                     from tailscale.com/net/interfaces+
      💣 tailscale.com/net/tshttpproxy                                from tailscale.com/derp/derphttp+
+        tailscale.com/net/wsconn                                     from tailscale.com/control/controlhttp+
         tailscale.com/paths                                          from tailscale.com/cmd/tailscale/cli+
         tailscale.com/safesocket                                     from tailscale.com/cmd/tailscale/cli+
         tailscale.com/syncs                                          from tailscale.com/net/netcheck+

+ 1 - 0
cmd/tailscaled/depaware.txt

@@ -241,6 +241,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
      💣 tailscale.com/net/tshttpproxy                                from tailscale.com/control/controlclient+
         tailscale.com/net/tstun                                      from tailscale.com/net/dns+
         tailscale.com/net/tunstats                                   from tailscale.com/net/tstun+
+        tailscale.com/net/wsconn                                     from tailscale.com/control/controlhttp+
         tailscale.com/paths                                          from tailscale.com/ipn/ipnlocal+
         tailscale.com/portlist                                       from tailscale.com/ipn/ipnlocal
         tailscale.com/safesocket                                     from tailscale.com/client/tailscale+

+ 2 - 1
control/controlhttp/client_js.go

@@ -13,6 +13,7 @@ import (
 
 	"nhooyr.io/websocket"
 	"tailscale.com/control/controlbase"
+	"tailscale.com/net/wsconn"
 )
 
 // Variant of Dial that tunnels the request over WebSockets, since we cannot do
@@ -51,7 +52,7 @@ func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) {
 	if err != nil {
 		return nil, err
 	}
-	netConn := websocket.NetConn(context.Background(), wsConn, websocket.MessageBinary)
+	netConn := wsconn.NetConn(context.Background(), wsConn, websocket.MessageBinary)
 	cbConn, err := cont(ctx, netConn)
 	if err != nil {
 		netConn.Close()

+ 2 - 1
control/controlhttp/server.go

@@ -14,6 +14,7 @@ import (
 	"nhooyr.io/websocket"
 	"tailscale.com/control/controlbase"
 	"tailscale.com/net/netutil"
+	"tailscale.com/net/wsconn"
 	"tailscale.com/types/key"
 )
 
@@ -118,7 +119,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request
 		return nil, fmt.Errorf("decoding base64 handshake parameter: %v", err)
 	}
 
-	conn := websocket.NetConn(ctx, c, websocket.MessageBinary)
+	conn := wsconn.NetConn(ctx, c, websocket.MessageBinary)
 	nc, err := controlbase.Server(ctx, conn, private, init)
 	if err != nil {
 		conn.Close()

+ 2 - 1
derp/derphttp/websocket.go

@@ -13,6 +13,7 @@ import (
 	"net"
 
 	"nhooyr.io/websocket"
+	"tailscale.com/net/wsconn"
 )
 
 func init() {
@@ -28,6 +29,6 @@ func dialWebsocket(ctx context.Context, urlStr string) (net.Conn, error) {
 		return nil, err
 	}
 	log.Printf("websocket: connected to %v", urlStr)
-	netConn := websocket.NetConn(context.Background(), c, websocket.MessageBinary)
+	netConn := wsconn.NetConn(context.Background(), c, websocket.MessageBinary)
 	return netConn, nil
 }

+ 213 - 0
net/wsconn/wsconn.go

@@ -0,0 +1,213 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package wsconn contains an adapter type that turns
+// a websocket connection into a net.Conn. It a temporary fork of the
+// netconn.go file from the nhooyr.io/websocket package while we wait for
+// https://github.com/nhooyr/websocket/pull/350 to be merged.
+package wsconn
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"math"
+	"net"
+	"os"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"nhooyr.io/websocket"
+)
+
+// NetConn converts a *websocket.Conn into a net.Conn.
+//
+// It's for tunneling arbitrary protocols over WebSockets.
+// Few users of the library will need this but it's tricky to implement
+// correctly and so provided in the library.
+// See https://github.com/nhooyr/websocket/issues/100.
+//
+// Every Write to the net.Conn will correspond to a message write of
+// the given type on *websocket.Conn.
+//
+// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
+// all reads and writes on the net.Conn will be cancelled.
+//
+// If a message is read that is not of the correct type, the connection
+// will be closed with StatusUnsupportedData and an error will be returned.
+//
+// Close will close the *websocket.Conn with StatusNormalClosure.
+//
+// When a deadline is hit, the connection will be closed. This is
+// different from most net.Conn implementations where only the
+// reading/writing goroutines are interrupted but the connection is kept alive.
+//
+// The Addr methods will return a mock net.Addr that returns "websocket" for Network
+// and "websocket/unknown-addr" for String.
+//
+// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
+// io.EOF when reading.
+func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageType) net.Conn {
+	nc := &netConn{
+		c:       c,
+		msgType: msgType,
+	}
+
+	var writeCancel context.CancelFunc
+	nc.writeContext, writeCancel = context.WithCancel(ctx)
+	nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
+		nc.afterWriteDeadline.Store(true)
+		if nc.writing.Load() {
+			writeCancel()
+		}
+	})
+	if !nc.writeTimer.Stop() {
+		<-nc.writeTimer.C
+	}
+
+	var readCancel context.CancelFunc
+	nc.readContext, readCancel = context.WithCancel(ctx)
+	nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
+		nc.afterReadDeadline.Store(true)
+		if nc.reading.Load() {
+			readCancel()
+		}
+	})
+	if !nc.readTimer.Stop() {
+		<-nc.readTimer.C
+	}
+
+	return nc
+}
+
+type netConn struct {
+	c       *websocket.Conn
+	msgType websocket.MessageType
+
+	writeTimer         *time.Timer
+	writeContext       context.Context
+	writing            atomic.Bool
+	afterWriteDeadline atomic.Bool
+
+	readTimer         *time.Timer
+	readContext       context.Context
+	reading           atomic.Bool
+	afterReadDeadline atomic.Bool
+
+	readMu sync.Mutex
+	eofed  bool
+	reader io.Reader
+}
+
+var _ net.Conn = &netConn{}
+
+func (c *netConn) Close() error {
+	return c.c.Close(websocket.StatusNormalClosure, "")
+}
+
+func (c *netConn) Write(p []byte) (int, error) {
+	if c.afterWriteDeadline.Load() {
+		return 0, os.ErrDeadlineExceeded
+	}
+
+	if swapped := c.writing.CompareAndSwap(false, true); !swapped {
+		panic("Concurrent writes not allowed")
+	}
+	defer c.writing.Store(false)
+
+	err := c.c.Write(c.writeContext, c.msgType, p)
+	if err != nil {
+		return 0, err
+	}
+
+	return len(p), nil
+}
+
+func (c *netConn) Read(p []byte) (int, error) {
+	if c.afterReadDeadline.Load() {
+		return 0, os.ErrDeadlineExceeded
+	}
+
+	c.readMu.Lock()
+	defer c.readMu.Unlock()
+	if swapped := c.reading.CompareAndSwap(false, true); !swapped {
+		panic("Concurrent reads not allowed")
+	}
+	defer c.reading.Store(false)
+
+	if c.eofed {
+		return 0, io.EOF
+	}
+
+	if c.reader == nil {
+		typ, r, err := c.c.Reader(c.readContext)
+		if err != nil {
+			switch websocket.CloseStatus(err) {
+			case websocket.StatusNormalClosure, websocket.StatusGoingAway:
+				c.eofed = true
+				return 0, io.EOF
+			}
+			return 0, err
+		}
+		if typ != c.msgType {
+			err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
+			c.c.Close(websocket.StatusUnsupportedData, err.Error())
+			return 0, err
+		}
+		c.reader = r
+	}
+
+	n, err := c.reader.Read(p)
+	if err == io.EOF {
+		c.reader = nil
+		err = nil
+	}
+	return n, err
+}
+
+type websocketAddr struct {
+}
+
+func (a websocketAddr) Network() string {
+	return "websocket"
+}
+
+func (a websocketAddr) String() string {
+	return "websocket/unknown-addr"
+}
+
+func (c *netConn) RemoteAddr() net.Addr {
+	return websocketAddr{}
+}
+
+func (c *netConn) LocalAddr() net.Addr {
+	return websocketAddr{}
+}
+
+func (c *netConn) SetDeadline(t time.Time) error {
+	c.SetWriteDeadline(t)
+	c.SetReadDeadline(t)
+	return nil
+}
+
+func (c *netConn) SetWriteDeadline(t time.Time) error {
+	if t.IsZero() {
+		c.writeTimer.Stop()
+	} else {
+		c.writeTimer.Reset(time.Until(t))
+	}
+	c.afterWriteDeadline.Store(false)
+	return nil
+}
+
+func (c *netConn) SetReadDeadline(t time.Time) error {
+	if t.IsZero() {
+		c.readTimer.Stop()
+	} else {
+		c.readTimer.Reset(time.Until(t))
+	}
+	c.afterReadDeadline.Store(false)
+	return nil
+}