Browse Source

Improve websocket writer

世界 3 years ago
parent
commit
22ea878fe9

+ 1 - 1
docs/features.md

@@ -110,7 +110,7 @@
 | /                  |    TCP    |   HTTP    |  H2 TLS   | WebSocket TLS | gRPC TLS  |
 |--------------------|:---------:|:---------:|:---------:|:-------------:|:---------:|
 | v2ray-core (5.1.0) | 7.86 GBps | 2.86 Gbps | 1.83 Gbps |   2.36 Gbps   | 2.43 Gbps |
-| sing-box           | 7.96 Gbps | 8.09 Gbps | 6.11 Gbps |   2.69 Gbps   | 6.35 Gbps |
+| sing-box           | 7.96 Gbps | 8.09 Gbps | 6.11 Gbps |   8.02 Gbps   | 6.35 Gbps |
 
 #### License
 

+ 1 - 1
transport/v2raywebsocket/client.go

@@ -74,7 +74,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 	if c.maxEarlyData <= 0 {
 		conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers)
 		if err == nil {
-			return &WebsocketConn{Conn: conn}, nil
+			return &WebsocketConn{Conn: conn, Writer: &Writer{conn, false}}, nil
 		}
 		return nil, wrapDialError(response, err)
 	} else {

+ 62 - 12
transport/v2raywebsocket/conn.go

@@ -10,16 +10,26 @@ import (
 	"time"
 
 	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing/common/buf"
 	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/websocket"
 )
 
 type WebsocketConn struct {
 	*websocket.Conn
+	*Writer
 	remoteAddr net.Addr
 	reader     io.Reader
 }
 
+func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn {
+	return &WebsocketConn{
+		Conn:       wsConn,
+		remoteAddr: remoteAddr,
+		Writer:     &Writer{wsConn, true},
+	}
+}
+
 func (c *WebsocketConn) Close() error {
 	err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout))
 	if err != nil {
@@ -47,14 +57,6 @@ func (c *WebsocketConn) Read(b []byte) (n int, err error) {
 	}
 }
 
