session.go 8.3 KB


  1. package sessdata
  2. import (
  3. "fmt"
  4. "math/rand"
  5. "net"
  6. "strconv"
  7. "strings"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/bjdgyc/anylink/base"
  12. "github.com/bjdgyc/anylink/dbdata"
  13. "github.com/bjdgyc/anylink/pkg/utils"
  14. "github.com/ivpusic/grpool"
  15. atomic2 "go.uber.org/atomic"
  16. )
  17. var (
  18. // session_token -> SessUser
  19. sessions = make(map[string]*Session)
  20. // dtlsId -> session_token
  21. dtlsIds = make(map[string]string)
  22. sessMux sync.RWMutex
  23. )
  24. // 连接sess
  25. type ConnSession struct {
  26. Sess *Session
  27. MasterSecret string // dtls协议的 master_secret
  28. IpAddr net.IP // 分配的ip地址
  29. LocalIp net.IP
  30. MacHw net.HardwareAddr // 客户端mac地址,从Session取出
  31. Username string
  32. RemoteAddr string
  33. Mtu int
  34. IfName string
  35. Client string // 客户端 mobile pc
  36. CstpDpd int
  37. Group *dbdata.Group
  38. Limit *LimitRater
  39. BandwidthUp atomic2.Uint32 // 使用上行带宽 Byte
  40. BandwidthDown atomic2.Uint32 // 使用下行带宽 Byte
  41. BandwidthUpPeriod atomic2.Uint32 // 前一周期的总量
  42. BandwidthDownPeriod atomic2.Uint32
  43. BandwidthUpAll atomic2.Uint64 // 使用上行带宽总量
  44. BandwidthDownAll atomic2.Uint64 // 使用下行带宽总量
  45. closeOnce sync.Once
  46. CloseChan chan struct{}
  47. PayloadIn chan *Payload
  48. PayloadOutCstp chan *Payload // Cstp的数据
  49. PayloadOutDtls chan *Payload // Dtls的数据
  50. IpAuditMap utils.IMaps // 审计的ip数据
  51. IpAuditPool *grpool.Pool // 审计的IP包解析池
  52. // dSess *DtlsSession
  53. dSess *atomic.Value
  54. }
  55. type DtlsSession struct {
  56. isActive int32
  57. CloseChan chan struct{}
  58. closeOnce sync.Once
  59. IpAddr net.IP
  60. }
  61. type Session struct {
  62. mux sync.RWMutex
  63. Sid string // auth返回的 session-id
  64. Token string // session信息的唯一token
  65. DtlsSid string // dtls协议的 session_id
  66. MacAddr string // 客户端mac地址
  67. UniqueIdGlobal string // 客户端唯一标示
  68. MacHw net.HardwareAddr
  69. Username string // 用户名
  70. Group string
  71. AuthStep string
  72. AuthPass string
  73. LastLogin time.Time
  74. IsActive bool
  75. // 开启link需要设置的参数
  76. CSess *ConnSession
  77. }
  78. func init() {
  79. rand.Seed(time.Now().UnixNano())
  80. }
  81. func checkSession() {
  82. // 检测过期的session
  83. go func() {
  84. if base.Cfg.SessionTimeout == 0 {
  85. return
  86. }
  87. timeout := time.Duration(base.Cfg.SessionTimeout) * time.Second
  88. tick := time.NewTicker(time.Second * 60)
  89. for range tick.C {
  90. sessMux.Lock()
  91. t := time.Now()
  92. for k, v := range sessions {
  93. v.mux.Lock()
  94. if !v.IsActive {
  95. if t.Sub(v.LastLogin) > timeout {
  96. delete(sessions, k)
  97. }
  98. }
  99. v.mux.Unlock()
  100. }
  101. sessMux.Unlock()
  102. }
  103. }()
  104. }
  105. func GenToken() string {
  106. // 生成32位的 token
  107. bToken := make([]byte, 32)
  108. rand.Read(bToken)
  109. return fmt.Sprintf("%x", bToken)
  110. }
  111. func NewSession(token string) *Session {
  112. if token == "" {
  113. btoken := make([]byte, 32)
  114. rand.Read(btoken)
  115. token = fmt.Sprintf("%x", btoken)
  116. }
  117. // 生成 dtlsn session_id
  118. dtlsid := make([]byte, 32)
  119. rand.Read(dtlsid)
  120. sess := &Session{
  121. Sid: fmt.Sprintf("%d", time.Now().Unix()),
  122. Token: token,
  123. DtlsSid: fmt.Sprintf("%x", dtlsid),
  124. LastLogin: time.Now(),
  125. }
  126. sessMux.Lock()
  127. sessions[token] = sess
  128. dtlsIds[sess.DtlsSid] = token
  129. sessMux.Unlock()
  130. return sess
  131. }
  132. func (s *Session) NewConn() *ConnSession {
  133. s.mux.RLock()
  134. active := s.IsActive
  135. macAddr := s.MacAddr
  136. macHw := s.MacHw
  137. username := s.Username
  138. s.mux.RUnlock()
  139. if active {
  140. s.CSess.Close()
  141. }
  142. limit := LimitClient(username, false)
  143. if !limit {
  144. return nil
  145. }
  146. ip := AcquireIp(username, macAddr)
  147. if ip == nil {
  148. LimitClient(username, true)
  149. return nil
  150. }
  151. // 查询group信息
  152. group := &dbdata.Group{}
  153. err := dbdata.One("Name", s.Group, group)
  154. if err != nil {
  155. base.Error(err)
  156. return nil
  157. }
  158. cSess := &ConnSession{
  159. Sess: s,
  160. MacHw: macHw,
  161. Username: username,
  162. IpAddr: ip,
  163. closeOnce: sync.Once{},
  164. CloseChan: make(chan struct{}),
  165. PayloadIn: make(chan *Payload, 64),
  166. PayloadOutCstp: make(chan *Payload, 64),
  167. PayloadOutDtls: make(chan *Payload, 64),
  168. dSess: &atomic.Value{},
  169. }
  170. // ip 审计
  171. if base.Cfg.AuditInterval >= 0 {
  172. cSess.IpAuditMap = utils.NewMap("cmap", 0)
  173. cSess.IpAuditPool = grpool.NewPool(1, 600)
  174. }
  175. dSess := &DtlsSession{
  176. isActive: -1,
  177. }
  178. cSess.dSess.Store(dSess)
  179. cSess.Group = group
  180. if group.Bandwidth > 0 {
  181. // 限流设置
  182. cSess.Limit = NewLimitRater(group.Bandwidth, group.Bandwidth)
  183. }
  184. go cSess.ratePeriod()
  185. s.mux.Lock()
  186. s.MacAddr = macAddr
  187. s.IsActive = true
  188. s.CSess = cSess
  189. s.mux.Unlock()
  190. return cSess
  191. }
  192. func (cs *ConnSession) Close() {
  193. cs.closeOnce.Do(func() {
  194. base.Info("closeOnce:", cs.IpAddr)
  195. cs.Sess.mux.Lock()
  196. defer cs.Sess.mux.Unlock()
  197. close(cs.CloseChan)
  198. cs.Sess.IsActive = false
  199. cs.Sess.LastLogin = time.Now()
  200. cs.Sess.CSess = nil
  201. dSess := cs.GetDtlsSession()
  202. if dSess != nil {
  203. dSess.Close()
  204. }
  205. ReleaseIp(cs.IpAddr, cs.Sess.MacAddr)
  206. LimitClient(cs.Username, true)
  207. })
  208. }
  209. // 创建dtls链接
  210. func (cs *ConnSession) NewDtlsConn() *DtlsSession {
  211. ds := cs.dSess.Load().(*DtlsSession)
  212. isActive := atomic.LoadInt32(&ds.isActive)
  213. if isActive > 0 {
  214. // 判断原有连接存在,不进行创建
  215. return nil
  216. }
  217. dSess := &DtlsSession{
  218. isActive: 1,
  219. CloseChan: make(chan struct{}),
  220. closeOnce: sync.Once{},
  221. IpAddr: cs.IpAddr,
  222. }
  223. cs.dSess.Store(dSess)
  224. return dSess
  225. }
  226. // 关闭dtls链接
  227. func (ds *DtlsSession) Close() {
  228. ds.closeOnce.Do(func() {
  229. base.Info("closeOnce dtls:", ds.IpAddr)
  230. atomic.StoreInt32(&ds.isActive, -1)
  231. close(ds.CloseChan)
  232. })
  233. }
  234. func (cs *ConnSession) GetDtlsSession() *DtlsSession {
  235. ds := cs.dSess.Load().(*DtlsSession)
  236. isActive := atomic.LoadInt32(&ds.isActive)
  237. if isActive > 0 {
  238. return ds
  239. }
  240. return nil
  241. }
  242. const BandwidthPeriodSec = 10 // 流量速率统计周期(秒)
  243. func (cs *ConnSession) ratePeriod() {
  244. tick := time.NewTicker(time.Second * BandwidthPeriodSec)
  245. defer tick.Stop()
  246. for range tick.C {
  247. select {
  248. case <-cs.CloseChan:
  249. return
  250. default:
  251. }
  252. // 实时流量清零
  253. rtUp := cs.BandwidthUp.Swap(0)
  254. rtDown := cs.BandwidthDown.Swap(0)
  255. // 设置上一周期每秒的流量
  256. cs.BandwidthUpPeriod.Swap(rtUp / BandwidthPeriodSec)
  257. cs.BandwidthDownPeriod.Swap(rtDown / BandwidthPeriodSec)
  258. // 累加所有流量
  259. cs.BandwidthUpAll.Add(uint64(rtUp))
  260. cs.BandwidthDownAll.Add(uint64(rtDown))
  261. }
  262. }
  263. var MaxMtu = 1460
  264. func (cs *ConnSession) SetMtu(mtu string) {
  265. if base.Cfg.Mtu > 0 {
  266. MaxMtu = base.Cfg.Mtu
  267. }
  268. cs.Mtu = MaxMtu
  269. mi, err := strconv.Atoi(mtu)
  270. if err != nil || mi < 100 {
  271. return
  272. }
  273. if mi < MaxMtu {
  274. cs.Mtu = mi
  275. }
  276. }
  277. func (cs *ConnSession) SetIfName(name string) {
  278. cs.Sess.mux.Lock()
  279. defer cs.Sess.mux.Unlock()
  280. cs.IfName = name
  281. }
  282. func (cs *ConnSession) RateLimit(byt int, isUp bool) error {
  283. if isUp {
  284. cs.BandwidthUp.Add(uint32(byt))
  285. return nil
  286. }
  287. // 只对下行速率限制
  288. cs.BandwidthDown.Add(uint32(byt))
  289. if cs.Limit == nil {
  290. return nil
  291. }
  292. return cs.Limit.Wait(byt)
  293. }
  294. func SToken2Sess(stoken string) *Session {
  295. stoken = strings.TrimSpace(stoken)
  296. sarr := strings.Split(stoken, "@")
  297. token := sarr[1]
  298. return Token2Sess(token)
  299. }
  300. func Token2Sess(token string) *Session {
  301. sessMux.RLock()
  302. defer sessMux.RUnlock()
  303. return sessions[token]
  304. }
  305. func Dtls2Sess(did string) *Session {
  306. sessMux.RLock()
  307. defer sessMux.RUnlock()
  308. token := dtlsIds[did]
  309. return sessions[token]
  310. }
  311. func Dtls2MasterSecret(did string) string {
  312. sessMux.RLock()
  313. token := dtlsIds[did]
  314. sess := sessions[token]
  315. sessMux.RUnlock()
  316. if sess == nil {
  317. return ""
  318. }
  319. sess.mux.RLock()
  320. defer sess.mux.RUnlock()
  321. if sess.CSess == nil {
  322. return ""
  323. }
  324. return sess.CSess.MasterSecret
  325. }
  326. func DelSess(token string) {
  327. // sessions.Delete(token)
  328. }
  329. func CloseSess(token string) {
  330. sessMux.Lock()
  331. defer sessMux.Unlock()
  332. sess, ok := sessions[token]
  333. if !ok {
  334. return
  335. }
  336. delete(sessions, token)
  337. sess.CSess.Close()
  338. }
  339. func CloseCSess(token string) {
  340. sessMux.RLock()
  341. defer sessMux.RUnlock()
  342. sess, ok := sessions[token]
  343. if !ok {
  344. return
  345. }
  346. sess.CSess.Close()
  347. }
  348. func DelSessByStoken(stoken string) {
  349. stoken = strings.TrimSpace(stoken)
  350. sarr := strings.Split(stoken, "@")
  351. token := sarr[1]
  352. sessMux.Lock()
  353. delete(sessions, token)
  354. sessMux.Unlock()
  355. }