link_dtls.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. package handler
  2. import (
  3. "net"
  4. "time"
  5. "github.com/bjdgyc/anylink/base"
  6. "github.com/bjdgyc/anylink/dbdata"
  7. "github.com/bjdgyc/anylink/pkg/utils"
  8. "github.com/bjdgyc/anylink/sessdata"
  9. )
  10. func LinkDtls(conn net.Conn, cSess *sessdata.ConnSession) {
  11. base.Debug("LinkDtls connect ip:", cSess.IpAddr, "user:", cSess.Username, "udp-rip:", conn.RemoteAddr())
  12. dSess := cSess.NewDtlsConn()
  13. if dSess == nil {
  14. // 创建失败,直接关闭链接
  15. _ = conn.Close()
  16. return
  17. }
  18. defer func() {
  19. base.Debug("LinkDtls return", cSess.Username, cSess.IpAddr)
  20. _ = conn.Close()
  21. dSess.Close()
  22. }()
  23. var (
  24. err error
  25. n int
  26. dead = time.Duration(cSess.CstpDpd+5) * time.Second
  27. )
  28. go dtlsWrite(conn, dSess, cSess)
  29. for {
  30. err = conn.SetReadDeadline(utils.NowSec().Add(dead))
  31. if err != nil {
  32. base.Error("SetDeadline: ", cSess.Username, err)
  33. return
  34. }
  35. pl := getPayload()
  36. n, err = conn.Read(pl.Data)
  37. if err != nil {
  38. base.Error("read hdata: ", cSess.Username, err)
  39. return
  40. }
  41. // 限流设置
  42. err = cSess.RateLimit(n, true)
  43. if err != nil {
  44. base.Error(err)
  45. }
  46. switch pl.Data[0] {
  47. case 0x07: // KEEPALIVE
  48. // do nothing
  49. // base.Debug("recv keepalive", cSess.IpAddr)
  50. case 0x05: // DISCONNECT
  51. cSess.UserLogoutCode = dbdata.UserLogoutClient
  52. base.Debug("DISCONNECT DTLS", cSess.Username, cSess.IpAddr)
  53. return
  54. case 0x03: // DPD-REQ
  55. // base.Debug("recv DPD-REQ", cSess.IpAddr)
  56. pl.PType = 0x04
  57. if payloadOutDtls(cSess, dSess, pl) {
  58. return
  59. }
  60. case 0x04:
  61. // base.Debug("recv DPD-RESP", cSess.IpAddr)
  62. case 0x08: // decompress
  63. if cSess.DtlsPickCmp == nil {
  64. continue
  65. }
  66. dst := getByteFull()
  67. nn, err := cSess.DtlsPickCmp.Uncompress(pl.Data[1:], *dst)
  68. if err != nil {
  69. putByte(dst)
  70. base.Error("dtls decompress error", err, n)
  71. continue
  72. }
  73. pl.Data = append(pl.Data[:1], (*dst)[:nn]...)
  74. putByte(dst)
  75. n = nn + 1
  76. fallthrough
  77. case 0x00: // DATA
  78. // 去除数据头
  79. // copy(pl.Data, pl.Data[1:n])
  80. // 更新切片长度
  81. // pl.Data = pl.Data[:n-1]
  82. pl.Data = append(pl.Data[:0], pl.Data[1:n]...)
  83. if payloadIn(cSess, pl) {
  84. return
  85. }
  86. }
  87. }
  88. }
  89. func dtlsWrite(conn net.Conn, dSess *sessdata.DtlsSession, cSess *sessdata.ConnSession) {
  90. defer func() {
  91. base.Debug("dtlsWrite return", cSess.Username, cSess.IpAddr)
  92. _ = conn.Close()
  93. dSess.Close()
  94. }()
  95. var (
  96. pl *sessdata.Payload
  97. )
  98. for {
  99. // dtls优先推送数据
  100. select {
  101. case pl = <-cSess.PayloadOutDtls:
  102. case <-dSess.CloseChan:
  103. return
  104. }
  105. if pl.LType != sessdata.LTypeIPData {
  106. continue
  107. }
  108. // header = []byte{payload.PType}
  109. if pl.PType == 0x00 { // data
  110. isCompress := false
  111. if cSess.DtlsPickCmp != nil && len(pl.Data) > base.Cfg.NoCompressLimit {
  112. dst := getByteFull()
  113. size, err := cSess.DtlsPickCmp.Compress(pl.Data, (*dst)[1:])
  114. if err == nil && size < len(pl.Data) {
  115. (*dst)[0] = 0x08
  116. pl.Data = append(pl.Data[:0], (*dst)[:size+1]...)
  117. isCompress = true
  118. }
  119. putByte(dst)
  120. }
  121. // 未压缩
  122. if !isCompress {
  123. // 获取数据长度
  124. l := len(pl.Data)
  125. // 先扩容 +1
  126. pl.Data = pl.Data[:l+1]
  127. // 数据后移
  128. copy(pl.Data[1:], pl.Data)
  129. // 添加头信息
  130. pl.Data[0] = pl.PType
  131. }
  132. } else {
  133. // 设置头类型
  134. pl.Data = append(pl.Data[:0], pl.PType)
  135. }
  136. n, err := conn.Write(pl.Data)
  137. if err != nil {
  138. base.Error("write err", cSess.Username, err)
  139. return
  140. }
  141. putPayload(pl)
  142. // 限流设置
  143. err = cSess.RateLimit(n, false)
  144. if err != nil {
  145. base.Error(err)
  146. }
  147. }
  148. }