|
@@ -22,7 +22,9 @@ package wireguard
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ "fmt"
|
|
|
"net/netip"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
|
|
|
"github.com/xtls/xray-core/common"
|
|
@@ -49,7 +51,6 @@ type Handler struct {
|
|
|
policyManager policy.Manager
|
|
|
dns dns.Client
|
|
|
// cached configuration
|
|
|
- ipc string
|
|
|
endpoints []netip.Addr
|
|
|
hasIPv4, hasIPv6 bool
|
|
|
wgLock sync.Mutex
|
|
@@ -69,7 +70,6 @@ func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
|
|
|
conf: conf,
|
|
|
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
|
|
dns: d,
|
|
|
- ipc: createIPCRequest(conf),
|
|
|
endpoints: endpoints,
|
|
|
hasIPv4: hasIPv4,
|
|
|
hasIPv6: hasIPv6,
|
|
@@ -247,9 +247,76 @@ func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
|
|
|
bind.dnsOption.IPv4Enable = h.hasIPv4
|
|
|
bind.dnsOption.IPv6Enable = h.hasIPv6
|
|
|
|
|
|
- if err = t.BuildDevice(h.ipc, bind); err != nil {
|
|
|
+ if err = t.BuildDevice(h.createIPCRequest(bind, h.conf), bind); err != nil {
|
|
|
_ = t.Close()
|
|
|
return nil, err
|
|
|
}
|
|
|
return t, nil
|
|
|
}
|
|
|
+
|
|
|
+
|
|
|
+// serialize the config into an IPC request
|
|
|
+func (h *Handler) createIPCRequest(bind *netBindClient, conf *DeviceConfig) string {
|
|
|
+ var request strings.Builder
|
|
|
+
|
|
|
+ request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
|
|
|
+
|
|
|
+ if !conf.IsClient {
|
|
|
+ // placeholder, we'll handle actual port listening on Xray
|
|
|
+ request.WriteString("listen_port=1337\n")
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, peer := range conf.Peers {
|
|
|
+ if peer.PublicKey != "" {
|
|
|
+ request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey))
|
|
|
+ }
|
|
|
+
|
|
|
+ if peer.PreSharedKey != "" {
|
|
|
+ request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey))
|
|
|
+ }
|
|
|
+
|
|
|
+ split := strings.Split(peer.Endpoint, ":")
|
|
|
+ addr := net.ParseAddress(split[0])
|
|
|
+ if addr.Family().IsDomain() {
|
|
|
+ dialerIp := bind.dialer.DestIpAddress()
|
|
|
+ if dialerIp != nil {
|
|
|
+ addr = net.ParseAddress(dialerIp.String())
|
|
|
+ newError("createIPCRequest use dialer dest ip: ", addr).WriteToLog()
|
|
|
+ } else {
|
|
|
+ ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
|
|
+ IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
|
|
|
+ IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
|
|
|
+ })
|
|
|
+ { // Resolve fallback
|
|
|
+ if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
|
|
|
+ ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
|
|
+ IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
|
|
|
+ IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ newError("createIPCRequest failed to lookup DNS").Base(err).WriteToLog()
|
|
|
+ } else if len(ips) == 0 {
|
|
|
+ newError("createIPCRequest empty lookup DNS").WriteToLog()
|
|
|
+ } else {
|
|
|
+ addr = net.IPAddress(ips[dice.Roll(len(ips))])
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if peer.Endpoint != "" {
|
|
|
+ request.WriteString(fmt.Sprintf("endpoint=%s:%s\n", addr, split[1]))
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, ip := range peer.AllowedIps {
|
|
|
+ request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
|
|
|
+ }
|
|
|
+
|
|
|
+ if peer.KeepAlive != 0 {
|
|
|
+ request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return request.String()[:request.Len()]
|
|
|
+}
|