link_dtls.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package handler
  2. import (
  3. "net"
  4. "time"
  5. "github.com/bjdgyc/anylink/base"
  6. "github.com/bjdgyc/anylink/pkg/utils"
  7. "github.com/bjdgyc/anylink/sessdata"
  8. )
  9. func LinkDtls(conn net.Conn, cSess *sessdata.ConnSession) {
  10. base.Debug("LinkDtls connect ip:", cSess.IpAddr, "user:", cSess.Username, "udp-rip:", conn.RemoteAddr())
  11. dSess := cSess.NewDtlsConn()
  12. if dSess == nil {
  13. // 创建失败,直接关闭链接
  14. _ = conn.Close()
  15. return
  16. }
  17. defer func() {
  18. base.Debug("LinkDtls return", cSess.IpAddr)
  19. _ = conn.Close()
  20. dSess.Close()
  21. }()
  22. var (
  23. err error
  24. n int
  25. dead = time.Duration(cSess.CstpDpd+5) * time.Second
  26. )
  27. go dtlsWrite(conn, dSess, cSess)
  28. for {
  29. err = conn.SetReadDeadline(utils.NowSec().Add(dead))
  30. if err != nil {
  31. base.Error("SetDeadline: ", err)
  32. return
  33. }
  34. pl := getPayload()
  35. n, err = conn.Read(pl.Data)
  36. if err != nil {
  37. base.Error("read hdata: ", err)
  38. return
  39. }
  40. // 限流设置
  41. err = cSess.RateLimit(n, true)
  42. if err != nil {
  43. base.Error(err)
  44. }
  45. switch pl.Data[0] {
  46. case 0x07: // KEEPALIVE
  47. // do nothing
  48. // base.Debug("recv keepalive", cSess.IpAddr)
  49. case 0x05: // DISCONNECT
  50. base.Debug("DISCONNECT DTLS", cSess.IpAddr)
  51. return
  52. case 0x03: // DPD-REQ
  53. // base.Debug("recv DPD-REQ", cSess.IpAddr)
  54. pl.PType = 0x04
  55. if payloadOutDtls(cSess, dSess, pl) {
  56. return
  57. }
  58. case 0x04:
  59. // base.Debug("recv DPD-RESP", cSess.IpAddr)
  60. case 0x00: // DATA
  61. // 去除数据头
  62. // copy(pl.Data, pl.Data[1:n])
  63. // 更新切片长度
  64. // pl.Data = pl.Data[:n-1]
  65. pl.Data = append(pl.Data[:0], pl.Data[1:n]...)
  66. if payloadIn(cSess, pl) {
  67. return
  68. }
  69. }
  70. }
  71. }
  72. func dtlsWrite(conn net.Conn, dSess *sessdata.DtlsSession, cSess *sessdata.ConnSession) {
  73. defer func() {
  74. base.Debug("dtlsWrite return", cSess.IpAddr)
  75. _ = conn.Close()
  76. dSess.Close()
  77. }()
  78. var (
  79. pl *sessdata.Payload
  80. )
  81. for {
  82. // dtls优先推送数据
  83. select {
  84. case pl = <-cSess.PayloadOutDtls:
  85. case <-dSess.CloseChan:
  86. return
  87. }
  88. if pl.LType != sessdata.LTypeIPData {
  89. continue
  90. }
  91. // header = []byte{payload.PType}
  92. if pl.PType == 0x00 { // data
  93. // 获取数据长度
  94. l := len(pl.Data)
  95. // 先扩容 +1
  96. pl.Data = pl.Data[:l+1]
  97. // 数据后移
  98. copy(pl.Data[1:], pl.Data)
  99. // 添加头信息
  100. pl.Data[0] = pl.PType
  101. } else {
  102. // 设置头类型
  103. pl.Data = append(pl.Data[:0], pl.PType)
  104. }
  105. n, err := conn.Write(pl.Data)
  106. if err != nil {
  107. base.Error("write err", err)
  108. return
  109. }
  110. putPayload(pl)
  111. // 限流设置
  112. err = cSess.RateLimit(n, false)
  113. if err != nil {
  114. base.Error(err)
  115. }
  116. }
  117. }