风扇滑翔翼 преди 4 месеца
родител
ревизия
fed6b77902
променени са 2 файла, в които са добавени 84 реда и са изтрити 6 реда
  1. 77 0
      common/utils/TypedSyncMap.go
  2. 7 6
      transport/internet/splithttp/hub.go

+ 77 - 0
common/utils/TypedSyncMap.go

@@ -0,0 +1,77 @@
+package utils
+
+import (
+	"sync"
+)
+
+// TypedSyncMap is a wrapper of sync.Map that provides type-safe for keys and values.
+// No need to use type assertions every time, so you can have more time to enjoy other things like GochiUsa
+// If sync.Map returned nil, it will return the zero value of the type V.
+type TypedSyncMap[K, V any] struct {
+	syncMap *sync.Map
+}
+
+func NewTypedSyncMap[K any, V any]() *TypedSyncMap[K, V] {
+	return &TypedSyncMap[K, V]{
+		syncMap: &sync.Map{},
+	}
+}
+
+func (m *TypedSyncMap[K, V]) Clear() {
+	m.syncMap.Clear()
+}
+
+func (m *TypedSyncMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
+	return m.syncMap.CompareAndDelete(key, old)
+}
+
+func (m *TypedSyncMap[K, V]) CompareAndSwap(key K, old V, new V) (swapped bool) {
+	return m.syncMap.CompareAndSwap(key, old, new)
+}
+
+func (m *TypedSyncMap[K, V]) Delete(key K) {
+	m.syncMap.Delete(key)
+}
+
+func (m *TypedSyncMap[K, V]) Load(key K) (value V, ok bool) {
+	anyValue, ok := m.syncMap.Load(key)
+	// anyValue might be nil
+	if anyValue != nil {
+		value = anyValue.(V)
+	}
+	return value, ok
+}
+
+func (m *TypedSyncMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
+	anyValue, loaded := m.syncMap.LoadAndDelete(key)
+	if anyValue != nil {
+		value = anyValue.(V)
+	}
+	return value, loaded
+}
+
+func (m *TypedSyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
+	anyActual, loaded := m.syncMap.LoadOrStore(key, value)
+	if anyActual != nil {
+		actual = anyActual.(V)
+	}
+	return actual, loaded
+}
+
+func (m *TypedSyncMap[K, V]) Range(f func(key K, value V) bool) {
+	m.syncMap.Range(func(key, value any) bool {
+		return f(key.(K), value.(V))
+	})
+}
+
+func (m *TypedSyncMap[K, V]) Store(key K, value V) {
+	m.syncMap.Store(key, value)
+}
+
+func (m *TypedSyncMap[K, V]) Swap(key K, value V) (previous V, loaded bool) {
+	anyPrevious, loaded := m.syncMap.Swap(key, value)
+	if anyPrevious != nil {
+		previous = anyPrevious.(V)
+	}
+	return previous, loaded
+}

+ 7 - 6
transport/internet/splithttp/hub.go

@@ -24,6 +24,7 @@ import (
 	"github.com/xtls/xray-core/transport/internet/reality"
 	"github.com/xtls/xray-core/transport/internet/stat"
 	"github.com/xtls/xray-core/transport/internet/tls"
+	"github.com/xtls/xray-core/common/utils"
 )
 
 type requestHandler struct {
@@ -32,7 +33,7 @@ type requestHandler struct {
 	path      string
 	ln        *Listener
 	sessionMu *sync.Mutex
-	sessions  sync.Map
+	sessions  *utils.TypedSyncMap[string, *httpSession]
 	localAddr net.Addr
 }
 
@@ -47,18 +48,18 @@ type httpSession struct {
 
 func (h *requestHandler) upsertSession(sessionId string) *httpSession {
 	// fast path
-	currentSessionAny, ok := h.sessions.Load(sessionId)
+	currentSession, ok := h.sessions.Load(sessionId)
 	if ok {
-		return currentSessionAny.(*httpSession)
+		return currentSession
 	}
 
 	// slow path
 	h.sessionMu.Lock()
 	defer h.sessionMu.Unlock()
 
-	currentSessionAny, ok = h.sessions.Load(sessionId)
+	currentSession, ok = h.sessions.Load(sessionId)
 	if ok {
-		return currentSessionAny.(*httpSession)
+		return currentSession
 	}
 
 	s := &httpSession{
@@ -361,7 +362,7 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet
 		path:      l.config.GetNormalizedPath(),
 		ln:        l,
 		sessionMu: &sync.Mutex{},
-		sessions:  sync.Map{},
+		sessions:  utils.NewTypedSyncMap[string, *httpSession](),
 	}
 	tlsConfig := getTLSConfig(streamSettings)
 	l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"