Przeglądaj źródła

Merge remote-tracking branch 'origin/main' into HEAD

* origin/main:
  net/packet: documentation pass.
  net/packet: remove NewIP, offer only a netaddr constructor.
  net/packet: documentation cleanups.
  net/packet: fix panic on invalid IHL field.
  net/packet: remove {get,put}{16,32} indirection to encoding/binary.
  net/packet: support full IPv6 decoding.
  net/packet: add IPv6 source and destination IPs to Parsed.
Avery Pennarun 5 lat temu
rodzic
commit
563d43b2a5

+ 16 - 9
net/packet/header.go

@@ -16,27 +16,34 @@ const tcpHeaderLength = 20
 const maxPacketLength = math.MaxUint16
 
 var (
+	// errSmallBuffer is returned when Marshal receives a buffer
+	// too small to contain the header to marshal.
 	errSmallBuffer = errors.New("buffer too small")
+	// errLargePacket is returned when Marshal receives a payload
+	// larger than the maximum representable size in header
+	// fields.
 	errLargePacket = errors.New("packet too large")
 )
 
-// Header is a packet header capable of marshaling itself into a byte buffer.
+// Header is a packet header capable of marshaling itself into a byte
+// buffer.
 type Header interface {
-	// Len returns the length of the header after marshaling.
+	// Len returns the length of the marshaled packet.
 	Len() int
-	// Marshal serializes the header into buf in wire format.
-	// It clobbers the header region, which is the first h.Length() bytes of buf.
-	// It explicitly initializes every byte of the header region,
-	// so pre-zeroing it on reuse is not required. It does not allocate memory.
-	// It fails if and only if len(buf) < Length().
+	// Marshal serializes the header into buf, which must be at
+	// least Len() bytes long. Implementations of Marshal assume
+	// that bytes after the first Len() are payload bytes for the
+	// purpose of computing length and checksum fields. Marshal
+	// implementations must not allocate memory.
 	Marshal(buf []byte) error
 	// ToResponse transforms the header into one for a response packet.
 	// For instance, this swaps the source and destination IPs.
 	ToResponse()
 }
 
