Ver Fonte

net/packet: add ICMP6Header, like ICMP4Header

So we can generate IPv6 ping replies.

Change-Id: I79a9a38d8aa242e5dfca4cd15dfaffaea6cb1aee
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick há 4 anos atrás
pai
commit
21741e111b
4 ficheiros alterados com 226 adições e 7 exclusões
  1. 14 0
      net/packet/header.go
  2. 123 0
      net/packet/icmp6.go
  3. 80 0
      net/packet/icmp6_test.go
  4. 9 7
      net/packet/packet.go

+ 14 - 0
net/packet/header.go

@@ -39,6 +39,16 @@ type Header interface {
 	Marshal(buf []byte) error
 }
 
+// HeaderChecksummer is implemented by Header implementations that
+// need to do a checksum over their paylods.
+type HeaderChecksummer interface {
+	Header
+
+	// WriteCheck writes the correct checksum into buf, which should
+	// be be the already-marshalled header and payload.
+	WriteChecksum(buf []byte)
+}
+
 // Generate generates a new packet with the given Header and
 // payload. This function allocates memory, see Header.Marshal for an
 // allocation-free option.
@@ -49,5 +59,9 @@ func Generate(h Header, payload []byte) []byte {
 	copy(buf[hlen:], payload)
 	h.Marshal(buf)
 
+	if hc, ok := h.(HeaderChecksummer); ok {
+		hc.WriteChecksum(buf)
+	}
+
 	return buf
 }

+ 123 - 0
net/packet/icmp6.go

@@ -4,6 +4,12 @@
 
 package packet
 
+import (
+	"encoding/binary"
+
+	"tailscale.com/types/ipproto"
+)
+
 // icmp6HeaderLength is the size of the ICMPv6 packet header, not
 // including the outer IP layer or the variable "response data"
 // trailer.
@@ -42,3 +48,120 @@ type ICMP6Code uint8
 const (
 	ICMP6NoCode ICMP6Code = 0
 )
+
+// ICMP6Header is an IPv4+ICMPv4 header.
+type ICMP6Header struct {
+	IP6Header
+	Type ICMP6Type
+	Code ICMP6Code
+}
+
+// Len implements Header.
+func (h ICMP6Header) Len() int {
+	return h.IP6Header.Len() + icmp6HeaderLength
+}
+
+// Marshal implements Header.
+func (h ICMP6Header) Marshal(buf []byte) error {
+	if len(buf) < h.Len() {
+		return errSmallBuffer
+	}
+	if len(buf) > maxPacketLength {
+		return errLargePacket
+	}
+	// The caller does not need to set this.
+	h.IPProto = ipproto.ICMPv6
+
+	h.IP6Header.Marshal(buf)
+
+	const o = ip6HeaderLength // start offset of ICMPv6 header
+	buf[o+0] = uint8(h.Type)
+	buf[o+1] = uint8(h.Code)
+	buf[o+2] = 0 // checksum, to be filled in later
+	buf[o+3] = 0 // checksum, to be filled in later
+	return nil
+}
+
+// ToResponse implements Header. TODO: it doesn't implement it
+// correctly, instead it statically generates an ICMP Echo Reply
+// packet.
+func (h *ICMP6Header) ToResponse() {
+	// TODO: this doesn't implement ToResponse correctly, as it
+	// assumes the ICMP request type.
+	h.Type = ICMP6EchoReply
+	h.Code = ICMP6NoCode
+	h.IP6Header.ToResponse()
+}
+
+// WriteChecksum implements HeaderChecksummer, writing just the checksum bytes
+// into the otherwise fully marshaled ICMP6 packet p (which should include the
+// IPv6 header, ICMPv6 header, and payload).
+func (h ICMP6Header) WriteChecksum(p []byte) {
+	const payOff = ip6HeaderLength + icmp6HeaderLength
+	xsum := icmp6Checksum(p[ip6HeaderLength:payOff], h.Src.As16(), h.Dst.As16(), p[payOff:])
+	binary.BigEndian.PutUint16(p[ip6HeaderLength+2:], xsum)
+}
+
+// Adapted from gVisor:
+
+// icmp6Checksum calculates the ICMP checksum over the provided ICMPv6
+// header (without the IPv6 header), IPv6 src/dst addresses and the
+// payload.
+//
+// The header's existing checksum must be zeroed.
+func icmp6Checksum(header []byte, src, dst [16]byte, payload []byte) uint16 {
+	// Calculate the IPv6 pseudo-header upper-layer checksum.
+	xsum := checksumBytes(src[:], 0)
+	xsum = checksumBytes(dst[:], xsum)
+
+	var scratch [4]byte
+	binary.BigEndian.PutUint32(scratch[:], uint32(len(header)+len(payload)))
+	xsum = checksumBytes(scratch[:], xsum)
+	xsum = checksumBytes(append(scratch[:0], 0, 0, 0, uint8(ipproto.ICMPv6)), xsum)
+	xsum = checksumBytes(payload, xsum)
+
+	var hdrz [icmp6HeaderLength]byte
+	copy(hdrz[:], header)
+	// Zero out the header.
+	hdrz[2] = 0
+	hdrz[3] = 0
+	xsum = ^checksumBytes(hdrz[:], xsum)
+	return xsum
+}
+
+// checksumCombine combines the two uint16 to form their
+// checksum. This is done by adding them and the carry.
+//
+// Note that checksum a must have been computed on an even number of
+// bytes.
+func checksumCombine(a, b uint16) uint16 {
+	v := uint32(a) + uint32(b)
+	return uint16(v + v>>16)
+}
+
+// checksumBytes calculates the checksum (as defined in RFC 1071) of
+// the bytes in buf.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func checksumBytes(buf []byte, initial uint16) uint16 {
+	v := uint32(initial)
+
+	odd := len(buf)%2 == 1
+	if odd {
+		v += uint32(buf[0])
+		buf = buf[1:]
+	}
+
+	n := len(buf)
+	odd = n&1 != 0
+	if odd {
+		n--
+		v += uint32(buf[n]) << 8
+	}
+
+	for i := 0; i < n; i += 2 {
+		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+	}
+
+	return checksumCombine(uint16(v), uint16(v>>16))
+}

