Browse Source

Refactor: Add Shadowsocks Validator (#233)

秋のかえで 4 years ago
parent
commit
df39991bb3

+ 27 - 0
proxy/shadowsocks/config.go

@@ -7,6 +7,8 @@ import (
 	"crypto/md5"
 	"crypto/sha1"
 	"io"
+	"reflect"
+	"strconv"
 
 	"golang.org/x/crypto/chacha20poly1305"
 	"golang.org/x/crypto/hkdf"
@@ -31,6 +33,31 @@ func (a *MemoryAccount) Equals(another protocol.Account) bool {
 	return false
 }
 
+func (a *MemoryAccount) GetCipherName() string {
+	switch a.Cipher.(type) {
+	case *AesCfb:
+		keyBytes := a.Cipher.(*AesCfb).KeyBytes
+		return "AES_" + strconv.FormatInt(int64(keyBytes*8), 10) + "_CFB"
+	case *ChaCha20:
+		if a.Cipher.(*ChaCha20).IVBytes == 8 {
+			return "CHACHA20"
+		}
+		return "CHACHA20_IETF"
+	case *AEADCipher:
+		switch reflect.ValueOf(a.Cipher.(*AEADCipher).AEADAuthCreator).Pointer() {
+		case reflect.ValueOf(createAesGcm).Pointer():
+			keyBytes := a.Cipher.(*AEADCipher).KeyBytes
+			return "AES_" + strconv.FormatInt(int64(keyBytes*8), 10) + "_GCM"
+		case reflect.ValueOf(createChacha20Poly1305).Pointer():
+			return "CHACHA20_POLY1305"
+		}
+	case *NoneCipher:
+		return "NONE"
+	}
+
+	return ""
+}
+
 func createAesGcm(key []byte) cipher.AEAD {
 	block, err := aes.NewCipher(key)
 	common.Must(err)

+ 54 - 75
proxy/shadowsocks/protocol.go

@@ -54,12 +54,9 @@ func (r *FullReader) Read(p []byte) (n int, err error) {
 }
 
 // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
-func ReadTCPSession(users []*protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
-	user := users[0]
-	account := user.Account.(*MemoryAccount)
+func ReadTCPSession(validator *Validator, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
 
 	hashkdf := hmac.New(sha256.New, []byte("SSBSKDF"))
-	hashkdf.Write(account.Key)
 
 	behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil))
 
@@ -71,10 +68,20 @@ func ReadTCPSession(users []*protocol.MemoryUser, reader io.Reader) (*protocol.R
 	readSizeRemain := DrainSize
 
 	var r2 buf.Reader
+	buffer := buf.New()
+	defer buffer.Release()
 
-	if len(users) > 1 {
-		buffer := buf.New()
-		defer buffer.Release()
+	var user *protocol.MemoryUser
+	var ivLen int32
+	var err error
+
+	count := validator.Count()
+	if count == 0 {
+		readSizeRemain -= int(buffer.Len())
+		DrainConnN(reader, readSizeRemain)
+		return nil, nil, newError("invalid user")
+	} else if count > 1 {
+		var aead cipher.AEAD
 
 		if _, err := buffer.ReadFullFrom(reader, 50); err != nil {
 			readSizeRemain -= int(buffer.Len())
@@ -83,45 +90,26 @@ func ReadTCPSession(users []*protocol.MemoryUser, reader io.Reader) (*protocol.R
 		}
 
 		bs := buffer.Bytes()
+		user, aead, _, ivLen, err = validator.Get(bs, protocol.RequestCommandTCP)
 
-		var aeadCipher *AEADCipher
-		var ivLen int32
-		subkey := make([]byte, 32)
-		length := make([]byte, 16)
-		var aead cipher.AEAD
-		var err error
-		for _, user = range users {
-			account = user.Account.(*MemoryAccount)
-			aeadCipher = account.Cipher.(*AEADCipher)
-			ivLen = aeadCipher.IVSize()
-			subkey = subkey[:aeadCipher.KeyBytes]
-			hkdfSHA1(account.Key, bs[:ivLen], subkey)
-			aead = aeadCipher.AEADAuthCreator(subkey)
-			_, err = aead.Open(length[:0], length[4:16], bs[ivLen:ivLen+18], nil)
-			if err == nil {
-				reader = &FullReader{reader, bs[ivLen:]}
-				auth := &crypto.AEADAuthenticator{
-					AEAD:           aead,
-					NonceGenerator: crypto.GenerateInitialAEADNonce(),
-				}
-				r2 = crypto.NewAuthenticationReader(auth, &crypto.AEADChunkSizeParser{
-					Auth: auth,
-				}, reader, protocol.TransferTypeStream, nil)
-				break
+		if user != nil {
+			reader = &FullReader{reader, bs[ivLen:]}
+			auth := &crypto.AEADAuthenticator{
+				AEAD:           aead,
+				NonceGenerator: crypto.GenerateInitialAEADNonce(),
 			}
-		}
-		if err != nil {
+			r2 = crypto.NewAuthenticationReader(auth, &crypto.AEADChunkSizeParser{
+				Auth: auth,
+			}, reader, protocol.TransferTypeStream, nil)
+		} else {
 			readSizeRemain -= int(buffer.Len())
 			DrainConnN(reader, readSizeRemain)
 			return nil, nil, newError("failed to match an user").Base(err)
 		}
-	}
-
-	buffer := buf.New()
-	defer buffer.Release()
-
-	if r2 == nil {
-		ivLen := account.Cipher.IVSize()
+	} else {
+		user, ivLen = validator.GetOnlyUser()
+		account := user.Account.(*MemoryAccount)
+		hashkdf.Write(account.Key)
 		var iv []byte
 		if ivLen > 0 {
 			if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil {
@@ -261,40 +249,31 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
 	return buffer, nil
 }
 
-func DecodeUDPPacket(users []*protocol.MemoryUser, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
+func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
+	bs := payload.Bytes()
+	if len(bs) <= 32 {
+		return nil, nil, newError("len(bs) <= 32")
+	}
+
 	var user *protocol.MemoryUser
-	var account *MemoryAccount
 	var err error
 
-	if len(users) > 1 {
-		bs := payload.Bytes()
-		if len(bs) <= 32 {
-			return nil, nil, newError("len(bs) <= 32")
-		}
-
-		var aeadCipher *AEADCipher
-		var ivLen int32
-		subkey := make([]byte, 32)
-		data := make([]byte, 8192)
-		var aead cipher.AEAD
+	count := validator.Count()
+	if count == 0 {
+		return nil, nil, newError("invalid user")
+	} else if count > 1 {
 		var d []byte
-		for _, user = range users {
-			account = user.Account.(*MemoryAccount)
-			aeadCipher = account.Cipher.(*AEADCipher)
-			ivLen = aeadCipher.IVSize()
-			subkey = subkey[:aeadCipher.KeyBytes]
-			hkdfSHA1(account.Key, bs[:ivLen], subkey)
-			aead = aeadCipher.AEADAuthCreator(subkey)
-			d, err = aead.Open(data[:0], data[8180:8192], bs[ivLen:], nil)
-			if err == nil {
-				payload.Clear()
-				payload.Write(d)
-				break
-			}
+		user, _, d, _, err = validator.Get(bs, protocol.RequestCommandUDP)
+
+		if user != nil {
+			payload.Clear()
+			payload.Write(d)
+		} else {
+			return nil, nil, newError("failed to decrypt UDP payload").Base(err)
 		}
 	} else {
-		user = users[0]
-		account = user.Account.(*MemoryAccount)
+		user, _ = validator.GetOnlyUser()
+		account := user.Account.(*MemoryAccount)
 
 		var iv []byte
 		if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 {
@@ -302,12 +281,9 @@ func DecodeUDPPacket(users []*protocol.MemoryUser, payload *buf.Buffer) (*protoc
 			iv = make([]byte, account.Cipher.IVSize())
 			copy(iv, payload.BytesTo(account.Cipher.IVSize()))
 		}
-
-		err = account.Cipher.DecodePacket(account.Key, payload)
-	}
-
-	if err != nil {
-		return nil, nil, newError("failed to decrypt UDP payload").Base(err)
+		if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
+			return nil, nil, newError("failed to decrypt UDP payload").Base(err)
+		}
 	}
 
 	request := &protocol.RequestHeader{
@@ -341,7 +317,10 @@ func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 		buffer.Release()
 		return nil, err
 	}
-	u, payload, err := DecodeUDPPacket([]*protocol.MemoryUser{v.User}, buffer)
+	validator := new(Validator)
+	validator.Add(v.User)
+
+	u, payload, err := DecodeUDPPacket(validator, buffer)
 	if err != nil {
 		buffer.Release()
 		return nil, err

+ 6 - 2
proxy/shadowsocks/protocol_test.go

@@ -38,7 +38,9 @@ func TestUDPEncoding(t *testing.T) {
 	encodedData, err := EncodeUDPPacket(request, data.Bytes())
 	common.Must(err)
 
-	decodedRequest, decodedData, err := DecodeUDPPacket([]*protocol.MemoryUser{request.User}, encodedData)
+	validator := new(Validator)
+	validator.Add(request.User)
+	decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData)
 	common.Must(err)
 
 	if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" {
@@ -117,7 +119,9 @@ func TestTCPRequest(t *testing.T) {
 
 		common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{data}))
 
-		decodedRequest, reader, err := ReadTCPSession([]*protocol.MemoryUser{request.User}, cache)
+		validator := new(Validator)
+		validator.Add(request.User)
+		decodedRequest, reader, err := ReadTCPSession(validator, cache)
 		common.Must(err)
 		if r := cmp.Diff(decodedRequest, request, cmp.Comparer(func(a1, a2 protocol.Account) bool { return a1.Equals(a2) })); r != "" {
 			t.Error("request: ", r)

+ 30 - 16
proxy/shadowsocks/server.go

@@ -22,35 +22,46 @@ import (
 
 type Server struct {
 	config        *ServerConfig
-	users         []*protocol.MemoryUser
+	validator     *Validator
 	policyManager policy.Manager
 	cone          bool
 }
 
 // NewServer create a new Shadowsocks server.
 func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
-	if config.Users == nil {
-		return nil, newError("empty users")
+	validator := new(Validator)
+	for _, user := range config.Users {
+		u, err := user.ToMemoryUser()
+		if err != nil {
+			return nil, newError("failed to get shadowsocks user").Base(err).AtError()
+		}
+
+		if err := validator.Add(u); err != nil {
+			return nil, newError("failed to add user").Base(err).AtError()
+		}
 	}
 
 	v := core.MustFromContext(ctx)
 	s := &Server{
 		config:        config,
+		validator:     validator,
 		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
 		cone:          ctx.Value("cone").(bool),
 	}
 
-	for _, user := range config.Users {
-		u, err := user.ToMemoryUser()
-		if err != nil {
-			return nil, newError("failed to parse user account").Base(err)
-		}
-		s.users = append(s.users, u)
-	}
-
 	return s, nil
 }
 
+// AddUser implements proxy.UserManager.AddUser().
+func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
+	return s.validator.Add(u)
+}
+
+// RemoveUser implements proxy.UserManager.RemoveUser().
+func (s *Server) RemoveUser(ctx context.Context, e string) error {
+	return s.validator.Del(e)
+}
+
 func (s *Server) Network() []net.Network {
 	list := s.config.Network
 	if len(list) == 0 {
@@ -102,8 +113,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 	if inbound == nil {
 		panic("no inbound metadata")
 	}
-	if len(s.users) == 1 {
-		inbound.User = s.users[0]
+
+	if s.validator.Count() == 1 {
+		inbound.User, _ = s.validator.GetOnlyUser()
 	}
 
 	var dest *net.Destination
@@ -121,9 +133,11 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 			var err error
 
 			if inbound.User != nil {
-				request, data, err = DecodeUDPPacket([]*protocol.MemoryUser{inbound.User}, payload)
+				validator := new(Validator)
+				validator.Add(inbound.User)
+				request, data, err = DecodeUDPPacket(validator, payload)
 			} else {
-				request, data, err = DecodeUDPPacket(s.users, payload)
+				request, data, err = DecodeUDPPacket(s.validator, payload)
 				if err == nil {
 					inbound.User = request.User
 				}
@@ -178,7 +192,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 	}
 
 	bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)}
-	request, bodyReader, err := ReadTCPSession(s.users, &bufferedReader)
+	request, bodyReader, err := ReadTCPSession(s.validator, &bufferedReader)
 	if err != nil {
 		log.Record(&log.AccessMessage{
 			From:   conn.RemoteAddr(),

+ 113 - 0
proxy/shadowsocks/validator.go

@@ -0,0 +1,113 @@
+package shadowsocks
+
+import (
+	"crypto/cipher"
+	"strings"
+	"sync"
+
+	"github.com/xtls/xray-core/common/protocol"
+)
+
+// Validator stores valid Shadowsocks users.
+type Validator struct {
+	// Considering email's usage here, map + sync.Mutex/RWMutex may have better performance.
+	email sync.Map
+	users sync.Map
+}
+
+// Add a Shadowsocks user, Email must be empty or unique.
+func (v *Validator) Add(u *protocol.MemoryUser) error {
+	account := u.Account.(*MemoryAccount)
+
+	if !account.Cipher.IsAEAD() && v.Count() > 0 {
+		return newError("The cipher do not support Single-port Multi-user")
+	}
+
+	if u.Email != "" {
+		_, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u)
+		if loaded {
+			return newError("User ", u.Email, " already exists.")
+		}
+	}
+
+	v.users.Store(string(account.Key)+"&"+account.GetCipherName(), u)
+	return nil
+}
+
+// Del a Shadowsocks user with a non-empty Email.
+func (v *Validator) Del(e string) error {
+	if e == "" {
+		return newError("Email must not be empty.")
+	}
+	le := strings.ToLower(e)
+	u, _ := v.email.Load(le)
+	if u == nil {
+		return newError("User ", e, " not found.")
+	}
+	account := u.(*protocol.MemoryUser).Account.(*MemoryAccount)
+	v.email.Delete(le)
+	v.users.Delete(string(account.Key) + "&" + account.GetCipherName())
+	return nil
+}
+
+// Count the number of Shadowsocks users
+func (v *Validator) Count() int {
+	length := 0
+	v.users.Range(func(_, _ interface{}) bool {
+		length++
+
+		return true
+	})
+	return length
+}
+
+// Get a Shadowsocks user and the user's cipher.
+func (v *Validator) Get(bs []byte, command protocol.RequestCommand) (u *protocol.MemoryUser, aead cipher.AEAD, ret []byte, ivLen int32, err error) {
+	var dataSize int
+
+	switch command {
+	case protocol.RequestCommandTCP:
+		dataSize = 16
+	case protocol.RequestCommandUDP:
+		dataSize = 8192
+	}
+
+	var aeadCipher *AEADCipher
+	subkey := make([]byte, 32)
+	data := make([]byte, dataSize)
+
+	v.users.Range(func(key, user interface{}) bool {
+		account := user.(*protocol.MemoryUser).Account.(*MemoryAccount)
+		aeadCipher = account.Cipher.(*AEADCipher)
+		ivLen = aeadCipher.IVSize()
+		subkey = subkey[:aeadCipher.KeyBytes]
+		hkdfSHA1(account.Key, bs[:ivLen], subkey)
+		aead = aeadCipher.AEADAuthCreator(subkey)
+
+		switch command {
+		case protocol.RequestCommandTCP:
+			ret, err = aead.Open(data[:0], data[4:16], bs[ivLen:ivLen+18], nil)
+		case protocol.RequestCommandUDP:
+			ret, err = aead.Open(data[:0], data[8180:8192], bs[ivLen:], nil)
+		}
+
+		if err == nil {
+			u = user.(*protocol.MemoryUser)
+			return false
+		}
+		return true
+	})
+
+	return
+}
+
+// Get the only user without authentication
+func (v *Validator) GetOnlyUser() (u *protocol.MemoryUser, ivLen int32) {
+	v.users.Range(func(_, user interface{}) bool {
+		u = user.(*protocol.MemoryUser)
+		return false
+	})
+	ivLen = u.Account.(*MemoryAccount).Cipher.IVSize()
+
+	return
+}