소스 검색

Change to TypedSyncMap

风扇滑翔翼 6 달 전
부모
커밋
20825f6f1a
6개의 변경된 파일29개의 추가작업 그리고 31개의 파일을 삭제
  1. 2 2
      common/utils/typed_sync_map.go
  2. 1 1
      proxy/freedom/freedom.go
  3. 9 9
      proxy/trojan/validator.go
  4. 9 9
      proxy/vless/validator.go
  5. 1 4
      transport/internet/grpc/dial.go
  6. 7 6
      transport/internet/splithttp/hub.go

+ 2 - 2
common/utils/typed_sync_map.go

@@ -15,8 +15,8 @@ type TypedSyncMap[K, V any] struct {
 // K is key type, V is value type
 // It is recommended to use pointer types for V because sync.Map might return nil
 // If sync.Map methods really returned nil, it will return the zero value of the type V
-func NewTypedSyncMap[K any, V any]() *TypedSyncMap[K, V] {
-	return &TypedSyncMap[K, V]{
+func NewTypedSyncMap[K any, V any]() TypedSyncMap[K, V] {
+	return TypedSyncMap[K, V]{
 		syncMap: &sync.Map{},
 	}
 }

+ 1 - 1
proxy/freedom/freedom.go

@@ -363,7 +363,7 @@ func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride
 			Handler:           h,
 			Context:           ctx,
 			UDPOverride:       UDPOverride,
-			resolvedUDPAddr:   resolvedUDPAddr,
+			resolvedUDPAddr:   &resolvedUDPAddr,
 		}
 
 	}

+ 9 - 9
proxy/trojan/validator.go

@@ -2,17 +2,17 @@ package trojan
 
 import (
 	"strings"
-	"sync"
 
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/utils"
 )
 
 // Validator stores valid trojan users.
 type Validator struct {
 	// Considering email's usage here, map + sync.Mutex/RWMutex may have better performance.
-	email sync.Map
-	users sync.Map
+	email utils.TypedSyncMap[string, *protocol.MemoryUser]
+	users utils.TypedSyncMap[string, *protocol.MemoryUser]
 }
 
 // Add a trojan user, Email must be empty or unique.
@@ -38,7 +38,7 @@ func (v *Validator) Del(e string) error {
 		return errors.New("User ", e, " not found.")
 	}
 	v.email.Delete(le)
-	v.users.Delete(hexString(u.(*protocol.MemoryUser).Account.(*MemoryAccount).Key))
+	v.users.Delete(hexString(u.Account.(*MemoryAccount).Key))
 	return nil
 }
 
