瀏覽代碼

Migrate to gobwas/ws

世界 1 年之前
父節點
當前提交
4d23773a25

+ 9 - 1
common/dialer/detour.go

@@ -6,6 +6,7 @@ import (
 	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing/common/bufio/deadline"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -44,7 +45,14 @@ func (d *DetourDialer) DialContext(ctx context.Context, network string, destinat
 	if err != nil {
 		return nil, err
 	}
-	return dialer.DialContext(ctx, network, destination)
+	conn, err := dialer.DialContext(ctx, network, destination)
+	if err != nil {
+		return nil, err
+	}
+	if deadline.NeedAdditionalReadDeadline(conn) {
+		conn = deadline.NewConn(conn)
+	}
+	return conn, nil
 }
 
 func (d *DetourDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {

+ 9 - 8
experimental/clashapi/api_meta.go

@@ -2,12 +2,14 @@ package clashapi
 
 import (
 	"bytes"
+	"net"
 	"net/http"
 	"time"
 
 	"github.com/sagernet/sing-box/common/json"
 	"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
-	"github.com/sagernet/websocket"
+	"github.com/sagernet/ws"
+	"github.com/sagernet/ws/wsutil"
 
 	"github.com/go-chi/chi/v5"
 	"github.com/go-chi/render"
@@ -27,16 +29,16 @@ type Memory struct {
 
 func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
 	return func(w http.ResponseWriter, r *http.Request) {
-		var wsConn *websocket.Conn
-		if websocket.IsWebSocketUpgrade(r) {
+		var conn net.Conn
+		if r.Header.Get("Upgrade") == "websocket" {
 			var err error
-			wsConn, err = upgrader.Upgrade(w, r, nil)
+			conn, _, _, err = ws.UpgradeHTTP(r, w)
 			if err != nil {
 				return
 			}
 		}
 
-		if wsConn == nil {
+		if conn == nil {
 			w.Header().Set("Content-Type", "application/json")
 			render.Status(r, http.StatusOK)
 		}
@@ -63,13 +65,12 @@ func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r
 			}); err != nil {
 				break
 			}
-			if wsConn == nil {
+			if conn == nil {
 				_, err = w.Write(buf.Bytes())
 				w.(http.Flusher).Flush()
 			} else {
-				err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
+				err = wsutil.WriteServerText(conn, buf.Bytes())
 			}
-
 			if err != nil {
 				break
 			}

+ 5 - 4
experimental/clashapi/connections.go

@@ -9,7 +9,8 @@ import (
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/json"
 	"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
-	"github.com/sagernet/websocket"
+	"github.com/sagernet/ws"
+	"github.com/sagernet/ws/wsutil"
 
 	"github.com/go-chi/chi/v5"
 	"github.com/go-chi/render"
@@ -25,13 +26,13 @@ func connectionRouter(router adapter.Router, trafficManager *trafficontrol.Manag
 
 func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
 	return func(w http.ResponseWriter, r *http.Request) {
-		if !websocket.IsWebSocketUpgrade(r) {
+		if r.Header.Get("Upgrade") != "websocket" {
 			snapshot := trafficManager.Snapshot()
 			render.JSON(w, r, snapshot)
 			return
 		}
 
-		conn, err := upgrader.Upgrade(w, r, nil)
+		conn, _, _, err := ws.UpgradeHTTP(r, w)
 		if err != nil {
 			return
 		}
@@ -56,7 +57,7 @@ func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseW
 			if err := json.NewEncoder(buf).Encode(snapshot); err != nil {
 				return err
 			}
-			return conn.WriteMessage(websocket.TextMessage, buf.Bytes())
+			return wsutil.WriteServerText(conn, buf.Bytes())
 		}
 
 		if err = sendSnapshot(); err != nil {

+ 17 - 21
experimental/clashapi/server.go

@@ -25,7 +25,8 @@ import (
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/service"
 	"github.com/sagernet/sing/service/filemanager"
-	"github.com/sagernet/websocket"
+	"github.com/sagernet/ws"
+	"github.com/sagernet/ws/wsutil"
 
 	"github.com/go-chi/chi/v5"
 	"github.com/go-chi/cors"
@@ -314,7 +315,7 @@ func authentication(serverSecret string) func(next http.Handler) http.Handler {
 			}
 
 			// Browser websocket not support custom header
-			if websocket.IsWebSocketUpgrade(r) && r.URL.Query().Get("token") != "" {
+			if r.Header.Get("Upgrade") == "websocket" && r.URL.Query().Get("token") != "" {
 				token := r.URL.Query().Get("token")
 				if token != serverSecret {
 					render.Status(r, http.StatusUnauthorized)
@@ -351,12 +352,6 @@ func hello(redirect bool) func(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-var upgrader = websocket.Upgrader{
-	CheckOrigin: func(r *http.Request) bool {
-		return true
-	},
-}
-
 type Traffic struct {
 	Up   int64 `json:"up"`
 	Down int64 `json:"down"`
@@ -364,16 +359,17 @@ type Traffic struct {
 
 func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
 	return func(w http.ResponseWriter, r *http.Request) {
-		var wsConn *websocket.Conn
-		if websocket.IsWebSocketUpgrade(r) {
+		var conn net.Conn
+		if r.Header.Get("Upgrade") == "websocket" {
 			var err error
-			wsConn, err = upgrader.Upgrade(w, r, nil)
+			conn, _, _, err = ws.UpgradeHTTP(r, w)
 			if err != nil {
 				return
 			}
+			defer conn.Close()
 		}
 
-		if wsConn == nil {
+		if conn == nil {
 			w.Header().Set("Content-Type", "application/json")
 			render.Status(r, http.StatusOK)
 		}
@@ -392,11 +388,11 @@ func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter,
 				break
 			}
 
-			if wsConn == nil {
+			if conn == nil {
 				_, err = w.Write(buf.Bytes())
 				w.(http.Flusher).Flush()
 			} else {
-				err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
+				err = wsutil.WriteServerText(conn, buf.Bytes())
 			}
 
 			if err != nil {
@@ -432,16 +428,16 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht
 		}
 		defer logFactory.UnSubscribe(subscription)
 
-		var wsConn *websocket.Conn
-		if websocket.IsWebSocketUpgrade(r) {
-			var err error
-			wsConn, err = upgrader.Upgrade(w, r, nil)
+		var conn net.Conn
+		if r.Header.Get("Upgrade") == "websocket" {
+			conn, _, _, err = ws.UpgradeHTTP(r, w)
 			if err != nil {
 				return
 			}
+			defer conn.Close()
 		}
 
-		if wsConn == nil {
+		if conn == nil {
 			w.Header().Set("Content-Type", "application/json")
 			render.Status(r, http.StatusOK)
 		}
@@ -465,11 +461,11 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht
 			if err != nil {
 				break
 			}
-			if wsConn == nil {
+			if conn == nil {
 				_, err = w.Write(buf.Bytes())
 				w.(http.Flusher).Flush()
 			} else {
-				err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
+				err = wsutil.WriteServerText(conn, buf.Bytes())
 			}
 
 			if err != nil {

+ 3 - 1
go.mod

@@ -38,8 +38,8 @@ require (
 	github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
 	github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6
 	github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2
-	github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e
 	github.com/sagernet/wireguard-go v0.0.0-20230807125731-5d4a7ef2dc5f
+	github.com/sagernet/ws v0.0.0-20231030053741-7d481eb31bed
 	github.com/spf13/cobra v1.8.0
 	github.com/stretchr/testify v1.8.4
 	go.uber.org/zap v1.26.0
@@ -61,6 +61,8 @@ require (
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/go-ole/go-ole v1.3.0 // indirect
 	github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
+	github.com/gobwas/httphead v0.1.0 // indirect
+	github.com/gobwas/pool v0.2.1 // indirect
 	github.com/golang/protobuf v1.5.3 // indirect
 	github.com/google/btree v1.1.2 // indirect
 	github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect

+ 7 - 2
go.sum

@@ -31,6 +31,10 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
 github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
 github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
 github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
+github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
+github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
+github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
+github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
 github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M=
 github.com/gofrs/uuid/v5 v5.0.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8=
 github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
@@ -134,10 +138,10 @@ github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6 h1:Px+hN4Vzgx+iCGV
 github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6/go.mod h1:zovq6vTvEM6ECiqE3Eeb9rpIylPpamPcmrJ9tv0Bt0M=
 github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2 h1:kDUqhc9Vsk5HJuhfIATJ8oQwBmpOZJuozQG7Vk88lL4=
 github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2/go.mod h1:JKQMZq/O2qnZjdrt+B57olmfgEmLtY9iiSIEYtWvoSM=
-github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e h1:7uw2njHFGE+VpWamge6o56j2RWk4omF6uLKKxMmcWvs=
-github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e/go.mod h1:45TUl8+gH4SIKr4ykREbxKWTxkDlSzFENzctB1dVRRY=
 github.com/sagernet/wireguard-go v0.0.0-20230807125731-5d4a7ef2dc5f h1:Kvo8w8Y9lzFGB/7z09MJ3TR99TFtfI/IuY87Ygcycho=
 github.com/sagernet/wireguard-go v0.0.0-20230807125731-5d4a7ef2dc5f/go.mod h1:mySs0abhpc/gLlvhoq7HP1RzOaRmIXVeZGCh++zoApk=
+github.com/sagernet/ws v0.0.0-20231030053741-7d481eb31bed h1:90a510OeE9siSJoYsI8nSjPmA+u5ROMDts/ZkdNsuXY=
+github.com/sagernet/ws v0.0.0-20231030053741-7d481eb31bed/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
 github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg=
 github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s=
 github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
@@ -189,6 +193,7 @@ golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
 golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=

+ 1 - 1
transport/v2ray/transport.go

@@ -50,7 +50,7 @@ func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socks
 	case C.V2RayTransportTypeGRPC:
 		return NewGRPCClient(ctx, dialer, serverAddr, options.GRPCOptions, tlsConfig)
 	case C.V2RayTransportTypeWebsocket:
-		return v2raywebsocket.NewClient(ctx, dialer, serverAddr, options.WebsocketOptions, tlsConfig), nil
+		return v2raywebsocket.NewClient(ctx, dialer, serverAddr, options.WebsocketOptions, tlsConfig)
 	case C.V2RayTransportTypeQUIC:
 		if tlsConfig == nil {
 			return nil, C.ErrTLSRequired

+ 1 - 1
transport/v2rayhttp/client.go

@@ -81,7 +81,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 	uri.Path = options.Path
 	err := sHTTP.URLSetPath(&uri, options.Path)
 	if err != nil {
-		return nil, E.New("failed to set path: " + err.Error())
+		return nil, E.Cause(err, "parse path")
 	}
 	client.url = &uri
 	return client, nil

+ 52 - 42
transport/v2raywebsocket/client.go

@@ -5,58 +5,37 @@ import (
 	"net"
 	"net/http"
 	"net/url"
+	"strings"
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/tls"
+	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	sHTTP "github.com/sagernet/sing/protocol/http"
-	"github.com/sagernet/websocket"
+	"github.com/sagernet/ws"
 )
 
 var _ adapter.V2RayClientTransport = (*Client)(nil)
 
 type Client struct {
-	dialer              *websocket.Dialer
+	dialer              N.Dialer
+	tlsConfig           tls.Config
+	serverAddr          M.Socksaddr
 	requestURL          url.URL
-	requestURLString    string
 	headers             http.Header
 	maxEarlyData        uint32
 	earlyDataHeaderName string
 }
 
-func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) adapter.V2RayClientTransport {
-	wsDialer := &websocket.Dialer{
-		ReadBufferSize:   4 * 1024,
-		WriteBufferSize:  4 * 1024,
-		HandshakeTimeout: time.Second * 8,
-	}
+func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
 	if tlsConfig != nil {
 		if len(tlsConfig.NextProtos()) == 0 {
 			tlsConfig.SetNextProtos([]string{"http/1.1"})
 		}
-		wsDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
-			conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
-			if err != nil {
-				return nil, err
-			}
-			tlsConn, err := tls.ClientHandshake(ctx, conn, tlsConfig)
-			if err != nil {
-				return nil, err
-			}
-			return &deadConn{tlsConn}, nil
-		}
-	} else {
-		wsDialer.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
-			conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
-			if err != nil {
-				return nil, err
-			}
-			return &deadConn{conn}, nil
-		}
 	}
 	var requestURL url.URL
 	if tlsConfig == nil {
@@ -68,37 +47,68 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 	requestURL.Path = options.Path
 	err := sHTTP.URLSetPath(&requestURL, options.Path)
 	if err != nil {
-		return nil
+		return nil, E.Cause(err, "parse path")
+	}
+	if !strings.HasPrefix(requestURL.Path, "/") {
+		requestURL.Path = "/" + requestURL.Path
 	}
 	headers := make(http.Header)
 	for key, value := range options.Headers {
 		headers[key] = value
+		if key == "Host" {
+			if len(value) > 1 {
+				return nil, E.New("multiple Host headers")
+			}
+			requestURL.Host = value[0]
+		}
+	}
+	if headers.Get("User-Agent") == "" {
+		headers.Set("User-Agent", "Go-http-client/1.1")
 	}
 	return &Client{
-		wsDialer,
+		dialer,
+		tlsConfig,
+		serverAddr,
 		requestURL,
-		requestURL.String(),
 		headers,
 		options.MaxEarlyData,
 		options.EarlyDataHeaderName,
+	}, nil
+}
+
+func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers http.Header) (*WebsocketConn, error) {
+	conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
+	if err != nil {
+		return nil, err
+	}
+	if c.tlsConfig != nil {
+		conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig)
+		if err != nil {
+			return nil, err
+		}
+	}
+	conn.SetDeadline(time.Now().Add(C.TCPTimeout))
+	var protocols []string
+	if protocolHeader := headers.Get("Sec-WebSocket-Protocol"); protocolHeader != "" {
+		protocols = []string{protocolHeader}
+		headers.Del("Sec-WebSocket-Protocol")
 	}
+	reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(conn, requestURL)
+	conn.SetDeadline(time.Time{})
+	if err != nil {
+		return nil, err
+	}
+	return NewConn(conn, reader, nil, ws.StateClientSide), nil
 }
 
 func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 	if c.maxEarlyData <= 0 {
-		conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers)
-		if err == nil {
-			return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
+		conn, err := c.dialContext(ctx, &c.requestURL, c.headers)
+		if err != nil {
+			return nil, err
 		}
-		return nil, wrapDialError(response, err)
+		return conn, nil
 	} else {
 		return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
 	}
 }
-
-func wrapDialError(response *http.Response, err error) error {
-	if response == nil {
-		return err
-	}
-	return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status)
-}

+ 77 - 43
transport/v2raywebsocket/conn.go

@@ -1,11 +1,11 @@
 package v2raywebsocket
 
 import (
+	"bufio"
 	"context"
 	"encoding/base64"
 	"io"
 	"net"
-	"net/http"
 	"os"
 	"sync"
 	"time"
@@ -13,50 +13,96 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/debug"
 	E "github.com/sagernet/sing/common/exceptions"
-	"github.com/sagernet/websocket"
+	"github.com/sagernet/ws"
+	"github.com/sagernet/ws/wsutil"
 )
 
 type WebsocketConn struct {
-	*websocket.Conn
+	net.Conn
 	*Writer
-	remoteAddr net.Addr
-	reader     io.Reader
+	state          ws.State
+	reader         *wsutil.Reader
+	controlHandler wsutil.FrameHandlerFunc
+	remoteAddr     net.Addr
 }
 
-func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn {
+func NewConn(conn net.Conn, br *bufio.Reader, remoteAddr net.Addr, state ws.State) *WebsocketConn {
+	controlHandler := wsutil.ControlFrameHandler(conn, state)
+	var reader io.Reader
+	if br != nil && br.Buffered() > 0 {
+		reader = br
+	} else {
+		reader = conn
+	}
 	return &WebsocketConn{
-		Conn:       wsConn,
-		remoteAddr: remoteAddr,
-		Writer:     NewWriter(wsConn, true),
+		Conn:  conn,
+		state: state,
+		reader: &wsutil.Reader{
+			Source:          reader,
+			State:           state,
+			SkipHeaderCheck: !debug.Enabled,
+			OnIntermediate:  controlHandler,
+		},
+		controlHandler: controlHandler,
+		remoteAddr:     remoteAddr,
+		Writer:         NewWriter(conn, state),
 	}
 }
 
 func (c *WebsocketConn) Close() error {
-	err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout))
-	if err != nil {
-		return c.Conn.Close()
+	c.Conn.SetWriteDeadline(time.Now().Add(C.TCPTimeout))
+	frame := ws.NewCloseFrame(ws.NewCloseFrameBody(
+		ws.StatusNormalClosure, "",
+	))
+	if c.state == ws.StateClientSide {
+		frame = ws.MaskFrameInPlace(frame)
 	}
+	ws.WriteFrame(c.Conn, frame)
+	c.Conn.Close()
 	return nil
 }
 
 func (c *WebsocketConn) Read(b []byte) (n int, err error) {
+	var header ws.Header
 	for {
-		if c.reader == nil {
-			_, c.reader, err = c.NextReader()
+		n, err = c.reader.Read(b)
+		if n > 0 {
+			err = nil
+			return
+		}
+		if !E.IsMulti(err, io.EOF, wsutil.ErrNoFrameAdvance) {
+			return
+		}
+		header, err = c.reader.NextFrame()
+		if err != nil {
+			return
+		}
+		if header.OpCode.IsControl() {
+			err = c.controlHandler(header, c.reader)
 			if err != nil {
-				err = wrapError(err)
 				return
 			}
+			continue
 		}
-		n, err = c.reader.Read(b)
-		if E.IsMulti(err, io.EOF) {
-			c.reader = nil
+		if header.OpCode&ws.OpBinary == 0 {
+			err = c.reader.Discard()
+			if err != nil {
+				return
+			}
 			continue
 		}
-		err = wrapError(err)
+	}
+}
+
+func (c *WebsocketConn) Write(p []byte) (n int, err error) {
+	err = wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, p)
+	if err != nil {
 		return
 	}
+	n = len(p)
+	return
 }
 
 func (c *WebsocketConn) RemoteAddr() net.Addr {
@@ -83,11 +129,7 @@ func (c *WebsocketConn) NeedAdditionalReadDeadline() bool {
 }
 
 func (c *WebsocketConn) Upstream() any {
-	return c.Conn.NetConn()
-}
-
-func (c *WebsocketConn) UpstreamWriter() any {
-	return c.Writer
+	return c.Conn
 }
 
 type EarlyWebsocketConn struct {
@@ -113,8 +155,7 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
 	var (
 		earlyData []byte
 		lateData  []byte
-		conn      *websocket.Conn
-		response  *http.Response
+		conn      *WebsocketConn
 		err       error
 	)
 	if len(content) > int(c.maxEarlyData) {
@@ -128,23 +169,26 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
 		if c.earlyDataHeaderName == "" {
 			requestURL := c.requestURL
 			requestURL.Path += earlyDataString
-			conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers)
+			conn, err = c.dialContext(c.ctx, &requestURL, c.headers)
 		} else {
 			headers := c.headers.Clone()
 			headers.Set(c.earlyDataHeaderName, earlyDataString)
-			conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers)
+			conn, err = c.dialContext(c.ctx, &c.requestURL, headers)
 		}
 	} else {
-		conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers)
+		conn, err = c.dialContext(c.ctx, &c.requestURL, c.headers)
 	}
 	if err != nil {
-		return wrapDialError(response, err)
+		return err
 	}
-	c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
 	if len(lateData) > 0 {
-		_, err = c.conn.Write(lateData)
+		_, err = conn.Write(lateData)
+		if err != nil {
+			return err
+		}
 	}
-	return err
+	c.conn = conn
+	return nil
 }
 
 func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
@@ -230,13 +274,3 @@ func (c *EarlyWebsocketConn) Upstream() any {
 func (c *EarlyWebsocketConn) LazyHeadroom() bool {
 	return c.conn == nil
 }
-
-func wrapError(err error) error {
-	if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
-		return io.EOF
-	}
-	if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) {
-		return net.ErrClosed
-	}
-	return err
-}

+ 0 - 6
transport/v2raywebsocket/mask.go

@@ -1,6 +0,0 @@
-package v2raywebsocket
-
-import _ "unsafe"
-
-//go:linkname maskBytes github.com/sagernet/websocket.maskBytes
-func maskBytes(key [4]byte, pos int, b []byte) int

+ 3 - 10
transport/v2raywebsocket/server.go

@@ -20,7 +20,7 @@ import (
 	N "github.com/sagernet/sing/common/network"
 	aTLS "github.com/sagernet/sing/common/tls"
 	sHttp "github.com/sagernet/sing/protocol/http"
-	"github.com/sagernet/websocket"
+	"github.com/sagernet/ws"
 )
 
 var _ adapter.V2RayServerTransport = (*Server)(nil)
@@ -58,13 +58,6 @@ func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsCon
 	return server, nil
 }
 
-var upgrader = websocket.Upgrader{
-	HandshakeTimeout: C.TCPTimeout,
-	CheckOrigin: func(r *http.Request) bool {
-		return true
-	},
-}
-
 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 	if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" {
 		if request.URL.Path != s.path {
@@ -95,14 +88,14 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 		s.invalidRequest(writer, request, http.StatusBadRequest, E.Cause(err, "decode early data"))
 		return
 	}
-	wsConn, err := upgrader.Upgrade(writer, request, nil)
+	wsConn, reader, _, err := ws.UpgradeHTTP(request, writer)
 	if err != nil {
 		s.invalidRequest(writer, request, 0, E.Cause(err, "upgrade websocket connection"))
 		return
 	}
 	var metadata M.Metadata
 	metadata.Source = sHttp.SourceAddress(request)
-	conn = NewServerConn(wsConn, metadata.Source.TCPAddr())
+	conn = NewConn(wsConn, reader.Reader, metadata.Source.TCPAddr(), ws.StateServerSide)
 	if len(earlyData) > 0 {
 		conn = bufio.NewCachedConn(conn, buf.As(earlyData))
 	}

+ 7 - 20
transport/v2raywebsocket/writer.go

@@ -2,36 +2,27 @@ package v2raywebsocket
 
 import (
 	"encoding/binary"
+	"io"
 	"math/rand"
 
 	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	N "github.com/sagernet/sing/common/network"
-	"github.com/sagernet/websocket"
+	"github.com/sagernet/ws"
 )
 
 type Writer struct {
-	*websocket.Conn
 	writer   N.ExtendedWriter
 	isServer bool
 }
 
-func NewWriter(conn *websocket.Conn, isServer bool) *Writer {
+func NewWriter(writer io.Writer, state ws.State) *Writer {
 	return &Writer{
-		conn,
-		bufio.NewExtendedWriter(conn.NetConn()),
-		isServer,
+		bufio.NewExtendedWriter(writer),
+		state == ws.StateServerSide,
 	}
 }
 
-func (w *Writer) Write(p []byte) (n int, err error) {
-	err = w.Conn.WriteMessage(websocket.BinaryMessage, p)
-	if err != nil {
-		return
-	}
-	return len(p), nil
-}
-
 func (w *Writer) WriteBuffer(buffer *buf.Buffer) error {
 	var payloadBitLength int
 	dataLen := buffer.Len()
@@ -52,7 +43,7 @@ func (w *Writer) WriteBuffer(buffer *buf.Buffer) error {
 	}
 
 	header := buffer.ExtendHeader(headerLen)
-	header[0] = websocket.BinaryMessage | 1<<7
+	header[0] = byte(ws.OpBinary) | 0x80
 	if w.isServer {
 		header[1] = 0
 	} else {
@@ -72,16 +63,12 @@ func (w *Writer) WriteBuffer(buffer *buf.Buffer) error {
 	if !w.isServer {
 		maskKey := rand.Uint32()
 		binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey)
-		maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data)
+		ws.Cipher(data, *(*[4]byte)(header[1+payloadBitLength:]), 0)
 	}
 
 	return w.writer.WriteBuffer(buffer)
 }
 
-func (w *Writer) Upstream() any {
-	return w.Conn.NetConn()
-}
-
 func (w *Writer) FrontHeadroom() int {
 	return 14
 }