+ 80 - 0
net/packet/icmp6_test.go

@@ -0,0 +1,80 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+	"testing"
+
+	"inet.af/netaddr"
+	"tailscale.com/types/ipproto"
+)
+
+func TestICMPv6PingResponse(t *testing.T) {
+	pingHdr := ICMP6Header{
+		IP6Header: IP6Header{
+			Src:     netaddr.MustParseIP("1::1"),
+			Dst:     netaddr.MustParseIP("2::2"),
+			IPProto: ipproto.ICMPv6,
+		},
+		Type: ICMP6EchoRequest,
+		Code: ICMP6NoCode,
+	}
+
+	// echoReqLen is 2 bytes identifier + 2 bytes seq number.
+	// https://datatracker.ietf.org/doc/html/rfc4443#section-4.1
+	// Packet.IsEchoRequest verifies that these 4 bytes are present.
+	const echoReqLen = 4
+	buf := make([]byte, pingHdr.Len()+echoReqLen)
+	if err := pingHdr.Marshal(buf); err != nil {
+		t.Fatal(err)
+	}
+
+	var p Parsed
+	p.Decode(buf)
+	if !p.IsEchoRequest() {
+		t.Fatalf("not an echo request, got: %+v", p)
+	}
+
+	pingHdr.ToResponse()
+	buf = make([]byte, pingHdr.Len()+echoReqLen)
+	if err := pingHdr.Marshal(buf); err != nil {
+		t.Fatal(err)
+	}
+
+	p.Decode(buf)
+	if p.IsEchoRequest() {
+		t.Fatalf("unexpectedly still an echo request: %+v", p)
+	}
+	if !p.IsEchoResponse() {
+		t.Fatalf("not an echo response: %+v", p)
+	}
+}
+
+func TestICMPv6Checksum(t *testing.T) {
+	const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" +
+		"\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" +
+		"\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" +
+		"\x61\xb1\x9e\xad\x00\x06\x45\xaa"
+	// The packet that we'd originally generated incorrectly, but with the checksum
+	// bytes fixed per WireShark's correct calculation:
+	const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" +
+		"\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" +
+		"\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" +
+		"\x61\xb1\x9e\xad\x00\x06\x45\xaa"
+
+	var p Parsed
+	p.Decode([]byte(req))
+	if !p.IsEchoRequest() {
+		t.Fatalf("not an echo request, got: %+v", p)
+	}
+
+	h := p.ICMP6Header()
+	h.ToResponse()
+	pong := Generate(&h, p.Payload())
+
+	if string(pong) != wantRes {
+		t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes)
+	}
+}

+ 9 - 7
net/packet/packet.go

@@ -75,7 +75,7 @@ func (p *Parsed) String() string {
 }
 
 // Decode extracts data from the packet in b into q.
-// It performs extremely simple packet decoding for basic IPv4 packet types.
+// It performs extremely simple packet decoding for basic IPv4 and IPv6 packet types.
 // It extracts only the subprotocol id, IP addresses, and (if any) ports,
 // and shouldn't need any memory allocation.
 func (q *Parsed) Decode(b []byte) {
@@ -339,9 +339,6 @@ func (q *Parsed) IP6Header() IP6Header {
 }
 
 func (q *Parsed) ICMP4Header() ICMP4Header {
-	if q.IPVersion != 4 {
-		panic("IP4Header called on non-IPv4 Parsed")
-	}
 	return ICMP4Header{
 		IP4Header: q.IP4Header(),
 		Type:      ICMP4Type(q.b[q.subofs+0]),
@@ -349,10 +346,15 @@ func (q *Parsed) ICMP4Header() ICMP4Header {
 	}
 }
 
-func (q *Parsed) UDP4Header() UDP4Header {
-	if q.IPVersion != 4 {
-		panic("IP4Header called on non-IPv4 Parsed")
+func (q *Parsed) ICMP6Header() ICMP6Header {
+	return ICMP6Header{
+		IP6Header: q.IP6Header(),
+		Type:      ICMP6Type(q.b[q.subofs+0]),
+		Code:      ICMP6Code(q.b[q.subofs+1]),
 	}
+}
+
+func (q *Parsed) UDP4Header() UDP4Header {
 	return UDP4Header{
 		IP4Header: q.IP4Header(),
 		SrcPort:   q.Src.Port(),