zstd_test.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package zstdframe
  4. import (
  5. "math/bits"
  6. "math/rand/v2"
  7. "os"
  8. "runtime"
  9. "strings"
  10. "sync"
  11. "testing"
  12. "github.com/klauspost/compress/zstd"
  13. "tailscale.com/util/must"
  14. )
  15. // Use the concatenation of all Go source files in zstdframe as testdata.
  16. var src = func() (out []byte) {
  17. for _, de := range must.Get(os.ReadDir(".")) {
  18. if strings.HasSuffix(de.Name(), ".go") {
  19. out = append(out, must.Get(os.ReadFile(de.Name()))...)
  20. }
  21. }
  22. return out
  23. }()
  24. var dst []byte
  25. var dsts [][]byte
  26. // zstdEnc is identical to getEncoder without options,
  27. // except it relies on concurrency managed by the zstd package itself.
  28. var zstdEnc = must.Get(zstd.NewWriter(nil,
  29. zstd.WithEncoderConcurrency(runtime.NumCPU()),
  30. zstd.WithSingleSegment(true),
  31. zstd.WithZeroFrames(true),
  32. zstd.WithEncoderLevel(zstd.SpeedDefault),
  33. zstd.WithEncoderCRC(true),
  34. zstd.WithLowerEncoderMem(false)))
  35. // zstdDec is identical to getDecoder without options,
  36. // except it relies on concurrency managed by the zstd package itself.
  37. var zstdDec = must.Get(zstd.NewReader(nil,
  38. zstd.WithDecoderConcurrency(runtime.NumCPU()),
  39. zstd.WithDecoderMaxMemory(1<<63),
  40. zstd.IgnoreChecksum(false),
  41. zstd.WithDecoderLowmem(false)))
  42. var coders = []struct {
  43. name string
  44. appendEncode func([]byte, []byte) []byte
  45. appendDecode func([]byte, []byte) ([]byte, error)
  46. }{{
  47. name: "zstd",
  48. appendEncode: func(dst, src []byte) []byte { return zstdEnc.EncodeAll(src, dst) },
  49. appendDecode: func(dst, src []byte) ([]byte, error) { return zstdDec.DecodeAll(src, dst) },
  50. }, {
  51. name: "zstdframe",
  52. appendEncode: func(dst, src []byte) []byte { return AppendEncode(dst, src) },
  53. appendDecode: func(dst, src []byte) ([]byte, error) { return AppendDecode(dst, src) },
  54. }}
  55. func TestDecodeMaxSize(t *testing.T) {
  56. var enc, dec []byte
  57. zeros := make([]byte, 1<<16, 2<<16)
  58. check := func(encSize, maxDecSize int) {
  59. var gotErr, wantErr error
  60. enc = AppendEncode(enc[:0], zeros[:encSize])
  61. // Directly calling zstd.Decoder.DecodeAll may not trigger size check
  62. // since it only operates on closest power-of-two.
  63. dec, gotErr = func() ([]byte, error) {
  64. d := getDecoder(MaxDecodedSize(uint64(maxDecSize)))
  65. defer putDecoder(d)
  66. return d.Decoder.DecodeAll(enc, dec[:0]) // directly call zstd.Decoder.DecodeAll
  67. }()
  68. if encSize > 1<<(64-bits.LeadingZeros64(uint64(maxDecSize)-1)) {
  69. wantErr = zstd.ErrDecoderSizeExceeded
  70. }
  71. if gotErr != wantErr {
  72. t.Errorf("DecodeAll(AppendEncode(%d), %d) error = %v, want %v", encSize, maxDecSize, gotErr, wantErr)
  73. }
  74. // Calling AppendDecode should perform the exact size check.
  75. dec, gotErr = AppendDecode(dec[:0], enc, MaxDecodedSize(uint64(maxDecSize)))
  76. if encSize > maxDecSize {
  77. wantErr = zstd.ErrDecoderSizeExceeded
  78. }
  79. if gotErr != wantErr {
  80. t.Errorf("AppendDecode(AppendEncode(%d), %d) error = %v, want %v", encSize, maxDecSize, gotErr, wantErr)
  81. }
  82. }
  83. rn := rand.New(rand.NewPCG(0, 0))
  84. for n := 1 << 10; n <= len(zeros); n <<= 1 {
  85. nl := rn.IntN(n + 1)
  86. check(nl, nl)
  87. check(nl, nl-1)
  88. check(nl, (n+nl)/2)
  89. check(nl, n)
  90. check((n+nl)/2, n)
  91. check(n-1, n-1)
  92. check(n-1, n)
  93. check(n-1, n+1)
  94. check(n, n-1)
  95. check(n, n)
  96. check(n, n+1)
  97. check(n+1, n-1)
  98. check(n+1, n)
  99. check(n+1, n+1)
  100. }
  101. }
  102. func BenchmarkEncode(b *testing.B) {
  103. options := []struct {
  104. name string
  105. opts []Option
  106. }{
  107. {name: "Best", opts: []Option{BestCompression}},
  108. {name: "Better", opts: []Option{BetterCompression}},
  109. {name: "Default", opts: []Option{DefaultCompression}},
  110. {name: "Fastest", opts: []Option{FastestCompression}},
  111. {name: "FastestLowMemory", opts: []Option{FastestCompression, LowMemory(true)}},
  112. {name: "FastestWindowSize", opts: []Option{FastestCompression, MaxWindowSize(1 << 10)}},
  113. {name: "FastestNoChecksum", opts: []Option{FastestCompression, WithChecksum(false)}},
  114. }
  115. for _, bb := range options {
  116. b.Run(bb.name, func(b *testing.B) {
  117. b.ReportAllocs()
  118. b.SetBytes(int64(len(src)))
  119. for range b.N {
  120. dst = AppendEncode(dst[:0], src, bb.opts...)
  121. }
  122. })
  123. if testing.Verbose() {
  124. ratio := float64(len(src)) / float64(len(dst))
  125. b.Logf("ratio: %0.3fx", ratio)
  126. }
  127. }
  128. }
  129. func BenchmarkDecode(b *testing.B) {
  130. options := []struct {
  131. name string
  132. opts []Option
  133. }{
  134. {name: "Checksum", opts: []Option{WithChecksum(true)}},
  135. {name: "NoChecksum", opts: []Option{WithChecksum(false)}},
  136. {name: "LowMemory", opts: []Option{LowMemory(true)}},
  137. }
  138. src := AppendEncode(nil, src)
  139. for _, bb := range options {
  140. b.Run(bb.name, func(b *testing.B) {
  141. b.ReportAllocs()
  142. b.SetBytes(int64(len(src)))
  143. for range b.N {
  144. dst = must.Get(AppendDecode(dst[:0], src, bb.opts...))
  145. }
  146. })
  147. }
  148. }
  149. func BenchmarkEncodeParallel(b *testing.B) {
  150. numCPU := runtime.NumCPU()
  151. for _, coder := range coders {
  152. dsts = dsts[:0]
  153. for range numCPU {
  154. dsts = append(dsts, coder.appendEncode(nil, src))
  155. }
  156. b.Run(coder.name, func(b *testing.B) {
  157. b.ReportAllocs()
  158. for range b.N {
  159. var group sync.WaitGroup
  160. for j := 0; j < numCPU; j++ {
  161. group.Add(1)
  162. go func(j int) {
  163. defer group.Done()
  164. dsts[j] = coder.appendEncode(dsts[j][:0], src)
  165. }(j)
  166. }
  167. group.Wait()
  168. }
  169. })
  170. }
  171. }
  172. func BenchmarkDecodeParallel(b *testing.B) {
  173. numCPU := runtime.NumCPU()
  174. for _, coder := range coders {
  175. dsts = dsts[:0]
  176. src := AppendEncode(nil, src)
  177. for range numCPU {
  178. dsts = append(dsts, must.Get(coder.appendDecode(nil, src)))
  179. }
  180. b.Run(coder.name, func(b *testing.B) {
  181. b.ReportAllocs()
  182. for range b.N {
  183. var group sync.WaitGroup
  184. for j := 0; j < numCPU; j++ {
  185. group.Add(1)
  186. go func(j int) {
  187. defer group.Done()
  188. dsts[j] = must.Get(coder.appendDecode(dsts[j][:0], src))
  189. }(j)
  190. }
  191. group.Wait()
  192. }
  193. })
  194. }
  195. }
  196. var opt Option
  197. func TestOptionAllocs(t *testing.T) {
  198. t.Run("EncoderLevel", func(t *testing.T) {
  199. t.Log(testing.AllocsPerRun(1e3, func() { opt = EncoderLevel(zstd.SpeedFastest) }))
  200. })
  201. t.Run("MaxDecodedSize/PowerOfTwo", func(t *testing.T) {
  202. t.Log(testing.AllocsPerRun(1e3, func() { opt = MaxDecodedSize(1024) }))
  203. })
  204. t.Run("MaxDecodedSize/Prime", func(t *testing.T) {
  205. t.Log(testing.AllocsPerRun(1e3, func() { opt = MaxDecodedSize(1021) }))
  206. })
  207. t.Run("MaxWindowSize", func(t *testing.T) {
  208. t.Log(testing.AllocsPerRun(1e3, func() { opt = MaxWindowSize(1024) }))
  209. })
  210. t.Run("LowMemory", func(t *testing.T) {
  211. t.Log(testing.AllocsPerRun(1e3, func() { opt = LowMemory(true) }))
  212. })
  213. }
  214. func TestGetDecoderAllocs(t *testing.T) {
  215. t.Log(testing.AllocsPerRun(1e3, func() { getDecoder() }))
  216. }
  217. func TestGetEncoderAllocs(t *testing.T) {
  218. t.Log(testing.AllocsPerRun(1e3, func() { getEncoder() }))
  219. }