|
@@ -1,11 +1,11 @@
|
|
|
package v2raywebsocket
|
|
|
|
|
|
import (
|
|
|
+ "bufio"
|
|
|
"context"
|
|
|
"encoding/base64"
|
|
|
"io"
|
|
|
"net"
|
|
|
- "net/http"
|
|
|
"os"
|
|
|
"sync"
|
|
|
"time"
|
|
@@ -13,50 +13,96 @@ import (
|
|
|
C "github.com/sagernet/sing-box/constant"
|
|
|
"github.com/sagernet/sing/common"
|
|
|
"github.com/sagernet/sing/common/buf"
|
|
|
+ "github.com/sagernet/sing/common/debug"
|
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
|
- "github.com/sagernet/websocket"
|
|
|
+ "github.com/sagernet/ws"
|
|
|
+ "github.com/sagernet/ws/wsutil"
|
|
|
)
|
|
|
|
|
|
type WebsocketConn struct {
|
|
|
- *websocket.Conn
|
|
|
+ net.Conn
|
|
|
*Writer
|
|
|
- remoteAddr net.Addr
|
|
|
- reader io.Reader
|
|
|
+ state ws.State
|
|
|
+ reader *wsutil.Reader
|
|
|
+ controlHandler wsutil.FrameHandlerFunc
|
|
|
+ remoteAddr net.Addr
|
|
|
}
|
|
|
|
|
|
-func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn {
|
|
|
+func NewConn(conn net.Conn, br *bufio.Reader, remoteAddr net.Addr, state ws.State) *WebsocketConn {
|
|
|
+ controlHandler := wsutil.ControlFrameHandler(conn, state)
|
|
|
+ var reader io.Reader
|
|
|
+ if br != nil && br.Buffered() > 0 {
|
|
|
+ reader = br
|
|
|
+ } else {
|
|
|
+ reader = conn
|
|
|
+ }
|
|
|
return &WebsocketConn{
|
|
|
- Conn: wsConn,
|
|
|
- remoteAddr: remoteAddr,
|
|
|
- Writer: NewWriter(wsConn, true),
|
|
|
+ Conn: conn,
|
|
|
+ state: state,
|
|
|
+ reader: &wsutil.Reader{
|
|
|
+ Source: reader,
|
|
|
+ State: state,
|
|
|
+ SkipHeaderCheck: !debug.Enabled,
|
|
|
+ OnIntermediate: controlHandler,
|
|
|
+ },
|
|
|
+ controlHandler: controlHandler,
|
|
|
+ remoteAddr: remoteAddr,
|
|
|
+ Writer: NewWriter(conn, state),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (c *WebsocketConn) Close() error {
|
|
|
- err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout))
|
|
|
- if err != nil {
|
|
|
- return c.Conn.Close()
|
|
|
+ c.Conn.SetWriteDeadline(time.Now().Add(C.TCPTimeout))
|
|
|
+ frame := ws.NewCloseFrame(ws.NewCloseFrameBody(
|
|
|
+ ws.StatusNormalClosure, "",
|
|
|
+ ))
|
|
|
+ if c.state == ws.StateClientSide {
|
|
|
+ frame = ws.MaskFrameInPlace(frame)
|
|
|
}
|
|
|
+ ws.WriteFrame(c.Conn, frame)
|
|
|
+ c.Conn.Close()
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
func (c *WebsocketConn) Read(b []byte) (n int, err error) {
|
|
|
+ var header ws.Header
|
|
|
for {
|
|
|
- if c.reader == nil {
|
|
|
- _, c.reader, err = c.NextReader()
|
|
|
+ n, err = c.reader.Read(b)
|
|
|
+ if n > 0 {
|
|
|
+ err = nil
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if !E.IsMulti(err, io.EOF, wsutil.ErrNoFrameAdvance) {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ header, err = c.reader.NextFrame()
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if header.OpCode.IsControl() {
|
|
|
+ err = c.controlHandler(header, c.reader)
|
|
|
if err != nil {
|
|
|
- err = wrapError(err)
|
|
|
return
|
|
|
}
|
|
|
+ continue
|
|
|
}
|
|
|
- n, err = c.reader.Read(b)
|
|
|
- if E.IsMulti(err, io.EOF) {
|
|
|
- c.reader = nil
|
|
|
+ if header.OpCode&ws.OpBinary == 0 {
|
|
|
+ err = c.reader.Discard()
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
continue
|
|
|
}
|
|
|
- err = wrapError(err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *WebsocketConn) Write(p []byte) (n int, err error) {
|
|
|
+ err = wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, p)
|
|
|
+ if err != nil {
|
|
|
return
|
|
|
}
|
|
|
+ n = len(p)
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
func (c *WebsocketConn) RemoteAddr() net.Addr {
|
|
@@ -83,11 +129,7 @@ func (c *WebsocketConn) NeedAdditionalReadDeadline() bool {
|
|
|
}
|
|
|
|
|
|
func (c *WebsocketConn) Upstream() any {
|
|
|
- return c.Conn.NetConn()
|
|
|
-}
|
|
|
-
|
|
|
-func (c *WebsocketConn) UpstreamWriter() any {
|
|
|
- return c.Writer
|
|
|
+ return c.Conn
|
|
|
}
|
|
|
|
|
|
type EarlyWebsocketConn struct {
|
|
@@ -113,8 +155,7 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
|
|
|
var (
|
|
|
earlyData []byte
|
|
|
lateData []byte
|
|
|
- conn *websocket.Conn
|
|
|
- response *http.Response
|
|
|
+ conn *WebsocketConn
|
|
|
err error
|
|
|
)
|
|
|
if len(content) > int(c.maxEarlyData) {
|
|
@@ -128,23 +169,26 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
|
|
|
if c.earlyDataHeaderName == "" {
|
|
|
requestURL := c.requestURL
|
|
|
requestURL.Path += earlyDataString
|
|
|
- conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers)
|
|
|
+ conn, err = c.dialContext(c.ctx, &requestURL, c.headers)
|
|
|
} else {
|
|
|
headers := c.headers.Clone()
|
|
|
headers.Set(c.earlyDataHeaderName, earlyDataString)
|
|
|
- conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers)
|
|
|
+ conn, err = c.dialContext(c.ctx, &c.requestURL, headers)
|
|
|
}
|
|
|
} else {
|
|
|
- conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers)
|
|
|
+ conn, err = c.dialContext(c.ctx, &c.requestURL, c.headers)
|
|
|
}
|
|
|
if err != nil {
|
|
|
- return wrapDialError(response, err)
|
|
|
+ return err
|
|
|
}
|
|
|
- c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
|
|
|
if len(lateData) > 0 {
|
|
|
- _, err = c.conn.Write(lateData)
|
|
|
+ _, err = conn.Write(lateData)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
}
|
|
|
- return err
|
|
|
+ c.conn = conn
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
|
|
@@ -230,13 +274,3 @@ func (c *EarlyWebsocketConn) Upstream() any {
|
|
|
func (c *EarlyWebsocketConn) LazyHeadroom() bool {
|
|
|
return c.conn == nil
|
|
|
}
|
|
|
-
|
|
|
-func wrapError(err error) error {
|
|
|
- if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
|
|
- return io.EOF
|
|
|
- }
|
|
|
- if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) {
|
|
|
- return net.ErrClosed
|
|
|
- }
|
|
|
- return err
|
|
|
-}
|