fragment_buffer.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package dtls
  2. import (
  3. "github.com/pion/dtls/v2/pkg/protocol"
  4. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  5. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  6. )
  7. type fragment struct {
  8. recordLayerHeader recordlayer.Header
  9. handshakeHeader handshake.Header
  10. data []byte
  11. }
  12. type fragmentBuffer struct {
  13. // map of MessageSequenceNumbers that hold slices of fragments
  14. cache map[uint16][]*fragment
  15. currentMessageSequenceNumber uint16
  16. }
  17. func newFragmentBuffer() *fragmentBuffer {
  18. return &fragmentBuffer{cache: map[uint16][]*fragment{}}
  19. }
  20. // Attempts to push a DTLS packet to the fragmentBuffer
  21. // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
  22. // when an error returns it is fatal, and the DTLS connection should be stopped
  23. func (f *fragmentBuffer) push(buf []byte) (bool, error) {
  24. frag := new(fragment)
  25. if err := frag.recordLayerHeader.Unmarshal(buf); err != nil {
  26. return false, err
  27. }
  28. // fragment isn't a handshake, we don't need to handle it
  29. if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
  30. return false, nil
  31. }
  32. for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) {
  33. if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
  34. return false, err
  35. }
  36. if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok {
  37. f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{}
  38. }
  39. // end index should be the length of handshake header but if the handshake
  40. // was fragmented, we should keep them all
  41. end := int(handshake.HeaderLength + frag.handshakeHeader.Length)
  42. if size := len(buf); end > size {
  43. end = size
  44. }
  45. // Discard all headers, when rebuilding the packet we will re-build
  46. frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...)
  47. f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag)
  48. buf = buf[end:]
  49. }
  50. return true, nil
  51. }
  52. func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
  53. frags, ok := f.cache[f.currentMessageSequenceNumber]
  54. if !ok {
  55. return nil, 0
  56. }
  57. // Go doesn't support recursive lambdas
  58. var appendMessage func(targetOffset uint32) bool
  59. rawMessage := []byte{}
  60. appendMessage = func(targetOffset uint32) bool {
  61. for _, f := range frags {
  62. if f.handshakeHeader.FragmentOffset == targetOffset {
  63. fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength)
  64. if fragmentEnd != f.handshakeHeader.Length {
  65. if !appendMessage(fragmentEnd) {
  66. return false
  67. }
  68. }
  69. rawMessage = append(f.data, rawMessage...)
  70. return true
  71. }
  72. }
  73. return false
  74. }
  75. // Recursively collect up
  76. if !appendMessage(0) {
  77. return nil, 0
  78. }
  79. firstHeader := frags[0].handshakeHeader
  80. firstHeader.FragmentOffset = 0
  81. firstHeader.FragmentLength = firstHeader.Length
  82. rawHeader, err := firstHeader.Marshal()
  83. if err != nil {
  84. return nil, 0
  85. }
  86. messageEpoch := frags[0].recordLayerHeader.Epoch
  87. delete(f.cache, f.currentMessageSequenceNumber)
  88. f.currentMessageSequenceNumber++
  89. return append(rawHeader, rawMessage...), messageEpoch
  90. }