Prechádzať zdrojové kódy

Trojan-UoT & UDP-nameserver: Fix forgotten release buffer; UDP dispatcher: Simplified and optimized (#5050)

patterniha 3 mesiacov pred
rodič
commit
593ededd3e

+ 7 - 1
app/dns/nameserver_udp.go

@@ -90,7 +90,9 @@ func (s *ClassicNameServer) RequestsCleanup() error {
 
 
 // HandleResponse handles udp response packet from remote DNS server.
 // HandleResponse handles udp response packet from remote DNS server.
 func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
 func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
-	ipRec, err := parseResponse(packet.Payload.Bytes())
+	payload := packet.Payload
+	ipRec, err := parseResponse(payload.Bytes())
+	payload.Release()
 	if err != nil {
 	if err != nil {
 		errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
 		errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
 		return
 		return
@@ -125,6 +127,8 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 			newReq.msg = &newMsg
 			newReq.msg = &newMsg
 			s.addPendingRequest(&newReq)
 			s.addPendingRequest(&newReq)
 			b, _ := dns.PackMessage(newReq.msg)
 			b, _ := dns.PackMessage(newReq.msg)
+			copyDest := net.UDPDestination(s.address.Address, s.address.Port)
+			b.UDP = &copyDest
 			s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
 			s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
 			return
 			return
 		}
 		}
@@ -158,6 +162,8 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai
 		}
 		}
 		s.addPendingRequest(udpReq)
 		s.addPendingRequest(udpReq)
 		b, _ := dns.PackMessage(req.msg)
 		b, _ := dns.PackMessage(req.msg)
+		copyDest := net.UDPDestination(s.address.Address, s.address.Port)
+		b.UDP = &copyDest
 		s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
 		s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
 	}
 	}
 }
 }

+ 8 - 2
app/proxyman/outbound/handler.go

@@ -239,8 +239,10 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
 	}
 	}
 out:
 out:
 	err := h.proxy.Process(ctx, link, h)
 	err := h.proxy.Process(ctx, link, h)
+	var errC error
 	if err != nil {
 	if err != nil {
-		if goerrors.Is(err, io.EOF) || goerrors.Is(err, io.ErrClosedPipe) || goerrors.Is(err, context.Canceled) {
+		errC = errors.Cause(err)
+		if goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled) {
 			err = nil
 			err = nil
 		}
 		}
 	}
 	}
@@ -251,7 +253,11 @@ out:
 		errors.LogInfo(ctx, err.Error())
 		errors.LogInfo(ctx, err.Error())
 		common.Interrupt(link.Writer)
 		common.Interrupt(link.Writer)
 	} else {
 	} else {
-		common.Close(link.Writer)
+		if errC != nil && goerrors.Is(errC, io.ErrClosedPipe) {
+			common.Interrupt(link.Writer)
+		} else {
+			common.Close(link.Writer)
+		}
 	}
 	}
 	common.Interrupt(link.Reader)
 	common.Interrupt(link.Reader)
 }
 }

+ 11 - 6
common/mux/client.go

@@ -2,6 +2,7 @@ package mux
 
 
 import (
 import (
 	"context"
 	"context"
+	goerrors "errors"
 	"io"
 	"io"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -154,8 +155,11 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
 		ctx := session.ContextWithOutbounds(context.Background(), outbounds)
 		ctx := session.ContextWithOutbounds(context.Background(), outbounds)
 		ctx, cancel := context.WithCancel(ctx)
 		ctx, cancel := context.WithCancel(ctx)
 
 
-		if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
-			errors.LogInfoInner(ctx, err, "failed to handler mux client connection")
+		if errP := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); errP != nil {
+			errC := errors.Cause(errP)
+			if !(goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled)) {
+				errors.LogInfoInner(ctx, errP, "failed to handler mux client connection")
+			}
 		}
 		}
 		common.Must(c.Close())
 		common.Must(c.Close())
 		cancel()
 		cancel()
