Browse Source

Fix v2ray websocket transport

世界 2 years ago
parent
commit
d74abbd20e
2 changed files with 39 additions and 54 deletions
  1. 11 9
      transport/v2raywebsocket/client.go
  2. 28 45
      transport/v2raywebsocket/conn.go

+ 11 - 9
transport/v2raywebsocket/client.go

@@ -21,7 +21,8 @@ var _ adapter.V2RayClientTransport = (*Client)(nil)
 
 type Client struct {
 	dialer              *websocket.Dialer
-	uri                 string
+	requestURL          url.URL
+	requestURLString    string
 	headers             http.Header
 	maxEarlyData        uint32
 	earlyDataHeaderName string
@@ -57,15 +58,15 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 			return &deadConn{conn}, nil
 		}
 	}
-	var uri url.URL
+	var requestURL url.URL
 	if tlsConfig == nil {
-		uri.Scheme = "ws"
+		requestURL.Scheme = "ws"
 	} else {
-		uri.Scheme = "wss"
+		requestURL.Scheme = "wss"
 	}
-	uri.Host = serverAddr.String()
-	uri.Path = options.Path
-	err := sHTTP.URLSetPath(&uri, options.Path)
+	requestURL.Host = serverAddr.String()
+	requestURL.Path = options.Path
+	err := sHTTP.URLSetPath(&requestURL, options.Path)
 	if err != nil {
 		return nil
 	}
@@ -75,7 +76,8 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 	}
 	return &Client{
 		wsDialer,
-		uri.String(),
+		requestURL,
+		requestURL.String(),
 		headers,
 		options.MaxEarlyData,
 		options.EarlyDataHeaderName,
@@ -84,7 +86,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 
 func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 	if c.maxEarlyData <= 0 {
-		conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers)
+		conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers)
 		if err == nil {
 			return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
 		}

+ 28 - 45
transport/v2raywebsocket/conn.go

@@ -94,51 +94,64 @@ type EarlyWebsocketConn struct {
 	ctx    context.Context
 	conn   *WebsocketConn
 	create chan struct{}
+	err    error
 }
 
 func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
 	if c.conn == nil {
 		<-c.create
+		if c.err != nil {
+			return 0, c.err
+		}
 	}
 	return c.conn.Read(b)
 }
 
-func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
-	if c.conn != nil {
-		return c.conn.Write(b)
-	}
+func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
 	var (
 		earlyData []byte
 		lateData  []byte
 		conn      *websocket.Conn
 		response  *http.Response
+		err       error
 	)
-	if len(b) > int(c.maxEarlyData) {
-		earlyData = b[:c.maxEarlyData]
-		lateData = b[c.maxEarlyData:]
+	if len(content) > int(c.maxEarlyData) {
+		earlyData = content[:c.maxEarlyData]
+		lateData = content[c.maxEarlyData:]
 	} else {
-		earlyData = b
+		earlyData = content
 	}
 	if len(earlyData) > 0 {
 		earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
 		if c.earlyDataHeaderName == "" {
-			conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
+			requestURL := c.requestURL
+			requestURL.Path += earlyDataString
+			conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers)
 		} else {
 			headers := c.headers.Clone()
 			headers.Set(c.earlyDataHeaderName, earlyDataString)
-			conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
+			conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers)
 		}
 	} else {
-		conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
+		conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers)
 	}
 	if err != nil {
-		return 0, wrapDialError(response, err)
+		return wrapDialError(response, err)
 	}
 	c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
-	close(c.create)
 	if len(lateData) > 0 {
 		_, err = c.conn.Write(lateData)
 	}
+	return err
+}
+
+func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
+	if c.conn != nil {
+		return c.conn.Write(b)
+	}
+	err = c.writeRequest(b)
+	c.err = err
+	close(c.create)
 	if err != nil {
 		return
 	}
@@ -149,39 +162,9 @@ func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
 	if c.conn != nil {
 		return c.conn.WriteBuffer(buffer)
 	}
-	var (
-		earlyData []byte
-		lateData  []byte
-		conn      *websocket.Conn
-		response  *http.Response
-		err       error
-	)
-	if buffer.Len() > int(c.maxEarlyData) {
-		earlyData = buffer.Bytes()[:c.maxEarlyData]
-		lateData = buffer.Bytes()[c.maxEarlyData:]
-	} else {
-		earlyData = buffer.Bytes()
-	}
-	if len(earlyData) > 0 {
-		earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
-		if c.earlyDataHeaderName == "" {
-			conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
-		} else {
-			headers := c.headers.Clone()
-			headers.Set(c.earlyDataHeaderName, earlyDataString)
-			conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
-		}
-	} else {
-		conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
-	}
-	if err != nil {
-		return wrapDialError(response, err)
-	}
-	c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
+	err := c.writeRequest(buffer.Bytes())
+	c.err = err
 	close(c.create)
-	if len(lateData) > 0 {
-		_, err = c.conn.Write(lateData)
-	}
 	return err
 }