|
|
@@ -0,0 +1,377 @@
|
|
|
+package xicmp
|
|
|
+
|
|
|
+import (
|
|
|
+ "context"
|
|
|
+ "io"
|
|
|
+ "net"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/xtls/xray-core/common/crypto"
|
|
|
+ "github.com/xtls/xray-core/common/errors"
|
|
|
+ "golang.org/x/net/icmp"
|
|
|
+ "golang.org/x/net/ipv4"
|
|
|
+ "golang.org/x/net/ipv6"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ idleTimeout = 2 * time.Minute
|
|
|
+ maxResponseDelay = 1 * time.Second
|
|
|
+)
|
|
|
+
|
|
|
+type record struct {
|
|
|
+ id int
|
|
|
+ seq int
|
|
|
+ needSeqByte bool
|
|
|
+ seqByte byte
|
|
|
+ addr net.Addr
|
|
|
+}
|
|
|
+
|
|
|
+type queue struct {
|
|
|
+ lash time.Time
|
|
|
+ queue chan []byte
|
|
|
+}
|
|
|
+
|
|
|
+type xicmpConnServer struct {
|
|
|
+ conn net.PacketConn
|
|
|
+ icmpConn *icmp.PacketConn
|
|
|
+
|
|
|
+ typ icmp.Type
|
|
|
+ proto int
|
|
|
+ config *Config
|
|
|
+
|
|
|
+ ch chan *record
|
|
|
+ readQueue chan *packet
|
|
|
+ writeQueueMap map[string]*queue
|
|
|
+
|
|
|
+ closed bool
|
|
|
+ mutex sync.Mutex
|
|
|
+}
|
|
|
+
|
|
|
+func NewConnServer(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) {
|
|
|
+ if !end {
|
|
|
+ return nil, errors.New("xicmp requires being at the outermost level")
|
|
|
+ }
|
|
|
+
|
|
|
+ network := "ip4:icmp"
|
|
|
+ typ := icmp.Type(ipv4.ICMPTypeEchoReply)
|
|
|
+ proto := 1
|
|
|
+ if strings.Contains(c.Ip, ":") {
|
|
|
+ network = "ip6:ipv6-icmp"
|
|
|
+ typ = ipv6.ICMPTypeEchoReply
|
|
|
+ proto = 58
|
|
|
+ }
|
|
|
+
|
|
|
+ icmpConn, err := icmp.ListenPacket(network, c.Ip)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.New("xicmp listen err").Base(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ conn := &xicmpConnServer{
|
|
|
+ conn: raw,
|
|
|
+ icmpConn: icmpConn,
|
|
|
+
|
|
|
+ typ: typ,
|
|
|
+ proto: proto,
|
|
|
+ config: c,
|
|
|
+
|
|
|
+ ch: make(chan *record, 100),
|
|
|
+ readQueue: make(chan *packet, 128),
|
|
|
+ writeQueueMap: make(map[string]*queue),
|
|
|
+ }
|
|
|
+
|
|
|
+ go conn.clean()
|
|
|
+ go conn.recvLoop()
|
|
|
+ go conn.sendLoop()
|
|
|
+
|
|
|
+ return conn, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) clean() {
|
|
|
+ f := func() bool {
|
|
|
+ c.mutex.Lock()
|
|
|
+ defer c.mutex.Unlock()
|
|
|
+
|
|
|
+ if c.closed {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ now := time.Now()
|
|
|
+
|
|
|
+ for key, q := range c.writeQueueMap {
|
|
|
+ if now.Sub(q.lash) >= idleTimeout {
|
|
|
+ close(q.queue)
|
|
|
+ delete(c.writeQueueMap, key)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ for {
|
|
|
+ time.Sleep(idleTimeout / 2)
|
|
|
+ if f() {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) ensureQueue(addr net.Addr) *queue {
|
|
|
+ c.mutex.Lock()
|
|
|
+ defer c.mutex.Unlock()
|
|
|
+
|
|
|
+ if c.closed {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ q, ok := c.writeQueueMap[addr.String()]
|
|
|
+ if !ok {
|
|
|
+ q = &queue{
|
|
|
+ queue: make(chan []byte, 128),
|
|
|
+ }
|
|
|
+ c.writeQueueMap[addr.String()] = q
|
|
|
+ }
|
|
|
+ q.lash = time.Now()
|
|
|
+
|
|
|
+ return q
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) encode(p []byte, id int, seq int, needSeqByte bool, seqByte byte) ([]byte, error) {
|
|
|
+ data := p
|
|
|
+ if needSeqByte {
|
|
|
+ b2 := c.randUntil(seqByte)
|
|
|
+ data = append([]byte{b2}, p...)
|
|
|
+ }
|
|
|
+
|
|
|
+ msg := icmp.Message{
|
|
|
+ Type: c.typ,
|
|
|
+ Code: 0,
|
|
|
+ Body: &icmp.Echo{
|
|
|
+ ID: id,
|
|
|
+ Seq: seq,
|
|
|
+ Data: data,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ buf, err := msg.Marshal(nil)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(buf) > 8192 {
|
|
|
+ return nil, errors.New("xicmp len(buf) > 8192")
|
|
|
+ }
|
|
|
+
|
|
|
+ return buf, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) randUntil(b1 byte) byte {
|
|
|
+ b2 := byte(crypto.RandBetween(0, 255))
|
|
|
+ for {
|
|
|
+ if b2 != b1 {
|
|
|
+ return b2
|
|
|
+ }
|
|
|
+ b2 = byte(crypto.RandBetween(0, 255))
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) recvLoop() {
|
|
|
+ for {
|
|
|
+ if c.closed {
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ var buf [8192]byte
|
|
|
+
|
|
|
+ n, addr, err := c.icmpConn.ReadFrom(buf[:])
|
|
|
+ if err != nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ msg, err := icmp.ParseMessage(c.proto, buf[:n])
|
|
|
+ if err != nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if msg.Type != ipv4.ICMPTypeEcho && msg.Type != ipv6.ICMPTypeEchoRequest {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ echo, ok := msg.Body.(*icmp.Echo)
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if c.config.Id != 0 && echo.ID != int(c.config.Id) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ needSeqByte := false
|
|
|
+ var seqByte byte
|
|
|
+
|
|
|
+ if len(echo.Data) > 0 {
|
|
|
+ needSeqByte = true
|
|
|
+ seqByte = echo.Data[0]
|
|
|
+
|
|
|
+ buf := make([]byte, len(echo.Data))
|
|
|
+ copy(buf, echo.Data)
|
|
|
+ select {
|
|
|
+ case c.readQueue <- &packet{
|
|
|
+ p: buf,
|
|
|
+ addr: &net.UDPAddr{
|
|
|
+ IP: addr.(*net.IPAddr).IP,
|
|
|
+ Port: echo.ID,
|
|
|
+ },
|
|
|
+ }:
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ select {
|
|
|
+ case c.ch <- &record{
|
|
|
+ id: echo.ID,
|
|
|
+ seq: echo.Seq,
|
|
|
+ needSeqByte: needSeqByte,
|
|
|
+ seqByte: seqByte,
|
|
|
+ addr: &net.UDPAddr{
|
|
|
+ IP: addr.(*net.IPAddr).IP,
|
|
|
+ Port: echo.ID,
|
|
|
+ },
|
|
|
+ }:
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ close(c.ch)
|
|
|
+ close(c.readQueue)
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) sendLoop() {
|
|
|
+ var nextRec *record
|
|
|
+ for {
|
|
|
+ rec := nextRec
|
|
|
+ nextRec = nil
|
|
|
+
|
|
|
+ if rec == nil {
|
|
|
+ var ok bool
|
|
|
+ rec, ok = <-c.ch
|
|
|
+ if !ok {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ queue := c.ensureQueue(rec.addr)
|
|
|
+ if queue == nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ var p []byte
|
|
|
+
|
|
|
+ timer := time.NewTimer(maxResponseDelay)
|
|
|
+
|
|
|
+ select {
|
|
|
+ case p = <-queue.queue:
|
|
|
+ default:
|
|
|
+ select {
|
|
|
+ case p = <-queue.queue:
|
|
|
+ case <-timer.C:
|
|
|
+ case nextRec = <-c.ch:
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ timer.Stop()
|
|
|
+
|
|
|
+ if len(p) == 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ buf, err := c.encode(p, rec.id, rec.seq, rec.needSeqByte, rec.seqByte)
|
|
|
+ if err != nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if c.closed {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err = c.icmpConn.WriteTo(buf, &net.IPAddr{IP: rec.addr.(*net.UDPAddr).IP})
|
|
|
+ if err != nil {
|
|
|
+ errors.LogDebug(context.Background(), "xicmp writeto err ", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) Size() int32 {
|
|
|
+ return 0
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
|
+ packet, ok := <-c.readQueue
|
|
|
+ if !ok {
|
|
|
+ return 0, nil, io.EOF
|
|
|
+ }
|
|
|
+ n = copy(p, packet.p)
|
|
|
+ if n != len(packet.p) {
|
|
|
+ return 0, nil, io.ErrShortBuffer
|
|
|
+ }
|
|
|
+ return n, packet.addr, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
|
+ q := c.ensureQueue(addr)
|
|
|
+ if q == nil {
|
|
|
+ return 0, errors.New("xicmp closed")
|
|
|
+ }
|
|
|
+
|
|
|
+ c.mutex.Lock()
|
|
|
+ defer c.mutex.Unlock()
|
|
|
+
|
|
|
+ if c.closed {
|
|
|
+ return 0, errors.New("xicmp closed")
|
|
|
+ }
|
|
|
+
|
|
|
+ buf := make([]byte, len(p))
|
|
|
+ copy(buf, p)
|
|
|
+
|
|
|
+ select {
|
|
|
+ case q.queue <- buf:
|
|
|
+ return len(p), nil
|
|
|
+ default:
|
|
|
+ return 0, errors.New("xicmp queue full")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) Close() error {
|
|
|
+ c.mutex.Lock()
|
|
|
+ defer c.mutex.Unlock()
|
|
|
+
|
|
|
+ if c.closed {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ c.closed = true
|
|
|
+ for key, q := range c.writeQueueMap {
|
|
|
+ close(q.queue)
|
|
|
+ delete(c.writeQueueMap, key)
|
|
|
+ }
|
|
|
+
|
|
|
+ _ = c.icmpConn.Close()
|
|
|
+ return c.conn.Close()
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) LocalAddr() net.Addr {
|
|
|
+ return &net.UDPAddr{IP: c.icmpConn.LocalAddr().(*net.IPAddr).IP}
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) SetDeadline(t time.Time) error {
|
|
|
+ return c.icmpConn.SetDeadline(t)
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) SetReadDeadline(t time.Time) error {
|
|
|
+ return c.icmpConn.SetReadDeadline(t)
|
|
|
+}
|
|
|
+
|
|
|
+func (c *xicmpConnServer) SetWriteDeadline(t time.Time) error {
|
|
|
+ return c.icmpConn.SetWriteDeadline(t)
|
|
|
+}
|