1
0
Эх сурвалжийг харах

Fix concurrent map writes error in ohm.Select(). (#2943)

* Add unit test for ohm.tagsCache.

* Fix concurrent map writes in ohm.Select().

---------

Co-authored-by: nobody <[email protected]>
nobody 1 жил өмнө
parent
commit
d20a835016

+ 93 - 0
app/proxyman/outbound/handler_test.go

@@ -2,9 +2,14 @@ package outbound_test
 
 import (
 	"context"
+	"fmt"
+	"sync"
+	"sync/atomic"
 	"testing"
+	"time"
 
 	"github.com/xtls/xray-core/app/policy"
+	"github.com/xtls/xray-core/app/proxyman"
 	. "github.com/xtls/xray-core/app/proxyman/outbound"
 	"github.com/xtls/xray-core/app/stats"
 	"github.com/xtls/xray-core/common/net"
@@ -78,3 +83,91 @@ func TestOutboundWithStatCounter(t *testing.T) {
 		t.Errorf("Expected conn to be CounterConnection")
 	}
 }
+
+func TestTagsCache(t *testing.T) {
+
+	test_duration := 10 * time.Second
+	threads_num := 50
+	delay := 10 * time.Millisecond
+	tags_prefix := "node"
+
+	tags := sync.Map{}
+	counter := atomic.Uint64{}
+
+	ohm, err := New(context.Background(), &proxyman.OutboundConfig{})
+	if err != nil {
+		t.Error("failed to create outbound handler manager")
+	}
+	config := &core.Config{
+		App: []*serial.TypedMessage{},
+	}
+	v, _ := core.New(config)
+	v.AddFeature(ohm)
+	ctx := context.WithValue(context.Background(), xrayKey, v)
+
+	stop_add_rm := false
+	wg_add_rm := sync.WaitGroup{}
+	addHandlers := func() {
+		defer wg_add_rm.Done()
+		for !stop_add_rm {
+			time.Sleep(delay)
+			idx := counter.Add(1)
+			tag := fmt.Sprintf("%s%d", tags_prefix, idx)
+			cfg := &core.OutboundHandlerConfig{
+				Tag:           tag,
+				ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
+			}
+			if h, err := NewHandler(ctx, cfg); err == nil {
+				if err := ohm.AddHandler(ctx, h); err == nil {
+					// t.Log("add handler:", tag)
+					tags.Store(tag, nil)
+				} else {
+					t.Error("failed to add handler:", tag)
+				}
+			} else {
+				t.Error("failed to create handler:", tag)
+			}
+		}
+	}
+
+	rmHandlers := func() {
+		defer wg_add_rm.Done()
+		for !stop_add_rm {
+			time.Sleep(delay)
+			tags.Range(func(key interface{}, value interface{}) bool {
+				if _, ok := tags.LoadAndDelete(key); ok {
+					// t.Log("remove handler:", key)
+					ohm.RemoveHandler(ctx, key.(string))
+					return false
+				}
+				return true
+			})
+		}
+	}
+
+	selectors := []string{tags_prefix}
+	wg_get := sync.WaitGroup{}
+	stop_get := false
+	getTags := func() {
+		defer wg_get.Done()
+		for !stop_get {
+			time.Sleep(delay)
+			_ = ohm.Select(selectors)
+			// t.Logf("get tags: %v", tag)
+		}
+	}
+
+	for i := 0; i < threads_num; i++ {
+		wg_add_rm.Add(2)
+		go rmHandlers()
+		go addHandlers()
+		wg_get.Add(1)
+		go getTags()
+	}
+
+	time.Sleep(test_duration)
+	stop_add_rm = true
+	wg_add_rm.Wait()
+	stop_get = true
+	wg_get.Wait()
+}

+ 10 - 9
app/proxyman/outbound/outbound.go

@@ -22,14 +22,14 @@ type Manager struct {
 	taggedHandler    map[string]outbound.Handler
 	untaggedHandlers []outbound.Handler
 	running          bool
-	tagsCache        map[string][]string
+	tagsCache        *sync.Map
 }
 
 // New creates a new Manager.
 func New(ctx context.Context, config *proxyman.OutboundConfig) (*Manager, error) {
 	m := &Manager{
 		taggedHandler: make(map[string]outbound.Handler),
-		tagsCache:     make(map[string][]string),
+		tagsCache:     &sync.Map{},
 	}
 	return m, nil
 }
@@ -106,7 +106,7 @@ func (m *Manager) AddHandler(ctx context.Context, handler outbound.Handler) erro
 	m.access.Lock()
 	defer m.access.Unlock()
 
-	m.tagsCache = make(map[string][]string)
+	m.tagsCache = &sync.Map{}
 
 	if m.defaultHandler == nil {
 		m.defaultHandler = handler
@@ -137,7 +137,7 @@ func (m *Manager) RemoveHandler(ctx context.Context, tag string) error {
 	m.access.Lock()
 	defer m.access.Unlock()
 
-	m.tagsCache = make(map[string][]string)
+	m.tagsCache = &sync.Map{}
 
 	delete(m.taggedHandler, tag)
 	if m.defaultHandler != nil && m.defaultHandler.Tag() == tag {
@@ -149,14 +149,15 @@ func (m *Manager) RemoveHandler(ctx context.Context, tag string) error {
 
 // Select implements outbound.HandlerSelector.
 func (m *Manager) Select(selectors []string) []string {
-	m.access.RLock()
-	defer m.access.RUnlock()
 
 	key := strings.Join(selectors, ",")
-	if cache, ok := m.tagsCache[key]; ok {
-		return cache
+	if cache, ok := m.tagsCache.Load(key); ok {
+		return cache.([]string)
 	}
 
+	m.access.RLock()
+	defer m.access.RUnlock()
+
 	tags := make([]string, 0, len(selectors))
 
 	for tag := range m.taggedHandler {
@@ -169,7 +170,7 @@ func (m *Manager) Select(selectors []string) []string {
 	}
 
 	sort.Strings(tags)
-	m.tagsCache[key] = tags
+	m.tagsCache.Store(key, tags)
 
 	return tags
 }