Przeglądaj źródła

Fix deadline usage

世界 1 rok temu
rodzic
commit
01f6e70bc5

+ 1 - 9
common/dialer/detour.go

@@ -6,7 +6,6 @@ import (
 	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
-	"github.com/sagernet/sing/common/bufio/deadline"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -45,14 +44,7 @@ func (d *DetourDialer) DialContext(ctx context.Context, network string, destinat
 	if err != nil {
 		return nil, err
 	}
-	conn, err := dialer.DialContext(ctx, network, destination)
-	if err != nil {
-		return nil, err
-	}
-	if deadline.NeedAdditionalReadDeadline(conn) {
-		conn = deadline.NewConn(conn)
-	}
-	return conn, nil
+	return dialer.DialContext(ctx, network, destination)
 }
 
 func (d *DetourDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {

+ 1 - 1
go.mod

@@ -26,7 +26,7 @@ require (
 	github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930
 	github.com/sagernet/quic-go v0.40.0
 	github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
-	github.com/sagernet/sing v0.2.18-0.20231124125253-2dcabf4bfcbc
+	github.com/sagernet/sing v0.2.18-0.20231201054122-bca74039ead5
 	github.com/sagernet/sing-dns v0.1.11
 	github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07
 	github.com/sagernet/sing-quic v0.1.5-0.20231123150216-00957d136203

+ 2 - 2
go.sum

@@ -110,8 +110,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL
 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU=
 github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
 github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
-github.com/sagernet/sing v0.2.18-0.20231124125253-2dcabf4bfcbc h1:vESVuxHgbd2EzHxd+TYTpNACIEGBOhp5n3KG7bgbcws=
-github.com/sagernet/sing v0.2.18-0.20231124125253-2dcabf4bfcbc/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
+github.com/sagernet/sing v0.2.18-0.20231201054122-bca74039ead5 h1:luykfsWNqFh9sdLXlkCQtkuzLUPRd3BMsdQJt0REB1g=
+github.com/sagernet/sing v0.2.18-0.20231201054122-bca74039ead5/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
 github.com/sagernet/sing-dns v0.1.11 h1:PPrMCVVrAeR3f5X23I+cmvacXJ+kzuyAsBiWyUKhGSE=
 github.com/sagernet/sing-dns v0.1.11/go.mod h1:zJ/YjnYB61SYE+ubMcMqVdpaSvsyQ2iShQGO3vuLvvE=
 github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07 h1:ncKb5tVOsCQgCsv6UpsA0jinbNb5OQ5GMPJlyQP3EHM=

+ 24 - 4
transport/v2raywebsocket/client.go

@@ -12,6 +12,9 @@ import (
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
+	"github.com/sagernet/sing/common/bufio/deadline"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -87,18 +90,35 @@ func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers h
 			return nil, err
 		}
 	}
-	conn.SetDeadline(time.Now().Add(C.TCPTimeout))
+	var deadlineConn net.Conn
+	if deadline.NeedAdditionalReadDeadline(conn) {
+		deadlineConn = deadline.NewConn(conn)
+	} else {
+		deadlineConn = conn
+	}
+	err = deadlineConn.SetDeadline(time.Now().Add(C.TCPTimeout))
+	if err != nil {
+		return nil, E.Cause(err, "set read deadline")
+	}
 	var protocols []string
 	if protocolHeader := headers.Get("Sec-WebSocket-Protocol"); protocolHeader != "" {
 		protocols = []string{protocolHeader}
 		headers.Del("Sec-WebSocket-Protocol")
 	}
-	reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(conn, requestURL)
-	conn.SetDeadline(time.Time{})
+	reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(deadlineConn, requestURL)
+	deadlineConn.SetDeadline(time.Time{})
 	if err != nil {
 		return nil, err
 	}
-	return NewConn(conn, reader, nil, ws.StateClientSide), nil
+	if reader != nil {
+		buffer := buf.NewSize(reader.Buffered())
+		_, err = buffer.ReadFullFrom(reader, buffer.Len())
+		if err != nil {
+			return nil, err
+		}
+		conn = bufio.NewCachedConn(conn, buffer)
+	}
+	return NewConn(conn, nil, ws.StateClientSide), nil
 }
 
 func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {

+ 2 - 9
transport/v2raywebsocket/conn.go

@@ -1,7 +1,6 @@
 package v2raywebsocket
 
 import (
-	"bufio"
 	"context"
 	"encoding/base64"
 	"io"
@@ -28,19 +27,13 @@ type WebsocketConn struct {
 	remoteAddr     net.Addr
 }
 
-func NewConn(conn net.Conn, br *bufio.Reader, remoteAddr net.Addr, state ws.State) *WebsocketConn {
+func NewConn(conn net.Conn, remoteAddr net.Addr, state ws.State) *WebsocketConn {
 	controlHandler := wsutil.ControlFrameHandler(conn, state)
-	var reader io.Reader
-	if br != nil && br.Buffered() > 0 {
-		reader = br
-	} else {
-		reader = conn
-	}
 	return &WebsocketConn{
 		Conn:  conn,
 		state: state,
 		reader: &wsutil.Reader{
-			Source:          reader,
+			Source:          conn,
 			State:           state,
 			SkipHeaderCheck: !debug.Enabled,
 			OnIntermediate:  controlHandler,

+ 2 - 2
transport/v2raywebsocket/server.go

@@ -88,14 +88,14 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 		s.invalidRequest(writer, request, http.StatusBadRequest, E.Cause(err, "decode early data"))
 		return
 	}
-	wsConn, reader, _, err := ws.UpgradeHTTP(request, writer)
+	wsConn, _, _, err := ws.UpgradeHTTP(request, writer)
 	if err != nil {
 		s.invalidRequest(writer, request, 0, E.Cause(err, "upgrade websocket connection"))
 		return
 	}
 	var metadata M.Metadata
 	metadata.Source = sHttp.SourceAddress(request)
-	conn = NewConn(wsConn, reader.Reader, metadata.Source.TCPAddr(), ws.StateServerSide)
+	conn = NewConn(wsConn, metadata.Source.TCPAddr(), ws.StateServerSide)
 	if len(earlyData) > 0 {
 		conn = bufio.NewCachedConn(conn, buf.As(earlyData))
 	}