@@ -222,7 +226,7 @@ func (m *ClientWorker) monitor() {
 		select {
 		select {
 		case <-m.done.Wait():
 		case <-m.done.Wait():
 			m.sessionManager.Close()
 			m.sessionManager.Close()
-			common.Close(m.link.Writer)
+			common.Interrupt(m.link.Writer)
 			common.Interrupt(m.link.Reader)
 			common.Interrupt(m.link.Reader)
 			return
 			return
 		case <-m.timer.C:
 		case <-m.timer.C:
@@ -247,7 +251,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
 	return nil
 	return nil
 }
 }
 
 
-func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
+func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
 	outbounds := session.OutboundsFromContext(ctx)
 	outbounds := session.OutboundsFromContext(ctx)
 	ob := outbounds[len(outbounds)-1]
 	ob := outbounds[len(outbounds)-1]
 	transferType := protocol.TransferTypeStream
 	transferType := protocol.TransferTypeStream
@@ -258,6 +262,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
 	writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
 	writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
 	defer s.Close(false)
 	defer s.Close(false)
 	defer writer.Close()
 	defer writer.Close()
+	defer timer.Reset(time.Second * 16)
 
 
 	errors.LogInfo(ctx, "dispatching request to ", ob.Target)
 	errors.LogInfo(ctx, "dispatching request to ", ob.Target)
 	if err := writeFirstPayload(s.input, writer); err != nil {
 	if err := writeFirstPayload(s.input, writer); err != nil {
@@ -308,9 +313,9 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
 	s.input = link.Reader
 	s.input = link.Reader
 	s.output = link.Writer
 	s.output = link.Writer
 	if _, ok := link.Reader.(*pipe.Reader); ok {
 	if _, ok := link.Reader.(*pipe.Reader); ok {
-		go fetchInput(ctx, s, m.link.Writer)
+		go fetchInput(ctx, s, m.link.Writer, m.timer)
 	} else {
 	} else {
-		fetchInput(ctx, s, m.link.Writer)
+		fetchInput(ctx, s, m.link.Writer, m.timer)
 	}
 	}
 	return true
 	return true
 }
 }

+ 1 - 1
common/mux/server.go

@@ -318,8 +318,8 @@ func (w *ServerWorker) run(ctx context.Context) {
 	reader := &buf.BufferedReader{Reader: w.link.Reader}
 	reader := &buf.BufferedReader{Reader: w.link.Reader}
 
 
 	defer w.sessionManager.Close()
 	defer w.sessionManager.Close()
-	defer common.Close(w.link.Writer)
 	defer common.Interrupt(w.link.Reader)
 	defer common.Interrupt(w.link.Reader)
+	defer common.Interrupt(w.link.Writer)
 
 
 	for {
 	for {
 		select {
 		select {

+ 2 - 2
proxy/freedom/freedom.go

@@ -73,7 +73,7 @@ func isValidAddress(addr *net.IPOrDomain) bool {
 	}
 	}
 
 
 	a := addr.AsAddress()
 	a := addr.AsAddress()
-	return a != net.AnyIP
+	return a != net.AnyIP && a != net.AnyIPv6
 }
 }
 
 
 // Process implements proxy.Outbound.
 // Process implements proxy.Outbound.
@@ -418,7 +418,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 					}
 					}
 				}
 				}
 			}
 			}
-			destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
+			destAddr := b.UDP.RawNetAddr()
 			if destAddr == nil {
 			if destAddr == nil {
 				b.Release()
 				b.Release()
 				continue
 				continue

+ 3 - 0
proxy/proxy.go

@@ -636,6 +636,9 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
 			}
 			}
 		}
 		}
 		if err != nil {
 		if err != nil {
+			if errors.Cause(err) == io.EOF {
+				return nil
+			}
 			return err
 			return err
 		}
 		}
 	}
 	}

+ 3 - 3
proxy/shadowsocks/server.go