@@ -46,7 +46,7 @@ func (v *Validator) Del(e string) error {
 func (v *Validator) Get(hash string) *protocol.MemoryUser {
 	u, _ := v.users.Load(hash)
 	if u != nil {
-		return u.(*protocol.MemoryUser)
+		return u
 	}
 	return nil
 }
@@ -56,7 +56,7 @@ func (v *Validator) GetByEmail(email string) *protocol.MemoryUser {
 	email = strings.ToLower(email)
 	u, _ := v.email.Load(email)
 	if u != nil {
-		return u.(*protocol.MemoryUser)
+		return u
 	}
 	return nil
 }
@@ -64,8 +64,8 @@ func (v *Validator) GetByEmail(email string) *protocol.MemoryUser {
 // Get all users
 func (v *Validator) GetAll() []*protocol.MemoryUser {
 	var u = make([]*protocol.MemoryUser, 0, 100)
-	v.email.Range(func(key, value interface{}) bool {
-		u = append(u, value.(*protocol.MemoryUser))
+	v.email.Range(func(key string, value *protocol.MemoryUser) bool {
+		u = append(u, value)
 		return true
 	})
 	return u
@@ -74,7 +74,7 @@ func (v *Validator) GetAll() []*protocol.MemoryUser {
 // Get users count
 func (v *Validator) GetCount() int64 {
 	var c int64 = 0
-	v.email.Range(func(key, value interface{}) bool {
+	v.email.Range(func(key string, value *protocol.MemoryUser) bool {
 		c++
 		return true
 	})

+ 9 - 9
proxy/vless/validator.go

@@ -2,10 +2,10 @@ package vless
 
 import (
 	"strings"
-	"sync"
 
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/utils"
 	"github.com/xtls/xray-core/common/uuid"
 )
 
@@ -21,8 +21,8 @@ type Validator interface {
 // MemoryValidator stores valid VLESS users.
 type MemoryValidator struct {
 	// Considering email's usage here, map + sync.Mutex/RWMutex may have better performance.
-	email sync.Map
-	users sync.Map
+	email utils.TypedSyncMap[string, *protocol.MemoryUser]
+	users utils.TypedSyncMap[uuid.UUID, *protocol.MemoryUser]
 }
 
 // Add a VLESS user, Email must be empty or unique.
@@ -48,7 +48,7 @@ func (v *MemoryValidator) Del(e string) error {
 		return errors.New("User ", e, " not found.")
 	}
 	v.email.Delete(le)
-	v.users.Delete(u.(*protocol.MemoryUser).Account.(*MemoryAccount).ID.UUID())
+	v.users.Delete(u.Account.(*MemoryAccount).ID.UUID())
 	return nil
 }
 
@@ -56,7 +56,7 @@ func (v *MemoryValidator) Del(e string) error {
 func (v *MemoryValidator) Get(id uuid.UUID) *protocol.MemoryUser {
 	u, _ := v.users.Load(id)
 	if u != nil {
-		return u.(*protocol.MemoryUser)
+		return u
 	}
 	return nil
 }
@@ -66,7 +66,7 @@ func (v *MemoryValidator) GetByEmail(email string) *protocol.MemoryUser {
 	email = strings.ToLower(email)
 	u, _ := v.email.Load(email)
 	if u != nil {
-		return u.(*protocol.MemoryUser)
+		return u
 	}
 	return nil
 }
@@ -74,8 +74,8 @@ func (v *MemoryValidator) GetByEmail(email string) *protocol.MemoryUser {
 // Get all users
 func (v *MemoryValidator) GetAll() []*protocol.MemoryUser {
 	var u = make([]*protocol.MemoryUser, 0, 100)
-	v.email.Range(func(key, value interface{}) bool {
-		u = append(u, value.(*protocol.MemoryUser))
+	v.email.Range(func(key string, value *protocol.MemoryUser) bool {
+		u = append(u, value)
 		return true
 	})
 	return u
@@ -84,7 +84,7 @@ func (v *MemoryValidator) GetAll() []*protocol.MemoryUser {
 // Get users count
 func (v *MemoryValidator) GetCount() int64 {
 	var c int64 = 0
-	v.email.Range(func(key, value interface{}) bool {
+	v.email.Range(func(key string, value *protocol.MemoryUser) bool {
 		c++
 		return true
 	})

+ 1 - 4
transport/internet/grpc/dial.go

@@ -43,7 +43,7 @@ type dialerConf struct {
 }
 
 var (
-	globalDialerMap    map[dialerConf]*grpc.ClientConn
+	globalDialerMap = make(map[dialerConf]*grpc.ClientConn)
 	globalDialerAccess sync.Mutex
 )
 
@@ -77,9 +77,6 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
 	globalDialerAccess.Lock()
 	defer globalDialerAccess.Unlock()
 
-	if globalDialerMap == nil {
-		globalDialerMap = make(map[dialerConf]*grpc.ClientConn)
-	}
 	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
 	realityConfig := reality.ConfigFromStreamSettings(streamSettings)
 	sockopt := streamSettings.SocketSettings

+ 7 - 6
transport/internet/splithttp/hub.go

@@ -20,6 +20,7 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	http_proto "github.com/xtls/xray-core/common/protocol/http"
 	"github.com/xtls/xray-core/common/signal/done"
+	"github.com/xtls/xray-core/common/utils"
 	"github.com/xtls/xray-core/transport/internet"
 	"github.com/xtls/xray-core/transport/internet/reality"
 	"github.com/xtls/xray-core/transport/internet/stat"
@@ -32,7 +33,7 @@ type requestHandler struct {
 	path      string
 	ln        *Listener
 	sessionMu *sync.Mutex
-	sessions  sync.Map
+	sessions  utils.TypedSyncMap[string, *httpSession]
 	localAddr net.Addr
 }
 
@@ -47,18 +48,18 @@ type httpSession struct {
 
 func (h *requestHandler) upsertSession(sessionId string) *httpSession {
 	// fast path
-	currentSessionAny, ok := h.sessions.Load(sessionId)
+	currentSession, ok := h.sessions.Load(sessionId)
 	if ok {
-		return currentSessionAny.(*httpSession)
+		return currentSession
 	}
 
 	// slow path
 	h.sessionMu.Lock()
 	defer h.sessionMu.Unlock()
 
-	currentSessionAny, ok = h.sessions.Load(sessionId)
+	currentSession, ok = h.sessions.Load(sessionId)
 	if ok {
-		return currentSessionAny.(*httpSession)
+		return currentSession
 	}
 
 	s := &httpSession{
@@ -361,7 +362,7 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet
 		path:      l.config.GetNormalizedPath(),
 		ln:        l,
 		sessionMu: &sync.Mutex{},
-		sessions:  sync.Map{},
+		sessions:  utils.NewTypedSyncMap[string, *httpSession](),
 	}
 	tlsConfig := getTLSConfig(streamSettings)
 	l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"