Browse Source

WireGuard: Improve config error handling; Prevent panic in case of errors during server initialization (#4566)

https://github.com/XTLS/Xray-core/pull/4566#issuecomment-2764779273
Ilya Gulya 10 months ago
parent
commit
17207fc5e4
4 changed files with 83 additions and 28 deletions
  1. 5 1
      infra/conf/wireguard.go
  2. 20 20
      infra/conf/xray.go
  3. 6 7
      proxy/wireguard/gvisortun/tun.go
  4. 52 0
      proxy/wireguard/server_test.go

+ 5 - 1
infra/conf/wireguard.go

@@ -67,7 +67,7 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
 	var err error
 	config.SecretKey, err = ParseWireGuardKey(c.SecretKey)
 	if err != nil {
-		return nil, err
+		return nil, errors.New("invalid WireGuard secret key: %w", err)
 	}
 
 	if c.Address == nil {
@@ -126,6 +126,10 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
 func ParseWireGuardKey(str string) (string, error) {
 	var err error
 
+	if str == "" {
+		return "", errors.New("key must not be empty")
+	}
+
 	if len(str)%2 == 0 {
 		_, err = hex.DecodeString(str)
 		if err == nil {

+ 20 - 20
infra/conf/xray.go

@@ -241,14 +241,14 @@ func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) {
 	}
 	rawConfig, err := inboundConfigLoader.LoadWithID(settings, c.Protocol)
 	if err != nil {
-		return nil, errors.New("failed to load inbound detour config.").Base(err)
+		return nil, errors.New("failed to load inbound detour config for protocol ", c.Protocol).Base(err)
 	}
 	if dokodemoConfig, ok := rawConfig.(*DokodemoConfig); ok {
 		receiverSettings.ReceiveOriginalDestination = dokodemoConfig.Redirect
 	}
 	ts, err := rawConfig.(Buildable).Build()
 	if err != nil {
-		return nil, err
+		return nil, errors.New("failed to build inbound handler for protocol ", c.Protocol).Base(err)
 	}
 
 	return &core.InboundHandlerConfig{
@@ -303,7 +303,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
 	if c.StreamSetting != nil {
 		ss, err := c.StreamSetting.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build stream settings for outbound detour").Base(err)
 		}
 		senderSettings.StreamSettings = ss
 	}
@@ -311,7 +311,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
 	if c.ProxySettings != nil {
 		ps, err := c.ProxySettings.Build()
 		if err != nil {
-			return nil, errors.New("invalid outbound detour proxy settings.").Base(err)
+			return nil, errors.New("invalid outbound detour proxy settings").Base(err)
 		}
 		if ps.TransportLayerProxy {
 			if senderSettings.StreamSettings != nil {
@@ -331,7 +331,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
 	if c.MuxSettings != nil {
 		ms, err := c.MuxSettings.Build()
 		if err != nil {
-			return nil, errors.New("failed to build Mux config.").Base(err)
+			return nil, errors.New("failed to build Mux config").Base(err)
 		}
 		senderSettings.MultiplexSettings = ms
 	}
@@ -342,11 +342,11 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
 	}
 	rawConfig, err := outboundConfigLoader.LoadWithID(settings, c.Protocol)
 	if err != nil {
-		return nil, errors.New("failed to parse to outbound detour config.").Base(err)
+		return nil, errors.New("failed to load outbound detour config for protocol ", c.Protocol).Base(err)
 	}
 	ts, err := rawConfig.(Buildable).Build()
 	if err != nil {
-		return nil, err
+		return nil, errors.New("failed to build outbound handler for protocol ", c.Protocol).Base(err)
 	}
 
 	return &core.OutboundHandlerConfig{
@@ -490,7 +490,7 @@ func (c *Config) Override(o *Config, fn string) {
 // Build implements Buildable.
 func (c *Config) Build() (*core.Config, error) {
 	if err := PostProcessConfigureFile(c); err != nil {
-		return nil, err
+		return nil, errors.New("failed to post-process configuration file").Base(err)
 	}
 
 	config := &core.Config{
@@ -504,21 +504,21 @@ func (c *Config) Build() (*core.Config, error) {
 	if c.API != nil {
 		apiConf, err := c.API.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build API configuration").Base(err)
 		}
 		config.App = append(config.App, serial.ToTypedMessage(apiConf))
 	}
 	if c.Metrics != nil {
 		metricsConf, err := c.Metrics.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build metrics configuration").Base(err)
 		}
 		config.App = append(config.App, serial.ToTypedMessage(metricsConf))
 	}
 	if c.Stats != nil {
 		statsConf, err := c.Stats.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build stats configuration").Base(err)
 		}
 		config.App = append(config.App, serial.ToTypedMessage(statsConf))
 	}
@@ -536,7 +536,7 @@ func (c *Config) Build() (*core.Config, error) {
 	if c.RouterConfig != nil {
 		routerConfig, err := c.RouterConfig.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build routing configuration").Base(err)
 		}
 		config.App = append(config.App, serial.ToTypedMessage(routerConfig))
 	}
@@ -544,7 +544,7 @@ func (c *Config) Build() (*core.Config, error) {
 	if c.DNSConfig != nil {
 		dnsApp, err := c.DNSConfig.Build()
 		if err != nil {
-			return nil, errors.New("failed to parse DNS config").Base(err)
+			return nil, errors.New("failed to build DNS configuration").Base(err)
 		}
 		config.App = append(config.App, serial.ToTypedMessage(dnsApp))
 	}
@@ -552,7 +552,7 @@ func (c *Config) Build() (*core.Config, error) {
 	if c.Policy != nil {
 		pc, err := c.Policy.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build policy configuration").Base(err)
 		}
 		config.App = append(config.App, serial.ToTypedMessage(pc))
 	}
@@ -560,7 +560,7 @@ func (c *Config) Build() (*core.Config, error) {
 	if c.Reverse != nil {
 		r, err := c.Reverse.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build reverse configuration").Base(err)
 		}
 		config.App = append(config.App, serial.ToTypedMessage(r))
 	}
@@ -568,7 +568,7 @@ func (c *Config) Build() (*core.Config, error) {
 	if c.FakeDNS != nil {
 		r, err := c.FakeDNS.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build fake DNS configuration").Base(err)
 		}
 		config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...)
 	}
@@ -576,7 +576,7 @@ func (c *Config) Build() (*core.Config, error) {
 	if c.Observatory != nil {
 		r, err := c.Observatory.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build observatory configuration").Base(err)
 		}
 		config.App = append(config.App, serial.ToTypedMessage(r))
 	}
@@ -584,7 +584,7 @@ func (c *Config) Build() (*core.Config, error) {
 	if c.BurstObservatory != nil {
 		r, err := c.BurstObservatory.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build burst observatory configuration").Base(err)
 		}
 		config.App = append(config.App, serial.ToTypedMessage(r))
 	}
@@ -602,7 +602,7 @@ func (c *Config) Build() (*core.Config, error) {
 	for _, rawInboundConfig := range inbounds {
 		ic, err := rawInboundConfig.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build inbound config with tag ", rawInboundConfig.Tag).Base(err)
 		}
 		config.Inbound = append(config.Inbound, ic)
 	}
@@ -616,7 +616,7 @@ func (c *Config) Build() (*core.Config, error) {
 	for _, rawOutboundConfig := range outbounds {
 		oc, err := rawOutboundConfig.Build()
 		if err != nil {
-			return nil, err
+			return nil, errors.New("failed to build outbound config with tag ", rawOutboundConfig.Tag).Base(err)
 		}
 		config.Outbound = append(config.Outbound, oc)
 	}

+ 6 - 7
proxy/wireguard/gvisortun/tun.go

@@ -10,6 +10,7 @@ import (
 	"fmt"
 	"net/netip"
 	"os"
+	"sync"
 	"syscall"
 
 	"golang.zx2c4.com/wireguard/tun"
@@ -33,6 +34,7 @@ type netTun struct {
 	incomingPacket chan *buffer.View
 	mtu            int
 	hasV4, hasV6   bool
+	closeOnce      sync.Once
 }
 
 type Net netTun
@@ -174,18 +176,15 @@ func (tun *netTun) Flush() error {
 
 // Close implements tun.Device
 func (tun *netTun) Close() error {
-	tun.stack.RemoveNIC(1)
+	tun.closeOnce.Do(func() {
+		tun.stack.RemoveNIC(1)
 
-	if tun.events != nil {
 		close(tun.events)
-	}
 
-	tun.ep.Close()
+		tun.ep.Close()
 
-	if tun.incomingPacket != nil {
 		close(tun.incomingPacket)
-	}
-
+	})
 	return nil
 }
 

+ 52 - 0
proxy/wireguard/server_test.go

@@ -0,0 +1,52 @@
+package wireguard_test
+
+import (
+	"context"
+	"github.com/stretchr/testify/assert"
+	"runtime/debug"
+	"testing"
+
+	"github.com/xtls/xray-core/core"
+	"github.com/xtls/xray-core/proxy/wireguard"
+)
+
+// TestWireGuardServerInitializationError verifies that an error during TUN initialization
+// (triggered by an empty SecretKey) in the WireGuard server does not cause a panic and returns an error instead.
+func TestWireGuardServerInitializationError(t *testing.T) {
+	// Create a minimal core instance with default features
+	config := &core.Config{}
+	instance, err := core.New(config)
+	if err != nil {
+		t.Fatalf("Failed to create core instance: %v", err)
+	}
+	// Set the Xray instance in the context
+	ctx := context.WithValue(context.Background(), core.XrayKey(1), instance)
+
+	// Define the server configuration with an empty SecretKey to trigger error
+	conf := &wireguard.DeviceConfig{
+		IsClient:  false,
+		Endpoint:  []string{"10.0.0.1/32"},
+		Mtu:       1420,
+		SecretKey: "", // Empty SecretKey to trigger error
+		Peers: []*wireguard.PeerConfig{
+			{
+				PublicKey:  "some_public_key",
+				AllowedIps: []string{"10.0.0.2/32"},
+			},
+		},
+	}
+
+	// Use defer to catch any panic and fail the test explicitly
+	defer func() {
+		if r := recover(); r != nil {
+			t.Errorf("TUN initialization panicked: %v", r)
+			debug.PrintStack()
+		}
+	}()
+
+	// Attempt to initialize the WireGuard server
+	_, err = wireguard.NewServer(ctx, conf)
+
+	// Check that an error is returned
+	assert.ErrorContains(t, err, "failed to set private_key: hex string does not fit the slice")
+}