auth_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package crypto_test
  2. import (
  3. "bytes"
  4. "crypto/aes"
  5. "crypto/cipher"
  6. "crypto/rand"
  7. "io"
  8. "testing"
  9. "github.com/google/go-cmp/cmp"
  10. "github.com/xtls/xray-core/common"
  11. "github.com/xtls/xray-core/common/buf"
  12. . "github.com/xtls/xray-core/common/crypto"
  13. "github.com/xtls/xray-core/common/protocol"
  14. )
  15. func TestAuthenticationReaderWriter(t *testing.T) {
  16. key := make([]byte, 16)
  17. rand.Read(key)
  18. block, err := aes.NewCipher(key)
  19. common.Must(err)
  20. aead, err := cipher.NewGCM(block)
  21. common.Must(err)
  22. const payloadSize = 1024 * 80
  23. rawPayload := make([]byte, payloadSize)
  24. rand.Read(rawPayload)
  25. payload := buf.MergeBytes(nil, rawPayload)
  26. cache := bytes.NewBuffer(nil)
  27. iv := make([]byte, 12)
  28. rand.Read(iv)
  29. writer := NewAuthenticationWriter(&AEADAuthenticator{
  30. AEAD: aead,
  31. NonceGenerator: GenerateStaticBytes(iv),
  32. AdditionalDataGenerator: GenerateEmptyBytes(),
  33. }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil)
  34. common.Must(writer.WriteMultiBuffer(payload))
  35. if cache.Len() <= 1024*80 {
  36. t.Error("cache len: ", cache.Len())
  37. }
  38. common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))
  39. reader := NewAuthenticationReader(&AEADAuthenticator{
  40. AEAD: aead,
  41. NonceGenerator: GenerateStaticBytes(iv),
  42. AdditionalDataGenerator: GenerateEmptyBytes(),
  43. }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil)
  44. var mb buf.MultiBuffer
  45. for mb.Len() < payloadSize {
  46. mb2, err := reader.ReadMultiBuffer()
  47. common.Must(err)
  48. mb, _ = buf.MergeMulti(mb, mb2)
  49. }
  50. if mb.Len() != payloadSize {
  51. t.Error("mb len: ", mb.Len())
  52. }
  53. mbContent := make([]byte, payloadSize)
  54. buf.SplitBytes(mb, mbContent)
  55. if r := cmp.Diff(mbContent, rawPayload); r != "" {
  56. t.Error(r)
  57. }
  58. _, err = reader.ReadMultiBuffer()
  59. if err != io.EOF {
  60. t.Error("error: ", err)
  61. }
  62. }
  63. func TestAuthenticationReaderWriterPacket(t *testing.T) {
  64. key := make([]byte, 16)
  65. common.Must2(rand.Read(key))
  66. block, err := aes.NewCipher(key)
  67. common.Must(err)
  68. aead, err := cipher.NewGCM(block)
  69. common.Must(err)
  70. cache := buf.New()
  71. iv := make([]byte, 12)
  72. rand.Read(iv)
  73. writer := NewAuthenticationWriter(&AEADAuthenticator{
  74. AEAD: aead,
  75. NonceGenerator: GenerateStaticBytes(iv),
  76. AdditionalDataGenerator: GenerateEmptyBytes(),
  77. }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket, nil)
  78. var payload buf.MultiBuffer
  79. pb1 := buf.New()
  80. pb1.Write([]byte("abcd"))
  81. payload = append(payload, pb1)
  82. pb2 := buf.New()
  83. pb2.Write([]byte("efgh"))
  84. payload = append(payload, pb2)
  85. common.Must(writer.WriteMultiBuffer(payload))
  86. if cache.Len() == 0 {
  87. t.Error("cache len: ", cache.Len())
  88. }
  89. common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))
  90. reader := NewAuthenticationReader(&AEADAuthenticator{
  91. AEAD: aead,
  92. NonceGenerator: GenerateStaticBytes(iv),
  93. AdditionalDataGenerator: GenerateEmptyBytes(),
  94. }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket, nil)
  95. mb, err := reader.ReadMultiBuffer()
  96. common.Must(err)
  97. mb, b1 := buf.SplitFirst(mb)
  98. if b1.String() != "abcd" {
  99. t.Error("b1: ", b1.String())
  100. }
  101. mb, b2 := buf.SplitFirst(mb)
  102. if b2.String() != "efgh" {
  103. t.Error("b2: ", b2.String())
  104. }
  105. if !mb.IsEmpty() {
  106. t.Error("not empty")
  107. }
  108. _, err = reader.ReadMultiBuffer()
  109. if err != io.EOF {
  110. t.Error("error: ", err)
  111. }
  112. }