Browse Source

Fix QUIC defragger

世界 2 years ago
parent
commit
ff209471d8

+ 0 - 62
common/baderror/baderror.go

@@ -1,62 +0,0 @@
-package baderror
-
-import (
-	"context"
-	"io"
-	"net"
-	"strings"
-
-	E "github.com/sagernet/sing/common/exceptions"
-)
-
-func Contains(err error, msgList ...string) bool {
-	for _, msg := range msgList {
-		if strings.Contains(err.Error(), msg) {
-			return true
-		}
-	}
-	return false
-}
-
-func WrapH2(err error) error {
-	if err == nil {
-		return nil
-	}
-	err = E.Unwrap(err)
-	if err == io.ErrUnexpectedEOF {
-		return io.EOF
-	}
-	if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
-		return net.ErrClosed
-	}
-	return err
-}
-
-func WrapGRPC(err error) error {
-	// grpc uses stupid internal error types
-	if err == nil {
-		return nil
-	}
-	if Contains(err, "EOF") {
-		return io.EOF
-	}
-	if Contains(err, "Canceled") {
-		return context.Canceled
-	}
-	if Contains(err,
-		"the client connection is closing",
-		"server closed the stream without sending trailers") {
-		return net.ErrClosed
-	}
-	return err
-}
-
-func WrapQUIC(err error) error {
-	if err == nil {
-		return nil
-	}
-	if Contains(err, "canceled by local with error code 0") {
-		return net.ErrClosed
-	}
-	return err
-}

+ 1 - 1
transport/hysteria/wrap.go

@@ -6,8 +6,8 @@ import (
 	"syscall"
 
 	"github.com/sagernet/quic-go"
-	"github.com/sagernet/sing-box/common/baderror"
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/baderror"
 )
 
 type PacketConnWrapper struct {

+ 1 - 1
transport/tuic/client.go

@@ -10,8 +10,8 @@ import (
 	"time"
 
 	"github.com/sagernet/quic-go"
-	"github.com/sagernet/sing-box/common/baderror"
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/baderror"
 	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"

+ 2 - 2
transport/tuic/client_packet.go

@@ -34,7 +34,7 @@ func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
 	}
 	switch data[1] {
 	case CommandPacket:
-		message := udpMessagePool.Get().(*udpMessage)
+		message := allocMessage()
 		err := decodeUDPMessage(message, data[2:])
 		if err != nil {
 			message.release()
@@ -82,7 +82,7 @@ func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.Receive
 		return E.New("unknown command ", command)
 	}
 	reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
-	message := udpMessagePool.Get().(*udpMessage)
+	message := allocMessage()
 	err = readUDPMessage(message, reader)
 	if err != nil {
 		message.release()

+ 16 - 6
transport/tuic/packet.go

@@ -27,11 +27,16 @@ var udpMessagePool = sync.Pool{
 	},
 }
 
+func allocMessage() *udpMessage {
+	message := udpMessagePool.Get().(*udpMessage)
+	message.referenced = true
+	return message
+}
+
 func releaseMessages(messages []*udpMessage) {
 	for _, message := range messages {
 		if message != nil {
-			*message = udpMessage{}
-			udpMessagePool.Put(message)
+			message.release()
 		}
 	}
 }
@@ -43,9 +48,13 @@ type udpMessage struct {
 	fragmentID    uint8
 	destination   M.Socksaddr
 	data          *buf.Buffer
+	referenced    bool
 }
 
 func (m *udpMessage) release() {
+	if !m.referenced {
+		return
+	}
 	*m = udpMessage{}
 	udpMessagePool.Put(m)
 }
@@ -83,7 +92,7 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
 	originPacket := message.data.Bytes()
 	udpMTU := maxPacketSize - message.headerSize()
 	for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
-		fragment := udpMessagePool.Get().(*udpMessage)
+		fragment := allocMessage()
 		*fragment = *message
 		if remaining > udpMTU {
 			fragment.data = buf.As(originPacket[:udpMTU])
@@ -214,7 +223,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
 		c.packetId.Store(0)
 		packetId = 0
 	}
-	message := udpMessagePool.Get().(*udpMessage)
+	message := allocMessage()
 	*message = udpMessage{
 		sessionID:     c.sessionID,
 		packetID:      uint16(packetId),
@@ -259,7 +268,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
 		c.packetId.Store(0)
 		packetId = 0
 	}
-	message := udpMessagePool.Get().(*udpMessage)
+	message := allocMessage()
 	*message = udpMessage{
 		sessionID:     c.sessionID,
 		packetID:      uint16(packetId),
@@ -431,7 +440,7 @@ func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
 	if int(item.count) != len(item.messages) {
 		return nil
 	}
-	newMessage := udpMessagePool.Get().(*udpMessage)
+	newMessage := allocMessage()
 	*newMessage = *item.messages[0]
 	var dataLength uint16
 	for _, message := range item.messages {
@@ -446,6 +455,7 @@ func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
 		item.messages = nil
 		return newMessage
 	}
+	item.messages = nil
 	return nil
 }
 

+ 2 - 2
transport/tuic/server.go

@@ -13,9 +13,9 @@ import (
 	"time"
 
 	"github.com/sagernet/quic-go"
-	"github.com/sagernet/sing-box/common/baderror"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/auth"
+	"github.com/sagernet/sing/common/baderror"
 	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
@@ -264,7 +264,7 @@ func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
 			return s.connErr
 		case <-s.authDone:
 		}
-		message := udpMessagePool.Get().(*udpMessage)
+		message := allocMessage()
 		err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
 		if err != nil {
 			message.release()

+ 1 - 1
transport/tuic/server_packet.go

@@ -35,7 +35,7 @@ func (s *serverSession) handleMessage(data []byte) error {
 	}
 	switch data[1] {
 	case CommandPacket:
-		message := udpMessagePool.Get().(*udpMessage)
+		message := allocMessage()
 		err := decodeUDPMessage(message, data[2:])
 		if err != nil {
 			message.release()

+ 1 - 1
transport/v2raygrpc/conn.go

@@ -5,8 +5,8 @@ import (
 	"os"
 	"time"
 
-	"github.com/sagernet/sing-box/common/baderror"
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/baderror"
 	M "github.com/sagernet/sing/common/metadata"
 	"github.com/sagernet/sing/common/rw"
 )

+ 1 - 1
transport/v2raygrpclite/conn.go

@@ -11,8 +11,8 @@ import (
 	"sync"
 	"time"
 
-	"github.com/sagernet/sing-box/common/baderror"
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/baderror"
 	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	M "github.com/sagernet/sing/common/metadata"

+ 1 - 1
transport/v2rayhttp/conn.go

@@ -10,8 +10,8 @@ import (
 	"sync"
 	"time"
 
-	"github.com/sagernet/sing-box/common/baderror"
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/baderror"
 	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"