Browse Source

Add WebSocket 0-RTT support (#375)

RPRX 4 years ago
parent
commit
60b06877bf

+ 13 - 0
infra/conf/transport_internet.go

@@ -3,6 +3,8 @@ package conf
 import (
 	"encoding/json"
 	"math"
+	"net/url"
+	"strconv"
 	"strings"
 
 	"github.com/golang/protobuf/proto"
@@ -155,9 +157,20 @@ func (c *WebSocketConfig) Build() (proto.Message, error) {
 			Value: value,
 		})
 	}
+	var ed uint32
+	if u, err := url.Parse(path); err == nil {
+		if q := u.Query(); q.Get("ed") != "" {
+			Ed, _ := strconv.Atoi(q.Get("ed"))
+			ed = uint32(Ed)
+			q.Del("ed")
+			u.RawQuery = q.Encode()
+			path = u.String()
+		}
+	}
 	config := &websocket.Config{
 		Path:   path,
 		Header: header,
+		Ed:     ed,
 	}
 	if c.AcceptProxyProtocol {
 		config.AcceptProxyProtocol = c.AcceptProxyProtocol

+ 12 - 3
transport/internet/websocket/config.pb.go

@@ -1,7 +1,7 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // versions:
 // 	protoc-gen-go v1.25.0
-// 	protoc        v3.14.0
+// 	protoc        v3.15.6
 // source: transport/internet/websocket/config.proto
 
 package websocket
@@ -89,6 +89,7 @@ type Config struct {
 	Path                string    `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"`
 	Header              []*Header `protobuf:"bytes,3,rep,name=header,proto3" json:"header,omitempty"`
 	AcceptProxyProtocol bool      `protobuf:"varint,4,opt,name=accept_proxy_protocol,json=acceptProxyProtocol,proto3" json:"accept_proxy_protocol,omitempty"`
+	Ed                  uint32    `protobuf:"varint,5,opt,name=ed,proto3" json:"ed,omitempty"`
 }
 
 func (x *Config) Reset() {
@@ -144,6 +145,13 @@ func (x *Config) GetAcceptProxyProtocol() bool {
 	return false
 }
 
+func (x *Config) GetEd() uint32 {
+	if x != nil {
+		return x.Ed
+	}
+	return 0
+}
+
 var File_transport_internet_websocket_config_proto protoreflect.FileDescriptor
 
 var file_transport_internet_websocket_config_proto_rawDesc = []byte{
@@ -155,7 +163,7 @@ var file_transport_internet_websocket_config_proto_rawDesc = []byte{
 	0x0a, 0x06, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18,
 	0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61,
 	0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
-	0x22, 0x99, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70,
+	0x22, 0xa9, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70,
 	0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12,
 	0x41, 0x0a, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32,
 	0x29, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74,
@@ -164,7 +172,8 @@ var file_transport_internet_websocket_config_proto_rawDesc = []byte{
 	0x65, 0x72, 0x12, 0x32, 0x0a, 0x15, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x5f, 0x70, 0x72, 0x6f,
 	0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28,
 	0x08, 0x52, 0x13, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x72,
-	0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x42, 0x85, 0x01, 0x0a,
+	0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0e, 0x0a, 0x02, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01,
+	0x28, 0x0d, 0x52, 0x02, 0x65, 0x64, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x42, 0x85, 0x01, 0x0a,
 	0x25, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70,
 	0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x77, 0x65, 0x62,
 	0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,

+ 2 - 0
transport/internet/websocket/config.proto

@@ -20,4 +20,6 @@ message Config {
   repeated Header header = 3;
 
   bool accept_proxy_protocol = 4;
+
+  uint32 ed = 5;
 }

+ 2 - 1
transport/internet/websocket/connection.go

@@ -22,10 +22,11 @@ type connection struct {
 	remoteAddr net.Addr
 }
 
-func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection {
+func newConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection {
 	return &connection{
 		conn:       conn,
 		remoteAddr: remoteAddr,
+		reader:     extraReader,
 	}
 }
 

+ 81 - 7
transport/internet/websocket/dialer.go

@@ -2,9 +2,12 @@ package websocket
 
 import (
 	"context"
+	"encoding/base64"
+	"io"
 	"time"
 
 	"github.com/gorilla/websocket"
+
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/session"
@@ -15,10 +18,21 @@ import (
 // Dial dials a WebSocket connection to the given destination.
 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
 	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
-
-	conn, err := dialWebsocket(ctx, dest, streamSettings)
-	if err != nil {
-		return nil, newError("failed to dial WebSocket").Base(err)
+	var conn net.Conn
+	if streamSettings.ProtocolSettings.(*Config).Ed > 0 {
+		ctx, cancel := context.WithCancel(ctx)
+		conn = &delayDialConn{
+			dialed:         make(chan bool, 1),
+			cancel:         cancel,
+			ctx:            ctx,
+			dest:           dest,
+			streamSettings: streamSettings,
+		}
+	} else {
+		var err error
+		if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil {
+			return nil, newError("failed to dial WebSocket").Base(err)
+		}
 	}
 	return internet.Connection(conn), nil
 }
@@ -27,7 +41,7 @@ func init() {
 	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
 }
 
-func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
+func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) {
 	wsSettings := streamSettings.ProtocolSettings.(*Config)
 
 	dialer := &websocket.Dialer{
@@ -52,7 +66,12 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
 	}
 	uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
 
-	conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader())
+	header := wsSettings.GetRequestHeader()
+	if ed != nil {
+		header.Set("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed))
+	}
+
+	conn, resp, err := dialer.Dial(uri, header)
 	if err != nil {
 		var reason string
 		if resp != nil {
@@ -61,5 +80,60 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
 		return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
 	}
 
-	return newConnection(conn, conn.RemoteAddr()), nil
+	return newConnection(conn, conn.RemoteAddr(), nil), nil
+}
+
+type delayDialConn struct {
+	net.Conn
+	closed         bool
+	dialed         chan bool
+	cancel         context.CancelFunc
+	ctx            context.Context
+	dest           net.Destination
+	streamSettings *internet.MemoryStreamConfig
+}
+
+func (d *delayDialConn) Write(b []byte) (int, error) {
+	if d.closed {
+		return 0, io.ErrClosedPipe
+	}
+	if d.Conn == nil {
+		ed := b
+		if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) {
+			ed = nil
+		}
+		var err error
+		if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil {
+			d.Close()
+			return 0, newError("failed to dial WebSocket").Base(err)
+		}
+		d.dialed <- true
+		if ed != nil {
+			return len(ed), nil
+		}
+	}
+	return d.Conn.Write(b)
+}
+
+func (d *delayDialConn) Read(b []byte) (int, error) {
+	if d.closed {
+		return 0, io.ErrClosedPipe
+	}
+	if d.Conn == nil {
+		select {
+		case <-d.ctx.Done():
+			return 0, io.ErrUnexpectedEOF
+		case <-d.dialed:
+		}
+	}
+	return d.Conn.Read(b)
+}
+
+func (d *delayDialConn) Close() error {
+	d.closed = true
+	d.cancel()
+	if d.Conn == nil {
+		return nil
+	}
+	return d.Conn.Close()
 }

+ 10 - 1
transport/internet/websocket/hub.go

@@ -1,8 +1,11 @@
 package websocket
 
 import (
+	"bytes"
 	"context"
 	"crypto/tls"
+	"encoding/base64"
+	"io"
 	"net/http"
 	"sync"
 	"time"
@@ -51,7 +54,13 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 		}
 	}
 
-	h.ln.addConn(newConnection(conn, remoteAddr))
+	var extraReader io.Reader
+	if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" {
+		if ed, err := base64.StdEncoding.DecodeString(str); err == nil && len(ed) > 0 {
+			extraReader = bytes.NewReader(ed)
+		}
+	}
+	h.ln.addConn(newConnection(conn, remoteAddr, extraReader))
 }
 
 type Listener struct {

+ 2 - 2
transport/internet/websocket/ws.go

@@ -1,6 +1,6 @@
-/*Package websocket implements Websocket transport
+/*Package websocket implements WebSocket transport
 
-Websocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability.
+WebSocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability.
 */
 package websocket