@@ -104,12 +104,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
 func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
 	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
 	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
 		request := protocol.RequestHeaderFromContext(ctx)
 		request := protocol.RequestHeaderFromContext(ctx)
+		payload := packet.Payload
 		if request == nil {
 		if request == nil {
+			payload.Release()
 			return
 			return
 		}
 		}
 
 
-		payload := packet.Payload
-
 		if payload.UDP != nil {
 		if payload.UDP != nil {
 			request = &protocol.RequestHeader{
 			request = &protocol.RequestHeader{
 				User:    request.User,
 				User:    request.User,
@@ -124,9 +124,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
 			errors.LogWarningInner(ctx, err, "failed to encode UDP packet")
 			errors.LogWarningInner(ctx, err, "failed to encode UDP packet")
 			return
 			return
 		}
 		}
-		defer data.Release()
 
 
 		conn.Write(data.Bytes())
 		conn.Write(data.Bytes())
+		data.Release()
 	})
 	})
 	defer udpServer.RemoveRay()
 	defer udpServer.RemoveRay()
 
 

+ 2 - 1
proxy/socks/server.go

@@ -196,6 +196,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
 
 
 		request := protocol.RequestHeaderFromContext(ctx)
 		request := protocol.RequestHeaderFromContext(ctx)
 		if request == nil {
 		if request == nil {
+			payload.Release()
 			return
 			return
 		}
 		}
 
 
@@ -214,9 +215,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
 			errors.LogWarningInner(ctx, err, "failed to write UDP response")
 			errors.LogWarningInner(ctx, err, "failed to write UDP response")
 			return
 			return
 		}
 		}
-		defer udpMessage.Release()
 
 
 		conn.Write(udpMessage.Bytes())
 		conn.Write(udpMessage.Bytes())
+		udpMessage.Release()
 	})
 	})
 	defer udpServer.RemoveRay()
 	defer udpServer.RemoveRay()
 
 

+ 2 - 0
proxy/trojan/protocol.go

@@ -113,9 +113,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 			target = b.UDP
 			target = b.UDP
 		}
 		}
 		if _, err := w.writePacket(b.Bytes(), *target); err != nil {
 		if _, err := w.writePacket(b.Bytes(), *target); err != nil {
+			b.Release()
 			buf.ReleaseMulti(mb)
 			buf.ReleaseMulti(mb)
 			return err
 			return err
 		}
 		}
+		b.Release()
 	}
 	}
 	return nil
 	return nil
 }
 }

+ 36 - 24
transport/internet/udp/dispatcher.go

