io.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package dns
  2. import (
  3. "encoding/binary"
  4. "sync"
  5. "github.com/xtls/xray-core/common"
  6. "github.com/xtls/xray-core/common/buf"
  7. "github.com/xtls/xray-core/common/serial"
  8. "golang.org/x/net/dns/dnsmessage"
  9. )
  10. func PackMessage(msg *dnsmessage.Message) (*buf.Buffer, error) {
  11. buffer := buf.New()
  12. rawBytes := buffer.Extend(buf.Size)
  13. packed, err := msg.AppendPack(rawBytes[:0])
  14. if err != nil {
  15. buffer.Release()
  16. return nil, err
  17. }
  18. buffer.Resize(0, int32(len(packed)))
  19. return buffer, nil
  20. }
  21. type MessageReader interface {
  22. ReadMessage() (*buf.Buffer, error)
  23. }
  24. type UDPReader struct {
  25. buf.Reader
  26. access sync.Mutex
  27. cache buf.MultiBuffer
  28. }
  29. func (r *UDPReader) readCache() *buf.Buffer {
  30. r.access.Lock()
  31. defer r.access.Unlock()
  32. mb, b := buf.SplitFirst(r.cache)
  33. r.cache = mb
  34. return b
  35. }
  36. func (r *UDPReader) refill() error {
  37. mb, err := r.Reader.ReadMultiBuffer()
  38. if err != nil {
  39. return err
  40. }
  41. r.access.Lock()
  42. r.cache = mb
  43. r.access.Unlock()
  44. return nil
  45. }
  46. // ReadMessage implements MessageReader.
  47. func (r *UDPReader) ReadMessage() (*buf.Buffer, error) {
  48. for {
  49. b := r.readCache()
  50. if b != nil {
  51. return b, nil
  52. }
  53. if err := r.refill(); err != nil {
  54. return nil, err
  55. }
  56. }
  57. }
  58. // Close implements common.Closable.
  59. func (r *UDPReader) Close() error {
  60. defer func() {
  61. r.access.Lock()
  62. buf.ReleaseMulti(r.cache)
  63. r.cache = nil
  64. r.access.Unlock()
  65. }()
  66. return common.Close(r.Reader)
  67. }
  68. type TCPReader struct {
  69. reader *buf.BufferedReader
  70. }
  71. func NewTCPReader(reader buf.Reader) *TCPReader {
  72. return &TCPReader{
  73. reader: &buf.BufferedReader{
  74. Reader: reader,
  75. },
  76. }
  77. }
  78. func (r *TCPReader) ReadMessage() (*buf.Buffer, error) {
  79. size, err := serial.ReadUint16(r.reader)
  80. if err != nil {
  81. return nil, err
  82. }
  83. if size > buf.Size {
  84. return nil, newError("message size too large: ", size)
  85. }
  86. b := buf.New()
  87. if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil {
  88. return nil, err
  89. }
  90. return b, nil
  91. }
  92. func (r *TCPReader) Interrupt() {
  93. common.Interrupt(r.reader)
  94. }
  95. func (r *TCPReader) Close() error {
  96. return common.Close(r.reader)
  97. }
  98. type MessageWriter interface {
  99. WriteMessage(msg *buf.Buffer) error
  100. }
  101. type UDPWriter struct {
  102. buf.Writer
  103. }
  104. func (w *UDPWriter) WriteMessage(b *buf.Buffer) error {
  105. return w.WriteMultiBuffer(buf.MultiBuffer{b})
  106. }
  107. type TCPWriter struct {
  108. buf.Writer
  109. }
  110. func (w *TCPWriter) WriteMessage(b *buf.Buffer) error {
  111. if b.IsEmpty() {
  112. return nil
  113. }
  114. mb := make(buf.MultiBuffer, 0, 2)
  115. size := buf.New()
  116. binary.BigEndian.PutUint16(size.Extend(2), uint16(b.Len()))
  117. mb = append(mb, size, b)
  118. return w.WriteMultiBuffer(mb)
  119. }