瀏覽代碼

fix(proxy): removed the udp payload length check when encryption is disabled

cty123 2 年之前
父節點
當前提交
a343d68944
共有 3 個文件被更改,包括 106 次插入52 次删除
  1. 34 28
      proxy/shadowsocks/protocol.go
  2. 67 24
      proxy/shadowsocks/protocol_test.go
  3. 5 0
      proxy/shadowsocks/validator.go

+ 34 - 28
proxy/shadowsocks/protocol.go

@@ -4,6 +4,7 @@ import (
 	"crypto/hmac"
 	"crypto/rand"
 	"crypto/sha256"
+	"errors"
 	"hash/crc32"
 	"io"
 
@@ -236,37 +237,37 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
 }
 
 func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
-	bs := payload.Bytes()
-	if len(bs) <= 32 {
-		return nil, nil, newError("len(bs) <= 32")
-	}
+	rawPayload := payload.Bytes()
+	user, _, d, _, err := validator.Get(rawPayload, protocol.RequestCommandUDP)
 
-	user, _, d, _, err := validator.Get(bs, protocol.RequestCommandUDP)
-	switch err {
-	case ErrIVNotUnique:
+	if errors.Is(err, ErrIVNotUnique) {
 		return nil, nil, newError("failed iv check").Base(err)
-	case ErrNotFound:
+	}
+
+	if errors.Is(err, ErrNotFound) {
 		return nil, nil, newError("failed to match an user").Base(err)
-	default:
-		account := user.Account.(*MemoryAccount)
-		if account.Cipher.IsAEAD() {
-			payload.Clear()
-			payload.Write(d)
-		} else {
-			if account.Cipher.IVSize() > 0 {
-				iv := make([]byte, account.Cipher.IVSize())
-				copy(iv, payload.BytesTo(account.Cipher.IVSize()))
-			}
-			if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
-				return nil, nil, newError("failed to decrypt UDP payload").Base(err)
-			}
-		}
 	}
 
-	request := &protocol.RequestHeader{
-		Version: Version,
-		User:    user,
-		Command: protocol.RequestCommandUDP,
+	if err != nil {
+		return nil, nil, newError("unexpected error").Base(err)
+	}
+
+	account, ok := user.Account.(*MemoryAccount)
+	if !ok {
+		return nil, nil, newError("expected MemoryAccount returned from validator")
+	}
+
+	if account.Cipher.IsAEAD() {
+		payload.Clear()
+		payload.Write(d)
+	} else {
+		if account.Cipher.IVSize() > 0 {
+			iv := make([]byte, account.Cipher.IVSize())
+			copy(iv, payload.BytesTo(account.Cipher.IVSize()))
+		}
+		if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
+			return nil, nil, newError("failed to decrypt UDP payload").Base(err)
+		}
 	}
 
 	payload.SetByte(0, payload.Byte(0)&0x0F)
@@ -276,8 +277,13 @@ func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.Reque
 		return nil, nil, newError("failed to parse address").Base(err)
 	}
 
-	request.Address = addr
-	request.Port = port
+	request := &protocol.RequestHeader{
+		Version: Version,
+		User:    user,
+		Command: protocol.RequestCommandUDP,
+		Address: addr,
+		Port:    port,
+	}
 
 	return request, payload, nil
 }

+ 67 - 24
proxy/shadowsocks/protocol_test.go

@@ -23,37 +23,80 @@ func equalRequestHeader(x, y *protocol.RequestHeader) bool {
 	}))
 }
 
-func TestUDPEncoding(t *testing.T) {
-	request := &protocol.RequestHeader{
-		Version: Version,
-		Command: protocol.RequestCommandUDP,
-		Address: net.LocalHostIP,
-		Port:    1234,
-		User: &protocol.MemoryUser{
-			Email: "[email protected]",
-			Account: toAccount(&Account{
-				Password:   "password",
-				CipherType: CipherType_AES_128_GCM,
-			}),
+func TestUDPEncodingDecoding(t *testing.T) {
+	testRequests := []protocol.RequestHeader{
+		{
+			Version: Version,
+			Command: protocol.RequestCommandUDP,
+			Address: net.LocalHostIP,
+			Port:    1234,
+			User: &protocol.MemoryUser{
+				Email: "[email protected]",
+				Account: toAccount(&Account{
+					Password:   "password",
+					CipherType: CipherType_AES_128_GCM,
+				}),
+			},
+		},
+		{
+			Version: Version,
+			Command: protocol.RequestCommandUDP,
+			Address: net.LocalHostIP,
+			Port:    1234,
+			User: &protocol.MemoryUser{
+				Email: "[email protected]",
+				Account: toAccount(&Account{
+					Password:   "123",
+					CipherType: CipherType_NONE,
+				}),
+			},
 		},
 	}
 
-	data := buf.New()
-	common.Must2(data.WriteString("test string"))
-	encodedData, err := EncodeUDPPacket(request, data.Bytes())
-	common.Must(err)
+	for _, request := range testRequests {
+		data := buf.New()
+		common.Must2(data.WriteString("test string"))
+		encodedData, err := EncodeUDPPacket(&request, data.Bytes())
+		common.Must(err)
 
-	validator := new(Validator)
-	validator.Add(request.User)
-	decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData)
-	common.Must(err)
+		validator := new(Validator)
+		validator.Add(request.User)
+		decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData)
+		common.Must(err)
 
-	if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" {
-		t.Error("data: ", r)
+		if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" {
+			t.Error("data: ", r)
+		}
+
+		if equalRequestHeader(decodedRequest, &request) == false {
+			t.Error("different request")
+		}
 	}
+}
 
-	if equalRequestHeader(decodedRequest, request) == false {
-		t.Error("different request")
+func TestUDPDecodingWithPayloadTooShort(t *testing.T) {
+	testAccounts := []protocol.Account{
+		toAccount(&Account{
+			Password:   "password",
+			CipherType: CipherType_AES_128_GCM,
+		}),
+		toAccount(&Account{
+			Password:   "password",
+			CipherType: CipherType_NONE,
+		}),
+	}
+
+	for _, account := range testAccounts {
+		data := buf.New()
+		data.WriteString("short payload")
+		validator := new(Validator)
+		validator.Add(&protocol.MemoryUser{
+			Account: account,
+		})
+		_, _, err := DecodeUDPPacket(validator, data)
+		if err == nil {
+			t.Fatal("expected error")
+		}
 	}
 }
 

+ 5 - 0
proxy/shadowsocks/validator.go

@@ -80,6 +80,11 @@ func (v *Validator) Get(bs []byte, command protocol.RequestCommand) (u *protocol
 
 	for _, user := range v.users {
 		if account := user.Account.(*MemoryAccount); account.Cipher.IsAEAD() {
+			// AEAD payload decoding requires the payload to be over 32 bytes
+			if len(bs) < 32 {
+				continue
+			}
+
 			aeadCipher := account.Cipher.(*AEADCipher)
 			ivLen = aeadCipher.IVSize()
 			iv := bs[:ivLen]