1
0
Эх сурвалжийг харах

Fix a concurrency issue in fakedns

In rare cases different domains asking for dns will return the same IP. Add a mutex.
yuhan6665 3 жил өмнө
parent
commit
c1a54ae58e

+ 7 - 1
app/dns/fakedns/fake.go

@@ -5,6 +5,7 @@ import (
 	"math"
 	"math/big"
 	gonet "net"
+	"sync"
 	"time"
 
 	"github.com/xtls/xray-core/common"
@@ -16,6 +17,7 @@ import (
 type Holder struct {
 	domainToIP cache.Lru
 	ipRange    *gonet.IPNet
+	mu         *sync.Mutex
 
 	config *FakeDnsPool
 }
@@ -49,6 +51,7 @@ func (fkdns *Holder) Start() error {
 func (fkdns *Holder) Close() error {
 	fkdns.domainToIP = nil
 	fkdns.ipRange = nil
+	fkdns.mu = nil
 	return nil
 }
 
@@ -67,7 +70,7 @@ func NewFakeDNSHolder() (*Holder, error) {
 }
 
 func NewFakeDNSHolderConfigOnly(conf *FakeDnsPool) (*Holder, error) {
-	return &Holder{nil, nil, conf}, nil
+	return &Holder{nil, nil, nil, conf}, nil
 }
 
 func (fkdns *Holder) initializeFromConfig() error {
@@ -89,11 +92,14 @@ func (fkdns *Holder) initialize(ipPoolCidr string, lruSize int) error {
 	}
 	fkdns.domainToIP = cache.NewLru(lruSize)
 	fkdns.ipRange = ipRange
+	fkdns.mu = new(sync.Mutex)
 	return nil
 }
 
 // GetFakeIPForDomain checks and generates a fake IP for a domain name
 func (fkdns *Holder) GetFakeIPForDomain(domain string) []net.Address {
+	fkdns.mu.Lock()
+	defer fkdns.mu.Unlock()
 	if v, ok := fkdns.domainToIP.Get(domain); ok {
 		return []net.Address{v.(net.Address)}
 	}

+ 28 - 0
app/dns/fakedns/fakedns_test.go

@@ -2,10 +2,13 @@ package fakedns
 
 import (
 	gonet "net"
+	"strconv"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
 
+	"golang.org/x/sync/errgroup"
+
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/uuid"
@@ -66,6 +69,31 @@ func TestFakeDnsHolderCreateMappingManySingleDomain(t *testing.T) {
 	assert.Equal(t, addr[0].IP().String(), addr2[0].IP().String())
 }
 
+func TestGetFakeIPForDomainConcurrently(t *testing.T) {
+	fkdns, err := NewFakeDNSHolder()
+	common.Must(err)
+
+	total := 200
+	addr := make([][]net.Address, total)
+	var errg errgroup.Group
+	for i := 0; i < total; i++ {
+		errg.Go(testGetFakeIP(i, addr, fkdns))
+	}
+	errg.Wait()
+	for i := 0; i < total; i++ {
+		for j := i + 1; j < total; j++ {
+			assert.NotEqual(t, addr[i][0].IP().String(), addr[j][0].IP().String())
+		}
+	}
+}
+
+func testGetFakeIP(index int, addr [][]net.Address, fkdns *Holder) func() error {
+	return func() error {
+		addr[index] = fkdns.GetFakeIPForDomain("fakednstest" + strconv.Itoa(index) + ".example.com")
+		return nil
+	}
+}
+
 func TestFakeDnsHolderCreateMappingAndRollOver(t *testing.T) {
 	fkdns, err := NewFakeDNSHolderConfigOnly(&FakeDnsPool{
 		IpPool:  dns.FakeIPv4Pool,