|
|
@@ -0,0 +1,213 @@
|
|
|
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
|
|
+// Use of this source code is governed by a BSD-style
|
|
|
+// license that can be found in the LICENSE file.
|
|
|
+
|
|
|
+// Package wsconn contains an adapter type that turns
|
|
|
+// a websocket connection into a net.Conn. It a temporary fork of the
|
|
|
+// netconn.go file from the nhooyr.io/websocket package while we wait for
|
|
|
+// https://github.com/nhooyr/websocket/pull/350 to be merged.
|
|
|
+package wsconn
|
|
|
+
|
|
|
+import (
|
|
|
+ "context"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "math"
|
|
|
+ "net"
|
|
|
+ "os"
|
|
|
+ "sync"
|
|
|
+ "sync/atomic"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "nhooyr.io/websocket"
|
|
|
+)
|
|
|
+
|
|
|
+// NetConn converts a *websocket.Conn into a net.Conn.
|
|
|
+//
|
|
|
+// It's for tunneling arbitrary protocols over WebSockets.
|
|
|
+// Few users of the library will need this but it's tricky to implement
|
|
|
+// correctly and so provided in the library.
|
|
|
+// See https://github.com/nhooyr/websocket/issues/100.
|
|
|
+//
|
|
|
+// Every Write to the net.Conn will correspond to a message write of
|
|
|
+// the given type on *websocket.Conn.
|
|
|
+//
|
|
|
+// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
|
|
|
+// all reads and writes on the net.Conn will be cancelled.
|
|
|
+//
|
|
|
+// If a message is read that is not of the correct type, the connection
|
|
|
+// will be closed with StatusUnsupportedData and an error will be returned.
|
|
|
+//
|
|
|
+// Close will close the *websocket.Conn with StatusNormalClosure.
|
|
|
+//
|
|
|
+// When a deadline is hit, the connection will be closed. This is
|
|
|
+// different from most net.Conn implementations where only the
|
|
|
+// reading/writing goroutines are interrupted but the connection is kept alive.
|
|
|
+//
|
|
|
+// The Addr methods will return a mock net.Addr that returns "websocket" for Network
|
|
|
+// and "websocket/unknown-addr" for String.
|
|
|
+//
|
|
|
+// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
|
|
|
+// io.EOF when reading.
|
|
|
+func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageType) net.Conn {
|
|
|
+ nc := &netConn{
|
|
|
+ c: c,
|
|
|
+ msgType: msgType,
|
|
|
+ }
|
|
|
+
|
|
|
+ var writeCancel context.CancelFunc
|
|
|
+ nc.writeContext, writeCancel = context.WithCancel(ctx)
|
|
|
+ nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
|
|
|
+ nc.afterWriteDeadline.Store(true)
|
|
|
+ if nc.writing.Load() {
|
|
|
+ writeCancel()
|
|
|
+ }
|
|
|
+ })
|
|
|
+ if !nc.writeTimer.Stop() {
|
|
|
+ <-nc.writeTimer.C
|
|
|
+ }
|
|
|
+
|
|
|
+ var readCancel context.CancelFunc
|
|
|
+ nc.readContext, readCancel = context.WithCancel(ctx)
|
|
|
+ nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
|
|
|
+ nc.afterReadDeadline.Store(true)
|
|
|
+ if nc.reading.Load() {
|
|
|
+ readCancel()
|
|
|
+ }
|
|
|
+ })
|
|
|
+ if !nc.readTimer.Stop() {
|
|
|
+ <-nc.readTimer.C
|
|
|
+ }
|
|
|
+
|
|
|
+ return nc
|
|
|
+}
|
|
|
+
|
|
|
+type netConn struct {
|
|
|
+ c *websocket.Conn
|
|
|
+ msgType websocket.MessageType
|
|
|
+
|
|
|
+ writeTimer *time.Timer
|
|
|
+ writeContext context.Context
|
|
|
+ writing atomic.Bool
|
|
|
+ afterWriteDeadline atomic.Bool
|
|
|
+
|
|
|
+ readTimer *time.Timer
|
|
|
+ readContext context.Context
|
|
|
+ reading atomic.Bool
|
|
|
+ afterReadDeadline atomic.Bool
|
|
|
+
|
|
|
+ readMu sync.Mutex
|
|
|
+ eofed bool
|
|
|
+ reader io.Reader
|
|
|
+}
|
|
|
+
|
|
|
+var _ net.Conn = &netConn{}
|
|
|
+
|
|
|
+func (c *netConn) Close() error {
|
|
|
+ return c.c.Close(websocket.StatusNormalClosure, "")
|
|
|
+}
|
|
|
+
|
|
|
+func (c *netConn) Write(p []byte) (int, error) {
|
|
|
+ if c.afterWriteDeadline.Load() {
|
|
|
+ return 0, os.ErrDeadlineExceeded
|
|
|
+ }
|
|
|
+
|
|
|
+ if swapped := c.writing.CompareAndSwap(false, true); !swapped {
|
|
|
+ panic("Concurrent writes not allowed")
|
|
|
+ }
|
|
|
+ defer c.writing.Store(false)
|
|
|
+
|
|
|
+ err := c.c.Write(c.writeContext, c.msgType, p)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return len(p), nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *netConn) Read(p []byte) (int, error) {
|
|
|
+ if c.afterReadDeadline.Load() {
|
|
|
+ return 0, os.ErrDeadlineExceeded
|
|
|
+ }
|
|
|
+
|
|
|
+ c.readMu.Lock()
|
|
|
+ defer c.readMu.Unlock()
|
|
|
+ if swapped := c.reading.CompareAndSwap(false, true); !swapped {
|
|
|
+ panic("Concurrent reads not allowed")
|
|
|
+ }
|
|
|
+ defer c.reading.Store(false)
|
|
|
+
|
|
|
+ if c.eofed {
|
|
|
+ return 0, io.EOF
|
|
|
+ }
|
|
|
+
|
|
|
+ if c.reader == nil {
|
|
|
+ typ, r, err := c.c.Reader(c.readContext)
|
|
|
+ if err != nil {
|
|
|
+ switch websocket.CloseStatus(err) {
|
|
|
+ case websocket.StatusNormalClosure, websocket.StatusGoingAway:
|
|
|
+ c.eofed = true
|
|
|
+ return 0, io.EOF
|
|
|
+ }
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ if typ != c.msgType {
|
|
|
+ err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
|
|
|
+ c.c.Close(websocket.StatusUnsupportedData, err.Error())
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ c.reader = r
|
|
|
+ }
|
|
|
+
|
|
|
+ n, err := c.reader.Read(p)
|
|
|
+ if err == io.EOF {
|
|
|
+ c.reader = nil
|
|
|
+ err = nil
|
|
|
+ }
|
|
|
+ return n, err
|
|
|
+}
|
|
|
+
|
|
|
+type websocketAddr struct {
|
|
|
+}
|
|
|
+
|
|
|
+func (a websocketAddr) Network() string {
|
|
|
+ return "websocket"
|
|
|
+}
|
|
|
+
|
|
|
+func (a websocketAddr) String() string {
|
|
|
+ return "websocket/unknown-addr"
|
|
|
+}
|
|
|
+
|
|
|
+func (c *netConn) RemoteAddr() net.Addr {
|
|
|
+ return websocketAddr{}
|
|
|
+}
|
|
|
+
|
|
|
+func (c *netConn) LocalAddr() net.Addr {
|
|
|
+ return websocketAddr{}
|
|
|
+}
|
|
|
+
|
|
|
+func (c *netConn) SetDeadline(t time.Time) error {
|
|
|
+ c.SetWriteDeadline(t)
|
|
|
+ c.SetReadDeadline(t)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *netConn) SetWriteDeadline(t time.Time) error {
|
|
|
+ if t.IsZero() {
|
|
|
+ c.writeTimer.Stop()
|
|
|
+ } else {
|
|
|
+ c.writeTimer.Reset(time.Until(t))
|
|
|
+ }
|
|
|
+ c.afterWriteDeadline.Store(false)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *netConn) SetReadDeadline(t time.Time) error {
|
|
|
+ if t.IsZero() {
|
|
|
+ c.readTimer.Stop()
|
|
|
+ } else {
|
|
|
+ c.readTimer.Reset(time.Until(t))
|
|
|
+ }
|
|
|
+ c.afterReadDeadline.Store(false)
|
|
|
+ return nil
|
|
|
+}
|