浏览代码

Add WSS Browser Dialer support (#421)

RPRX 4 年之前
父节点
当前提交
d46af8b5d4
共有 3 个文件被更改,包括 119 次插入8 次删除
  1. 49 0
      transport/internet/websocket/dialer.go
  2. 55 0
      transport/internet/websocket/dialer.html
  3. 15 8
      transport/internet/websocket/hub.go

+ 49 - 0
transport/internet/websocket/dialer.go

@@ -2,8 +2,12 @@ package websocket
 
 import (
 	"context"
+	_ "embed"
 	"encoding/base64"
+	"fmt"
 	"io"
+	"net/http"
+	"os"
 	"time"
 
 	"github.com/gorilla/websocket"
@@ -15,6 +19,27 @@ import (
 	"github.com/xtls/xray-core/transport/internet/tls"
 )
 
+//go:embed dialer.html
+var webpage []byte
+var conns chan *websocket.Conn
+
+func init() {
+	if addr := os.Getenv("XRAY_BROWSER_DIALER"); addr != "" {
+		conns = make(chan *websocket.Conn, 256)
+		go http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			if r.URL.Path == "/websocket" {
+				if conn, err := upgrader.Upgrade(w, r, nil); err == nil {
+					conns <- conn
+				} else {
+					fmt.Println("unexpected error")
+				}
+			} else {
+				w.Write(webpage)
+			}
+		}))
+	}
+}
+
 // Dial dials a WebSocket connection to the given destination.
 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
 	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
@@ -66,6 +91,30 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
 	}
 	uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
 
+	if conns != nil {
+		data := []byte(uri)
+		if ed != nil {
+			data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...)
+		}
+		var conn *websocket.Conn
+		for {
+			conn = <-conns
+			if conn.WriteMessage(websocket.TextMessage, data) != nil {
+				conn.Close()
+			} else {
+				break
+			}
+		}
+		if _, p, err := conn.ReadMessage(); err != nil {
+			conn.Close()
+			return nil, err
+		} else if s := string(p); s != "ok" {
+			conn.Close()
+			return nil, newError(s)
+		}
+		return newConnection(conn, conn.RemoteAddr(), nil), nil
+	}
+
 	header := wsSettings.GetRequestHeader()
 	if ed != nil {
 		header.Set("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed))

+ 55 - 0
transport/internet/websocket/dialer.html

@@ -0,0 +1,55 @@
+<!DOCTYPE html>
+<html>
+<head>
+	<title>Browser Dialer</title>
+</head>
+<body></body>
+<script>
+	// Copyright (c) 2021 XRAY. Mozilla Public License 2.0.
+	var url = "ws://" + window.location.host + "/websocket"
+	var count = 0
+	setInterval(check, 1000)
+	function check() {
+		if (count <= 0) {
+			count += 1
+			console.log("Prepare", url)
+			var ws = new WebSocket(url)
+			var wss = undefined
+			var first = true
+			ws.onmessage = function (event) {
+				if (first) {
+					first = false
+					count -= 1
+					var arr = event.data.split(" ")
+					console.log("Dial", arr[0], arr[1])
+					wss = new WebSocket(arr[0], arr[1])
+					var opened = false
+					wss.onopen = function (event) {
+						opened = true
+						ws.send("ok")
+					}
+					wss.onmessage = function (event) {
+						ws.send(event.data)
+					}
+					wss.onclose = function (event) {
+						ws.close()
+					}
+					wss.onerror = function (event) {
+						!opened && ws.send("fail")
+						wss.close()
+					}
+					check()
+				} else wss.send(event.data)
+			}
+			ws.onclose = function (event) {
+				if (first) count -= 1
+				else wss.close()
+			}
+			ws.onerror = function (event) {
+				ws.close()
+			}
+		}
+	}
+</script>
+</body>
+</html>

+ 15 - 8
transport/internet/websocket/hub.go

@@ -7,6 +7,7 @@ import (
 	"encoding/base64"
 	"io"
 	"net/http"
+	"strings"
 	"sync"
 	"time"
 
@@ -25,6 +26,8 @@ type requestHandler struct {
 	ln   *Listener
 }
 
+var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "")
+
 var upgrader = &websocket.Upgrader{
 	ReadBufferSize:   4 * 1024,
 	WriteBufferSize:  4 * 1024,
@@ -39,7 +42,17 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 		writer.WriteHeader(http.StatusNotFound)
 		return
 	}
-	conn, err := upgrader.Upgrade(writer, request, nil)
+
+	var extraReader io.Reader
+	var responseHeader = http.Header{}
+	if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" {
+		if ed, err := base64.RawURLEncoding.DecodeString(replacer.Replace(str)); err == nil && len(ed) > 0 {
+			extraReader = bytes.NewReader(ed)
+			responseHeader.Set("Sec-WebSocket-Protocol", str)
+		}
+	}
+
+	conn, err := upgrader.Upgrade(writer, request, responseHeader)
 	if err != nil {
 		newError("failed to convert to WebSocket connection").Base(err).WriteToLog()
 		return
@@ -54,12 +67,6 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 		}
 	}
 
-	var extraReader io.Reader
-	if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" {
-		if ed, err := base64.StdEncoding.DecodeString(str); err == nil && len(ed) > 0 {
-			extraReader = bytes.NewReader(ed)
-		}
-	}
 	h.ln.addConn(newConnection(conn, remoteAddr, extraReader))
 }
 
@@ -128,7 +135,7 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
 			ln:   l,
 		},
 		ReadHeaderTimeout: time.Second * 4,
-		MaxHeaderBytes:    2048,
+		MaxHeaderBytes:    4096,
 	}
 
 	go func() {