io.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package dns
  2. import (
  3. "encoding/binary"
  4. "sync"
  5. "github.com/xtls/xray-core/common"
  6. "github.com/xtls/xray-core/common/buf"
  7. "github.com/xtls/xray-core/common/errors"
  8. "github.com/xtls/xray-core/common/serial"
  9. "golang.org/x/net/dns/dnsmessage"
  10. )
  11. func PackMessage(msg *dnsmessage.Message) (*buf.Buffer, error) {
  12. buffer := buf.New()
  13. rawBytes := buffer.Extend(buf.Size)
  14. packed, err := msg.AppendPack(rawBytes[:0])
  15. if err != nil {
  16. buffer.Release()
  17. return nil, err
  18. }
  19. buffer.Resize(0, int32(len(packed)))
  20. return buffer, nil
  21. }
  22. type MessageReader interface {
  23. ReadMessage() (*buf.Buffer, error)
  24. }
  25. type UDPReader struct {
  26. buf.Reader
  27. access sync.Mutex
  28. cache buf.MultiBuffer
  29. }
  30. func (r *UDPReader) readCache() *buf.Buffer {
  31. r.access.Lock()
  32. defer r.access.Unlock()
  33. mb, b := buf.SplitFirst(r.cache)
  34. r.cache = mb
  35. return b
  36. }
  37. func (r *UDPReader) refill() error {
  38. mb, err := r.Reader.ReadMultiBuffer()
  39. if err != nil {
  40. return err
  41. }
  42. r.access.Lock()
  43. r.cache = mb
  44. r.access.Unlock()
  45. return nil
  46. }
  47. // ReadMessage implements MessageReader.
  48. func (r *UDPReader) ReadMessage() (*buf.Buffer, error) {
  49. for {
  50. b := r.readCache()
  51. if b != nil {
  52. return b, nil
  53. }
  54. if err := r.refill(); err != nil {
  55. return nil, err
  56. }
  57. }
  58. }
  59. // Close implements common.Closable.
  60. func (r *UDPReader) Close() error {
  61. defer func() {
  62. r.access.Lock()
  63. buf.ReleaseMulti(r.cache)
  64. r.cache = nil
  65. r.access.Unlock()
  66. }()
  67. return common.Close(r.Reader)
  68. }
  69. type TCPReader struct {
  70. reader *buf.BufferedReader
  71. }
  72. func NewTCPReader(reader buf.Reader) *TCPReader {
  73. return &TCPReader{
  74. reader: &buf.BufferedReader{
  75. Reader: reader,
  76. },
  77. }
  78. }
  79. func (r *TCPReader) ReadMessage() (*buf.Buffer, error) {
  80. size, err := serial.ReadUint16(r.reader)
  81. if err != nil {
  82. return nil, err
  83. }
  84. if size > buf.Size {
  85. return nil, errors.New("message size too large: ", size)
  86. }
  87. b := buf.New()
  88. if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil {
  89. return nil, err
  90. }
  91. return b, nil
  92. }
  93. func (r *TCPReader) Interrupt() {
  94. common.Interrupt(r.reader)
  95. }
  96. func (r *TCPReader) Close() error {
  97. return common.Close(r.reader)
  98. }
  99. type MessageWriter interface {
  100. WriteMessage(msg *buf.Buffer) error
  101. }
  102. type UDPWriter struct {
  103. buf.Writer
  104. }
  105. func (w *UDPWriter) WriteMessage(b *buf.Buffer) error {
  106. return w.WriteMultiBuffer(buf.MultiBuffer{b})
  107. }
  108. type TCPWriter struct {
  109. buf.Writer
  110. }
  111. func (w *TCPWriter) WriteMessage(b *buf.Buffer) error {
  112. if b.IsEmpty() {
  113. return nil
  114. }
  115. mb := make(buf.MultiBuffer, 0, 2)
  116. size := buf.New()
  117. binary.BigEndian.PutUint16(size.Extend(2), uint16(b.Len()))
  118. mb = append(mb, size, b)
  119. return w.WriteMultiBuffer(mb)
  120. }