wireguard.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package conf
  2. import (
  3. "encoding/base64"
  4. "encoding/hex"
  5. "strings"
  6. "github.com/xtls/xray-core/common/errors"
  7. "github.com/xtls/xray-core/proxy/wireguard"
  8. "google.golang.org/protobuf/proto"
  9. )
  10. type WireGuardPeerConfig struct {
  11. PublicKey string `json:"publicKey"`
  12. PreSharedKey string `json:"preSharedKey"`
  13. Endpoint string `json:"endpoint"`
  14. KeepAlive uint32 `json:"keepAlive"`
  15. AllowedIPs []string `json:"allowedIPs,omitempty"`
  16. }
  17. func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
  18. var err error
  19. config := new(wireguard.PeerConfig)
  20. if c.PublicKey != "" {
  21. config.PublicKey, err = ParseWireGuardKey(c.PublicKey)
  22. if err != nil {
  23. return nil, err
  24. }
  25. }
  26. if c.PreSharedKey != "" {
  27. config.PreSharedKey, err = ParseWireGuardKey(c.PreSharedKey)
  28. if err != nil {
  29. return nil, err
  30. }
  31. }
  32. config.Endpoint = c.Endpoint
  33. // default 0
  34. config.KeepAlive = c.KeepAlive
  35. if c.AllowedIPs == nil {
  36. config.AllowedIps = []string{"0.0.0.0/0", "::0/0"}
  37. } else {
  38. config.AllowedIps = c.AllowedIPs
  39. }
  40. return config, nil
  41. }
  42. type WireGuardConfig struct {
  43. IsClient bool `json:""`
  44. NoKernelTun bool `json:"noKernelTun"`
  45. SecretKey string `json:"secretKey"`
  46. Address []string `json:"address"`
  47. Peers []*WireGuardPeerConfig `json:"peers"`
  48. MTU int32 `json:"mtu"`
  49. NumWorkers int32 `json:"workers"`
  50. Reserved []byte `json:"reserved"`
  51. DomainStrategy string `json:"domainStrategy"`
  52. }
  53. func (c *WireGuardConfig) Build() (proto.Message, error) {
  54. config := new(wireguard.DeviceConfig)
  55. var err error
  56. config.SecretKey, err = ParseWireGuardKey(c.SecretKey)
  57. if err != nil {
  58. return nil, err
  59. }
  60. if c.Address == nil {
  61. // bogon ips
  62. config.Endpoint = []string{"10.0.0.1", "fd59:7153:2388:b5fd:0000:0000:0000:0001"}
  63. } else {
  64. config.Endpoint = c.Address
  65. }
  66. if c.Peers != nil {
  67. config.Peers = make([]*wireguard.PeerConfig, len(c.Peers))
  68. for i, p := range c.Peers {
  69. msg, err := p.Build()
  70. if err != nil {
  71. return nil, err
  72. }
  73. config.Peers[i] = msg.(*wireguard.PeerConfig)
  74. }
  75. }
  76. if c.MTU == 0 {
  77. config.Mtu = 1420
  78. } else {
  79. config.Mtu = c.MTU
  80. }
  81. // these a fallback code exists in wireguard-go code,
  82. // we don't need to process fallback manually
  83. config.NumWorkers = c.NumWorkers
  84. if len(c.Reserved) != 0 && len(c.Reserved) != 3 {
  85. return nil, errors.New(`"reserved" should be empty or 3 bytes`)
  86. }
  87. config.Reserved = c.Reserved
  88. switch strings.ToLower(c.DomainStrategy) {
  89. case "forceip", "":
  90. config.DomainStrategy = wireguard.DeviceConfig_FORCE_IP
  91. case "forceipv4":
  92. config.DomainStrategy = wireguard.DeviceConfig_FORCE_IP4
  93. case "forceipv6":
  94. config.DomainStrategy = wireguard.DeviceConfig_FORCE_IP6
  95. case "forceipv4v6":
  96. config.DomainStrategy = wireguard.DeviceConfig_FORCE_IP46
  97. case "forceipv6v4":
  98. config.DomainStrategy = wireguard.DeviceConfig_FORCE_IP64
  99. default:
  100. return nil, errors.New("unsupported domain strategy: ", c.DomainStrategy)
  101. }
  102. config.IsClient = c.IsClient
  103. config.NoKernelTun = c.NoKernelTun
  104. return config, nil
  105. }
  106. func ParseWireGuardKey(str string) (string, error) {
  107. var err error
  108. if len(str)%2 == 0 {
  109. _, err = hex.DecodeString(str)
  110. if err == nil {
  111. return str, nil
  112. }
  113. }
  114. var dat []byte
  115. str = strings.TrimSuffix(str, "=")
  116. if strings.ContainsRune(str, '+') || strings.ContainsRune(str, '/') {
  117. dat, err = base64.RawStdEncoding.DecodeString(str)
  118. } else {
  119. dat, err = base64.RawURLEncoding.DecodeString(str)
  120. }
  121. if err == nil {
  122. return hex.EncodeToString(dat), nil
  123. }
  124. return "", errors.New("failed to deserialize key").Base(err)
  125. }