-// Generate generates a new packet with the given header and payload.
-// Unlike Header.Marshal, this does allocate memory.
+// Generate generates a new packet with the given Header and
+// payload. This function allocates memory, see Header.Marshal for an
+// allocation-free option.
 func Generate(h Header, payload []byte) []byte {
 	hlen := h.Len()
 	buf := make([]byte, hlen+len(payload))

+ 22 - 12
net/packet/icmp4.go

@@ -4,6 +4,15 @@
 
 package packet
 
+import "encoding/binary"
+
+// icmp4HeaderLength is the size of the ICMPv4 packet header, not
+// including the outer IP layer or the variable "response data"
+// trailer.
+const icmp4HeaderLength = 4
+
+// ICMP4Type is an ICMPv4 type, as specified in
+// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml
 type ICMP4Type uint8
 
 const (
@@ -28,49 +37,50 @@ func (t ICMP4Type) String() string {
 	}
 }
 
+// ICMP4Code is an ICMPv4 code, as specified in
+// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml
 type ICMP4Code uint8
 
 const (
 	ICMP4NoCode ICMP4Code = 0
 )
 
-// ICMPHeader represents an ICMP packet header.
+// ICMP4Header is an IPv4+ICMPv4 header.
 type ICMP4Header struct {
 	IP4Header
 	Type ICMP4Type
 	Code ICMP4Code
 }
 
-const (
-	icmpHeaderLength = 4
-	// icmpTotalHeaderLength is the length of all headers in a ICMP packet.
-	icmpAllHeadersLength = ipHeaderLength + icmpHeaderLength
-)
-
-func (ICMP4Header) Len() int {
-	return icmpAllHeadersLength
+// Len implements Header.
+func (h ICMP4Header) Len() int {
+	return h.IP4Header.Len() + icmp4HeaderLength
 }
 
+// Marshal implements Header.
 func (h ICMP4Header) Marshal(buf []byte) error {
-	if len(buf) < icmpAllHeadersLength {
+	if len(buf) < h.Len() {
 		return errSmallBuffer
 	}
 	if len(buf) > maxPacketLength {
 		return errLargePacket
 	}
 	// The caller does not need to set this.
-	h.IPProto = ICMP
+	h.IPProto = ICMPv4
 
 	buf[20] = uint8(h.Type)
 	buf[21] = uint8(h.Code)
 
 	h.IP4Header.Marshal(buf)
 
-	put16(buf[22:24], ipChecksum(buf))
+	binary.BigEndian.PutUint16(buf[22:24], ip4Checksum(buf))
 
 	return nil
 }
 
+// ToResponse implements Header. TODO: it doesn't implement it
+// correctly, instead it statically generates an ICMP Echo Reply
+// packet.
 func (h *ICMP4Header) ToResponse() {
 	// TODO: this doesn't implement ToResponse correctly, as it
 	// assumes the ICMP request type.

+ 44 - 0
net/packet/icmp6.go

@@ -0,0 +1,44 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+// icmp6HeaderLength is the size of the ICMPv6 packet header, not
+// including the outer IP layer or the variable "response data"
+// trailer.
+const icmp6HeaderLength = 4
+
+// ICMP6Type is an ICMPv6 type, as specified in
+// https://www.iana.org/assignments/icmpv6-parameters/icmpv6-parameters.xhtml
+type ICMP6Type uint8
+
+const (
+	ICMP6Unreachable  ICMP6Type = 1
+	ICMP6TimeExceeded ICMP6Type = 3
+	ICMP6EchoRequest  ICMP6Type = 128
+	ICMP6EchoReply    ICMP6Type = 129
+)
+
+func (t ICMP6Type) String() string {
+	switch t {
+	case ICMP6Unreachable:
+		return "Unreachable"
+	case ICMP6TimeExceeded:
+		return "TimeExceeded"
+	case ICMP6EchoRequest:
+		return "EchoRequest"
+	case ICMP6EchoReply:
+		return "EchoReply"
+	default:
+		return "Unknown"
+	}
+}
+
+// ICMP6Code is an ICMPv6 code, as specified in
+// https://www.iana.org/assignments/icmpv6-parameters/icmpv6-parameters.xhtml
+type ICMP6Code uint8
+
+const (
+	ICMP6NoCode ICMP6Code = 0
+)

+ 53 - 0
net/packet/ip.go

@@ -0,0 +1,53 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+// IPProto is an IP subprotocol as defined by the IANA protocol
+// numbers list
+// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml),
+// or the special values Unknown or Fragment.
+type IPProto uint8
+
+const (
+	// Unknown represents an unknown or unsupported protocol; it's
+	// deliberately the zero value. Strictly speaking the zero
+	// value is IPv6 hop-by-hop extensions, but we don't support
+	// those, so this is still technically correct.
+	Unknown IPProto = 0x00
+
+	// Values from the IANA registry.
+	ICMPv4 IPProto = 0x01
+	IGMP   IPProto = 0x02
+	ICMPv6 IPProto = 0x3a
+	TCP    IPProto = 0x06
+	UDP    IPProto = 0x11
+
+	// Fragment represents any non-first IP fragment, for which we
+	// don't have the sub-protocol header (and therefore can't
+	// figure out what the sub-protocol is).
+	//
+	// 0xFF is reserved in the IANA registry, so we steal it for
+	// internal use.
+	Fragment IPProto = 0xFF
+)
+
+func (p IPProto) String() string {
+	switch p {
+	case Fragment:
+		return "Frag"
+	case ICMPv4:
+		return "ICMPv4"
+	case IGMP:
+		return "IGMP"
+	case ICMPv6:
+		return "ICMPv6"
+	case UDP:
+		return "UDP"
+	case TCP:
+		return "TCP"
+	default:
+		return "Unknown"
+	}
+}

+ 72 - 83
net/packet/ip4.go

@@ -5,8 +5,8 @@
 package packet
 
 import (
+	"encoding/binary"
 	"fmt"
-	"net"
 
 	"inet.af/netaddr"
 )
@@ -14,23 +14,13 @@ import (
 // IP4 is an IPv4 address.
 type IP4 uint32
 
-// NewIP converts a standard library IP address into an IP.
-// It panics if b is not an IPv4 address.
-func NewIP4(b net.IP) IP4 {
-	b4 := b.To4()
-	if b4 == nil {
-		panic(fmt.Sprintf("To4(%v) failed", b))
-	}
-	return IP4(get32(b4))
-}
-
-// IPFromNetaddr converts a netaddr.IP to an IP.
+// IPFromNetaddr converts a netaddr.IP to an IP4. Panics if !ip.Is4.
 func IP4FromNetaddr(ip netaddr.IP) IP4 {
 	ipbytes := ip.As4()
-	return IP4(get32(ipbytes[:]))
+	return IP4(binary.BigEndian.Uint32(ipbytes[:]))
 }
 
-// Netaddr converts an IP to a netaddr.IP.
+// Netaddr converts ip to a netaddr.IP.
 func (ip IP4) Netaddr() netaddr.IP {
 	return netaddr.IPv4(byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
 }
@@ -39,110 +29,109 @@ func (ip IP4) String() string {
 	return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
 }
 
+// IsMulticast returns whether ip is a multicast address.
 func (ip IP4) IsMulticast() bool {
 	return byte(ip>>24)&0xf0 == 0xe0
 }
 
+// IsLinkLocalUnicast returns whether ip is a link-local unicast
+// address.
 func (ip IP4) IsLinkLocalUnicast() bool {
 	return byte(ip>>24) == 169 && byte(ip>>16) == 254
 }
 
-// IP4Proto is either a real IP protocol (TCP, UDP, ...) or an special
-// value like Unknown.  If it is a real IP protocol, its value
-// corresponds to its IP protocol number.
-type IP4Proto uint8
-
-const (
-	// Unknown represents an unknown or unsupported protocol; it's deliberately the zero value.
-	Unknown IP4Proto = 0x00
-	ICMP    IP4Proto = 0x01
-	IGMP    IP4Proto = 0x02
-	ICMPv6  IP4Proto = 0x3a
-	TCP     IP4Proto = 0x06
-	UDP     IP4Proto = 0x11
-	// Fragment is a special value. It's not really an IPProto value
-	// so we're using the unassigned 0xFF value.
-	// TODO(dmytro): special values should be taken out of here.
-	Fragment IP4Proto = 0xFF
-)
-
-func (p IP4Proto) String() string {
-	switch p {
-	case Fragment:
-		return "Frag"
-	case ICMP:
-		return "ICMP"
-	case UDP:
-		return "UDP"
-	case TCP:
-		return "TCP"
-	default:
-		return "Unknown"
-	}
-}
+// ip4HeaderLength is the length of an IPv4 header with no IP options.
+const ip4HeaderLength = 20
 
-// IPHeader represents an IP packet header.
+// IP4Header represents an IPv4 packet header.
 type IP4Header struct {
-	IPProto IP4Proto
+	IPProto IPProto
 	IPID    uint16
 	SrcIP   IP4
 	DstIP   IP4
 }
 
-const ipHeaderLength = 20
-
-func (IP4Header) Len() int {
-	return ipHeaderLength
+// Len implements Header.
+func (h IP4Header) Len() int {
+	return ip4HeaderLength
 }
 
+// Marshal implements Header.
 func (h IP4Header) Marshal(buf []byte) error {
-	if len(buf) < ipHeaderLength {
+	if len(buf) < h.Len() {
 		return errSmallBuffer
 	}
 	if len(buf) > maxPacketLength {
 		return errLargePacket
 	}
 
-	buf[0] = 0x40 | (ipHeaderLength >> 2) // IPv4
-	buf[1] = 0x00                         // DHCP, ECN
-	put16(buf[2:4], uint16(len(buf)))
-	put16(buf[4:6], h.IPID)
-	put16(buf[6:8], 0) // flags, offset
-	buf[8] = 64        // TTL
-	buf[9] = uint8(h.IPProto)
-	put16(buf[10:12], 0) // blank IP header checksum
-	put32(buf[12:16], uint32(h.SrcIP))
-	put32(buf[16:20], uint32(h.DstIP))
-
-	put16(buf[10:12], ipChecksum(buf[0:20]))
+	buf[0] = 0x40 | (byte(h.Len() >> 2))                   // IPv4 + IHL
+	buf[1] = 0x00                                          // DSCP + ECN
+	binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length
+	binary.BigEndian.PutUint16(buf[4:6], h.IPID)           // ID
+	binary.BigEndian.PutUint16(buf[6:8], 0)                // Flags + fragment offset
+	buf[8] = 64                                            // TTL
+	buf[9] = uint8(h.IPProto)                              // Inner protocol
+	// Blank checksum. This is necessary even though we overwrite
+	// it later, because the checksum computation runs over these
+	// bytes and expects them to be zero.
+	binary.BigEndian.PutUint16(buf[10:12], 0)
+	binary.BigEndian.PutUint32(buf[12:16], uint32(h.SrcIP)) // Src
+	binary.BigEndian.PutUint32(buf[16:20], uint32(h.DstIP)) // Dst
+
+	binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum
 
 	return nil
 }
 
-// MarshalPseudo serializes the header into buf in the "pseudo-header"
-// form required when calculating UDP checksums. Overwrites the first
-// h.Length() bytes of buf.
-func (h IP4Header) MarshalPseudo(buf []byte) error {
-	if len(buf) < ipHeaderLength {
+// ToResponse implements Header.
+func (h *IP4Header) ToResponse() {
+	h.SrcIP, h.DstIP = h.DstIP, h.SrcIP
+	// Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
+	h.IPID = ^h.IPID
+}
+
+// ip4Checksum computes an IPv4 checksum, as specified in
+// https://tools.ietf.org/html/rfc1071
+func ip4Checksum(b []byte) uint16 {
+	var ac uint32
+	i := 0
+	n := len(b)
+	for n >= 2 {
+		ac += uint32(binary.BigEndian.Uint16(b[i : i+2]))
+		n -= 2
+		i += 2
+	}
+	if n == 1 {
+		ac += uint32(b[i]) << 8
+	}
+	for (ac >> 16) > 0 {
+		ac = (ac >> 16) + (ac & 0xffff)
+	}
+	return uint16(^ac)
+}
+
+// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP
+// pseudo-header is smaller than the real IPv4 header.
+const ip4PseudoHeaderOffset = 8
+
+// marshalPseudo serializes h into buf in the "pseudo-header" form
+// required when calculating UDP checksums. The pseudo-header starts
+// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP
+// header, while leaving enough space in buf for a full IPv4 header.
+func (h IP4Header) marshalPseudo(buf []byte) error {
+	if len(buf) < h.Len() {
 		return errSmallBuffer
 	}
 	if len(buf) > maxPacketLength {
 		return errLargePacket
 	}
 
-	length := len(buf) - ipHeaderLength
-	put32(buf[8:12], uint32(h.SrcIP))
-	put32(buf[12:16], uint32(h.DstIP))
+	length := len(buf) - h.Len()
+	binary.BigEndian.PutUint32(buf[8:12], uint32(h.SrcIP))
+	binary.BigEndian.PutUint32(buf[12:16], uint32(h.DstIP))
 	buf[16] = 0x0
 	buf[17] = uint8(h.IPProto)
-	put16(buf[18:20], uint16(length))
-
+	binary.BigEndian.PutUint16(buf[18:20], uint16(length))
 	return nil
 }
-
-// ToResponse implements Header.
-func (h *IP4Header) ToResponse() {
-	h.SrcIP, h.DstIP = h.DstIP, h.SrcIP
-	// Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
-	h.IPID = ^h.IPID
-}

+ 34 - 0
net/packet/ip6.go

@@ -0,0 +1,34 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+	"fmt"
+
+	"inet.af/netaddr"
+)
+
+// IP6 is an IPv6 address.
+type IP6 [16]byte
+
+// IP6FromNetaddr converts a netaddr.IP to an IP6. Panics if !ip.Is6.
+func IP6FromNetaddr(ip netaddr.IP) IP6 {
+	if !ip.Is6() {
+		panic(fmt.Sprintf("IP6FromNetaddr called with non-v6 addr %q", ip))
+	}
+	return IP6(ip.As16())
+}
+
+// Netaddr converts ip to a netaddr.IP.
+func (ip IP6) Netaddr() netaddr.IP {
+	return netaddr.IPFrom16(ip)
+}
+
+func (ip IP6) String() string {
+	return ip.Netaddr().String()
+}
+
+// ip6HeaderLength is the length of an IPv6 header with no IP options.
+const ip6HeaderLength = 40

+ 187 - 104
net/packet/packet.go

@@ -21,17 +21,7 @@ const (
 	TCPSynAck = TCPSyn | TCPAck
 )
 
-var (
-	get16 = binary.BigEndian.Uint16
-	get32 = binary.BigEndian.Uint32
-
-	put16 = binary.BigEndian.PutUint16
-	put32 = binary.BigEndian.PutUint32
-)
-
 // Parsed is a minimal decoding of a packet suitable for use in filters.
-//
-// In general, it only supports IPv4. The IPv6 parsing is very minimal.
 type Parsed struct {
 	// b is the byte buffer that this decodes.
 	b []byte
@@ -43,37 +33,53 @@ type Parsed struct {
 	// This is not the same as len(b) because b can have trailing zeros.
 	length int
 
-	IPVersion uint8    // 4, 6, or 0
-	IPProto   IP4Proto // IP subprotocol (UDP, TCP, etc); the NextHeader field for IPv6
-	SrcIP     IP4      // IP source address (not used for IPv6)
-	DstIP     IP4      // IP destination address (not used for IPv6)
-	SrcPort   uint16   // TCP/UDP source port
-	DstPort   uint16   // TCP/UDP destination port
-	TCPFlags  uint8    // TCP flags (SYN, ACK, etc)
+	// IPVersion is the IP protocol version of the packet (4 or
+	// 6), or 0 if the packet doesn't look like IPv4 or IPv6.
+	IPVersion uint8
+	// IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0.
+	IPProto IPProto
+	// SrcIP4 is the IPv4 source address. Valid iff IPVersion == 4.
+	SrcIP4 IP4
+	// DstIP4 is the IPv4 destination address. Valid iff IPVersion == 4.
+	DstIP4 IP4
+	// SrcIP6 is the IPv6 source address. Valid iff IPVersion == 6.
+	SrcIP6 IP6
+	// DstIP6 is the IPv6 destination address. Valid iff IPVersion == 6.
+	DstIP6 IP6
+	// SrcPort is the TCP/UDP source port. Valid iff IPProto == TCP || IPProto == UDP.
+	SrcPort uint16
+	// DstPort is the TCP/UDP source port. Valid iff IPProto == TCP || IPProto == UDP.
+	DstPort uint16
+	// TCPFlags is the packet's TCP flag bigs. Valid iff IPProto == TCP.
+	TCPFlags uint8
 }
 
-// NextHeader
-type NextHeader uint8
-
 func (p *Parsed) String() string {
-	if p.IPVersion == 6 {
-		return fmt.Sprintf("IPv6{Proto=%d}", p.IPProto)
-	}
-	switch p.IPProto {
-	case Unknown:
+	switch p.IPVersion {
+	case 4:
+		sb := strbuilder.Get()
+		sb.WriteString(p.IPProto.String())
+		sb.WriteByte('{')
+		writeIP4Port(sb, p.SrcIP4, p.SrcPort)
+		sb.WriteString(" > ")
+		writeIP4Port(sb, p.DstIP4, p.DstPort)
+		sb.WriteByte('}')
+		return sb.String()
+	case 6:
+		sb := strbuilder.Get()
+		sb.WriteString(p.IPProto.String())
+		sb.WriteByte('{')
+		writeIP6Port(sb, p.SrcIP6, p.SrcPort)
+		sb.WriteString(" > ")
+		writeIP6Port(sb, p.DstIP6, p.DstPort)
+		sb.WriteByte('}')
+		return sb.String()
+	default:
 		return "Unknown{???}"
 	}
-	sb := strbuilder.Get()
-	sb.WriteString(p.IPProto.String())
-	sb.WriteByte('{')
-	writeIPPort(sb, p.SrcIP, p.SrcPort)
-	sb.WriteString(" > ")
-	writeIPPort(sb, p.DstIP, p.DstPort)
-	sb.WriteByte('}')
-	return sb.String()
 }
 
-func writeIPPort(sb *strbuilder.Builder, ip IP4, port uint16) {
+func writeIP4Port(sb *strbuilder.Builder, ip IP4, port uint16) {
 	sb.WriteUint(uint64(byte(ip >> 24)))
 	sb.WriteByte('.')
 	sb.WriteUint(uint64(byte(ip >> 16)))
@@ -85,23 +91,11 @@ func writeIPPort(sb *strbuilder.Builder, ip IP4, port uint16) {
 	sb.WriteUint(uint64(port))
 }
 
-// based on https://tools.ietf.org/html/rfc1071
-func ipChecksum(b []byte) uint16 {
-	var ac uint32
-	i := 0
-	n := len(b)
-	for n >= 2 {
-		ac += uint32(get16(b[i : i+2]))
-		n -= 2
-		i += 2
-	}
-	if n == 1 {
-		ac += uint32(b[i]) << 8
-	}
-	for (ac >> 16) > 0 {
-		ac = (ac >> 16) + (ac & 0xffff)
-	}
-	return uint16(^ac)
+func writeIP6Port(sb *strbuilder.Builder, ip IP6, port uint16) {
+	sb.WriteByte('[')
+	sb.WriteString(ip.Netaddr().String()) // TODO: faster?
+	sb.WriteString("]:")
+	sb.WriteUint(uint64(port))
 }
 
 // Decode extracts data from the packet in b into q.
@@ -111,28 +105,34 @@ func ipChecksum(b []byte) uint16 {
 func (q *Parsed) Decode(b []byte) {
 	q.b = b
 
-	if len(b) < ipHeaderLength {
+	if len(b) < 1 {
 		q.IPVersion = 0
 		q.IPProto = Unknown
 		return
 	}
 
-	// Check that it's IPv4.
-	// TODO(apenwarr): consider IPv6 support
 	q.IPVersion = (b[0] & 0xF0) >> 4
 	switch q.IPVersion {
 	case 4:
-		q.IPProto = IP4Proto(b[9])
+		q.decode4(b)
 	case 6:
-		q.IPProto = IP4Proto(b[6]) // "Next Header" field
-		return
+		q.decode6(b)
 	default:
 		q.IPVersion = 0
 		q.IPProto = Unknown
+	}
+}
+
+func (q *Parsed) decode4(b []byte) {
+	if len(b) < ip4HeaderLength {
+		q.IPVersion = 0
+		q.IPProto = Unknown
 		return
 	}
 
-	q.length = int(get16(b[2:4]))
+	// Check that it's IPv4.
+	q.IPProto = IPProto(b[9])
+	q.length = int(binary.BigEndian.Uint16(b[2:4]))
 	if len(b) < q.length {
 		// Packet was cut off before full IPv4 length.
 		q.IPProto = Unknown
@@ -140,10 +140,15 @@ func (q *Parsed) Decode(b []byte) {
 	}
 
 	// If it's valid IPv4, then the IP addresses are valid
-	q.SrcIP = IP4(get32(b[12:16]))
-	q.DstIP = IP4(get32(b[16:20]))
+	q.SrcIP4 = IP4(binary.BigEndian.Uint32(b[12:16]))
+	q.DstIP4 = IP4(binary.BigEndian.Uint32(b[16:20]))
 
 	q.subofs = int((b[0] & 0x0F) << 2)
+	if q.subofs > q.length {
+		// next-proto starts beyond end of packet.
+		q.IPProto = Unknown
+		return
+	}
 	sub := b[q.subofs:]
 
 	// We don't care much about IP fragmentation, except insofar as it's
@@ -158,7 +163,7 @@ func (q *Parsed) Decode(b []byte) {
 	// zero reason to send such a short first fragment, so we can treat
 	// it as Unknown. We can also treat any subsequent fragment that starts
 	// at such a low offset as Unknown.
-	fragFlags := get16(b[6:8])
+	fragFlags := binary.BigEndian.Uint16(b[6:8])
 	moreFrags := (fragFlags & 0x20) != 0
 	fragOfs := fragFlags & 0x1FFF
 	if fragOfs == 0 {
@@ -172,22 +177,22 @@ func (q *Parsed) Decode(b []byte) {
 		// or a big enough initial fragment that we can read the
 		// whole subprotocol header.
 		switch q.IPProto {
-		case ICMP:
-			if len(sub) < icmpHeaderLength {
+		case ICMPv4:
+			if len(sub) < icmp4HeaderLength {
 				q.IPProto = Unknown
 				return
 			}
 			q.SrcPort = 0
 			q.DstPort = 0
-			q.dataofs = q.subofs + icmpHeaderLength
+			q.dataofs = q.subofs + icmp4HeaderLength
 			return
 		case TCP:
 			if len(sub) < tcpHeaderLength {
 				q.IPProto = Unknown
 				return
 			}
-			q.SrcPort = get16(sub[0:2])
-			q.DstPort = get16(sub[2:4])
+			q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
+			q.DstPort = binary.BigEndian.Uint16(sub[2:4])
 			q.TCPFlags = sub[13] & 0x3F
 			headerLength := (sub[12] & 0xF0) >> 2
 			q.dataofs = q.subofs + int(headerLength)
@@ -197,8 +202,8 @@ func (q *Parsed) Decode(b []byte) {
 				q.IPProto = Unknown
 				return
 			}
-			q.SrcPort = get16(sub[0:2])
-			q.DstPort = get16(sub[2:4])
+			q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
+			q.DstPort = binary.BigEndian.Uint16(sub[2:4])
 			q.dataofs = q.subofs + udpHeaderLength
 			return
 		default:
@@ -224,27 +229,103 @@ func (q *Parsed) Decode(b []byte) {
 	}
 }
 
-func (q *Parsed) IPHeader() IP4Header {
-	ipid := get16(q.b[4:6])
+func (q *Parsed) decode6(b []byte) {
+	if len(b) < ip6HeaderLength {
+		q.IPVersion = 0
+		q.IPProto = Unknown
+		return
+	}
+
+	q.IPProto = IPProto(b[6])
+	q.length = int(binary.BigEndian.Uint16(b[4:6])) + ip6HeaderLength
+	if len(b) < q.length {
+		// Packet was cut off before the full IPv6 length.
+		q.IPProto = Unknown
+		return
+	}
+
+	copy(q.SrcIP6[:], b[8:24])
+	copy(q.DstIP6[:], b[24:40])
+
+	// We don't support any IPv6 extension headers. Don't try to
+	// be clever. Therefore, the IP subprotocol always starts at
+	// byte 40.
+	//
+	// Note that this means we don't support fragmentation in
+	// IPv6. This is fine, because IPv6 strongly mandates that you
+	// should not fragment, which makes fragmentation on the open
+	// internet extremely uncommon.
+	//
+	// This also means we don't support IPSec headers (AH/ESP), or
+	// IPv6 jumbo frames. Those will get marked Unknown and
+	// dropped.
+	q.subofs = 40
+	sub := b[q.subofs:]
+
+	switch q.IPProto {
+	case ICMPv6:
+		if len(sub) < icmp6HeaderLength {
+			q.IPProto = Unknown
+			return
+		}
+		q.SrcPort = 0
+		q.DstPort = 0
+		q.dataofs = q.subofs + icmp6HeaderLength
+	case TCP:
+		if len(sub) < tcpHeaderLength {
+			q.IPProto = Unknown
+			return
+		}
+		q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
+		q.DstPort = binary.BigEndian.Uint16(sub[2:4])
+		q.TCPFlags = sub[13] & 0x3F
+		headerLength := (sub[12] & 0xF0) >> 2
+		q.dataofs = q.subofs + int(headerLength)
+		return
+	case UDP:
+		if len(sub) < udpHeaderLength {
+			q.IPProto = Unknown
+			return
+		}
+		q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
+		q.DstPort = binary.BigEndian.Uint16(sub[2:4])
+		q.dataofs = q.subofs + udpHeaderLength
+	default:
+		q.IPProto = Unknown
+		return
+	}
+}
+
+func (q *Parsed) IP4Header() IP4Header {
+	if q.IPVersion != 4 {
+		panic("IP4Header called on non-IPv4 Parsed")
+	}
+	ipid := binary.BigEndian.Uint16(q.b[4:6])
 	return IP4Header{
 		IPID:    ipid,
 		IPProto: q.IPProto,
-		SrcIP:   q.SrcIP,
-		DstIP:   q.DstIP,
+		SrcIP:   q.SrcIP4,
+		DstIP:   q.DstIP4,
 	}
 }
 
-func (q *Parsed) ICMPHeader() ICMP4Header {
+func (q *Parsed) ICMP4Header() ICMP4Header {
+	if q.IPVersion != 4 {
+		panic("IP4Header called on non-IPv4 Parsed")
+	}
 	return ICMP4Header{
-		IP4Header: q.IPHeader(),
+		IP4Header: q.IP4Header(),
 		Type:      ICMP4Type(q.b[q.subofs+0]),
 		Code:      ICMP4Code(q.b[q.subofs+1]),
 	}
 }
 
-func (q *Parsed) UDPHeader() UDP4Header {
+func (q *Parsed) UDP4Header() UDP4Header {
+	if q.IPVersion != 4 {
+		panic("IP4Header called on non-IPv4 Parsed")
+	}
 	return UDP4Header{
-		IP4Header: q.IPHeader(),
+		IP4Header: q.IP4Header(),
 		SrcPort:   q.SrcPort,
 		DstPort:   q.DstPort,
 	}
@@ -256,58 +337,60 @@ func (q *Parsed) Buffer() []byte {
 	return q.b
 }
 
-// Sub returns the IP subprotocol section.
-// This is a read-only view; that is, q retains the ownership of the buffer.
-func (q *Parsed) Sub(begin, n int) []byte {
-	return q.b[q.subofs+begin : q.subofs+begin+n]
-}
-
 // Payload returns the payload of the IP subprotocol section.
 // This is a read-only view; that is, q retains the ownership of the buffer.
 func (q *Parsed) Payload() []byte {
 	return q.b[q.dataofs:q.length]
 }
 
-// Trim trims the buffer to its IPv4 length.
-// Sometimes packets arrive from an interface with extra bytes on the end.
-// This removes them.
-func (q *Parsed) Trim() []byte {
-	return q.b[:q.length]
-}
-
 // IsTCPSyn reports whether q is a TCP SYN packet
 // (i.e. the first packet in a new connection).
 func (q *Parsed) IsTCPSyn() bool {
 	return (q.TCPFlags & TCPSynAck) == TCPSyn
 }
 
-// IsError reports whether q is an IPv4 ICMP "Error" packet.
+// IsError reports whether q is an ICMP "Error" packet.
 func (q *Parsed) IsError() bool {
-	if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
-		switch ICMP4Type(q.b[q.subofs]) {
-		case ICMP4Unreachable, ICMP4TimeExceeded:
-			return true
+	switch q.IPProto {
+	case ICMPv4:
+		if len(q.b) < q.subofs+8 {
+			return false
 		}
+		t := ICMP4Type(q.b[q.subofs])
+		return t == ICMP4Unreachable || t == ICMP4TimeExceeded
+	case ICMPv6:
+		if len(q.b) < q.subofs+8 {
+			return false
+		}
+		t := ICMP6Type(q.b[q.subofs])
+		return t == ICMP6Unreachable || t == ICMP6TimeExceeded
+	default:
+		return false
 	}
-	return false
 }
 
-// IsEchoRequest reports whether q is an IPv4 ICMP Echo Request.
+// IsEchoRequest reports whether q is an ICMP Echo Request.
 func (q *Parsed) IsEchoRequest() bool {
-	if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
-		return ICMP4Type(q.b[q.subofs]) == ICMP4EchoRequest &&
-			ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode
+	switch q.IPProto {
+	case ICMPv4:
+		return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoRequest && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode
+	case ICMPv6:
+		return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoRequest && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode
+	default:
+		return false
 	}
-	return false
 }
 
 // IsEchoRequest reports whether q is an IPv4 ICMP Echo Response.
 func (q *Parsed) IsEchoResponse() bool {
-	if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
-		return ICMP4Type(q.b[q.subofs]) == ICMP4EchoReply &&
-			ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode
+	switch q.IPProto {
+	case ICMPv4:
+		return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoReply && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode
+	case ICMPv6:
+		return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoReply && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode
+	default:
+		return false
 	}
-	return false
 }
 
 func Hexdump(b []byte) string {

+ 181 - 55
net/packet/packet_test.go

@@ -6,14 +6,31 @@ package packet
 
 import (
 	"bytes"
-	"net"
 	"reflect"
 	"testing"
+
+	"inet.af/netaddr"
 )
 
+func mustIP4(s string) IP4 {
+	ip, err := netaddr.ParseIP(s)
+	if err != nil {
+		panic(err)
+	}
+	return IP4FromNetaddr(ip)
+}
+
+func mustIP6(s string) IP6 {
+	ip, err := netaddr.ParseIP(s)
+	if err != nil {
+		panic(err)
+	}
+	return IP6FromNetaddr(ip)
+}
+
 func TestIP4String(t *testing.T) {
 	const str = "1.2.3.4"
-	ip := NewIP4(net.ParseIP(str))
+	ip := mustIP4(str)
 
 	var got string
 	allocs := testing.AllocsPerRun(1000, func() {
@@ -28,7 +45,24 @@ func TestIP4String(t *testing.T) {
 	}
 }
 
-var icmpRequestBuffer = []byte{
+func TestIP6String(t *testing.T) {
+	const str = "2607:f8b0:400a:809::200e"
+	ip := mustIP6(str)
+
+	var got string
+	allocs := testing.AllocsPerRun(1000, func() {
+		got = ip.String()
+	})
+
+	if got != str {
+		t.Errorf("got %q; want %q", got, str)
+	}
+	if allocs != 2 {
+		t.Errorf("allocs = %v; want 1", allocs)
+	}
+}
+
+var icmp4RequestBuffer = []byte{
 	// IP header up to checksum
 	0x45, 0x00, 0x00, 0x27, 0xde, 0xad, 0x00, 0x00, 0x40, 0x01, 0x8c, 0x15,
 	// source ip
@@ -41,21 +75,21 @@ var icmpRequestBuffer = []byte{
 	0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
 }
 
-var icmpRequestDecode = Parsed{
-	b:       icmpRequestBuffer,
+var icmp4RequestDecode = Parsed{
+	b:       icmp4RequestBuffer,
 	subofs:  20,
 	dataofs: 24,
-	length:  len(icmpRequestBuffer),
+	length:  len(icmp4RequestBuffer),
 
 	IPVersion: 4,
-	IPProto:   ICMP,
-	SrcIP:     NewIP4(net.ParseIP("1.2.3.4")),
-	DstIP:     NewIP4(net.ParseIP("5.6.7.8")),
+	IPProto:   ICMPv4,
+	SrcIP4:    mustIP4("1.2.3.4"),
+	DstIP4:    mustIP4("5.6.7.8"),
 	SrcPort:   0,
 	DstPort:   0,
 }
 
-var icmpReplyBuffer = []byte{
+var icmp4ReplyBuffer = []byte{
 	0x45, 0x00, 0x00, 0x25, 0x21, 0x52, 0x00, 0x00, 0x40, 0x01, 0x49, 0x73,
 	// source ip
 	0x05, 0x06, 0x07, 0x08,
@@ -67,22 +101,22 @@ var icmpReplyBuffer = []byte{
 	0x72, 0x65, 0x70, 0x6c, 0x79, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
 }
 
-var icmpReplyDecode = Parsed{
-	b:       icmpReplyBuffer,
+var icmp4ReplyDecode = Parsed{
+	b:       icmp4ReplyBuffer,
 	subofs:  20,
 	dataofs: 24,
-	length:  len(icmpReplyBuffer),
+	length:  len(icmp4ReplyBuffer),
 
 	IPVersion: 4,
-	IPProto:   ICMP,
-	SrcIP:     NewIP4(net.ParseIP("1.2.3.4")),
-	DstIP:     NewIP4(net.ParseIP("5.6.7.8")),
+	IPProto:   ICMPv4,
+	SrcIP4:    mustIP4("1.2.3.4"),
+	DstIP4:    mustIP4("5.6.7.8"),
 	SrcPort:   0,
 	DstPort:   0,
 }
 
-// IPv6 Router Solicitation
-var ipv6PacketBuffer = []byte{
+// ICMPv6 Router Solicitation
+var icmp6PacketBuffer = []byte{
 	0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x3a, 0xff,
 	0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
 	0xfb, 0x57, 0x1d, 0xea, 0x9c, 0x39, 0x8f, 0xb7,
@@ -91,10 +125,15 @@ var ipv6PacketBuffer = []byte{
 	0x85, 0x00, 0x38, 0x04, 0x00, 0x00, 0x00, 0x00,
 }
 
-var ipv6PacketDecode = Parsed{
-	b:         ipv6PacketBuffer,
+var icmp6PacketDecode = Parsed{
+	b:         icmp6PacketBuffer,
+	subofs:    40,
+	dataofs:   44,
+	length:    len(icmp6PacketBuffer),
 	IPVersion: 6,
 	IPProto:   ICMPv6,
+	SrcIP6:    mustIP6("fe80::fb57:1dea:9c39:8fb7"),
+	DstIP6:    mustIP6("ff02::2"),
 }
 
 // This is a malformed IPv4 packet.
@@ -109,7 +148,7 @@ var unknownPacketDecode = Parsed{
 	IPProto:   Unknown,
 }
 
-var tcpPacketBuffer = []byte{
+var tcp4PacketBuffer = []byte{
 	// IP header up to checksum
 	0x45, 0x00, 0x00, 0x37, 0xde, 0xad, 0x00, 0x00, 0x40, 0x06, 0x49, 0x5f,
 	// source ip
@@ -123,22 +162,50 @@ var tcpPacketBuffer = []byte{
 	0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
 }
 
-var tcpPacketDecode = Parsed{
-	b:       tcpPacketBuffer,
+var tcp4PacketDecode = Parsed{
+	b:       tcp4PacketBuffer,
 	subofs:  20,
 	dataofs: 40,
-	length:  len(tcpPacketBuffer),
+	length:  len(tcp4PacketBuffer),
 
 	IPVersion: 4,
 	IPProto:   TCP,
-	SrcIP:     NewIP4(net.ParseIP("1.2.3.4")),
-	DstIP:     NewIP4(net.ParseIP("5.6.7.8")),
+	SrcIP4:    mustIP4("1.2.3.4"),
+	DstIP4:    mustIP4("5.6.7.8"),
 	SrcPort:   123,
 	DstPort:   567,
 	TCPFlags:  TCPSynAck,
 }
 
-var udpRequestBuffer = []byte{
+var tcp6RequestBuffer = []byte{
+	// IPv6 header up to hop limit
+	0x60, 0x06, 0xef, 0xcc, 0x00, 0x28, 0x06, 0x40,
+	// Src addr
+	0x20, 0x01, 0x05, 0x59, 0xbc, 0x13, 0x54, 0x00, 0x17, 0x49, 0x46, 0x28, 0x39, 0x34, 0x0e, 0x1b,
+	// Dst addr
+	0x26, 0x07, 0xf8, 0xb0, 0x40, 0x0a, 0x08, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x0e,
+	// TCP SYN segment, no payload
+	0xa4, 0x60, 0x00, 0x50, 0xf3, 0x82, 0xa1, 0x25, 0x00, 0x00, 0x00, 0x00, 0xa0, 0x02, 0xfd, 0x20,
+	0xb1, 0xc6, 0x00, 0x00, 0x02, 0x04, 0x05, 0xa0, 0x04, 0x02, 0x08, 0x0a, 0xca, 0x76, 0xa6, 0x8e,
+	0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x03, 0x07,
+}
+
+var tcp6RequestDecode = Parsed{
+	b:       tcp6RequestBuffer,
+	subofs:  40,
+	dataofs: len(tcp6RequestBuffer),
+	length:  len(tcp6RequestBuffer),
+
+	IPVersion: 6,
+	IPProto:   TCP,
+	SrcIP6:    mustIP6("2001:559:bc13:5400:1749:4628:3934:e1b"),
+	DstIP6:    mustIP6("2607:f8b0:400a:809::200e"),
+	SrcPort:   42080,
+	DstPort:   80,
+	TCPFlags:  TCPSyn,
+}
+
+var udp4RequestBuffer = []byte{
 	// IP header up to checksum
 	0x45, 0x00, 0x00, 0x2b, 0xde, 0xad, 0x00, 0x00, 0x40, 0x11, 0x8c, 0x01,
 	// source ip
@@ -151,21 +218,70 @@ var udpRequestBuffer = []byte{
 	0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
 }
 
-var udpRequestDecode = Parsed{
-	b:       udpRequestBuffer,
+var udp4RequestDecode = Parsed{
+	b:       udp4RequestBuffer,
 	subofs:  20,
 	dataofs: 28,
-	length:  len(udpRequestBuffer),
+	length:  len(udp4RequestBuffer),
 
 	IPVersion: 4,
 	IPProto:   UDP,
-	SrcIP:     NewIP4(net.ParseIP("1.2.3.4")),
-	DstIP:     NewIP4(net.ParseIP("5.6.7.8")),
+	SrcIP4:    mustIP4("1.2.3.4"),
+	DstIP4:    mustIP4("5.6.7.8"),
 	SrcPort:   123,
 	DstPort:   567,
 }
 
-var udpReplyBuffer = []byte{
+var invalid4RequestBuffer = []byte{
+	// IP header up to checksum. IHL field points beyond end of packet.
+	0x4a, 0x00, 0x00, 0x14, 0xde, 0xad, 0x00, 0x00, 0x40, 0x11, 0x8c, 0x01,
+	// source ip
+	0x01, 0x02, 0x03, 0x04,
+	// destination ip
+	0x05, 0x06, 0x07, 0x08,
+}
+
+// Regression check for the IHL field pointing beyond the end of the
+// packet.
+var invalid4RequestDecode = Parsed{
+	b:      invalid4RequestBuffer,
+	subofs: 40,
+	length: len(invalid4RequestBuffer),
+
+	IPVersion: 4,
+	IPProto:   Unknown,
+	SrcIP4:    mustIP4("1.2.3.4"),
+	DstIP4:    mustIP4("5.6.7.8"),
+}
+
+var udp6RequestBuffer = []byte{
+	// IPv6 header up to hop limit
+	0x60, 0x0e, 0xc9, 0x67, 0x00, 0x29, 0x11, 0x40,
+	// Src addr
+	0x20, 0x01, 0x05, 0x59, 0xbc, 0x13, 0x54, 0x00, 0x17, 0x49, 0x46, 0x28, 0x39, 0x34, 0x0e, 0x1b,
+	// Dst addr
+	0x26, 0x07, 0xf8, 0xb0, 0x40, 0x0a, 0x08, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x0e,
+	// UDP header
+	0xd4, 0x04, 0x01, 0xbb, 0x00, 0x29, 0x96, 0x84,
+	// Payload
+	0x5c, 0x06, 0xae, 0x85, 0x02, 0xf5, 0xdb, 0x90, 0xe0, 0xe0, 0x93, 0xed, 0x9a, 0xd9, 0x92, 0x69, 0xbe, 0x36, 0x8a, 0x7d, 0xd7, 0xce, 0xd0, 0x8a, 0xf2, 0x51, 0x95, 0xff, 0xb6, 0x92, 0x70, 0x10, 0xd7,
+}
+
+var udp6RequestDecode = Parsed{
+	b:       udp6RequestBuffer,
+	subofs:  40,
+	dataofs: 48,
+	length:  len(udp6RequestBuffer),
+
+	IPVersion: 6,
+	IPProto:   UDP,
+	SrcIP6:    mustIP6("2001:559:bc13:5400:1749:4628:3934:e1b"),
+	DstIP6:    mustIP6("2607:f8b0:400a:809::200e"),
+	SrcPort:   54276,
+	DstPort:   443,
+}
+
+var udp4ReplyBuffer = []byte{
 	// IP header up to checksum
 	0x45, 0x00, 0x00, 0x29, 0x21, 0x52, 0x00, 0x00, 0x40, 0x11, 0x49, 0x5f,
 	// source ip
@@ -178,15 +294,15 @@ var udpReplyBuffer = []byte{
 	0x72, 0x65, 0x70, 0x6c, 0x79, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
 }
 
-var udpReplyDecode = Parsed{
-	b:       udpReplyBuffer,
+var udp4ReplyDecode = Parsed{
+	b:       udp4ReplyBuffer,
 	subofs:  20,
 	dataofs: 28,
-	length:  len(udpReplyBuffer),
+	length:  len(udp4ReplyBuffer),
 
 	IPProto: UDP,
-	SrcIP:   NewIP4(net.ParseIP("1.2.3.4")),
-	DstIP:   NewIP4(net.ParseIP("5.6.7.8")),
+	SrcIP4:  mustIP4("1.2.3.4"),
+	DstIP4:  mustIP4("5.6.7.8"),
 	SrcPort: 567,
 	DstPort: 123,
 }
@@ -197,10 +313,13 @@ func TestParsed(t *testing.T) {
 		qdecode Parsed
 		want    string
 	}{
-		{"tcp", tcpPacketDecode, "TCP{1.2.3.4:123 > 5.6.7.8:567}"},
-		{"icmp", icmpRequestDecode, "ICMP{1.2.3.4:0 > 5.6.7.8:0}"},
+		{"tcp4", tcp4PacketDecode, "TCP{1.2.3.4:123 > 5.6.7.8:567}"},
+		{"tcp6", tcp6RequestDecode, "TCP{[2001:559:bc13:5400:1749:4628:3934:e1b]:42080 > [2607:f8b0:400a:809::200e]:80}"},
+		{"udp4", udp4RequestDecode, "UDP{1.2.3.4:123 > 5.6.7.8:567}"},
+		{"udp6", udp6RequestDecode, "UDP{[2001:559:bc13:5400:1749:4628:3934:e1b]:54276 > [2607:f8b0:400a:809::200e]:443}"},
+		{"icmp4", icmp4RequestDecode, "ICMPv4{1.2.3.4:0 > 5.6.7.8:0}"},
+		{"icmp6", icmp6PacketDecode, "ICMPv6{[fe80::fb57:1dea:9c39:8fb7]:0 > [ff02::2]:0}"},
 		{"unknown", unknownPacketDecode, "Unknown{???}"},
-		{"ipv6", ipv6PacketDecode, "IPv6{Proto=58}"},
 	}
 
 	for _, tt := range tests {
@@ -228,11 +347,14 @@ func TestDecode(t *testing.T) {
 		buf  []byte
 		want Parsed
 	}{
-		{"icmp", icmpRequestBuffer, icmpRequestDecode},
-		{"ipv6", ipv6PacketBuffer, ipv6PacketDecode},
+		{"icmp4", icmp4RequestBuffer, icmp4RequestDecode},
+		{"icmp6", icmp6PacketBuffer, icmp6PacketDecode},
+		{"tcp4", tcp4PacketBuffer, tcp4PacketDecode},
+		{"tcp6", tcp6RequestBuffer, tcp6RequestDecode},
+		{"udp4", udp4RequestBuffer, udp4RequestDecode},
+		{"udp6", udp6RequestBuffer, udp6RequestDecode},
 		{"unknown", unknownPacketBuffer, unknownPacketDecode},
-		{"tcp", tcpPacketBuffer, tcpPacketDecode},
-		{"udp", udpRequestBuffer, udpRequestDecode},
+		{"invalid4", invalid4RequestBuffer, invalid4RequestDecode},
 	}
 
 	for _, tt := range tests {
@@ -259,9 +381,13 @@ func BenchmarkDecode(b *testing.B) {
 		name string
 		buf  []byte
 	}{
-		{"icmp", icmpRequestBuffer},
+		{"tcp4", tcp4PacketBuffer},
+		{"tcp6", tcp6RequestBuffer},
+		{"udp4", udp4RequestBuffer},
+		{"udp6", udp6RequestBuffer},
+		{"icmp4", icmp4RequestBuffer},
+		{"icmp6", icmp6PacketBuffer},
 		{"unknown", unknownPacketBuffer},
-		{"tcp", tcpPacketBuffer},
 	}
 
 	for _, bench := range benches {
@@ -280,15 +406,15 @@ func TestMarshalRequest(t *testing.T) {
 	var small [20]byte
 	var large [64]byte
 
-	icmpHeader := icmpRequestDecode.ICMPHeader()
-	udpHeader := udpRequestDecode.UDPHeader()
+	icmpHeader := icmp4RequestDecode.ICMP4Header()
+	udpHeader := udp4RequestDecode.UDP4Header()
 	tests := []struct {
 		name   string
 		header Header
 		want   []byte
 	}{
-		{"icmp", &icmpHeader, icmpRequestBuffer},
-		{"udp", &udpHeader, udpRequestBuffer},
+		{"icmp", &icmpHeader, icmp4RequestBuffer},
+		{"udp", &udpHeader, udp4RequestBuffer},
 	}
 
 	for _, tt := range tests {
@@ -317,16 +443,16 @@ func TestMarshalRequest(t *testing.T) {
 func TestMarshalResponse(t *testing.T) {
 	var buf [64]byte
 
-	icmpHeader := icmpRequestDecode.ICMPHeader()
-	udpHeader := udpRequestDecode.UDPHeader()
+	icmpHeader := icmp4RequestDecode.ICMP4Header()
+	udpHeader := udp4RequestDecode.UDP4Header()
 
 	tests := []struct {
 		name   string
 		header Header
 		want   []byte
 	}{
-		{"icmp", &icmpHeader, icmpReplyBuffer},
-		{"udp", &udpHeader, udpReplyBuffer},
+		{"icmp", &icmpHeader, icmp4ReplyBuffer},
+		{"udp", &udpHeader, udp4ReplyBuffer},
 	}
 
 	for _, tt := range tests {

+ 19 - 17
net/packet/udp4.go

@@ -4,25 +4,27 @@
 
 package packet
 
-// UDPHeader represents an UDP packet header.
+import "encoding/binary"
+
+// udpHeaderLength is the size of the UDP packet header, not including
+// the outer IP header.
+const udpHeaderLength = 8
+
+// UDP4Header is an IPv4+UDP header.
 type UDP4Header struct {
 	IP4Header
 	SrcPort uint16
 	DstPort uint16
 }
 
-const (
-	udpHeaderLength = 8
-	// udpTotalHeaderLength is the length of all headers in a UDP packet.
-	udpTotalHeaderLength = ipHeaderLength + udpHeaderLength
-)
-
-func (UDP4Header) Len() int {
-	return udpTotalHeaderLength
+// Len implements Header.
+func (h UDP4Header) Len() int {
+	return h.IP4Header.Len() + udpHeaderLength
 }
 
+// Marshal implements Header.
 func (h UDP4Header) Marshal(buf []byte) error {
-	if len(buf) < udpTotalHeaderLength {
+	if len(buf) < h.Len() {
 		return errSmallBuffer
 	}
 	if len(buf) > maxPacketLength {
@@ -32,21 +34,21 @@ func (h UDP4Header) Marshal(buf []byte) error {
 	h.IPProto = UDP
 
 	length := len(buf) - h.IP4Header.Len()
-	put16(buf[20:22], h.SrcPort)
-	put16(buf[22:24], h.DstPort)
-	put16(buf[24:26], uint16(length))
-	put16(buf[26:28], 0) // blank checksum
-
-	h.IP4Header.MarshalPseudo(buf)
+	binary.BigEndian.PutUint16(buf[20:22], h.SrcPort)
+	binary.BigEndian.PutUint16(buf[22:24], h.DstPort)
+	binary.BigEndian.PutUint16(buf[24:26], uint16(length))
+	binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum
 
 	// UDP checksum with IP pseudo header.
-	put16(buf[26:28], ipChecksum(buf[8:]))
+	h.IP4Header.marshalPseudo(buf)
+	binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:]))
 
 	h.IP4Header.Marshal(buf)
 
 	return nil
 }
 
+// ToResponse implements Header.
 func (h *UDP4Header) ToResponse() {
 	h.SrcPort, h.DstPort = h.DstPort, h.SrcPort
 	h.IP4Header.ToResponse()

+ 10 - 10
wgengine/filter/filter.go

@@ -191,8 +191,8 @@ func (f *Filter) CheckTCP(srcIP, dstIP netaddr.IP, dstPort uint16) Response {
 	pkt.IPVersion = 4
 	pkt.IPProto = packet.TCP
 	pkt.TCPFlags = packet.TCPSyn
-	pkt.SrcIP = packet.IP4FromNetaddr(srcIP) // TODO: IPv6
-	pkt.DstIP = packet.IP4FromNetaddr(dstIP)
+	pkt.SrcIP4 = packet.IP4FromNetaddr(srcIP) // TODO: IPv6
+	pkt.DstIP4 = packet.IP4FromNetaddr(dstIP)
 	pkt.SrcPort = 0
 	pkt.DstPort = dstPort
 
@@ -233,7 +233,7 @@ func (f *Filter) runIn(q *packet.Parsed) (r Response, why string) {
 	// A compromised peer could try to send us packets for
 	// destinations we didn't explicitly advertise. This check is to
 	// prevent that.
-	if !ip4InList(q.DstIP, f.local4) {
+	if !ip4InList(q.DstIP4, f.local4) {
 		return Drop, "destination not allowed"
 	}
 
@@ -243,7 +243,7 @@ func (f *Filter) runIn(q *packet.Parsed) (r Response, why string) {
 	}
 
 	switch q.IPProto {
-	case packet.ICMP:
+	case packet.ICMPv4:
 		if q.IsEchoResponse() || q.IsError() {
 			// ICMP responses are allowed.
 			// TODO(apenwarr): consider using conntrack state.
@@ -271,7 +271,7 @@ func (f *Filter) runIn(q *packet.Parsed) (r Response, why string) {
 			return Accept, "tcp ok"
 		}
 	case packet.UDP:
-		t := tuple{q.SrcIP, q.DstIP, q.SrcPort, q.DstPort}
+		t := tuple{q.SrcIP4, q.DstIP4, q.SrcPort, q.DstPort}
 
 		f.state.mu.Lock()
 		_, ok := f.state.lru.Get(t)
@@ -292,7 +292,7 @@ func (f *Filter) runIn(q *packet.Parsed) (r Response, why string) {
 // runIn runs the output-specific part of the filter logic.
 func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) {
 	if q.IPProto == packet.UDP {
-		t := tuple{q.DstIP, q.SrcIP, q.DstPort, q.SrcPort}
+		t := tuple{q.DstIP4, q.SrcIP4, q.DstPort, q.SrcPort}
 		var ti interface{} = t // allocate once, rather than twice inside mutex
 
 		f.state.mu.Lock()
@@ -338,11 +338,11 @@ func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response {
 		f.logRateLimit(rf, q, dir, Drop, "ipv6")
 		return Drop
 	}
-	if q.DstIP.IsMulticast() {
+	if q.DstIP4.IsMulticast() {
 		f.logRateLimit(rf, q, dir, Drop, "multicast")
 		return Drop
 	}
-	if q.DstIP.IsLinkLocalUnicast() {
+	if q.DstIP4.IsLinkLocalUnicast() {
 		f.logRateLimit(rf, q, dir, Drop, "link-local-unicast")
 		return Drop
 	}
@@ -383,13 +383,13 @@ func omitDropLogging(p *packet.Parsed, dir direction) bool {
 			// it doesn't know about, so parse it out ourselves if needed.
 			ipProto := p.IPProto
 			if ipProto == 0 && len(b) > 8 {
-				ipProto = packet.IP4Proto(b[9])
+				ipProto = packet.IPProto(b[9])
 			}
 			// Omit logging about outgoing IGMP.
 			if ipProto == packet.IGMP {
 				return true
 			}
-			if p.DstIP.IsMulticast() || p.DstIP.IsLinkLocalUnicast() {
+			if p.DstIP4.IsMulticast() || p.DstIP4.IsLinkLocalUnicast() {
 				return true
 			}
 		case 6:

+ 23 - 16
wgengine/filter/filter_test.go

@@ -9,7 +9,6 @@ import (
 	"encoding/hex"
 	"encoding/json"
 	"fmt"
-	"net"
 	"strconv"
 	"strings"
 	"testing"
@@ -20,11 +19,19 @@ import (
 )
 
 var Unknown = packet.Unknown
-var ICMP = packet.ICMP
+var ICMPv4 = packet.ICMPv4
 var TCP = packet.TCP
 var UDP = packet.UDP
 var Fragment = packet.Fragment
 
+func mustIP4(s string) packet.IP4 {
+	ip, err := netaddr.ParseIP(s)
+	if err != nil {
+		panic(err)
+	}
+	return packet.IP4FromNetaddr(ip)
+}
+
 func pfx(s string) netaddr.IPPrefix {
 	pfx, err := netaddr.ParseIPPrefix(s)
 	if err != nil {
@@ -140,7 +147,7 @@ func TestFilter(t *testing.T) {
 		// Basic
 		{Accept, parsed(TCP, 0x08010101, 0x01020304, 999, 22)},
 		{Accept, parsed(UDP, 0x08010101, 0x01020304, 999, 22)},
-		{Accept, parsed(ICMP, 0x08010101, 0x01020304, 0, 0)},
+		{Accept, parsed(ICMPv4, 0x08010101, 0x01020304, 0, 0)},
 		{Drop, parsed(TCP, 0x08010101, 0x01020304, 0, 0)},
 		{Accept, parsed(TCP, 0x08010101, 0x01020304, 0, 22)},
 		{Drop, parsed(TCP, 0x08010101, 0x01020304, 0, 21)},
@@ -168,7 +175,7 @@ func TestFilter(t *testing.T) {
 			t.Errorf("#%d runIn got=%v want=%v packet:%v", i, got, test.want, test.p)
 		}
 		if test.p.IPProto == TCP {
-			if got := acl.CheckTCP(test.p.SrcIP.Netaddr(), test.p.DstIP.Netaddr(), test.p.DstPort); test.want != got {
+			if got := acl.CheckTCP(test.p.SrcIP4.Netaddr(), test.p.DstIP4.Netaddr(), test.p.DstPort); test.want != got {
 				t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p)
 			}
 		}
@@ -250,7 +257,7 @@ func BenchmarkFilter(b *testing.B) {
 
 	tcpPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0)
 	udpPacket := rawpacket(UDP, 0x08010101, 0x01020304, 999, 22, 0)
-	icmpPacket := rawpacket(ICMP, 0x08010101, 0x01020304, 0, 0, 0)
+	icmpPacket := rawpacket(ICMPv4, 0x08010101, 0x01020304, 0, 0, 0)
 
 	tcpSynPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0)
 	// TCP filtering is trivial (Accept) for non-SYN packets.
@@ -299,7 +306,7 @@ func TestPreFilter(t *testing.T) {
 		{"fragment", Accept, rawdefault(Fragment, 40)},
 		{"tcp", noVerdict, rawdefault(TCP, 200)},
 		{"udp", noVerdict, rawdefault(UDP, 200)},
-		{"icmp", noVerdict, rawdefault(ICMP, 200)},
+		{"icmp", noVerdict, rawdefault(ICMPv4, 200)},
 	}
 	f := NewAllowNone(t.Logf)
 	for _, testPacket := range packets {
@@ -312,11 +319,11 @@ func TestPreFilter(t *testing.T) {
 	}
 }
 
-func parsed(proto packet.IP4Proto, src, dst packet.IP4, sport, dport uint16) packet.Parsed {
+func parsed(proto packet.IPProto, src, dst packet.IP4, sport, dport uint16) packet.Parsed {
 	return packet.Parsed{
 		IPProto:  proto,
-		SrcIP:    src,
-		DstIP:    dst,
+		SrcIP4:   src,
+		DstIP4:   dst,
 		SrcPort:  sport,
 		DstPort:  dport,
 		TCPFlags: packet.TCPSyn,
@@ -325,11 +332,11 @@ func parsed(proto packet.IP4Proto, src, dst packet.IP4, sport, dport uint16) pac
 
 // rawpacket generates a packet with given source and destination ports and IPs
 // and resizes the header to trimLength if it is nonzero.
-func rawpacket(proto packet.IP4Proto, src, dst packet.IP4, sport, dport uint16, trimLength int) []byte {
+func rawpacket(proto packet.IPProto, src, dst packet.IP4, sport, dport uint16, trimLength int) []byte {
 	var headerLength int
 
 	switch proto {
-	case ICMP:
+	case ICMPv4:
 		headerLength = 24
 	case TCP:
 		headerLength = 40
@@ -357,7 +364,7 @@ func rawpacket(proto packet.IP4Proto, src, dst packet.IP4, sport, dport uint16,
 	bin.PutUint16(hdr[22:24], dport)
 
 	switch proto {
-	case ICMP:
+	case ICMPv4:
 		hdr[9] = 1
 	case TCP:
 		hdr[9] = 6
@@ -379,7 +386,7 @@ func rawpacket(proto packet.IP4Proto, src, dst packet.IP4, sport, dport uint16,
 }
 
 // rawdefault calls rawpacket with default ports and IPs.
-func rawdefault(proto packet.IP4Proto, trimLength int) []byte {
+func rawdefault(proto packet.IPProto, trimLength int) []byte {
 	ip := packet.IP4(0x08080808) // 8.8.8.8
 	port := uint16(53)
 	return rawpacket(proto, ip, ip, port, port, trimLength)
@@ -435,19 +442,19 @@ func TestOmitDropLogging(t *testing.T) {
 		},
 		{
 			name: "v4_multicast_out_low",
-			pkt:  &packet.Parsed{IPVersion: 4, DstIP: packet.NewIP4(net.ParseIP("224.0.0.0"))},
+			pkt:  &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("224.0.0.0")},
 			dir:  out,
 			want: true,
 		},
 		{
 			name: "v4_multicast_out_high",
-			pkt:  &packet.Parsed{IPVersion: 4, DstIP: packet.NewIP4(net.ParseIP("239.255.255.255"))},
+			pkt:  &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("239.255.255.255")},
 			dir:  out,
 			want: true,
 		},
 		{
 			name: "v4_link_local_unicast",
-			pkt:  &packet.Parsed{IPVersion: 4, DstIP: packet.NewIP4(net.ParseIP("169.254.1.2"))},
+			pkt:  &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("169.254.1.2")},
 			dir:  out,
 			want: true,
 		},

+ 4 - 4
wgengine/filter/match4.go

@@ -104,11 +104,11 @@ func newMatches4(ms []Match) (ret matches4) {
 // any of ms.
 func (ms matches4) match(q *packet.Parsed) bool {
 	for _, m := range ms {
-		if !ip4InList(q.SrcIP, m.srcs) {
+		if !ip4InList(q.SrcIP4, m.srcs) {
 			continue
 		}
 		for _, dst := range m.dsts {
-			if !dst.net.Contains(q.DstIP) {
+			if !dst.net.Contains(q.DstIP4) {
 				continue
 			}
 			if !dst.ports.contains(q.DstPort) {
@@ -124,11 +124,11 @@ func (ms matches4) match(q *packet.Parsed) bool {
 // any of ms.
 func (ms matches4) matchIPsOnly(q *packet.Parsed) bool {
 	for _, m := range ms {
-		if !ip4InList(q.SrcIP, m.srcs) {
+		if !ip4InList(q.SrcIP4, m.srcs) {
 			continue
 		}
 		for _, dst := range m.dsts {
-			if dst.net.Contains(q.DstIP) {
+			if dst.net.Contains(q.DstIP4) {
 				return true
 			}
 		}

+ 1 - 1
wgengine/tstun/tun.go

@@ -283,7 +283,7 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
 	p.Decode(buf[offset : offset+n])
 
 	if m, ok := t.destIPActivity.Load().(map[packet.IP4]func()); ok {
-		if fn := m[p.DstIP]; fn != nil {
+		if fn := m[p.DstIP4]; fn != nil {
 			fn()
 		}
 	}

+ 6 - 6
wgengine/userspace.go

@@ -372,7 +372,7 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
 // echoRespondToAll is an inbound post-filter responding to all echo requests.
 func echoRespondToAll(p *packet.Parsed, t *tstun.TUN) filter.Response {
 	if p.IsEchoRequest() {
-		header := p.ICMPHeader()
+		header := p.ICMP4Header()
 		header.ToResponse()
 		outp := packet.Generate(&header, p.Payload())
 		t.InjectOutbound(outp)
@@ -397,7 +397,7 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.TUN) fil
 		return filter.Drop
 	}
 
-	if runtime.GOOS == "darwin" && e.isLocalAddr(p.DstIP) {
+	if runtime.GOOS == "darwin" && e.isLocalAddr(p.DstIP4) {
 		// macOS NetworkExtension directs packets destined to the
 		// tunnel's local IP address into the tunnel, instead of
 		// looping back within the kernel network stack. We have to
@@ -421,10 +421,10 @@ func (e *userspaceEngine) isLocalAddr(ip packet.IP4) bool {
 
 // handleDNS is an outbound pre-filter resolving Tailscale domains.
 func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.TUN) filter.Response {
-	if p.DstIP == magicDNSIP && p.DstPort == magicDNSPort && p.IPProto == packet.UDP {
+	if p.DstIP4 == magicDNSIP && p.DstPort == magicDNSPort && p.IPProto == packet.UDP {
 		request := tsdns.Packet{
 			Payload: append([]byte(nil), p.Payload()...),
-			Addr:    netaddr.IPPort{IP: p.SrcIP.Netaddr(), Port: p.SrcPort},
+			Addr:    netaddr.IPPort{IP: p.SrcIP4.Netaddr(), Port: p.SrcPort},
 		}
 		err := e.resolver.EnqueueRequest(request)
 		if err != nil {
@@ -515,7 +515,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src
 	start := time.Now()
 	var dstIPs []packet.IP4
 	for _, ip := range ips {
-		dstIPs = append(dstIPs, packet.NewIP4(ip.IP()))
+		dstIPs = append(dstIPs, packet.IP4FromNetaddr(netaddr.IPFrom16(ip.Addr)))
 	}
 
 	payload := []byte("magicsock_spray") // no meaning
@@ -555,7 +555,7 @@ func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) {
 
 	e.wgLock.Lock()
 	if len(e.lastCfgFull.Addresses) > 0 {
-		srcIP = packet.NewIP4(e.lastCfgFull.Addresses[0].IP.IP())
+		srcIP = packet.IP4FromNetaddr(netaddr.IPFrom16(e.lastCfgFull.Addresses[0].IP.Addr))
 	}
 	e.wgLock.Unlock()