| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- package zstdframe
- import (
- "math/bits"
- "strconv"
- "sync"
- "github.com/klauspost/compress/zstd"
- "tailscale.com/util/must"
- )
- // Option is an option that can be passed to [AppendEncode] or [AppendDecode].
- type Option interface{ isOption() }
- type encoderLevel int
- // Constants that implement [Option] and can be passed to [AppendEncode].
- const (
- FastestCompression = encoderLevel(zstd.SpeedFastest)
- DefaultCompression = encoderLevel(zstd.SpeedDefault)
- BetterCompression = encoderLevel(zstd.SpeedBetterCompression)
- BestCompression = encoderLevel(zstd.SpeedBestCompression)
- )
- func (encoderLevel) isOption() {}
- // EncoderLevel specifies the compression level when encoding.
- //
- // This exists for compatibility with [zstd.EncoderLevel] values.
- // Most usages should directly use one of the following constants:
- // - [FastestCompression]
- // - [DefaultCompression]
- // - [BetterCompression]
- // - [BestCompression]
- //
- // By default, [DefaultCompression] is chosen.
- // This option is ignored when decoding.
- func EncoderLevel(level zstd.EncoderLevel) Option { return encoderLevel(level) }
- type withChecksum bool
- func (withChecksum) isOption() {}
- // WithChecksum specifies whether to produce a checksum when encoding,
- // or whether to verify the checksum when decoding.
- // By default, checksums are produced and verified.
- func WithChecksum(check bool) Option { return withChecksum(check) }
- type maxDecodedSize uint64
- func (maxDecodedSize) isOption() {}
- type maxDecodedSizeLog2 uint8 // uint8 avoids allocation when storing into interface
- func (maxDecodedSizeLog2) isOption() {}
- // MaxDecodedSize specifies the maximum decoded size and
- // is used to protect against hostile content.
- // By default, there is no limit.
- // This option is ignored when encoding.
- func MaxDecodedSize(maxSize uint64) Option {
- if bits.OnesCount64(maxSize) == 1 {
- return maxDecodedSizeLog2(log2(maxSize))
- }
- return maxDecodedSize(maxSize)
- }
- type maxWindowSizeLog2 uint8 // uint8 avoids allocation when storing into interface
- func (maxWindowSizeLog2) isOption() {}
- // MaxWindowSize specifies the maximum window size, which must be a power-of-two
- // and be in the range of [[zstd.MinWindowSize], [zstd.MaxWindowSize]].
- //
- // The compression or decompression algorithm will use a LZ77 rolling window
- // no larger than the specified size. The compression ratio will be
- // adversely affected, but memory requirements will be lower.
- // When decompressing, an error is reported if a LZ77 back reference exceeds
- // the specified maximum window size.
- //
- // For decompression, [MaxDecodedSize] is generally more useful.
- func MaxWindowSize(maxSize uint64) Option {
- switch {
- case maxSize < zstd.MinWindowSize:
- panic("maximum window size cannot be less than " + strconv.FormatUint(zstd.MinWindowSize, 10))
- case bits.OnesCount64(maxSize) != 1:
- panic("maximum window size must be a power-of-two")
- case maxSize > zstd.MaxWindowSize:
- panic("maximum window size cannot be greater than " + strconv.FormatUint(zstd.MaxWindowSize, 10))
- default:
- return maxWindowSizeLog2(log2(maxSize))
- }
- }
- type lowMemory bool
- func (lowMemory) isOption() {}
- // LowMemory specifies that the encoder and decoder should aim to use
- // lower amounts of memory at the cost of speed.
- // By default, more memory used for better speed.
- func LowMemory(low bool) Option { return lowMemory(low) }
- var encoderPools sync.Map // map[encoderOptions]*sync.Pool -> *zstd.Encoder
- type encoderOptions struct {
- level zstd.EncoderLevel
- maxWindowLog2 uint8
- checksum bool
- lowMemory bool
- }
- type encoder struct {
- pool *sync.Pool
- *zstd.Encoder
- }
- func getEncoder(opts ...Option) encoder {
- eopts := encoderOptions{level: zstd.SpeedDefault, checksum: true}
- for _, opt := range opts {
- switch opt := opt.(type) {
- case encoderLevel:
- eopts.level = zstd.EncoderLevel(opt)
- case maxWindowSizeLog2:
- eopts.maxWindowLog2 = uint8(opt)
- case withChecksum:
- eopts.checksum = bool(opt)
- case lowMemory:
- eopts.lowMemory = bool(opt)
- }
- }
- vpool, ok := encoderPools.Load(eopts)
- if !ok {
- vpool, _ = encoderPools.LoadOrStore(eopts, new(sync.Pool))
- }
- pool := vpool.(*sync.Pool)
- enc, _ := pool.Get().(*zstd.Encoder)
- if enc == nil {
- var noopts int
- zopts := [...]zstd.EOption{
- // Set concurrency=1 to ensure synchronous operation.
- zstd.WithEncoderConcurrency(1),
- // In stateless compression, the data is already in a single buffer,
- // so we might as well encode it as a single segment,
- // which ensures that the Frame_Content_Size is always populated,
- // informing decoders up-front the expected decompressed size.
- zstd.WithSingleSegment(true),
- // Ensure strict compliance with RFC 8878, section 3.1.,
- // where zstandard "is made up of one or more frames".
- zstd.WithZeroFrames(true),
- zstd.WithEncoderLevel(eopts.level),
- zstd.WithEncoderCRC(eopts.checksum),
- zstd.WithLowerEncoderMem(eopts.lowMemory),
- nil, // reserved for zstd.WithWindowSize
- }
- if eopts.maxWindowLog2 > 0 {
- zopts[len(zopts)-noopts-1] = zstd.WithWindowSize(1 << eopts.maxWindowLog2)
- } else {
- noopts++
- }
- enc = must.Get(zstd.NewWriter(nil, zopts[:len(zopts)-noopts]...))
- }
- return encoder{pool, enc}
- }
- func putEncoder(e encoder) { e.pool.Put(e.Encoder) }
- var decoderPools sync.Map // map[decoderOptions]*sync.Pool -> *zstd.Decoder
- type decoderOptions struct {
- maxSizeLog2 uint8
- maxWindowLog2 uint8
- checksum bool
- lowMemory bool
- }
- type decoder struct {
- pool *sync.Pool
- *zstd.Decoder
- maxSize uint64
- }
- func getDecoder(opts ...Option) decoder {
- maxSize := uint64(1 << 63)
- dopts := decoderOptions{maxSizeLog2: 63, checksum: true}
- for _, opt := range opts {
- switch opt := opt.(type) {
- case maxDecodedSizeLog2:
- maxSize = 1 << uint8(opt)
- dopts.maxSizeLog2 = uint8(opt)
- case maxDecodedSize:
- maxSize = uint64(opt)
- dopts.maxSizeLog2 = uint8(log2(maxSize))
- case maxWindowSizeLog2:
- dopts.maxWindowLog2 = uint8(opt)
- case withChecksum:
- dopts.checksum = bool(opt)
- case lowMemory:
- dopts.lowMemory = bool(opt)
- }
- }
- vpool, ok := decoderPools.Load(dopts)
- if !ok {
- vpool, _ = decoderPools.LoadOrStore(dopts, new(sync.Pool))
- }
- pool := vpool.(*sync.Pool)
- dec, _ := pool.Get().(*zstd.Decoder)
- if dec == nil {
- var noopts int
- zopts := [...]zstd.DOption{
- // Set concurrency=1 to ensure synchronous operation.
- zstd.WithDecoderConcurrency(1),
- zstd.WithDecoderMaxMemory(1 << min(max(10, dopts.maxSizeLog2), 63)),
- zstd.IgnoreChecksum(!dopts.checksum),
- zstd.WithDecoderLowmem(dopts.lowMemory),
- nil, // reserved for zstd.WithDecoderMaxWindow
- }
- if dopts.maxWindowLog2 > 0 {
- zopts[len(zopts)-noopts-1] = zstd.WithDecoderMaxWindow(1 << dopts.maxWindowLog2)
- } else {
- noopts++
- }
- dec = must.Get(zstd.NewReader(nil, zopts[:len(zopts)-noopts]...))
- }
- return decoder{pool, dec, maxSize}
- }
- func putDecoder(d decoder) { d.pool.Put(d.Decoder) }
- func (d decoder) DecodeAll(src, dst []byte) ([]byte, error) {
- // We only configure DecodeAll to enforce MaxDecodedSize by powers-of-two.
- // Perform a more fine grain check based on the exact value.
- dst2, err := d.Decoder.DecodeAll(src, dst)
- if err == nil && uint64(len(dst2)-len(dst)) > d.maxSize {
- err = zstd.ErrDecoderSizeExceeded
- }
- return dst2, err
- }
- // log2 computes log2 of x rounded up to the nearest integer.
- func log2(x uint64) int { return 64 - bits.LeadingZeros64(x-1) }
|