1
0

device_stack.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. //go:build with_gvisor
  2. package wireguard
  3. import (
  4. "context"
  5. "net"
  6. "net/netip"
  7. "os"
  8. "github.com/sagernet/gvisor/pkg/buffer"
  9. "github.com/sagernet/gvisor/pkg/tcpip"
  10. "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
  11. "github.com/sagernet/gvisor/pkg/tcpip/header"
  12. "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
  13. "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
  14. "github.com/sagernet/gvisor/pkg/tcpip/stack"
  15. "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
  16. "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
  17. "github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
  18. "github.com/sagernet/sing-box/adapter"
  19. "github.com/sagernet/sing-tun"
  20. "github.com/sagernet/sing/common/buf"
  21. E "github.com/sagernet/sing/common/exceptions"
  22. M "github.com/sagernet/sing/common/metadata"
  23. N "github.com/sagernet/sing/common/network"
  24. "github.com/sagernet/wireguard-go/device"
  25. wgTun "github.com/sagernet/wireguard-go/tun"
  26. )
  27. var _ NatDevice = (*stackDevice)(nil)
  28. type stackDevice struct {
  29. stack *stack.Stack
  30. mtu uint32
  31. events chan wgTun.Event
  32. outbound chan *stack.PacketBuffer
  33. packetOutbound chan *buf.Buffer
  34. done chan struct{}
  35. dispatcher stack.NetworkDispatcher
  36. addr4 tcpip.Address
  37. addr6 tcpip.Address
  38. mapping *tun.NatMapping
  39. writer *tun.NatWriter
  40. }
  41. func newStackDevice(options DeviceOptions) (*stackDevice, error) {
  42. tunDevice := &stackDevice{
  43. mtu: options.MTU,
  44. events: make(chan wgTun.Event, 1),
  45. outbound: make(chan *stack.PacketBuffer, 256),
  46. packetOutbound: make(chan *buf.Buffer, 256),
  47. done: make(chan struct{}),
  48. mapping: tun.NewNatMapping(true),
  49. }
  50. ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
  51. if err != nil {
  52. return nil, err
  53. }
  54. for _, prefix := range options.Address {
  55. addr := tun.AddressFromAddr(prefix.Addr())
  56. protoAddr := tcpip.ProtocolAddress{
  57. AddressWithPrefix: tcpip.AddressWithPrefix{
  58. Address: addr,
  59. PrefixLen: prefix.Bits(),
  60. },
  61. }
  62. if prefix.Addr().Is4() {
  63. tunDevice.addr4 = addr
  64. protoAddr.Protocol = ipv4.ProtocolNumber
  65. } else {
  66. tunDevice.addr6 = addr
  67. protoAddr.Protocol = ipv6.ProtocolNumber
  68. }
  69. gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
  70. if gErr != nil {
  71. return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
  72. }
  73. }
  74. tunDevice.writer = tun.NewNatWriter(tunDevice.Inet4Address(), tunDevice.Inet6Address())
  75. tunDevice.stack = ipStack
  76. if options.Handler != nil {
  77. ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
  78. ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
  79. icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
  80. ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
  81. ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
  82. }
  83. return tunDevice, nil
  84. }
  85. func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  86. addr := tcpip.FullAddress{
  87. NIC: tun.DefaultNIC,
  88. Port: destination.Port,
  89. Addr: tun.AddressFromAddr(destination.Addr),
  90. }
  91. bind := tcpip.FullAddress{
  92. NIC: tun.DefaultNIC,
  93. }
  94. var networkProtocol tcpip.NetworkProtocolNumber
  95. if destination.IsIPv4() {
  96. networkProtocol = header.IPv4ProtocolNumber
  97. bind.Addr = w.addr4
  98. } else {
  99. networkProtocol = header.IPv6ProtocolNumber
  100. bind.Addr = w.addr6
  101. }
  102. switch N.NetworkName(network) {
  103. case N.NetworkTCP:
  104. tcpConn, err := DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
  105. if err != nil {
  106. return nil, err
  107. }
  108. return tcpConn, nil
  109. case N.NetworkUDP:
  110. udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
  111. if err != nil {
  112. return nil, err
  113. }
  114. return udpConn, nil
  115. default:
  116. return nil, E.Extend(N.ErrUnknownNetwork, network)
  117. }
  118. }
  119. func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  120. bind := tcpip.FullAddress{
  121. NIC: tun.DefaultNIC,
  122. }
  123. var networkProtocol tcpip.NetworkProtocolNumber
  124. if destination.IsIPv4() {
  125. networkProtocol = header.IPv4ProtocolNumber
  126. bind.Addr = w.addr4
  127. } else {
  128. networkProtocol = header.IPv6ProtocolNumber
  129. bind.Addr = w.addr6
  130. }
  131. udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
  132. if err != nil {
  133. return nil, err
  134. }
  135. return udpConn, nil
  136. }
  137. func (w *stackDevice) Inet4Address() netip.Addr {
  138. return netip.AddrFrom4(w.addr4.As4())
  139. }
  140. func (w *stackDevice) Inet6Address() netip.Addr {
  141. return netip.AddrFrom16(w.addr6.As16())
  142. }
  143. func (w *stackDevice) SetDevice(device *device.Device) {
  144. }
  145. func (w *stackDevice) Start() error {
  146. w.events <- wgTun.EventUp
  147. return nil
  148. }
  149. func (w *stackDevice) File() *os.File {
  150. return nil
  151. }
  152. func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
  153. select {
  154. case packet, ok := <-w.outbound:
  155. if !ok {
  156. return 0, os.ErrClosed
  157. }
  158. defer packet.DecRef()
  159. var copyN int
  160. /*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) {
  161. copyN += copy(bufs[0][offset+copyN:], view.AsSlice())
  162. })*/
  163. for _, view := range packet.AsSlices() {
  164. copyN += copy(bufs[0][offset+copyN:], view)
  165. }
  166. sizes[0] = copyN
  167. return 1, nil
  168. case packet := <-w.packetOutbound:
  169. defer packet.Release()
  170. sizes[0] = copy(bufs[0][offset:], packet.Bytes())
  171. return 1, nil
  172. case <-w.done:
  173. return 0, os.ErrClosed
  174. }
  175. }
  176. func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
  177. for _, b := range bufs {
  178. b = b[offset:]
  179. if len(b) == 0 {
  180. continue
  181. }
  182. handled, err := w.mapping.WritePacket(b)
  183. if handled {
  184. if err != nil {
  185. return count, err
  186. }
  187. count++
  188. continue
  189. }
  190. var networkProtocol tcpip.NetworkProtocolNumber
  191. switch header.IPVersion(b) {
  192. case header.IPv4Version:
  193. networkProtocol = header.IPv4ProtocolNumber
  194. case header.IPv6Version:
  195. networkProtocol = header.IPv6ProtocolNumber
  196. }
  197. packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
  198. Payload: buffer.MakeWithData(b),
  199. })
  200. w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
  201. packetBuffer.DecRef()
  202. count++
  203. }
  204. return
  205. }
  206. func (w *stackDevice) Flush() error {
  207. return nil
  208. }
  209. func (w *stackDevice) MTU() (int, error) {
  210. return int(w.mtu), nil
  211. }
  212. func (w *stackDevice) Name() (string, error) {
  213. return "sing-box", nil
  214. }
  215. func (w *stackDevice) Events() <-chan wgTun.Event {
  216. return w.events
  217. }
  218. func (w *stackDevice) Close() error {
  219. close(w.done)
  220. close(w.events)
  221. w.stack.Close()
  222. for _, endpoint := range w.stack.CleanupEndpoints() {
  223. endpoint.Abort()
  224. }
  225. w.stack.Wait()
  226. return nil
  227. }
  228. func (w *stackDevice) BatchSize() int {
  229. return 1
  230. }
  231. var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
  232. type wireEndpoint stackDevice
  233. func (ep *wireEndpoint) MTU() uint32 {
  234. return ep.mtu
  235. }
  236. func (ep *wireEndpoint) SetMTU(mtu uint32) {
  237. }
  238. func (ep *wireEndpoint) MaxHeaderLength() uint16 {
  239. return 0
  240. }
  241. func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
  242. return ""
  243. }
  244. func (ep *wireEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
  245. }
  246. func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
  247. return stack.CapabilityRXChecksumOffload
  248. }
  249. func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
  250. ep.dispatcher = dispatcher
  251. }
  252. func (ep *wireEndpoint) IsAttached() bool {
  253. return ep.dispatcher != nil
  254. }
  255. func (ep *wireEndpoint) Wait() {
  256. }
  257. func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
  258. return header.ARPHardwareNone
  259. }
  260. func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
  261. }
  262. func (ep *wireEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
  263. return true
  264. }
  265. func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
  266. for _, packetBuffer := range list.AsSlice() {
  267. packetBuffer.IncRef()
  268. select {
  269. case <-ep.done:
  270. return 0, &tcpip.ErrClosedForSend{}
  271. case ep.outbound <- packetBuffer:
  272. }
  273. }
  274. return list.Len(), nil
  275. }
  276. func (ep *wireEndpoint) Close() {
  277. }
  278. func (ep *wireEndpoint) SetOnCloseAction(f func()) {
  279. }
  280. func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) {
  281. /* var wq waiter.Queue
  282. ep, err := raw.NewEndpoint(w.stack, ipv4.ProtocolNumber, icmp.ProtocolNumber4, &wq)
  283. if err != nil {
  284. return nil, E.Cause(gonet.TranslateNetstackError(err), "create endpoint")
  285. }
  286. err = ep.Connect(tcpip.FullAddress{
  287. NIC: tun.DefaultNIC,
  288. Port: metadata.Destination.Port,
  289. Addr: tun.AddressFromAddr(metadata.Destination.Addr),
  290. })
  291. if err != nil {
  292. ep.Close()
  293. return nil, E.Cause(gonet.TranslateNetstackError(err), "ICMP connect ", metadata.Destination)
  294. }
  295. fmt.Println("linked ", metadata.Network, " connection to ", metadata.Destination.AddrString())
  296. destination := &endpointNatDestination{
  297. ep: ep,
  298. wq: &wq,
  299. context: routeContext,
  300. }
  301. go destination.loopRead()
  302. return destination, nil*/
  303. session := tun.DirectRouteSession{
  304. Source: metadata.Source.Addr,
  305. Destination: metadata.Destination.Addr,
  306. }
  307. w.mapping.CreateSession(session, routeContext)
  308. return &stackNatDestination{
  309. device: w,
  310. session: session,
  311. }, nil
  312. }
  313. type stackNatDestination struct {
  314. device *stackDevice
  315. session tun.DirectRouteSession
  316. }
  317. func (d *stackNatDestination) WritePacket(buffer *buf.Buffer) error {
  318. if d.device.writer != nil {
  319. d.device.writer.RewritePacket(buffer.Bytes())
  320. }
  321. d.device.packetOutbound <- buffer
  322. return nil
  323. }
  324. func (d *stackNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error {
  325. if d.device.writer != nil {
  326. d.device.writer.RewritePacketBuffer(buffer)
  327. }
  328. d.device.outbound <- buffer
  329. return nil
  330. }
  331. func (d *stackNatDestination) Close() error {
  332. d.device.mapping.DeleteSession(d.session)
  333. return nil
  334. }
  335. func (d *stackNatDestination) Timeout() bool {
  336. return false
  337. }
  338. /*type endpointNatDestination struct {
  339. ep tcpip.Endpoint
  340. wq *waiter.Queue
  341. networkProto tcpip.NetworkProtocolNumber
  342. context tun.DirectRouteContext
  343. done chan struct{}
  344. }
  345. func (d *endpointNatDestination) loopRead() {
  346. for {
  347. println("start read")
  348. buffer, err := commonRead(d.ep, d.wq, d.done)
  349. if err != nil {
  350. log.Error(err)
  351. return
  352. }
  353. println("done read")
  354. ipHdr := header.IPv4(buffer.Bytes())
  355. if ipHdr.TransportProtocol() != header.ICMPv4ProtocolNumber {
  356. buffer.Release()
  357. continue
  358. }
  359. icmpHdr := header.ICMPv4(ipHdr.Payload())
  360. if icmpHdr.Type() != header.ICMPv4EchoReply {
  361. buffer.Release()
  362. continue
  363. }
  364. fmt.Println("read echo reply")
  365. _ = d.context.WritePacket(ipHdr)
  366. buffer.Release()
  367. }
  368. }
  369. func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, done chan struct{}) (*buf.Buffer, error) {
  370. buffer := buf.NewPacket()
  371. result, err := ep.Read(buffer, tcpip.ReadOptions{})
  372. if err != nil {
  373. if _, ok := err.(*tcpip.ErrWouldBlock); ok {
  374. waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents)
  375. wq.EventRegister(&waitEntry)
  376. defer wq.EventUnregister(&waitEntry)
  377. for {
  378. result, err = ep.Read(buffer, tcpip.ReadOptions{})
  379. if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
  380. break
  381. }
  382. select {
  383. case <-notifyCh:
  384. case <-done:
  385. buffer.Release()
  386. return nil, context.DeadlineExceeded
  387. }
  388. }
  389. }
  390. return nil, gonet.TranslateNetstackError(err)
  391. }
  392. buffer.Truncate(result.Count)
  393. return buffer, nil
  394. }
  395. func (d *endpointNatDestination) WritePacket(buffer *buf.Buffer) error {
  396. _, err := d.ep.Write(buffer, tcpip.WriteOptions{})
  397. if err != nil {
  398. return gonet.TranslateNetstackError(err)
  399. }
  400. return nil
  401. }
  402. func (d *endpointNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error {
  403. data := buffer.ToView().AsSlice()
  404. println("write echo request buffer :" + fmt.Sprint(data))
  405. _, err := d.ep.Write(bytes.NewReader(data), tcpip.WriteOptions{})
  406. if err != nil {
  407. log.Error(err)
  408. return gonet.TranslateNetstackError(err)
  409. }
  410. return nil
  411. }
  412. func (d *endpointNatDestination) Close() error {
  413. d.ep.Abort()
  414. close(d.done)
  415. return nil
  416. }
  417. func (d *endpointNatDestination) Timeout() bool {
  418. return false
  419. }
  420. */