| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342 |
- //go:build with_gvisor
- package wireguard
- import (
- "context"
- "net"
- "net/netip"
- "os"
- "time"
- "github.com/sagernet/gvisor/pkg/buffer"
- "github.com/sagernet/gvisor/pkg/tcpip"
- "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
- "github.com/sagernet/gvisor/pkg/tcpip/header"
- "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
- "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
- "github.com/sagernet/gvisor/pkg/tcpip/stack"
- "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
- "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
- "github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
- "github.com/sagernet/sing-box/adapter"
- "github.com/sagernet/sing-box/log"
- "github.com/sagernet/sing-tun"
- "github.com/sagernet/sing-tun/ping"
- "github.com/sagernet/sing/common/buf"
- E "github.com/sagernet/sing/common/exceptions"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "github.com/sagernet/wireguard-go/device"
- wgTun "github.com/sagernet/wireguard-go/tun"
- )
- var _ NatDevice = (*stackDevice)(nil)
- type stackDevice struct {
- ctx context.Context
- logger log.ContextLogger
- stack *stack.Stack
- mtu uint32
- events chan wgTun.Event
- outbound chan *stack.PacketBuffer
- packetOutbound chan *buf.Buffer
- done chan struct{}
- dispatcher stack.NetworkDispatcher
- inet4Address netip.Addr
- inet6Address netip.Addr
- }
- func newStackDevice(options DeviceOptions) (*stackDevice, error) {
- tunDevice := &stackDevice{
- ctx: options.Context,
- logger: options.Logger,
- mtu: options.MTU,
- events: make(chan wgTun.Event, 1),
- outbound: make(chan *stack.PacketBuffer, 256),
- packetOutbound: make(chan *buf.Buffer, 256),
- done: make(chan struct{}),
- }
- ipStack, err := tun.NewGVisorStackWithOptions((*wireEndpoint)(tunDevice), stack.NICOptions{}, true)
- if err != nil {
- return nil, err
- }
- var (
- inet4Address netip.Addr
- inet6Address netip.Addr
- )
- for _, prefix := range options.Address {
- addr := tun.AddressFromAddr(prefix.Addr())
- protoAddr := tcpip.ProtocolAddress{
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: prefix.Bits(),
- },
- }
- if prefix.Addr().Is4() {
- inet4Address = prefix.Addr()
- tunDevice.inet4Address = inet4Address
- protoAddr.Protocol = ipv4.ProtocolNumber
- } else {
- inet6Address = prefix.Addr()
- tunDevice.inet6Address = inet6Address
- protoAddr.Protocol = ipv6.ProtocolNumber
- }
- gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
- if gErr != nil {
- return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
- }
- }
- tunDevice.stack = ipStack
- if options.Handler != nil {
- ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
- ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
- icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
- icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
- ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
- ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
- }
- return tunDevice, nil
- }
- func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
- addr := tcpip.FullAddress{
- NIC: tun.DefaultNIC,
- Port: destination.Port,
- Addr: tun.AddressFromAddr(destination.Addr),
- }
- bind := tcpip.FullAddress{
- NIC: tun.DefaultNIC,
- }
- var networkProtocol tcpip.NetworkProtocolNumber
- if destination.IsIPv4() {
- if !w.inet4Address.IsValid() {
- return nil, E.New("missing IPv4 local address")
- }
- networkProtocol = header.IPv4ProtocolNumber
- bind.Addr = tun.AddressFromAddr(w.inet4Address)
- } else {
- if !w.inet6Address.IsValid() {
- return nil, E.New("missing IPv6 local address")
- }
- networkProtocol = header.IPv6ProtocolNumber
- bind.Addr = tun.AddressFromAddr(w.inet6Address)
- }
- switch N.NetworkName(network) {
- case N.NetworkTCP:
- tcpConn, err := DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
- if err != nil {
- return nil, err
- }
- return tcpConn, nil
- case N.NetworkUDP:
- udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
- if err != nil {
- return nil, err
- }
- return udpConn, nil
- default:
- return nil, E.Extend(N.ErrUnknownNetwork, network)
- }
- }
- func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
- bind := tcpip.FullAddress{
- NIC: tun.DefaultNIC,
- }
- var networkProtocol tcpip.NetworkProtocolNumber
- if destination.IsIPv4() {
- networkProtocol = header.IPv4ProtocolNumber
- bind.Addr = tun.AddressFromAddr(w.inet4Address)
- } else {
- networkProtocol = header.IPv6ProtocolNumber
- bind.Addr = tun.AddressFromAddr(w.inet4Address)
- }
- udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
- if err != nil {
- return nil, err
- }
- return udpConn, nil
- }
- func (w *stackDevice) Inet4Address() netip.Addr {
- return w.inet4Address
- }
- func (w *stackDevice) Inet6Address() netip.Addr {
- return w.inet6Address
- }
- func (w *stackDevice) SetDevice(device *device.Device) {
- }
- func (w *stackDevice) Start() error {
- w.events <- wgTun.EventUp
- return nil
- }
- func (w *stackDevice) File() *os.File {
- return nil
- }
- func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
- select {
- case packet, ok := <-w.outbound:
- if !ok {
- return 0, os.ErrClosed
- }
- defer packet.DecRef()
- var copyN int
- /*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) {
- copyN += copy(bufs[0][offset+copyN:], view.AsSlice())
- })*/
- for _, view := range packet.AsSlices() {
- copyN += copy(bufs[0][offset+copyN:], view)
- }
- sizes[0] = copyN
- return 1, nil
- case packet := <-w.packetOutbound:
- defer packet.Release()
- sizes[0] = copy(bufs[0][offset:], packet.Bytes())
- return 1, nil
- case <-w.done:
- return 0, os.ErrClosed
- }
- }
- func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
- for _, b := range bufs {
- b = b[offset:]
- if len(b) == 0 {
- continue
- }
- var networkProtocol tcpip.NetworkProtocolNumber
- switch header.IPVersion(b) {
- case header.IPv4Version:
- networkProtocol = header.IPv4ProtocolNumber
- case header.IPv6Version:
- networkProtocol = header.IPv6ProtocolNumber
- }
- packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Payload: buffer.MakeWithData(b),
- })
- w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
- packetBuffer.DecRef()
- count++
- }
- return
- }
- func (w *stackDevice) Flush() error {
- return nil
- }
- func (w *stackDevice) MTU() (int, error) {
- return int(w.mtu), nil
- }
- func (w *stackDevice) Name() (string, error) {
- return "sing-box", nil
- }
- func (w *stackDevice) Events() <-chan wgTun.Event {
- return w.events
- }
- func (w *stackDevice) Close() error {
- close(w.done)
- close(w.events)
- w.stack.Close()
- for _, endpoint := range w.stack.CleanupEndpoints() {
- endpoint.Abort()
- }
- w.stack.Wait()
- return nil
- }
- func (w *stackDevice) BatchSize() int {
- return 1
- }
- func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
- ctx := log.ContextWithNewID(w.ctx)
- destination, err := ping.ConnectGVisor(
- ctx, w.logger,
- metadata.Source.Addr, metadata.Destination.Addr,
- routeContext,
- w.stack,
- w.inet4Address, w.inet6Address,
- timeout,
- )
- if err != nil {
- return nil, err
- }
- w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
- return destination, nil
- }
- var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
- type wireEndpoint stackDevice
- func (ep *wireEndpoint) MTU() uint32 {
- return ep.mtu
- }
- func (ep *wireEndpoint) SetMTU(mtu uint32) {
- }
- func (ep *wireEndpoint) MaxHeaderLength() uint16 {
- return 0
- }
- func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
- return ""
- }
- func (ep *wireEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
- }
- func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityRXChecksumOffload
- }
- func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
- ep.dispatcher = dispatcher
- }
- func (ep *wireEndpoint) IsAttached() bool {
- return ep.dispatcher != nil
- }
- func (ep *wireEndpoint) Wait() {
- }
- func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
- return header.ARPHardwareNone
- }
- func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
- }
- func (ep *wireEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
- return true
- }
- func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
- for _, packetBuffer := range list.AsSlice() {
- packetBuffer.IncRef()
- select {
- case <-ep.done:
- return 0, &tcpip.ErrClosedForSend{}
- case ep.outbound <- packetBuffer:
- }
- }
- return list.Len(), nil
- }
- func (ep *wireEndpoint) Close() {
- }
- func (ep *wireEndpoint) SetOnCloseAction(f func()) {
- }
|