options.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package zstdframe
  4. import (
  5. "math/bits"
  6. "strconv"
  7. "sync"
  8. "github.com/klauspost/compress/zstd"
  9. "tailscale.com/util/must"
  10. )
  11. // Option is an option that can be passed to [AppendEncode] or [AppendDecode].
  12. type Option interface{ isOption() }
  13. type encoderLevel int
  14. // Constants that implement [Option] and can be passed to [AppendEncode].
  15. const (
  16. FastestCompression = encoderLevel(zstd.SpeedFastest)
  17. DefaultCompression = encoderLevel(zstd.SpeedDefault)
  18. BetterCompression = encoderLevel(zstd.SpeedBetterCompression)
  19. BestCompression = encoderLevel(zstd.SpeedBestCompression)
  20. )
  21. func (encoderLevel) isOption() {}
  22. // EncoderLevel specifies the compression level when encoding.
  23. //
  24. // This exists for compatibility with [zstd.EncoderLevel] values.
  25. // Most usages should directly use one of the following constants:
  26. // - [FastestCompression]
  27. // - [DefaultCompression]
  28. // - [BetterCompression]
  29. // - [BestCompression]
  30. //
  31. // By default, [DefaultCompression] is chosen.
  32. // This option is ignored when decoding.
  33. func EncoderLevel(level zstd.EncoderLevel) Option { return encoderLevel(level) }
  34. type withChecksum bool
  35. func (withChecksum) isOption() {}
  36. // WithChecksum specifies whether to produce a checksum when encoding,
  37. // or whether to verify the checksum when decoding.
  38. // By default, checksums are produced and verified.
  39. func WithChecksum(check bool) Option { return withChecksum(check) }
  40. type maxDecodedSize uint64
  41. func (maxDecodedSize) isOption() {}
  42. type maxDecodedSizeLog2 uint8 // uint8 avoids allocation when storing into interface
  43. func (maxDecodedSizeLog2) isOption() {}
  44. // MaxDecodedSize specifies the maximum decoded size and
  45. // is used to protect against hostile content.
  46. // By default, there is no limit.
  47. // This option is ignored when encoding.
  48. func MaxDecodedSize(maxSize uint64) Option {
  49. if bits.OnesCount64(maxSize) == 1 {
  50. return maxDecodedSizeLog2(log2(maxSize))
  51. }
  52. return maxDecodedSize(maxSize)
  53. }
  54. type maxWindowSizeLog2 uint8 // uint8 avoids allocation when storing into interface
  55. func (maxWindowSizeLog2) isOption() {}
  56. // MaxWindowSize specifies the maximum window size, which must be a power-of-two
  57. // and be in the range of [[zstd.MinWindowSize], [zstd.MaxWindowSize]].
  58. //
  59. // The compression or decompression algorithm will use a LZ77 rolling window
  60. // no larger than the specified size. The compression ratio will be
  61. // adversely affected, but memory requirements will be lower.
  62. // When decompressing, an error is reported if a LZ77 back reference exceeds
  63. // the specified maximum window size.
  64. //
  65. // For decompression, [MaxDecodedSize] is generally more useful.
  66. func MaxWindowSize(maxSize uint64) Option {
  67. switch {
  68. case maxSize < zstd.MinWindowSize:
  69. panic("maximum window size cannot be less than " + strconv.FormatUint(zstd.MinWindowSize, 10))
  70. case bits.OnesCount64(maxSize) != 1:
  71. panic("maximum window size must be a power-of-two")
  72. case maxSize > zstd.MaxWindowSize:
  73. panic("maximum window size cannot be greater than " + strconv.FormatUint(zstd.MaxWindowSize, 10))
  74. default:
  75. return maxWindowSizeLog2(log2(maxSize))
  76. }
  77. }
  78. type lowMemory bool
  79. func (lowMemory) isOption() {}
  80. // LowMemory specifies that the encoder and decoder should aim to use
  81. // lower amounts of memory at the cost of speed.
  82. // By default, more memory used for better speed.
  83. func LowMemory(low bool) Option { return lowMemory(low) }
  84. var encoderPools sync.Map // map[encoderOptions]*sync.Pool -> *zstd.Encoder
  85. type encoderOptions struct {
  86. level zstd.EncoderLevel
  87. maxWindowLog2 uint8
  88. checksum bool
  89. lowMemory bool
  90. }
  91. type encoder struct {
  92. pool *sync.Pool
  93. *zstd.Encoder
  94. }
  95. func getEncoder(opts ...Option) encoder {
  96. eopts := encoderOptions{level: zstd.SpeedDefault, checksum: true}
  97. for _, opt := range opts {
  98. switch opt := opt.(type) {
  99. case encoderLevel:
  100. eopts.level = zstd.EncoderLevel(opt)
  101. case maxWindowSizeLog2:
  102. eopts.maxWindowLog2 = uint8(opt)
  103. case withChecksum:
  104. eopts.checksum = bool(opt)
  105. case lowMemory:
  106. eopts.lowMemory = bool(opt)
  107. }
  108. }
  109. vpool, ok := encoderPools.Load(eopts)
  110. if !ok {
  111. vpool, _ = encoderPools.LoadOrStore(eopts, new(sync.Pool))
  112. }
  113. pool := vpool.(*sync.Pool)
  114. enc, _ := pool.Get().(*zstd.Encoder)
  115. if enc == nil {
  116. var noopts int
  117. zopts := [...]zstd.EOption{
  118. // Set concurrency=1 to ensure synchronous operation.
  119. zstd.WithEncoderConcurrency(1),
  120. // In stateless compression, the data is already in a single buffer,
  121. // so we might as well encode it as a single segment,
  122. // which ensures that the Frame_Content_Size is always populated,
  123. // informing decoders up-front the expected decompressed size.
  124. zstd.WithSingleSegment(true),
  125. // Ensure strict compliance with RFC 8878, section 3.1.,
  126. // where zstandard "is made up of one or more frames".
  127. zstd.WithZeroFrames(true),
  128. zstd.WithEncoderLevel(eopts.level),
  129. zstd.WithEncoderCRC(eopts.checksum),
  130. zstd.WithLowerEncoderMem(eopts.lowMemory),
  131. nil, // reserved for zstd.WithWindowSize
  132. }
  133. if eopts.maxWindowLog2 > 0 {
  134. zopts[len(zopts)-noopts-1] = zstd.WithWindowSize(1 << eopts.maxWindowLog2)
  135. } else {
  136. noopts++
  137. }
  138. enc = must.Get(zstd.NewWriter(nil, zopts[:len(zopts)-noopts]...))
  139. }
  140. return encoder{pool, enc}
  141. }
  142. func putEncoder(e encoder) { e.pool.Put(e.Encoder) }
  143. var decoderPools sync.Map // map[decoderOptions]*sync.Pool -> *zstd.Decoder
  144. type decoderOptions struct {
  145. maxSizeLog2 uint8
  146. maxWindowLog2 uint8
  147. checksum bool
  148. lowMemory bool
  149. }
  150. type decoder struct {
  151. pool *sync.Pool
  152. *zstd.Decoder
  153. maxSize uint64
  154. }
  155. func getDecoder(opts ...Option) decoder {
  156. maxSize := uint64(1 << 63)
  157. dopts := decoderOptions{maxSizeLog2: 63, checksum: true}
  158. for _, opt := range opts {
  159. switch opt := opt.(type) {
  160. case maxDecodedSizeLog2:
  161. maxSize = 1 << uint8(opt)
  162. dopts.maxSizeLog2 = uint8(opt)
  163. case maxDecodedSize:
  164. maxSize = uint64(opt)
  165. dopts.maxSizeLog2 = uint8(log2(maxSize))
  166. case maxWindowSizeLog2:
  167. dopts.maxWindowLog2 = uint8(opt)
  168. case withChecksum:
  169. dopts.checksum = bool(opt)
  170. case lowMemory:
  171. dopts.lowMemory = bool(opt)
  172. }
  173. }
  174. vpool, ok := decoderPools.Load(dopts)
  175. if !ok {
  176. vpool, _ = decoderPools.LoadOrStore(dopts, new(sync.Pool))
  177. }
  178. pool := vpool.(*sync.Pool)
  179. dec, _ := pool.Get().(*zstd.Decoder)
  180. if dec == nil {
  181. var noopts int
  182. zopts := [...]zstd.DOption{
  183. // Set concurrency=1 to ensure synchronous operation.
  184. zstd.WithDecoderConcurrency(1),
  185. zstd.WithDecoderMaxMemory(1 << min(max(10, dopts.maxSizeLog2), 63)),
  186. zstd.IgnoreChecksum(!dopts.checksum),
  187. zstd.WithDecoderLowmem(dopts.lowMemory),
  188. nil, // reserved for zstd.WithDecoderMaxWindow
  189. }
  190. if dopts.maxWindowLog2 > 0 {
  191. zopts[len(zopts)-noopts-1] = zstd.WithDecoderMaxWindow(1 << dopts.maxWindowLog2)
  192. } else {
  193. noopts++
  194. }
  195. dec = must.Get(zstd.NewReader(nil, zopts[:len(zopts)-noopts]...))
  196. }
  197. return decoder{pool, dec, maxSize}
  198. }
  199. func putDecoder(d decoder) { d.pool.Put(d.Decoder) }
  200. func (d decoder) DecodeAll(src, dst []byte) ([]byte, error) {
  201. // We only configure DecodeAll to enforce MaxDecodedSize by powers-of-two.
  202. // Perform a more fine grain check based on the exact value.
  203. dst2, err := d.Decoder.DecodeAll(src, dst)
  204. if err == nil && uint64(len(dst2)-len(dst)) > d.maxSize {
  205. err = zstd.ErrDecoderSizeExceeded
  206. }
  207. return dst2, err
  208. }
  209. // log2 computes log2 of x rounded up to the nearest integer.
  210. func log2(x uint64) int { return 64 - bits.LeadingZeros64(x-1) }