|
@@ -7,6 +7,7 @@ import (
|
|
|
"encoding/base64"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
@@ -25,6 +26,8 @@ type requestHandler struct {
|
|
|
ln *Listener
|
|
|
}
|
|
|
|
|
|
+var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "")
|
|
|
+
|
|
|
var upgrader = &websocket.Upgrader{
|
|
|
ReadBufferSize: 4 * 1024,
|
|
|
WriteBufferSize: 4 * 1024,
|
|
@@ -39,7 +42,17 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
|
|
|
writer.WriteHeader(http.StatusNotFound)
|
|
|
return
|
|
|
}
|
|
|
- conn, err := upgrader.Upgrade(writer, request, nil)
|
|
|
+
|
|
|
+ var extraReader io.Reader
|
|
|
+ var responseHeader = http.Header{}
|
|
|
+ if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" {
|
|
|
+ if ed, err := base64.RawURLEncoding.DecodeString(replacer.Replace(str)); err == nil && len(ed) > 0 {
|
|
|
+ extraReader = bytes.NewReader(ed)
|
|
|
+ responseHeader.Set("Sec-WebSocket-Protocol", str)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ conn, err := upgrader.Upgrade(writer, request, responseHeader)
|
|
|
if err != nil {
|
|
|
newError("failed to convert to WebSocket connection").Base(err).WriteToLog()
|
|
|
return
|
|
@@ -54,12 +67,6 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- var extraReader io.Reader
|
|
|
- if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" {
|
|
|
- if ed, err := base64.StdEncoding.DecodeString(str); err == nil && len(ed) > 0 {
|
|
|
- extraReader = bytes.NewReader(ed)
|
|
|
- }
|
|
|
- }
|
|
|
h.ln.addConn(newConnection(conn, remoteAddr, extraReader))
|
|
|
}
|
|
|
|
|
@@ -128,7 +135,7 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
|
|
|
ln: l,
|
|
|
},
|
|
|
ReadHeaderTimeout: time.Second * 4,
|
|
|
- MaxHeaderBytes: 2048,
|
|
|
+ MaxHeaderBytes: 4096,
|
|
|
}
|
|
|
|
|
|
go func() {
|