|
@@ -15,40 +15,44 @@ import (
|
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
|
N "github.com/sagernet/sing/common/network"
|
|
N "github.com/sagernet/sing/common/network"
|
|
|
"github.com/sagernet/sing/common/x/list"
|
|
"github.com/sagernet/sing/common/x/list"
|
|
|
-
|
|
|
|
|
- "github.com/hashicorp/yamux"
|
|
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
var _ N.Dialer = (*Client)(nil)
|
|
var _ N.Dialer = (*Client)(nil)
|
|
|
|
|
|
|
|
type Client struct {
|
|
type Client struct {
|
|
|
access sync.Mutex
|
|
access sync.Mutex
|
|
|
- connections list.List[*yamux.Session]
|
|
|
|
|
|
|
+ connections list.List[abstractSession]
|
|
|
ctx context.Context
|
|
ctx context.Context
|
|
|
dialer N.Dialer
|
|
dialer N.Dialer
|
|
|
|
|
+ protocol Protocol
|
|
|
maxConnections int
|
|
maxConnections int
|
|
|
minStreams int
|
|
minStreams int
|
|
|
maxStreams int
|
|
maxStreams int
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func NewClient(ctx context.Context, dialer N.Dialer, maxConnections int, minStreams int, maxStreams int) *Client {
|
|
|
|
|
|
|
+func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int) *Client {
|
|
|
return &Client{
|
|
return &Client{
|
|
|
ctx: ctx,
|
|
ctx: ctx,
|
|
|
dialer: dialer,
|
|
dialer: dialer,
|
|
|
|
|
+ protocol: protocol,
|
|
|
maxConnections: maxConnections,
|
|
maxConnections: maxConnections,
|
|
|
minStreams: minStreams,
|
|
minStreams: minStreams,
|
|
|
maxStreams: maxStreams,
|
|
maxStreams: maxStreams,
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) N.Dialer {
|
|
|
|
|
|
|
+func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (N.Dialer, error) {
|
|
|
if !options.Enabled {
|
|
if !options.Enabled {
|
|
|
- return dialer
|
|
|
|
|
|
|
+ return dialer, nil
|
|
|
}
|
|
}
|
|
|
if options.MaxConnections == 0 && options.MaxStreams == 0 {
|
|
if options.MaxConnections == 0 && options.MaxStreams == 0 {
|
|
|
options.MinStreams = 8
|
|
options.MinStreams = 8
|
|
|
}
|
|
}
|
|
|
- return NewClient(ctx, dialer, options.MaxConnections, options.MinStreams, options.MaxStreams)
|
|
|
|
|
|
|
+ protocol, err := ParseProtocol(options.Protocol)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams), nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
|
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
|
@@ -80,8 +84,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
|
|
|
|
|
|
|
|
func (c *Client) openStream() (net.Conn, error) {
|
|
func (c *Client) openStream() (net.Conn, error) {
|
|
|
var (
|
|
var (
|
|
|
- session *yamux.Session
|
|
|
|
|
- stream *yamux.Stream
|
|
|
|
|
|
|
+ session abstractSession
|
|
|
|
|
+ stream net.Conn
|
|
|
err error
|
|
err error
|
|
|
)
|
|
)
|
|
|
for attempts := 0; attempts < 2; attempts++ {
|
|
for attempts := 0; attempts < 2; attempts++ {
|
|
@@ -89,7 +93,7 @@ func (c *Client) openStream() (net.Conn, error) {
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
continue
|
|
continue
|
|
|
}
|
|
}
|
|
|
- stream, err = session.OpenStream()
|
|
|
|
|
|
|
+ stream, err = session.Open()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
continue
|
|
continue
|
|
|
}
|
|
}
|
|
@@ -101,11 +105,11 @@ func (c *Client) openStream() (net.Conn, error) {
|
|
|
return &wrapStream{stream}, nil
|
|
return &wrapStream{stream}, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (c *Client) offer() (*yamux.Session, error) {
|
|
|
|
|
|
|
+func (c *Client) offer() (abstractSession, error) {
|
|
|
c.access.Lock()
|
|
c.access.Lock()
|
|
|
defer c.access.Unlock()
|
|
defer c.access.Unlock()
|
|
|
|
|
|
|
|
- sessions := make([]*yamux.Session, 0, c.maxConnections)
|
|
|
|
|
|
|
+ sessions := make([]abstractSession, 0, c.maxConnections)
|
|
|
for element := c.connections.Front(); element != nil; {
|
|
for element := c.connections.Front(); element != nil; {
|
|
|
if element.Value.IsClosed() {
|
|
if element.Value.IsClosed() {
|
|
|
nextElement := element.Next()
|
|
nextElement := element.Next()
|
|
@@ -120,10 +124,7 @@ func (c *Client) offer() (*yamux.Session, error) {
|
|
|
if sLen == 0 {
|
|
if sLen == 0 {
|
|
|
return c.offerNew()
|
|
return c.offerNew()
|
|
|
}
|
|
}
|
|
|
- // session := common.MinBy(sessions, yamux.Session.NumStreams)
|
|
|
|
|
- session := common.MinBy(sessions, func(it *yamux.Session) int {
|
|
|
|
|
- return it.NumStreams()
|
|
|
|
|
- })
|
|
|
|
|
|
|
+ session := common.MinBy(sessions, abstractSession.NumStreams)
|
|
|
numStreams := session.NumStreams()
|
|
numStreams := session.NumStreams()
|
|
|
if numStreams == 0 {
|
|
if numStreams == 0 {
|
|
|
return session, nil
|
|
return session, nil
|
|
@@ -140,12 +141,12 @@ func (c *Client) offer() (*yamux.Session, error) {
|
|
|
return c.offerNew()
|
|
return c.offerNew()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (c *Client) offerNew() (*yamux.Session, error) {
|
|
|
|
|
|
|
+func (c *Client) offerNew() (abstractSession, error) {
|
|
|
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination)
|
|
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
- session, err := yamux.Client(conn, newMuxConfig())
|
|
|
|
|
|
|
+ session, err := c.protocol.newClient(&protocolConn{Conn: conn, protocol: c.protocol})
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
@@ -170,7 +171,7 @@ type ClientConn struct {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *ClientConn) readResponse() error {
|
|
func (c *ClientConn) readResponse() error {
|
|
|
- response, err := ReadResponse(c.Conn)
|
|
|
|
|
|
|
+ response, err := ReadStreamResponse(c.Conn)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
@@ -195,7 +196,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) {
|
|
|
if c.requestWrite {
|
|
if c.requestWrite {
|
|
|
return c.Conn.Write(b)
|
|
return c.Conn.Write(b)
|
|
|
}
|
|
}
|
|
|
- request := Request{
|
|
|
|
|
|
|
+ request := StreamRequest{
|
|
|
Network: N.NetworkTCP,
|
|
Network: N.NetworkTCP,
|
|
|
Destination: c.destination,
|
|
Destination: c.destination,
|
|
|
}
|
|
}
|
|
@@ -203,7 +204,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) {
|
|
|
defer common.KeepAlive(_buffer)
|
|
defer common.KeepAlive(_buffer)
|
|
|
buffer := common.Dup(_buffer)
|
|
buffer := common.Dup(_buffer)
|
|
|
defer buffer.Release()
|
|
defer buffer.Release()
|
|
|
- EncodeRequest(request, buffer)
|
|
|
|
|
|
|
+ EncodeStreamRequest(request, buffer)
|
|
|
buffer.Write(b)
|
|
buffer.Write(b)
|
|
|
_, err = c.Conn.Write(buffer.Bytes())
|
|
_, err = c.Conn.Write(buffer.Bytes())
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -255,7 +256,7 @@ type ClientPacketConn struct {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *ClientPacketConn) readResponse() error {
|
|
func (c *ClientPacketConn) readResponse() error {
|
|
|
- response, err := ReadResponse(c.ExtendedConn)
|
|
|
|
|
|
|
+ response, err := ReadStreamResponse(c.ExtendedConn)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
@@ -285,7 +286,7 @@ func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
|
|
func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
|
|
|
- request := Request{
|
|
|
|
|
|
|
+ request := StreamRequest{
|
|
|
Network: N.NetworkUDP,
|
|
Network: N.NetworkUDP,
|
|
|
Destination: c.destination,
|
|
Destination: c.destination,
|
|
|
}
|
|
}
|
|
@@ -297,7 +298,7 @@ func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
|
|
|
defer common.KeepAlive(_buffer)
|
|
defer common.KeepAlive(_buffer)
|
|
|
buffer := common.Dup(_buffer)
|
|
buffer := common.Dup(_buffer)
|
|
|
defer buffer.Release()
|
|
defer buffer.Release()
|
|
|
- EncodeRequest(request, buffer)
|
|
|
|
|
|
|
+ EncodeStreamRequest(request, buffer)
|
|
|
if len(payload) > 0 {
|
|
if len(payload) > 0 {
|
|
|
common.Must(
|
|
common.Must(
|
|
|
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
|
|
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
|
|
@@ -363,7 +364,7 @@ type ClientPacketAddrConn struct {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *ClientPacketAddrConn) readResponse() error {
|
|
func (c *ClientPacketAddrConn) readResponse() error {
|
|
|
- response, err := ReadResponse(c.ExtendedConn)
|
|
|
|
|
|
|
+ response, err := ReadStreamResponse(c.ExtendedConn)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
@@ -399,7 +400,7 @@ func (c *ClientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
|
|
func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
|
|
|
- request := Request{
|
|
|
|
|
|
|
+ request := StreamRequest{
|
|
|
Network: N.NetworkUDP,
|
|
Network: N.NetworkUDP,
|
|
|
Destination: c.destination,
|
|
Destination: c.destination,
|
|
|
PacketAddr: true,
|
|
PacketAddr: true,
|
|
@@ -412,7 +413,7 @@ func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
|
|
|
defer common.KeepAlive(_buffer)
|
|
defer common.KeepAlive(_buffer)
|
|
|
buffer := common.Dup(_buffer)
|
|
buffer := common.Dup(_buffer)
|
|
|
defer buffer.Release()
|
|
defer buffer.Release()
|
|
|
- EncodeRequest(request, buffer)
|
|
|
|
|
|
|
+ EncodeStreamRequest(request, buffer)
|
|
|
if len(payload) > 0 {
|
|
if len(payload) > 0 {
|
|
|
common.Must(
|
|
common.Must(
|
|
|
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
|
|
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
|