|
@@ -20,6 +20,7 @@ import (
|
|
|
)
|
|
|
|
|
|
func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
|
|
|
+ defer conn.Close()
|
|
|
ctx = adapter.WithContext(ctx, &metadata)
|
|
|
var outConn net.Conn
|
|
|
var err error
|
|
@@ -40,6 +41,7 @@ func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata a
|
|
|
}
|
|
|
|
|
|
func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error {
|
|
|
+ defer conn.Close()
|
|
|
ctx = adapter.WithContext(ctx, &metadata)
|
|
|
var outConn net.Conn
|
|
|
var err error
|
|
@@ -67,29 +69,49 @@ func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dial
|
|
|
}
|
|
|
|
|
|
func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error {
|
|
|
+ defer conn.Close()
|
|
|
ctx = adapter.WithContext(ctx, &metadata)
|
|
|
- var outConn net.PacketConn
|
|
|
- var destinationAddress netip.Addr
|
|
|
- var err error
|
|
|
- if len(metadata.DestinationAddresses) > 0 {
|
|
|
- outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
|
|
|
+ var (
|
|
|
+ outPacketConn net.PacketConn
|
|
|
+ outConn net.Conn
|
|
|
+ destinationAddress netip.Addr
|
|
|
+ err error
|
|
|
+ )
|
|
|
+ if metadata.UDPConnect {
|
|
|
+ if len(metadata.DestinationAddresses) > 0 {
|
|
|
+ outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
|
|
|
+ } else {
|
|
|
+ outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ return N.ReportHandshakeFailure(conn, err)
|
|
|
+ }
|
|
|
+ outPacketConn = bufio.NewUnbindPacketConn(outConn)
|
|
|
+ connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr())
|
|
|
+ if connRemoteAddr != metadata.Destination.Addr {
|
|
|
+ destinationAddress = connRemoteAddr
|
|
|
+ }
|
|
|
} else {
|
|
|
- outConn, err = this.ListenPacket(ctx, metadata.Destination)
|
|
|
- }
|
|
|
- if err != nil {
|
|
|
- return N.ReportHandshakeFailure(conn, err)
|
|
|
+ if len(metadata.DestinationAddresses) > 0 {
|
|
|
+ outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
|
|
|
+ } else {
|
|
|
+ outPacketConn, err = this.ListenPacket(ctx, metadata.Destination)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ return N.ReportHandshakeFailure(conn, err)
|
|
|
+ }
|
|
|
}
|
|
|
- err = N.ReportPacketConnHandshakeSuccess(conn, outConn)
|
|
|
+ err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn)
|
|
|
if err != nil {
|
|
|
- outConn.Close()
|
|
|
+ outPacketConn.Close()
|
|
|
return err
|
|
|
}
|
|
|
if destinationAddress.IsValid() {
|
|
|
if metadata.Destination.IsFqdn() {
|
|
|
if metadata.UDPDisableDomainUnmapping {
|
|
|
- outConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
|
|
|
+ outPacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
|
|
|
} else {
|
|
|
- outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
|
|
|
+ outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
|
|
|
}
|
|
|
}
|
|
|
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
|
|
@@ -104,37 +126,63 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
|
|
|
case C.ProtocolDNS:
|
|
|
ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
|
|
|
}
|
|
|
- return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
|
|
|
+ return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn))
|
|
|
}
|
|
|
|
|
|
func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error {
|
|
|
+ defer conn.Close()
|
|
|
ctx = adapter.WithContext(ctx, &metadata)
|
|
|
- var outConn net.PacketConn
|
|
|
- var destinationAddress netip.Addr
|
|
|
- var err error
|
|
|
- if len(metadata.DestinationAddresses) > 0 {
|
|
|
- outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
|
|
|
- } else if metadata.Destination.IsFqdn() {
|
|
|
- var destinationAddresses []netip.Addr
|
|
|
- destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
|
|
|
+ var (
|
|
|
+ outPacketConn net.PacketConn
|
|
|
+ outConn net.Conn
|
|
|
+ destinationAddress netip.Addr
|
|
|
+ err error
|
|
|
+ )
|
|
|
+ if metadata.UDPConnect {
|
|
|
+ if len(metadata.DestinationAddresses) > 0 {
|
|
|
+ outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
|
|
|
+ } else if metadata.Destination.IsFqdn() {
|
|
|
+ var destinationAddresses []netip.Addr
|
|
|
+ destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
|
|
|
+ if err != nil {
|
|
|
+ return N.ReportHandshakeFailure(conn, err)
|
|
|
+ }
|
|
|
+ outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, destinationAddresses)
|
|
|
+ } else {
|
|
|
+ outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
|
|
|
+ }
|
|
|
if err != nil {
|
|
|
return N.ReportHandshakeFailure(conn, err)
|
|
|
}
|
|
|
- outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses)
|
|
|
+ connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr())
|
|
|
+ if connRemoteAddr != metadata.Destination.Addr {
|
|
|
+ destinationAddress = connRemoteAddr
|
|
|
+ }
|
|
|
} else {
|
|
|
- outConn, err = this.ListenPacket(ctx, metadata.Destination)
|
|
|
- }
|
|
|
- if err != nil {
|
|
|
- return N.ReportHandshakeFailure(conn, err)
|
|
|
+ if len(metadata.DestinationAddresses) > 0 {
|
|
|
+ outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
|
|
|
+ } else if metadata.Destination.IsFqdn() {
|
|
|
+ var destinationAddresses []netip.Addr
|
|
|
+ destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
|
|
|
+ if err != nil {
|
|
|
+ return N.ReportHandshakeFailure(conn, err)
|
|
|
+ }
|
|
|
+ outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses)
|
|
|
+ } else {
|
|
|
+ outPacketConn, err = this.ListenPacket(ctx, metadata.Destination)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ return N.ReportHandshakeFailure(conn, err)
|
|
|
+ }
|
|
|
}
|
|
|
- err = N.ReportPacketConnHandshakeSuccess(conn, outConn)
|
|
|
+ err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn)
|
|
|
if err != nil {
|
|
|
- outConn.Close()
|
|
|
+ outPacketConn.Close()
|
|
|
return err
|
|
|
}
|
|
|
if destinationAddress.IsValid() {
|
|
|
if metadata.Destination.IsFqdn() {
|
|
|
- outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
|
|
|
+ outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
|
|
|
}
|
|
|
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
|
|
|
natConn.UpdateDestination(destinationAddress)
|
|
@@ -148,7 +196,7 @@ func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this
|
|
|
case C.ProtocolDNS:
|
|
|
ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
|
|
|
}
|
|
|
- return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
|
|
|
+ return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn))
|
|
|
}
|
|
|
|
|
|
func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {
|