| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- // Package ping allows sending ICMP echo requests to a host in order to
- // determine network latency.
- package ping
- import (
- "bytes"
- "context"
- "crypto/rand"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "log"
- "net"
- "net/netip"
- "sync"
- "sync/atomic"
- "time"
- "golang.org/x/net/icmp"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
- "tailscale.com/syncs"
- "tailscale.com/types/logger"
- "tailscale.com/util/mak"
- )
- const (
- v4Type = "ip4:icmp"
- v6Type = "ip6:icmp"
- )
- type response struct {
- t time.Time
- err error
- }
- type outstanding struct {
- ch chan response
- data []byte
- }
- // PacketListener defines the interface required to listen to packages
- // on an address.
- type ListenPacketer interface {
- ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error)
- }
- // Pinger represents a set of ICMP echo requests to be sent at a single time.
- //
- // A new instance should be created for each concurrent set of ping requests;
- // this type should not be reused.
- type Pinger struct {
- lp ListenPacketer
- // closed guards against send incrementing the waitgroup concurrently with close.
- closed atomic.Bool
- Logf logger.Logf
- Verbose bool
- timeNow func() time.Time
- id uint16 // uint16 per RFC 792
- wg sync.WaitGroup
- // Following fields protected by mu
- mu syncs.Mutex
- // conns is a map of "type" to net.PacketConn, type is either
- // "ip4:icmp" or "ip6:icmp"
- conns map[string]net.PacketConn
- seq uint16 // uint16 per RFC 792
- pings map[uint16]outstanding
- }
- // New creates a new Pinger. The Context provided will be used to create
- // network listeners, and to set an absolute deadline (if any) on the net.Conn
- func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger {
- var id [2]byte
- if _, err := io.ReadFull(rand.Reader, id[:]); err != nil {
- panic("net/ping: New:" + err.Error())
- }
- return &Pinger{
- lp: lp,
- Logf: logf,
- timeNow: time.Now,
- id: binary.LittleEndian.Uint16(id[:]),
- pings: make(map[uint16]outstanding),
- }
- }
- func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) {
- if p.closed.Load() {
- return nil, net.ErrClosed
- }
- c, err := p.lp.ListenPacket(ctx, typ, addr)
- if err != nil {
- return nil, err
- }
- // Start by setting the deadline from the context; note that this
- // applies to all future I/O, so we only need to do it once.
- deadline, ok := ctx.Deadline()
- if ok {
- if err := c.SetReadDeadline(deadline); err != nil {
- return nil, err
- }
- }
- p.wg.Add(1)
- go p.run(ctx, c, typ)
- return c, err
- }
- // getConn creates or returns a conn matching typ which is ip4:icmp
- // or ip6:icmp.
- func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
- if c, ok := p.conns[typ]; ok {
- return c, nil
- }
- var addr = "0.0.0.0"
- if typ == v6Type {
- addr = "::"
- }
- c, err := p.mkconn(ctx, typ, addr)
- if err != nil {
- return nil, err
- }
- mak.Set(&p.conns, typ, c)
- return c, nil
- }
- func (p *Pinger) logf(format string, a ...any) {
- if p.Logf != nil {
- p.Logf(format, a...)
- } else {
- log.Printf(format, a...)
- }
- }
- func (p *Pinger) vlogf(format string, a ...any) {
- if p.Verbose {
- p.logf(format, a...)
- }
- }
- func (p *Pinger) Close() error {
- p.closed.Store(true)
- p.mu.Lock()
- conns := p.conns
- p.conns = nil
- p.mu.Unlock()
- var errs []error
- for _, c := range conns {
- if err := c.Close(); err != nil {
- errs = append(errs, err)
- }
- }
- p.wg.Wait()
- p.cleanupOutstanding()
- return errors.Join(errs...)
- }
- func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) {
- defer p.wg.Done()
- defer func() {
- conn.Close()
- p.mu.Lock()
- delete(p.conns, typ)
- p.mu.Unlock()
- }()
- buf := make([]byte, 1500)
- loop:
- for {
- select {
- case <-ctx.Done():
- break loop
- default:
- }
- n, _, err := conn.ReadFrom(buf)
- if err != nil {
- // Ignore temporary errors; everything else is fatal
- if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() {
- break
- }
- continue
- }
- p.handleResponse(buf[:n], p.timeNow(), typ)
- }
- }
- func (p *Pinger) cleanupOutstanding() {
- // Complete outstanding requests
- p.mu.Lock()
- defer p.mu.Unlock()
- for _, o := range p.pings {
- o.ch <- response{err: net.ErrClosed}
- }
- }
- func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) {
- // We need to handle responding to both IPv4
- // and IPv6.
- var icmpType icmp.Type
- switch typ {
- case v4Type:
- icmpType = ipv4.ICMPTypeEchoReply
- case v6Type:
- icmpType = ipv6.ICMPTypeEchoReply
- default:
- p.vlogf("handleResponse: unknown icmp.Type")
- return
- }
- m, err := icmp.ParseMessage(icmpType.Protocol(), buf)
- if err != nil {
- p.vlogf("handleResponse: invalid packet: %v", err)
- return
- }
- if m.Type != icmpType {
- p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type)
- return
- }
- resp, ok := m.Body.(*icmp.Echo)
- if !ok || resp == nil {
- p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body)
- return
- }
- // We assume we sent this if the ID in the response is ours.
- if uint16(resp.ID) != p.id {
- p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID)
- return
- }
- // Search for existing running echo request
- var o outstanding
- p.mu.Lock()
- if o, ok = p.pings[uint16(resp.Seq)]; ok {
- // Ensure that the data matches before we delete from our map,
- // so a future correct packet will be handled correctly.
- if bytes.Equal(resp.Data, o.data) {
- delete(p.pings, uint16(resp.Seq))
- } else {
- p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq)
- ok = false
- }
- } else {
- p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq)
- }
- p.mu.Unlock()
- if ok {
- o.ch <- response{t: now}
- }
- }
- // Send sends an ICMP Echo Request packet to the destination, waits for a
- // response, and returns the duration between when the request was sent and
- // when the reply was received.
- //
- // If provided, "data" is sent with the packet and is compared upon receiving a
- // reply.
- func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) {
- // Use sequential sequence numbers on the assumption that we will not
- // wrap around when using a single Pinger instance
- p.mu.Lock()
- p.seq++
- seq := p.seq
- p.mu.Unlock()
- // Check whether the address is IPv4 or IPv6 to
- // determine the icmp.Type and conn to use.
- var conn net.PacketConn
- var icmpType icmp.Type = ipv4.ICMPTypeEcho
- ap, err := netip.ParseAddr(dest.String())
- if err != nil {
- return 0, err
- }
- if ap.Is6() {
- icmpType = ipv6.ICMPTypeEchoRequest
- conn, err = p.getConn(ctx, v6Type)
- } else {
- conn, err = p.getConn(ctx, v4Type)
- }
- if err != nil {
- return 0, err
- }
- m := icmp.Message{
- Type: icmpType,
- Code: 0,
- Body: &icmp.Echo{
- ID: int(p.id),
- Seq: int(seq),
- Data: data,
- },
- }
- b, err := m.Marshal(nil)
- if err != nil {
- return 0, err
- }
- // Register our response before sending since we could otherwise race a
- // quick reply.
- ch := make(chan response, 1)
- p.mu.Lock()
- p.pings[seq] = outstanding{ch: ch, data: data}
- p.mu.Unlock()
- start := p.timeNow()
- n, err := conn.WriteTo(b, dest)
- if err != nil {
- return 0, err
- } else if n != len(b) {
- return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b))
- }
- select {
- case resp := <-ch:
- if resp.err != nil {
- return 0, resp.err
- }
- return resp.t.Sub(start), nil
- case <-ctx.Done():
- return 0, ctx.Err()
- }
- }
|