浏览代码

Disable VMess drain when not pure connection

RPRX 4 年之前
父节点
当前提交
f390047b37
共有 3 个文件被更改,包括 16 次插入7 次删除
  1. 4 4
      proxy/vmess/encoding/encoding_test.go
  2. 2 2
      proxy/vmess/encoding/server.go
  3. 10 1
      proxy/vmess/inbound/inbound.go

+ 4 - 4
proxy/vmess/encoding/encoding_test.go

@@ -57,14 +57,14 @@ func TestRequestSerialization(t *testing.T) {
 	defer common.Close(userValidator)
 
 	server := NewServerSession(userValidator, sessionHistory)
-	actualRequest, err := server.DecodeRequestHeader(buffer)
+	actualRequest, err := server.DecodeRequestHeader(buffer, false)
 	common.Must(err)
 
 	if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
 		t.Error(r)
 	}
 
-	_, err = server.DecodeRequestHeader(buffer2)
+	_, err = server.DecodeRequestHeader(buffer2, false)
 	// anti replay attack
 	if err == nil {
 		t.Error("nil error")
@@ -107,7 +107,7 @@ func TestInvalidRequest(t *testing.T) {
 	defer common.Close(userValidator)
 
 	server := NewServerSession(userValidator, sessionHistory)
-	_, err := server.DecodeRequestHeader(buffer)
+	_, err := server.DecodeRequestHeader(buffer, false)
 	if err == nil {
 		t.Error("nil error")
 	}
@@ -148,7 +148,7 @@ func TestMuxRequest(t *testing.T) {
 	defer common.Close(userValidator)
 
 	server := NewServerSession(userValidator, sessionHistory)
-	actualRequest, err := server.DecodeRequestHeader(buffer)
+	actualRequest, err := server.DecodeRequestHeader(buffer, false)
 	common.Must(err)
 
 	if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {

+ 2 - 2
proxy/vmess/encoding/server.go

@@ -131,7 +131,7 @@ func parseSecurityType(b byte) protocol.SecurityType {
 }
 
 // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
-func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
+func (s *ServerSession) DecodeRequestHeader(reader io.Reader, isDrain bool) (*protocol.RequestHeader, error) {
 	buffer := buf.New()
 	behaviorRand := dice.NewDeterministicDice(int64(s.userValidator.GetBehaviorSeed()))
 	BaseDrainSize := behaviorRand.Roll(3266)
@@ -143,7 +143,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	drainConnection := func(e error) error {
 		// We read a deterministic generated length of data before closing the connection to offset padding read pattern
 		readSizeRemain -= int(buffer.Len())
-		if readSizeRemain > 0 {
+		if readSizeRemain > 0 && isDrain {
 			err := s.DrainConnN(reader, readSizeRemain)
 			if err != nil {
 				return newError("failed to drain connection DrainSize = ", BaseDrainSize, " ", RandDrainMax, " ", RandDrainRolled).Base(err).Base(e)

+ 10 - 1
proxy/vmess/inbound/inbound.go

@@ -220,9 +220,18 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
 		return newError("unable to set read deadline").Base(err).AtWarning()
 	}
 
+	iConn := connection
+	if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
+		iConn = statConn.Connection
+	}
+	_, isDrain := iConn.(*net.TCPConn)
+	if !isDrain {
+		_, isDrain = iConn.(*net.UnixConn)
+	}
+
 	reader := &buf.BufferedReader{Reader: buf.NewReader(connection)}
 	svrSession := encoding.NewServerSession(h.clients, h.sessionHistory)
-	request, err := svrSession.DecodeRequestHeader(reader)
+	request, err := svrSession.DecodeRequestHeader(reader, isDrain)
 	if err != nil {
 		if errors.Cause(err) != io.EOF {
 			log.Record(&log.AccessMessage{