tun.go 7.2 KB


  1. /* SPDX-License-Identifier: MIT
  2. *
  3. * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
  4. */
  5. package wireguard
  6. import (
  7. "context"
  8. "fmt"
  9. "net"
  10. "net/netip"
  11. "os"
  12. "github.com/sagernet/wireguard-go/tun"
  13. "github.com/xtls/xray-core/features/dns"
  14. "gvisor.dev/gvisor/pkg/bufferv2"
  15. "gvisor.dev/gvisor/pkg/tcpip"
  16. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  17. "gvisor.dev/gvisor/pkg/tcpip/header"
  18. "gvisor.dev/gvisor/pkg/tcpip/link/channel"
  19. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  20. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  21. "gvisor.dev/gvisor/pkg/tcpip/stack"
  22. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  23. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  24. )
  25. type netTun struct {
  26. ep *channel.Endpoint
  27. stack *stack.Stack
  28. events chan tun.Event
  29. incomingPacket chan *bufferv2.View
  30. mtu int
  31. dnsClient dns.Client
  32. hasV4, hasV6 bool
  33. }
  34. type Net netTun
  35. func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) {
  36. opts := stack.Options{
  37. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  38. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
  39. HandleLocal: true,
  40. }
  41. dev := &netTun{
  42. ep: channel.New(1024, uint32(mtu), ""),
  43. stack: stack.New(opts),
  44. events: make(chan tun.Event, 10),
  45. incomingPacket: make(chan *bufferv2.View),
  46. dnsClient: dnsClient,
  47. mtu: mtu,
  48. }
  49. dev.ep.AddNotify(dev)
  50. tcpipErr := dev.stack.CreateNIC(1, dev.ep)
  51. if tcpipErr != nil {
  52. return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
  53. }
  54. for _, ip := range localAddresses {
  55. var protoNumber tcpip.NetworkProtocolNumber
  56. if ip.Is4() {
  57. protoNumber = ipv4.ProtocolNumber
  58. } else if ip.Is6() {
  59. protoNumber = ipv6.ProtocolNumber
  60. }
  61. protoAddr := tcpip.ProtocolAddress{
  62. Protocol: protoNumber,
  63. AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
  64. }
  65. tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
  66. if tcpipErr != nil {
  67. return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
  68. }
  69. if ip.Is4() {
  70. dev.hasV4 = true
  71. } else if ip.Is6() {
  72. dev.hasV6 = true
  73. }
  74. }
  75. if dev.hasV4 {
  76. dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
  77. }
  78. if dev.hasV6 {
  79. dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
  80. }
  81. dev.events <- tun.EventUp
  82. return dev, (*Net)(dev), nil
  83. }
  84. func (tun *netTun) Name() (string, error) {
  85. return "go", nil
  86. }
  87. func (tun *netTun) File() *os.File {
  88. return nil
  89. }
  90. func (tun *netTun) Events() chan tun.Event {
  91. return tun.events
  92. }
  93. func (tun *netTun) Read(buf []byte, offset int) (int, error) {
  94. view, ok := <-tun.incomingPacket
  95. if !ok {
  96. return 0, os.ErrClosed
  97. }
  98. return view.Read(buf[offset:])
  99. }
  100. func (tun *netTun) Write(buf []byte, offset int) (int, error) {
  101. packet := buf[offset:]
  102. if len(packet) == 0 {
  103. return 0, nil
  104. }
  105. pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
  106. switch packet[0] >> 4 {
  107. case 4:
  108. tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
  109. case 6:
  110. tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
  111. }
  112. return len(buf), nil
  113. }
  114. func (tun *netTun) WriteNotify() {
  115. pkt := tun.ep.Read()
  116. if pkt == nil {
  117. return
  118. }
  119. view := pkt.ToView()
  120. pkt.DecRef()
  121. tun.incomingPacket <- view
  122. }
  123. func (tun *netTun) Flush() error {
  124. return nil
  125. }
  126. func (tun *netTun) Close() error {
  127. tun.stack.RemoveNIC(1)
  128. if tun.events != nil {
  129. close(tun.events)
  130. }
  131. tun.ep.Close()
  132. if tun.incomingPacket != nil {
  133. close(tun.incomingPacket)
  134. }
  135. return nil
  136. }
  137. func (tun *netTun) MTU() (int, error) {
  138. return tun.mtu, nil
  139. }
  140. func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
  141. var protoNumber tcpip.NetworkProtocolNumber
  142. if endpoint.Addr().Is4() {
  143. protoNumber = ipv4.ProtocolNumber
  144. } else {
  145. protoNumber = ipv6.ProtocolNumber
  146. }
  147. return tcpip.FullAddress{
  148. NIC: 1,
  149. Addr: tcpip.Address(endpoint.Addr().AsSlice()),
  150. Port: endpoint.Port(),
  151. }, protoNumber
  152. }
  153. func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
  154. fa, pn := convertToFullAddr(addr)
  155. return gonet.DialContextTCP(ctx, net.stack, fa, pn)
  156. }
  157. func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
  158. if addr == nil {
  159. return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
  160. }
  161. ip, _ := netip.AddrFromSlice(addr.IP)
  162. return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
  163. }
  164. func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
  165. fa, pn := convertToFullAddr(addr)
  166. return gonet.DialTCP(net.stack, fa, pn)
  167. }
  168. func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
  169. if addr == nil {
  170. return net.DialTCPAddrPort(netip.AddrPort{})
  171. }
  172. ip, _ := netip.AddrFromSlice(addr.IP)
  173. return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
  174. }
  175. func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
  176. fa, pn := convertToFullAddr(addr)
  177. return gonet.ListenTCP(net.stack, fa, pn)
  178. }
  179. func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
  180. if addr == nil {
  181. return net.ListenTCPAddrPort(netip.AddrPort{})
  182. }
  183. ip, _ := netip.AddrFromSlice(addr.IP)
  184. return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
  185. }
  186. func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
  187. var lfa, rfa *tcpip.FullAddress
  188. var pn tcpip.NetworkProtocolNumber
  189. if laddr.IsValid() || laddr.Port() > 0 {
  190. var addr tcpip.FullAddress
  191. addr, pn = convertToFullAddr(laddr)
  192. lfa = &addr
  193. }
  194. if raddr.IsValid() || raddr.Port() > 0 {
  195. var addr tcpip.FullAddress
  196. addr, pn = convertToFullAddr(raddr)
  197. rfa = &addr
  198. }
  199. return gonet.DialUDP(net.stack, lfa, rfa, pn)
  200. }
  201. func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
  202. return net.DialUDPAddrPort(laddr, netip.AddrPort{})
  203. }
  204. func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
  205. var la, ra netip.AddrPort
  206. if laddr != nil {
  207. ip, _ := netip.AddrFromSlice(laddr.IP)
  208. la = netip.AddrPortFrom(ip, uint16(laddr.Port))
  209. }
  210. if raddr != nil {
  211. ip, _ := netip.AddrFromSlice(raddr.IP)
  212. ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
  213. }
  214. return net.DialUDPAddrPort(la, ra)
  215. }
  216. func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
  217. return net.DialUDP(laddr, nil)
  218. }
  219. func (n *Net) HasV4() bool {
  220. return n.hasV4
  221. }
  222. func (n *Net) HasV6() bool {
  223. return n.hasV6
  224. }
  225. func IsDomainName(s string) bool {
  226. l := len(s)
  227. if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
  228. return false
  229. }
  230. last := byte('.')
  231. nonNumeric := false
  232. partlen := 0
  233. for i := 0; i < len(s); i++ {
  234. c := s[i]
  235. switch {
  236. default:
  237. return false
  238. case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
  239. nonNumeric = true
  240. partlen++
  241. case '0' <= c && c <= '9':
  242. partlen++
  243. case c == '-':
  244. if last == '.' {
  245. return false
  246. }
  247. partlen++
  248. nonNumeric = true
  249. case c == '.':
  250. if last == '.' || last == '-' {
  251. return false
  252. }
  253. if partlen > 63 || partlen == 0 {
  254. return false
  255. }
  256. partlen = 0
  257. }
  258. last = c
  259. }
  260. if last == '-' || partlen > 63 {
  261. return false
  262. }
  263. return nonNumeric
  264. }