@@ -22,8 +22,24 @@ type ResponseCallback func(ctx context.Context, packet *udp.Packet)
 
 
 type connEntry struct {
 type connEntry struct {
 	link   *transport.Link
 	link   *transport.Link
-	timer  signal.ActivityUpdater
+	timer  *signal.ActivityTimer
 	cancel context.CancelFunc
 	cancel context.CancelFunc
+	closed bool
+}
+
+func (c *connEntry) Close() error {
+	c.timer.SetTimeout(0)
+	return nil
+}
+
+func (c *connEntry) terminate() {
+	if c.closed {
+		panic("terminate called more than once")
+	}
+	c.closed = true
+	c.cancel()
+	common.Interrupt(c.link.Reader)
+	common.Interrupt(c.link.Writer)
 }
 }
 
 
 type Dispatcher struct {
 type Dispatcher struct {
@@ -32,6 +48,7 @@ type Dispatcher struct {
 	dispatcher routing.Dispatcher
 	dispatcher routing.Dispatcher
 	callback   ResponseCallback
 	callback   ResponseCallback
 	callClose  func() error
 	callClose  func() error
+	closed     bool
 }
 }
 
 
 func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
 func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
@@ -44,13 +61,9 @@ func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Di
 func (v *Dispatcher) RemoveRay() {
 func (v *Dispatcher) RemoveRay() {
 	v.Lock()
 	v.Lock()
 	defer v.Unlock()
 	defer v.Unlock()
-	v.removeRay()
-}
-
-func (v *Dispatcher) removeRay() {
+	v.closed = true
 	if v.conn != nil {
 	if v.conn != nil {
-		common.Interrupt(v.conn.link.Reader)
-		common.Close(v.conn.link.Writer)
+		v.conn.Close()
 		v.conn = nil
 		v.conn = nil
 	}
 	}
 }
 }
@@ -59,35 +72,34 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
 	v.Lock()
 	v.Lock()
 	defer v.Unlock()
 	defer v.Unlock()
 
 
+	if v.closed {
+		return nil, errors.New("dispatcher is closed")
+	}
+
 	if v.conn != nil {
 	if v.conn != nil {
-		return v.conn, nil
+		if v.conn.closed {
+			v.conn = nil
+		} else {
+			return v.conn, nil
+		}
 	}
 	}
 
 
 	errors.LogInfo(ctx, "establishing new connection for ", dest)
 	errors.LogInfo(ctx, "establishing new connection for ", dest)
 
 
 	ctx, cancel := context.WithCancel(ctx)
 	ctx, cancel := context.WithCancel(ctx)
-	entry := &connEntry{}
-	removeRay := func() {
-		v.Lock()
-		defer v.Unlock()
-		// sometimes the entry is already removed by others, don't close again
-		if entry == v.conn {
-			cancel()
-			v.removeRay()
-		}
-	}
-	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
 
 
 	link, err := v.dispatcher.Dispatch(ctx, dest)
 	link, err := v.dispatcher.Dispatch(ctx, dest)
 	if err != nil {
 	if err != nil {
+		cancel()
 		return nil, errors.New("failed to dispatch request to ", dest).Base(err)
 		return nil, errors.New("failed to dispatch request to ", dest).Base(err)
 	}
 	}
 
 
-	*entry = connEntry{
+	entry := &connEntry{
 		link:   link,
 		link:   link,
-		timer:  timer,
-		cancel: removeRay,
+		cancel: cancel,
 	}
 	}
+
+	entry.timer = signal.CancelAfterInactivity(ctx, entry.terminate, time.Minute)
 	v.conn = entry
 	v.conn = entry
 	go handleInput(ctx, entry, dest, v.callback, v.callClose)
 	go handleInput(ctx, entry, dest, v.callback, v.callClose)
 	return entry, nil
 	return entry, nil
@@ -106,7 +118,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
 	if outputStream != nil {
 	if outputStream != nil {
 		if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
 		if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
 			errors.LogInfoInner(ctx, err, "failed to write first UDP payload")
 			errors.LogInfoInner(ctx, err, "failed to write first UDP payload")
-			conn.cancel()
+			conn.Close()
 			return
 			return
 		}
 		}
 	}
 	}
@@ -114,7 +126,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
 
 
 func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
 func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
 	defer func() {
 	defer func() {
-		conn.cancel()
+		conn.Close()
 		if callClose != nil {
 		if callClose != nil {
 			callClose()
 			callClose()
 		}
 		}

+ 8 - 5
transport/pipe/impl.go

@@ -200,16 +200,19 @@ func (p *pipe) Interrupt() {
 	p.Lock()
 	p.Lock()
 	defer p.Unlock()
 	defer p.Unlock()
 
 
+	if !p.data.IsEmpty() {
+		buf.ReleaseMulti(p.data)
+		p.data = nil
+		if p.state == closed {
+			p.state = errord
+		}
+	}
+
 	if p.state == closed || p.state == errord {
 	if p.state == closed || p.state == errord {
 		return
 		return
 	}
 	}
 
 
 	p.state = errord
 	p.state = errord
 
 
-	if !p.data.IsEmpty() {
-		buf.ReleaseMulti(p.data)
-		p.data = nil
-	}
-
 	common.Must(p.done.Close())
 	common.Must(p.done.Close())
 }
 }