123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262 |
- package wireguard
- import (
- "context"
- "net"
- "net/netip"
- "sync"
- "time"
- "github.com/sagernet/sing/common"
- "github.com/sagernet/sing/common/bufio"
- E "github.com/sagernet/sing/common/exceptions"
- "github.com/sagernet/sing/common/logger"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "github.com/sagernet/sing/service"
- "github.com/sagernet/sing/service/pause"
- "github.com/sagernet/wireguard-go/conn"
- )
- var _ conn.Bind = (*ClientBind)(nil)
- type ClientBind struct {
- ctx context.Context
- logger logger.Logger
- pauseManager pause.Manager
- bindCtx context.Context
- bindDone context.CancelFunc
- dialer N.Dialer
- reservedForEndpoint map[netip.AddrPort][3]uint8
- connAccess sync.Mutex
- conn *wireConn
- done chan struct{}
- isConnect bool
- connectAddr netip.AddrPort
- reserved [3]uint8
- }
- func NewClientBind(ctx context.Context, logger logger.Logger, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
- return &ClientBind{
- ctx: ctx,
- logger: logger,
- pauseManager: service.FromContext[pause.Manager](ctx),
- dialer: dialer,
- reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
- done: make(chan struct{}),
- isConnect: isConnect,
- connectAddr: connectAddr,
- reserved: reserved,
- }
- }
- func (c *ClientBind) connect() (*wireConn, error) {
- serverConn := c.conn
- if serverConn != nil {
- select {
- case <-serverConn.done:
- serverConn = nil
- default:
- return serverConn, nil
- }
- }
- c.connAccess.Lock()
- defer c.connAccess.Unlock()
- select {
- case <-c.done:
- return nil, net.ErrClosed
- default:
- }
- serverConn = c.conn
- if serverConn != nil {
- select {
- case <-serverConn.done:
- serverConn = nil
- default:
- return serverConn, nil
- }
- }
- if c.isConnect {
- udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
- if err != nil {
- return nil, err
- }
- c.conn = &wireConn{
- PacketConn: bufio.NewUnbindPacketConn(udpConn),
- done: make(chan struct{}),
- }
- } else {
- udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
- if err != nil {
- return nil, err
- }
- c.conn = &wireConn{
- PacketConn: bufio.NewPacketConn(udpConn),
- done: make(chan struct{}),
- }
- }
- return c.conn, nil
- }
- func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
- select {
- case <-c.done:
- c.done = make(chan struct{})
- default:
- }
- c.bindCtx, c.bindDone = context.WithCancel(c.ctx)
- return []conn.ReceiveFunc{c.receive}, 0, nil
- }
- func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
- udpConn, err := c.connect()
- if err != nil {
- select {
- case <-c.done:
- return
- default:
- }
- c.logger.Error(E.Cause(err, "connect to server"))
- err = nil
- c.pauseManager.WaitActive()
- time.Sleep(time.Second)
- return
- }
- n, addr, err := udpConn.ReadFrom(packets[0])
- if err != nil {
- udpConn.Close()
- select {
- case <-c.done:
- default:
- c.logger.Error(E.Cause(err, "read packet"))
- err = nil
- }
- return
- }
- sizes[0] = n
- if n > 3 {
- b := packets[0]
- common.ClearArray(b[1:4])
- }
- eps[0] = remoteEndpoint(M.SocksaddrFromNet(addr).Unwrap().AddrPort())
- count = 1
- return
- }
- func (c *ClientBind) Close() error {
- select {
- case <-c.done:
- default:
- close(c.done)
- }
- if c.bindDone != nil {
- c.bindDone()
- }
- c.connAccess.Lock()
- defer c.connAccess.Unlock()
- common.Close(common.PtrOrNil(c.conn))
- return nil
- }
- func (c *ClientBind) SetMark(mark uint32) error {
- return nil
- }
- func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint, offset int) error {
- udpConn, err := c.connect()
- if err != nil {
- c.pauseManager.WaitActive()
- time.Sleep(time.Second)
- return err
- }
- destination := netip.AddrPort(ep.(remoteEndpoint))
- for _, buf := range bufs {
- if offset > 0 {
- buf = buf[offset:]
- }
- if len(buf) > 3 {
- reserved, loaded := c.reservedForEndpoint[destination]
- if !loaded {
- reserved = c.reserved
- }
- copy(buf[1:4], reserved[:])
- }
- _, err = udpConn.WriteToUDPAddrPort(buf, destination)
- if err != nil {
- udpConn.Close()
- return err
- }
- }
- return nil
- }
- func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
- ap, err := netip.ParseAddrPort(s)
- if err != nil {
- return nil, err
- }
- return remoteEndpoint(ap), nil
- }
- func (c *ClientBind) BatchSize() int {
- return 1
- }
- func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
- c.reservedForEndpoint[destination] = reserved
- }
- type wireConn struct {
- net.PacketConn
- conn net.Conn
- access sync.Mutex
- done chan struct{}
- }
- func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
- if w.conn != nil {
- return w.conn.Write(b)
- }
- return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr())
- }
- func (w *wireConn) Close() error {
- w.access.Lock()
- defer w.access.Unlock()
- select {
- case <-w.done:
- return net.ErrClosed
- default:
- }
- w.PacketConn.Close()
- close(w.done)
- return nil
- }
- var _ conn.Endpoint = (*remoteEndpoint)(nil)
- type remoteEndpoint netip.AddrPort
- func (e remoteEndpoint) ClearSrc() {
- }
- func (e remoteEndpoint) SrcToString() string {
- return ""
- }
- func (e remoteEndpoint) DstToString() string {
- return (netip.AddrPort)(e).String()
- }
- func (e remoteEndpoint) DstToBytes() []byte {
- b, _ := (netip.AddrPort)(e).MarshalBinary()
- return b
- }
- func (e remoteEndpoint) DstIP() netip.Addr {
- return (netip.AddrPort)(e).Addr()
- }
- func (e remoteEndpoint) SrcIP() netip.Addr {
- return netip.Addr{}
- }
|