-func (c *WebsocketConn) Write(b []byte) (n int, err error) {
-	err = wrapError(c.WriteMessage(websocket.BinaryMessage, b))
-	if err != nil {
-		return
-	}
-	return len(b), nil
-}
-
 func (c *WebsocketConn) RemoteAddr() net.Addr {
 	if c.remoteAddr != nil {
 		return c.remoteAddr
@@ -66,6 +68,10 @@ func (c *WebsocketConn) SetDeadline(t time.Time) error {
 	return os.ErrInvalid
 }
 
+func (c *WebsocketConn) FrontHeadroom() int {
+	return frontHeadroom
+}
+
 type EarlyWebsocketConn struct {
 	*Client
 	ctx    context.Context
@@ -90,9 +96,9 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
 		conn      *websocket.Conn
 		response  *http.Response
 	)
-	if len(earlyData) > int(c.maxEarlyData) {
-		earlyData = earlyData[:c.maxEarlyData]
-		lateData = lateData[c.maxEarlyData:]
+	if len(b) > int(c.maxEarlyData) {
+		earlyData = b[:c.maxEarlyData]
+		lateData = b[c.maxEarlyData:]
 	} else {
 		earlyData = b
 	}
@@ -111,7 +117,7 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
 	if err != nil {
 		return 0, wrapDialError(response, err)
 	}
-	c.conn = &WebsocketConn{Conn: conn}
+	c.conn = &WebsocketConn{Conn: conn, Writer: &Writer{conn, false}}
 	close(c.create)
 	if len(lateData) > 0 {
 		_, err = c.conn.Write(lateData)
@@ -122,6 +128,46 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
 	return len(b), nil
 }
 
+func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
+	if c.conn != nil {
+		return c.conn.WriteBuffer(buffer)
+	}
+	var (
+		earlyData []byte
+		lateData  []byte
+		conn      *websocket.Conn
+		response  *http.Response
+		err       error
+	)
+	if buffer.Len() > int(c.maxEarlyData) {
+		earlyData = buffer.Bytes()[:c.maxEarlyData]
+		lateData = buffer.Bytes()[c.maxEarlyData:]
+	} else {
+		earlyData = buffer.Bytes()
+	}
+	if len(earlyData) > 0 {
+		earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
+		if c.earlyDataHeaderName == "" {
+			conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
+		} else {
+			headers := c.headers.Clone()
+			headers.Set(c.earlyDataHeaderName, earlyDataString)
+			conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
+		}
+	} else {
+		conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
+	}
+	if err != nil {
+		return wrapDialError(response, err)
+	}
+	c.conn = &WebsocketConn{Conn: conn, Writer: &Writer{conn, false}}
+	close(c.create)
+	if len(lateData) > 0 {
+		_, err = c.conn.Write(lateData)
+	}
+	return err
+}
+
 func (c *EarlyWebsocketConn) Close() error {
 	if c.conn == nil {
 		return nil
@@ -164,6 +210,10 @@ func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error {
 	return c.conn.SetWriteDeadline(t)
 }
 
+func (c *EarlyWebsocketConn) FrontHeadroom() int {
+	return frontHeadroom
+}
+
 func wrapError(err error) error {
 	if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
 		return io.EOF

+ 6 - 0
transport/v2raywebsocket/mask.go

@@ -0,0 +1,6 @@
+package v2raywebsocket
+
+import _ "unsafe"
+
+//go:linkname maskBytes github.com/sagernet/websocket.maskBytes
+func maskBytes(key [4]byte, pos int, b []byte) int

+ 1 - 4
transport/v2raywebsocket/server.go

@@ -108,10 +108,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 	}
 	var metadata M.Metadata
 	metadata.Source = sHttp.SourceAddress(request)
-	conn = &WebsocketConn{
-		Conn:       wsConn,
-		remoteAddr: metadata.Source.TCPAddr(),
-	}
+	conn = NewServerConn(wsConn, metadata.Source.TCPAddr())
 	if len(earlyData) > 0 {
 		conn = bufio.NewCachedConn(conn, buf.As(earlyData))
 	}

+ 73 - 0
transport/v2raywebsocket/writer.go

@@ -0,0 +1,73 @@
+package v2raywebsocket
+
+import (
+	"encoding/binary"
+	"math/rand"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/websocket"
+)
+
+const frontHeadroom = 14
+
+type Writer struct {
+	*websocket.Conn
+	isServer bool
+}
+
+func (w *Writer) Write(p []byte) (n int, err error) {
+	err = w.Conn.WriteMessage(websocket.BinaryMessage, p)
+	if err != nil {
+		return
+	}
+	return len(p), nil
+}
+
+func (w *Writer) WriteBuffer(buffer *buf.Buffer) error {
+	defer buffer.Release()
+
+	var payloadBitLength int
+	dataLen := buffer.Len()
+	data := buffer.Bytes()
+	if dataLen < 126 {
+		payloadBitLength = 1
+	} else if dataLen < 65536 {
+		payloadBitLength = 3
+	} else {
+		payloadBitLength = 9
+	}
+
+	var headerLen int
+	headerLen += 1 // FIN / RSV / OPCODE
+	headerLen += payloadBitLength
+	if !w.isServer {
+		headerLen += 4 // MASK KEY
+	}
+
+	header := buffer.ExtendHeader(headerLen)
+	header[0] = websocket.BinaryMessage | 1<<7
+	if w.isServer {
+		header[1] = 0
+	} else {
+		header[1] = 1 << 7
+	}
+
+	if dataLen < 126 {
+		header[1] |= byte(dataLen)
+	} else if dataLen < 65536 {
+		header[1] |= 126
+		binary.BigEndian.PutUint16(header[2:], uint16(dataLen))
+	} else {
+		header[1] |= 127
+		binary.BigEndian.PutUint64(header[2:], uint64(dataLen))
+	}
+
+	if !w.isServer {
+		maskKey := rand.Uint32()
+		binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey)
+		maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data)
+	}
+
+	return common.Error(w.Conn.NetConn().Write(buffer.Bytes()))
+}