Browse Source

Added utls to websocket (#1256)

* Added utls to websocket

* Slightly better code

One less allocation
Hirbod Behnam 3 years ago
parent
commit
1f93cbbc5d
2 changed files with 53 additions and 1 deletions
  1. 27 0
      transport/internet/tls/tls.go
  2. 26 1
      transport/internet/websocket/dialer.go

+ 27 - 0
transport/internet/tls/tls.go

@@ -66,6 +66,33 @@ func (c *UConn) HandshakeAddress() net.Address {
 	return net.ParseAddress(state.ServerName)
 }
 
+// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
+// http/1.1 in its ALPN.
+func (c *UConn) WebsocketHandshake() error {
+	// Build the handshake state. This will apply every variable of the TLS of the
+	// fingerprint in the UConn
+	if err := c.BuildHandshakeState(); err != nil {
+		return err
+	}
+	// Iterate over extensions and check for utls.ALPNExtension
+	hasALPNExtension := false
+	for _, extension := range c.Extensions {
+		if alpn, ok := extension.(*utls.ALPNExtension); ok {
+			hasALPNExtension = true
+			alpn.AlpnProtocols = []string{"http/1.1"}
+			break
+		}
+	}
+	if !hasALPNExtension { // Append extension if doesn't exists
+		c.Extensions = append(c.Extensions, &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}})
+	}
+	// Rebuild the client hello and do the handshake
+	if err := c.BuildHandshakeState(); err != nil {
+		return err
+	}
+	return c.Handshake()
+}
+
 func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
 	state := c.ConnectionState()
 	return state.NegotiatedProtocol, state.NegotiatedProtocolIsMutual

+ 26 - 1
transport/internet/websocket/dialer.go

@@ -6,6 +6,7 @@ import (
 	"encoding/base64"
 	"fmt"
 	"io"
+	gonet "net"
 	"net/http"
 	"os"
 	"time"
@@ -83,7 +84,31 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
 
 	if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
 		protocol = "wss"
-		dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
+		tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
+		dialer.TLSClientConfig = tlsConfig
+		if fingerprint, exists := tls.Fingerprints[config.Fingerprint]; exists {
+			dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) {
+				// Like the NetDial in the dialer
+				pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
+				if err != nil {
+					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
+					return nil, err
+				}
+				// TLS and apply the handshake
+				cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
+				if err := cn.WebsocketHandshake(); err != nil {
+					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
+					return nil, err
+				}
+				if !tlsConfig.InsecureSkipVerify {
+					if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
+						newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
+						return nil, err
+					}
+				}
+				return cn, nil
+			}
+		}
 	}
 
 	host := dest.NetAddr()