| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 | 
							- package wireguard
 
- import (
 
- 	"context"
 
- 	goerrors "errors"
 
- 	"io"
 
- 	"github.com/xtls/xray-core/common"
 
- 	"github.com/xtls/xray-core/common/buf"
 
- 	"github.com/xtls/xray-core/common/errors"
 
- 	"github.com/xtls/xray-core/common/log"
 
- 	"github.com/xtls/xray-core/common/net"
 
- 	"github.com/xtls/xray-core/common/session"
 
- 	"github.com/xtls/xray-core/common/signal"
 
- 	"github.com/xtls/xray-core/common/task"
 
- 	"github.com/xtls/xray-core/core"
 
- 	"github.com/xtls/xray-core/features/dns"
 
- 	"github.com/xtls/xray-core/features/policy"
 
- 	"github.com/xtls/xray-core/features/routing"
 
- 	"github.com/xtls/xray-core/transport/internet/stat"
 
- )
 
- var nullDestination = net.TCPDestination(net.AnyIP, 0)
 
- type Server struct {
 
- 	bindServer *netBindServer
 
- 	info          routingInfo
 
- 	policyManager policy.Manager
 
- }
 
- type routingInfo struct {
 
- 	ctx         context.Context
 
- 	dispatcher  routing.Dispatcher
 
- 	inboundTag  *session.Inbound
 
- 	outboundTag *session.Outbound
 
- 	contentTag  *session.Content
 
- }
 
- func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
 
- 	v := core.MustFromContext(ctx)
 
- 	endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
 
- 	if err != nil {
 
- 		return nil, err
 
- 	}
 
- 	server := &Server{
 
- 		bindServer: &netBindServer{
 
- 			netBind: netBind{
 
- 				dns: v.GetFeature(dns.ClientType()).(dns.Client),
 
- 				dnsOption: dns.IPOption{
 
- 					IPv4Enable: hasIPv4,
 
- 					IPv6Enable: hasIPv6,
 
- 				},
 
- 			},
 
- 		},
 
- 		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
 
- 	}
 
- 	tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
 
- 	if err != nil {
 
- 		return nil, err
 
- 	}
 
- 	if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
 
- 		_ = tun.Close()
 
- 		return nil, err
 
- 	}
 
- 	return server, nil
 
- }
 
- // Network implements proxy.Inbound.
 
- func (*Server) Network() []net.Network {
 
- 	return []net.Network{net.Network_UDP}
 
- }
 
- // Process implements proxy.Inbound.
 
- func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
 
- 	inbound := session.InboundFromContext(ctx)
 
- 	inbound.Name = "wireguard"
 
- 	inbound.CanSpliceCopy = 3
 
- 	outbounds := session.OutboundsFromContext(ctx)
 
- 	ob := outbounds[len(outbounds)-1]
 
- 	s.info = routingInfo{
 
- 		ctx:         core.ToBackgroundDetachedContext(ctx),
 
- 		dispatcher:  dispatcher,
 
- 		inboundTag:  session.InboundFromContext(ctx),
 
- 		outboundTag: ob,
 
- 		contentTag:  session.ContentFromContext(ctx),
 
- 	}
 
- 	ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
 
- 	if err != nil {
 
- 		return err
 
- 	}
 
- 	nep := ep.(*netEndpoint)
 
- 	nep.conn = conn
 
- 	reader := buf.NewPacketReader(conn)
 
- 	for {
 
- 		mpayload, err := reader.ReadMultiBuffer()
 
- 		if err != nil {
 
- 			return err
 
- 		}
 
- 		for _, payload := range mpayload {
 
- 			v, ok := <-s.bindServer.readQueue
 
- 			if !ok {
 
- 				return nil
 
- 			}
 
- 			i, err := payload.Read(v.buff)
 
- 			v.bytes = i
 
- 			v.endpoint = nep
 
- 			v.err = err
 
- 			v.waiter.Done()
 
- 			if err != nil && goerrors.Is(err, io.EOF) {
 
- 				nep.conn = nil
 
- 				return nil
 
- 			}
 
- 		}
 
- 	}
 
- }
 
- func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 
- 	if s.info.dispatcher == nil {
 
- 		errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
 
- 		return
 
- 	}
 
- 	defer conn.Close()
 
- 	ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
 
- 	plcy := s.policyManager.ForLevel(0)
 
- 	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
 
- 	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
 
- 		From:   nullDestination,
 
- 		To:     dest,
 
- 		Status: log.AccessAccepted,
 
- 		Reason: "",
 
- 	})
 
- 	// what's this?
 
- 	// Session information should not be shared between different connections
 
- 	// why reuse them in server level? This will cause incorrect destoverride and unexpected routing behavior.
 
- 	// Disable it temporarily. Maybe s.info should be removed.
 
- 	//  if s.info.inboundTag != nil {
 
- 	//  ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
 
- 	//  }
 
- 	//	if s.info.outboundTag != nil {
 
- 	//		ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{s.info.outboundTag})
 
- 	//	}
 
- 	//  if s.info.contentTag != nil {
 
- 	//	    ctx = session.ContextWithContent(ctx, s.info.contentTag)
 
- 	//  }
 
- 	link, err := s.info.dispatcher.Dispatch(ctx, dest)
 
- 	if err != nil {
 
- 		errors.LogErrorInner(s.info.ctx, err, "dispatch connection")
 
- 	}
 
- 	defer cancel()
 
- 	requestDone := func() error {
 
- 		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
 
- 		if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
 
- 			return errors.New("failed to transport all TCP request").Base(err)
 
- 		}
 
- 		return nil
 
- 	}
 
- 	responseDone := func() error {
 
- 		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
 
- 		if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
 
- 			return errors.New("failed to transport all TCP response").Base(err)
 
- 		}
 
- 		return nil
 
- 	}
 
- 	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
 
- 	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
 
- 		common.Interrupt(link.Reader)
 
- 		common.Interrupt(link.Writer)
 
- 		errors.LogDebugInner(s.info.ctx, err, "connection ends")
 
- 		return
 
- 	}
 
- }
 
 
  |