|
|
@@ -10,6 +10,7 @@ import (
|
|
|
"net/netip"
|
|
|
"os"
|
|
|
"reflect"
|
|
|
+ "runtime"
|
|
|
"slices"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
@@ -17,6 +18,7 @@ import (
|
|
|
"time"
|
|
|
|
|
|
"github.com/gaissmai/bart"
|
|
|
+ "github.com/tailscale/wireguard-go/conn"
|
|
|
"github.com/tailscale/wireguard-go/device"
|
|
|
"github.com/tailscale/wireguard-go/tun"
|
|
|
"go4.org/mem"
|
|
|
@@ -894,13 +896,7 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
|
|
|
return 0, res.err
|
|
|
}
|
|
|
if res.data == nil {
|
|
|
- n, err := t.injectedRead(res.injected, buffs[0], offset)
|
|
|
- sizes[0] = n
|
|
|
- if err != nil && n == 0 {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
-
|
|
|
- return 1, err
|
|
|
+ return t.injectedRead(res.injected, buffs, sizes, offset)
|
|
|
}
|
|
|
|
|
|
metricPacketOut.Add(int64(len(res.data)))
|
|
|
@@ -955,27 +951,85 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
|
|
|
return buffsPos, res.err
|
|
|
}
|
|
|
|
|
|
-// injectedRead handles injected reads, which bypass filters.
|
|
|
-func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int, error) {
|
|
|
- metricPacketOut.Add(1)
|
|
|
+const (
|
|
|
+ minTCPHeaderSize = 20
|
|
|
+)
|
|
|
|
|
|
- var n int
|
|
|
- if !res.packet.IsNil() {
|
|
|
+func stackGSOToTunGSO(pkt []byte, gso stack.GSO) (tun.GSOOptions, error) {
|
|
|
+ options := tun.GSOOptions{
|
|
|
+ CsumStart: gso.L3HdrLen,
|
|
|
+ CsumOffset: gso.CsumOffset,
|
|
|
+ GSOSize: gso.MSS,
|
|
|
+ NeedsCsum: gso.NeedsCsum,
|
|
|
+ }
|
|
|
+ switch gso.Type {
|
|
|
+ case stack.GSONone:
|
|
|
+ options.GSOType = tun.GSONone
|
|
|
+ return options, nil
|
|
|
+ case stack.GSOTCPv4:
|
|
|
+ options.GSOType = tun.GSOTCPv4
|
|
|
+ case stack.GSOTCPv6:
|
|
|
+ options.GSOType = tun.GSOTCPv6
|
|
|
+ default:
|
|
|
+ return tun.GSOOptions{}, fmt.Errorf("unsupported gVisor GSOType: %v", gso.Type)
|
|
|
+ }
|
|
|
+ // options.HdrLen is both layer 3 and 4 together, whereas gVisor only
|
|
|
+ // gives us layer 3 length. We have to gather TCP header length
|
|
|
+ // ourselves.
|
|
|
+ if len(pkt) < int(gso.L3HdrLen)+minTCPHeaderSize {
|
|
|
+ return tun.GSOOptions{}, errors.New("gVisor GSOTCP packet length too short")
|
|
|
+ }
|
|
|
+ tcphLen := uint16(pkt[int(gso.L3HdrLen)+12] >> 4 * 4)
|
|
|
+ options.HdrLen = gso.L3HdrLen + tcphLen
|
|
|
+ return options, nil
|
|
|
+}
|
|
|
|
|
|
- n = copy(buf[offset:], res.packet.NetworkHeader().Slice())
|
|
|
- n += copy(buf[offset+n:], res.packet.TransportHeader().Slice())
|
|
|
- n += copy(buf[offset+n:], res.packet.Data().AsRange().ToSlice())
|
|
|
- res.packet.DecRef()
|
|
|
+func invertGSOChecksum(pkt []byte, gso stack.GSO) {
|
|
|
+ if gso.NeedsCsum != true {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ at := int(gso.L3HdrLen + gso.CsumOffset)
|
|
|
+ if at+1 > len(pkt)-1 {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ pkt[at] = ^pkt[at]
|
|
|
+ pkt[at+1] = ^pkt[at+1]
|
|
|
+}
|
|
|
+
|
|
|
+// injectedRead handles injected reads, which bypass filters.
|
|
|
+func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []int, offset int) (n int, err error) {
|
|
|
+ var gso stack.GSO
|
|
|
+
|
|
|
+ pkt := outBuffs[0][offset:]
|
|
|
+ if res.packet != nil {
|
|
|
+ bufN := copy(pkt, res.packet.NetworkHeader().Slice())
|
|
|
+ bufN += copy(pkt[bufN:], res.packet.TransportHeader().Slice())
|
|
|
+ bufN += copy(pkt[bufN:], res.packet.Data().AsRange().ToSlice())
|
|
|
+ gso = res.packet.GSOOptions
|
|
|
+ pkt = pkt[:bufN]
|
|
|
+ defer res.packet.DecRef() // defer DecRef so we may continue to reference it
|
|
|
} else {
|
|
|
- n = copy(buf[offset:], res.data)
|
|
|
+ sizes[0] = copy(pkt, res.data)
|
|
|
+ pkt = pkt[:sizes[0]]
|
|
|
+ n = 1
|
|
|
}
|
|
|
|
|
|
pc := t.peerConfig.Load()
|
|
|
|
|
|
p := parsedPacketPool.Get().(*packet.Parsed)
|
|
|
defer parsedPacketPool.Put(p)
|
|
|
- p.Decode(buf[offset : offset+n])
|
|
|
+ p.Decode(pkt)
|
|
|
+
|
|
|
+ // We invert the transport layer checksum before and after snat() if gVisor
|
|
|
+ // handed us a segment with a partial checksum. A partial checksum is not a
|
|
|
+ // ones' complement of the sum, and incremental checksum updating that could
|
|
|
+ // occur as a result of snat() is not aware of this. Alternatively we could
|
|
|
+ // plumb partial transport layer checksum awareness down through snat(),
|
|
|
+ // but the surface area of such a change is much larger, and not yet
|
|
|
+ // justified by this singular case.
|
|
|
+ invertGSOChecksum(pkt, gso)
|
|
|
pc.snat(p)
|
|
|
+ invertGSOChecksum(pkt, gso)
|
|
|
|
|
|
if m := t.destIPActivity.Load(); m != nil {
|
|
|
if fn := m[p.Dst.Addr()]; fn != nil {
|
|
|
@@ -983,11 +1037,24 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ if res.packet != nil {
|
|
|
+ var gsoOptions tun.GSOOptions
|
|
|
+ gsoOptions, err = stackGSOToTunGSO(pkt, gso)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ n, err = tun.GSOSplit(pkt, gsoOptions, outBuffs, sizes, offset)
|
|
|
+ }
|
|
|
+
|
|
|
if stats := t.stats.Load(); stats != nil {
|
|
|
- stats.UpdateTxVirtual(buf[offset:][:n])
|
|
|
+ for i := 0; i < n; i++ {
|
|
|
+ stats.UpdateTxVirtual(outBuffs[i][offset : offset+sizes[i]])
|
|
|
+ }
|
|
|
}
|
|
|
+
|
|
|
t.noteActivity()
|
|
|
- return n, nil
|
|
|
+ metricPacketOut.Add(int64(n))
|
|
|
+ return n, err
|
|
|
}
|
|
|
|
|
|
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback, pc *peerConfigTable) filter.Response {
|
|
|
@@ -1288,6 +1355,14 @@ func (t *Wrapper) InjectOutboundPacketBuffer(pkt *stack.PacketBuffer) error {
|
|
|
}
|
|
|
|
|
|
func (t *Wrapper) BatchSize() int {
|
|
|
+ if runtime.GOOS == "linux" {
|
|
|
+ // Always setup Linux to handle vectors, even in the very rare case that
|
|
|
+ // the underlying t.tdev returns 1. gVisor GSO is always enabled for
|
|
|
+ // Linux, and we cannot make a determination on gVisor usage at
|
|
|
+ // wireguard-go.Device startup, which is when this value matters for
|
|
|
+ // packet memory init.
|
|
|
+ return conn.IdealBatchSize
|
|
|
+ }
|
|
|
return t.tdev.BatchSize()
|
|
|
}
|
|
|
|