Browse Source

util/cstruct: add package for decoding padded C structures (#5429)

I was working on my "dump iptables rules using only syscalls" branch and
had a bunch of C structure decoding to do. Rather than manually
calculating the padding or using unsafe trickery to actually cast
variable-length structures to Go types, I'd rather use a helper package
that deals with padding for me.

Padding rules were taken from the following article:
  http://www.catb.org/esr/structure-packing/

Signed-off-by: Andrew Dunham <[email protected]>
Andrew Dunham 3 years ago
parent
commit
58cc049a9f
3 changed files with 405 additions and 0 deletions
  1. 179 0
      util/cstruct/cstruct.go
  2. 75 0
      util/cstruct/cstruct_example_test.go
  3. 151 0
      util/cstruct/cstruct_test.go

+ 179 - 0
util/cstruct/cstruct.go

@@ -0,0 +1,179 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package cstruct provides a helper for decoding binary data that is in the
+// form of a padded C structure.
+package cstruct
+
+import (
+	"errors"
+	"io"
+
+	"tailscale.com/util/endian"
+)
+
+// Size of a pointer-typed value, in bits
+const pointerSize = 32 << (^uintptr(0) >> 63)
+
+// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on
+// a 16- or 8-bit architecture any time soon.
+const is64Bit = pointerSize == 64
+
+// Decoder reads and decodes padded fields from a slice of bytes. All fields
+// are decoded with native endianness.
+//
+// Methods of a Decoder do not return errors, but rather store any error within
+// the Decoder. The first error can be obtained via the Err method; after the
+// first error, methods will return the zero value for their type.
+type Decoder struct {
+	b    []byte
+	off  int
+	err  error
+	dbuf [8]byte // for decoding
+}
+
+// NewDecoder creates a Decoder from a byte slice.
+func NewDecoder(b []byte) *Decoder {
+	return &Decoder{b: b}
+}
+
+var errUnsupportedSize = errors.New("unsupported size")
+
+func padBytes(offset, size int) int {
+	if offset == 0 || size == 1 {
+		return 0
+	}
+	remainder := offset % size
+	return size - remainder
+}
+
+func (d *Decoder) getField(b []byte) error {
+	size := len(b)
+
+	// We only support fields that are multiples of 2 (or 1-sized)
+	if size != 1 && size&1 == 1 {
+		return errUnsupportedSize
+	}
+
+	// Fields are aligned to their size
+	padBytes := padBytes(d.off, size)
+	if d.off+size+padBytes > len(d.b) {
+		return io.EOF
+	}
+	d.off += padBytes
+
+	copy(b, d.b[d.off:d.off+size])
+	d.off += size
+	return nil
+}
+
+// Err returns the first error that was encountered by this Decoder.
+func (d *Decoder) Err() error {
+	return d.err
+}
+
+// Offset returns the current read offset for data in the buffer.
+func (d *Decoder) Offset() int {
+	return d.off
+}
+
+// Byte returns a single byte from the buffer.
+func (d *Decoder) Byte() byte {
+	if d.err != nil {
+		return 0
+	}
+
+	if err := d.getField(d.dbuf[0:1]); err != nil {
+		d.err = err
+		return 0
+	}
+	return d.dbuf[0]
+}
+
+// Byte returns a number of bytes from the buffer based on the size of the
+// input slice. No padding is applied.
+//
+// If an error is encountered or this Decoder has previously encountered an
+// error, no changes are made to the provided buffer.
+func (d *Decoder) Bytes(b []byte) {
+	if d.err != nil {
+		return
+	}
+
+	// No padding for byte slices
+	size := len(b)
+	if d.off+size >= len(d.b) {
+		d.err = io.EOF
+		return
+	}
+	copy(b, d.b[d.off:d.off+size])
+	d.off += size
+}
+
+// Uint16 returns a uint16 decoded from the buffer.
+func (d *Decoder) Uint16() uint16 {
+	if d.err != nil {
+		return 0
+	}
+
+	if err := d.getField(d.dbuf[0:2]); err != nil {
+		d.err = err
+		return 0
+	}
+	return endian.Native.Uint16(d.dbuf[0:2])
+}
+
+// Uint32 returns a uint32 decoded from the buffer.
+func (d *Decoder) Uint32() uint32 {
+	if d.err != nil {
+		return 0
+	}
+
+	if err := d.getField(d.dbuf[0:4]); err != nil {
+		d.err = err
+		return 0
+	}
+	return endian.Native.Uint32(d.dbuf[0:4])
+}
+
+// Uint64 returns a uint64 decoded from the buffer.
+func (d *Decoder) Uint64() uint64 {
+	if d.err != nil {
+		return 0
+	}
+
+	if err := d.getField(d.dbuf[0:8]); err != nil {
+		d.err = err
+		return 0
+	}
+	return endian.Native.Uint64(d.dbuf[0:8])
+}
+
+// Uintptr returns a uintptr decoded from the buffer.
+func (d *Decoder) Uintptr() uintptr {
+	if d.err != nil {
+		return 0
+	}
+
+	if is64Bit {
+		return uintptr(d.Uint64())
+	} else {
+		return uintptr(d.Uint32())
+	}
+}
+
+// Int16 returns a int16 decoded from the buffer.
+func (d *Decoder) Int16() int16 {
+	return int16(d.Uint16())
+}
+
+// Int32 returns a int32 decoded from the buffer.
+func (d *Decoder) Int32() int32 {
+	return int32(d.Uint32())
+}
+
+// Int64 returns a int64 decoded from the buffer.
+func (d *Decoder) Int64() int64 {
+	return int64(d.Uint64())
+}

+ 75 - 0
util/cstruct/cstruct_example_test.go

@@ -0,0 +1,75 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Only built on 64-bit platforms to avoid complexity
+
+//go:build amd64 || arm64 || mips64le || ppc64le || riscv64
+// +build amd64 arm64 mips64le ppc64le riscv64
+
+package cstruct
+
+import "fmt"
+
+// This test provides a semi-realistic example of how you can
+// use this package to decode a C structure.
+func ExampleDecoder() {
+	// Our example C structure:
+	//    struct mystruct {
+	//      char *p;
+	//      char c;
+	//	/* implicit: char _pad[3]; */
+	//      int x;
+	//    };
+	//
+	// The Go structure definition:
+	type myStruct struct {
+		Ptr    uintptr
+		Ch     byte
+		Intval uint32
+	}
+
+	// Our "in-memory" version of the above structure
+	buf := []byte{
+		1, 2, 3, 4, 0, 0, 0, 0, // ptr
+		5,          // ch
+		99, 99, 99, // padding
+		78, 6, 0, 0, // x
+	}
+	d := NewDecoder(buf)
+
+	// Decode the structure; if one of these function returns an error,
+	// then subsequent decoder functions will return the zero value.
+	var x myStruct
+	x.Ptr = d.Uintptr()
+	x.Ch = d.Byte()
+	x.Intval = d.Uint32()
+
+	// Note that per the Go language spec:
+	//    [...] when evaluating the operands of an expression, assignment,
+	//    or return statement, all function calls, method calls, and
+	//    (channel) communication operations are evaluated in lexical
+	//    left-to-right order
+	//
+	// Since each field is assigned via a function call, one could use the
+	// following snippet to decode the struct.
+	//     x := myStruct{
+	//         Ptr:    d.Uintptr(),
+	//         Ch:     d.Byte(),
+	//         Intval: d.Uint32(),
+	//     }
+	//
+	// However, this means that reordering the fields in the initialization
+	// statement–normally a semantically identical operation–would change
+	// the way the structure is parsed. Thus we do it as above with
+	// explicit ordering.
+
+	// After finishing with the decoder, check errors
+	if err := d.Err(); err != nil {
+		panic(err)
+	}
+
+	// Print the decoder offset and structure
+	fmt.Printf("off=%d struct=%#v\n", d.Offset(), x)
+	// Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e}
+}

+ 151 - 0
util/cstruct/cstruct_test.go

@@ -0,0 +1,151 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cstruct
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"testing"
+)
+
+func TestPadBytes(t *testing.T) {
+	testCases := []struct {
+		offset int
+		size   int
+		want   int
+	}{
+		// No padding at beginning of structure
+		{0, 1, 0},
+		{0, 2, 0},
+		{0, 4, 0},
+		{0, 8, 0},
+
+		// No padding for single bytes
+		{1, 1, 0},
+
+		// Single byte padding
+		{1, 2, 1},
+		{3, 4, 1},
+
+		// Multi-byte padding
+		{1, 4, 3},
+		{2, 8, 6},
+	}
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("%d_%d_%d", tc.offset, tc.size, tc.want), func(t *testing.T) {
+			got := padBytes(tc.offset, tc.size)
+			if got != tc.want {
+				t.Errorf("got=%d; want=%d", got, tc.want)
+			}
+		})
+	}
+}
+
+func TestDecoder(t *testing.T) {
+	t.Run("UnsignedTypes", func(t *testing.T) {
+		dec := func(n int) *Decoder {
+			buf := make([]byte, n)
+			buf[0] = 1
+
+			d := NewDecoder(buf)
+
+			// Use t.Cleanup to perform an assertion on this
+			// decoder after the test code is finished with it.
+			t.Cleanup(func() {
+				if err := d.Err(); err != nil {
+					t.Fatal(err)
+				}
+			})
+			return d
+		}
+		if got := dec(2).Uint16(); got != 1 {
+			t.Errorf("uint16: got=%d; want=1", got)
+		}
+		if got := dec(4).Uint32(); got != 1 {
+			t.Errorf("uint32: got=%d; want=1", got)
+		}
+		if got := dec(8).Uint64(); got != 1 {
+			t.Errorf("uint64: got=%d; want=1", got)
+		}
+		if got := dec(pointerSize / 8).Uintptr(); got != 1 {
+			t.Errorf("uintptr: got=%d; want=1", got)
+		}
+	})
+
+	t.Run("SignedTypes", func(t *testing.T) {
+		dec := func(n int) *Decoder {
+			// Make a buffer of the exact size that consists of 0xff bytes
+			buf := make([]byte, n)
+			for i := 0; i < n; i++ {
+				buf[i] = 0xff
+			}
+
+			d := NewDecoder(buf)
+
+			// Use t.Cleanup to perform an assertion on this
+			// decoder after the test code is finished with it.
+			t.Cleanup(func() {
+				if err := d.Err(); err != nil {
+					t.Fatal(err)
+				}
+			})
+			return d
+		}
+		if got := dec(2).Int16(); got != -1 {
+			t.Errorf("int16: got=%d; want=-1", got)
+		}
+		if got := dec(4).Int32(); got != -1 {
+			t.Errorf("int32: got=%d; want=-1", got)
+		}
+		if got := dec(8).Int64(); got != -1 {
+			t.Errorf("int64: got=%d; want=-1", got)
+		}
+	})
+
+	t.Run("InsufficientData", func(t *testing.T) {
+		dec := func(n int) *Decoder {
+			// Make a buffer that's too small and contains arbitrary bytes
+			buf := make([]byte, n-1)
+			for i := 0; i < n-1; i++ {
+				buf[i] = 0xAD
+			}
+
+			// Use t.Cleanup to perform an assertion on this
+			// decoder after the test code is finished with it.
+			d := NewDecoder(buf)
+			t.Cleanup(func() {
+				if err := d.Err(); err == nil || !errors.Is(err, io.EOF) {
+					t.Errorf("(n=%d) expected io.EOF; got=%v", n, err)
+				}
+			})
+			return d
+		}
+
+		dec(2).Uint16()
+		dec(4).Uint32()
+		dec(8).Uint64()
+		dec(pointerSize / 8).Uintptr()
+
+		dec(2).Int16()
+		dec(4).Int32()
+		dec(8).Int64()
+	})
+
+	t.Run("Bytes", func(t *testing.T) {
+		d := NewDecoder([]byte("hello worldasdf"))
+		t.Cleanup(func() {
+			if err := d.Err(); err != nil {
+				t.Fatal(err)
+			}
+		})
+
+		buf := make([]byte, 11)
+		d.Bytes(buf)
+		if got := string(buf); got != "hello world" {
+			t.Errorf("bytes: got=%q; want=%q", got, "hello world")
+		}
+	})
+}