recorder_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package agent
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "path/filepath"
  8. "reflect"
  9. "strings"
  10. "testing"
  11. "go.yaml.in/yaml/v4"
  12. "gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
  13. "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
  14. )
  15. func newRecorder(t *testing.T) *recorder.Recorder {
  16. cassetteName := filepath.Join("testdata", t.Name())
  17. r, err := recorder.New(
  18. cassetteName,
  19. recorder.WithMode(recorder.ModeRecordOnce),
  20. recorder.WithMatcher(customMatcher(t)),
  21. recorder.WithMarshalFunc(marshalFunc),
  22. recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
  23. recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
  24. )
  25. if err != nil {
  26. t.Fatalf("recorder: failed to create recorder: %v", err)
  27. }
  28. t.Cleanup(func() {
  29. if err := r.Stop(); err != nil {
  30. t.Errorf("recorder: failed to stop recorder: %v", err)
  31. }
  32. })
  33. return r
  34. }
  35. func customMatcher(t *testing.T) recorder.MatcherFunc {
  36. return func(r *http.Request, i cassette.Request) bool {
  37. if r.Body == nil || r.Body == http.NoBody {
  38. return cassette.DefaultMatcher(r, i)
  39. }
  40. if r.Method != i.Method || r.URL.String() != i.URL {
  41. return false
  42. }
  43. reqBody, err := io.ReadAll(r.Body)
  44. if err != nil {
  45. t.Fatalf("recorder: failed to read request body")
  46. }
  47. r.Body.Close()
  48. r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
  49. // Some providers can sometimes generate JSON requests with keys in
  50. // a different order, which means a direct string comparison will fail.
  51. // Falling back to deserializing the content if we don't have a match.
  52. requestContent := normalizeLineEndings(reqBody)
  53. cassetteContent := normalizeLineEndings(i.Body)
  54. if requestContent == cassetteContent {
  55. return true
  56. }
  57. var content1, content2 any
  58. if err := json.Unmarshal([]byte(requestContent), &content1); err != nil {
  59. return false
  60. }
  61. if err := json.Unmarshal([]byte(cassetteContent), &content2); err != nil {
  62. return false
  63. }
  64. return reflect.DeepEqual(content1, content2)
  65. }
  66. }
  67. func marshalFunc(in any) ([]byte, error) {
  68. var buff bytes.Buffer
  69. enc := yaml.NewEncoder(&buff)
  70. enc.SetIndent(2)
  71. enc.CompactSeqIndent()
  72. if err := enc.Encode(in); err != nil {
  73. return nil, err
  74. }
  75. return buff.Bytes(), nil
  76. }
  77. var headersToKeep = map[string]struct{}{
  78. "accept": {},
  79. "content-type": {},
  80. "user-agent": {},
  81. }
  82. func hookRemoveHeaders(i *cassette.Interaction) error {
  83. for k := range i.Request.Headers {
  84. if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
  85. delete(i.Request.Headers, k)
  86. }
  87. }
  88. for k := range i.Response.Headers {
  89. if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
  90. delete(i.Response.Headers, k)
  91. }
  92. }
  93. return nil
  94. }
  95. // normalizeLineEndings does not only replace `\r\n` into `\n`,
  96. // but also replaces `\\r\\n` into `\\n`. That's because we want the content
  97. // inside JSON string to be replaces as well.
  98. func normalizeLineEndings[T string | []byte](s T) string {
  99. str := string(s)
  100. str = strings.ReplaceAll(str, "\r\n", "\n")
  101. str = strings.ReplaceAll(str, `\r\n`, `\n`)
  102. return str
  103. }