瀏覽代碼

Allow to inject custom validator in VLESS controller (#3453)

* Make Validator an interface

* Move validator creation away from VLESS inbound controller
Torikki 1 年之前
父節點
當前提交
c259e4e4a6
共有 4 個文件被更改,包括 33 次插入23 次删除
  1. 1 1
      proxy/vless/encoding/encoding.go
  2. 3 3
      proxy/vless/encoding/encoding_test.go
  3. 18 14
      proxy/vless/inbound/inbound.go
  4. 11 5
      proxy/vless/validator.go

+ 1 - 1
proxy/vless/encoding/encoding.go

@@ -64,7 +64,7 @@ func EncodeRequestHeader(writer io.Writer, request *protocol.RequestHeader, requ
 }
 
 // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
-func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator *vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
+func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
 	buffer := buf.StackNew()
 	defer buffer.Release()
 

+ 3 - 3
proxy/vless/encoding/encoding_test.go

@@ -42,7 +42,7 @@ func TestRequestSerialization(t *testing.T) {
 	buffer := buf.StackNew()
 	common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
 
-	Validator := new(vless.Validator)
+	Validator := new(vless.MemoryValidator)
 	Validator.Add(user)
 
 	actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
@@ -83,7 +83,7 @@ func TestInvalidRequest(t *testing.T) {
 	buffer := buf.StackNew()
 	common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
 
-	Validator := new(vless.Validator)
+	Validator := new(vless.MemoryValidator)
 	Validator.Add(user)
 
 	_, _, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
@@ -114,7 +114,7 @@ func TestMuxRequest(t *testing.T) {
 	buffer := buf.StackNew()
 	common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
 
-	Validator := new(vless.Validator)
+	Validator := new(vless.MemoryValidator)
 	Validator.Add(user)
 
 	actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)

+ 18 - 14
proxy/vless/inbound/inbound.go

@@ -45,7 +45,21 @@ func init() {
 		}); err != nil {
 			return nil, err
 		}
-		return New(ctx, config.(*Config), dc)
+
+		c := config.(*Config)
+
+		validator := new(vless.MemoryValidator)
+		for _, user := range c.Clients {
+			u, err := user.ToMemoryUser()
+			if err != nil {
+				return nil, errors.New("failed to get VLESS user").Base(err).AtError()
+			}
+			if err := validator.Add(u); err != nil {
+				return nil, errors.New("failed to initiate user").Base(err).AtError()
+			}
+		}
+
+		return New(ctx, c, dc, validator)
 	}))
 }
 
@@ -53,30 +67,20 @@ func init() {
 type Handler struct {
 	inboundHandlerManager feature_inbound.Manager
 	policyManager         policy.Manager
-	validator             *vless.Validator
+	validator             vless.Validator
 	dns                   dns.Client
 	fallbacks             map[string]map[string]map[string]*Fallback // or nil
 	// regexps               map[string]*regexp.Regexp       // or nil
 }
 
 // New creates a new VLess inbound handler.
-func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) {
+func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) {
 	v := core.MustFromContext(ctx)
 	handler := &Handler{
 		inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
 		policyManager:         v.GetFeature(policy.ManagerType()).(policy.Manager),
-		validator:             new(vless.Validator),
 		dns:                   dc,
-	}
-
-	for _, user := range config.Clients {
-		u, err := user.ToMemoryUser()
-		if err != nil {
-			return nil, errors.New("failed to get VLESS user").Base(err).AtError()
-		}
-		if err := handler.AddUser(ctx, u); err != nil {
-			return nil, errors.New("failed to initiate user").Base(err).AtError()
-		}
+		validator:             validator,
 	}
 
 	if config.Fallbacks != nil {

+ 11 - 5
proxy/vless/validator.go

@@ -9,15 +9,21 @@ import (
 	"github.com/xtls/xray-core/common/uuid"
 )
 
-// Validator stores valid VLESS users.
-type Validator struct {
+type Validator interface {
+	Get(id uuid.UUID) *protocol.MemoryUser
+	Add(u *protocol.MemoryUser) error
+	Del(email string) error
+}
+
+// MemoryValidator stores valid VLESS users.
+type MemoryValidator struct {
 	// Considering email's usage here, map + sync.Mutex/RWMutex may have better performance.
 	email sync.Map
 	users sync.Map
 }
 
 // Add a VLESS user, Email must be empty or unique.
-func (v *Validator) Add(u *protocol.MemoryUser) error {
+func (v *MemoryValidator) Add(u *protocol.MemoryUser) error {
 	if u.Email != "" {
 		_, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u)
 		if loaded {
@@ -29,7 +35,7 @@ func (v *Validator) Add(u *protocol.MemoryUser) error {
 }
 
 // Del a VLESS user with a non-empty Email.
-func (v *Validator) Del(e string) error {
+func (v *MemoryValidator) Del(e string) error {
 	if e == "" {
 		return errors.New("Email must not be empty.")
 	}
@@ -44,7 +50,7 @@ func (v *Validator) Del(e string) error {
 }
 
 // Get a VLESS user with UUID, nil if user doesn't exist.
-func (v *Validator) Get(id uuid.UUID) *protocol.MemoryUser {
+func (v *MemoryValidator) Get(id uuid.UUID) *protocol.MemoryUser {
 	u, _ := v.users.Load(id)
 	if u != nil {
 		return u.(*protocol.MemoryUser)