validator.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. package vmess
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha256"
  5. "hash/crc64"
  6. "strings"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/xtls/xray-core/common"
  11. "github.com/xtls/xray-core/common/dice"
  12. "github.com/xtls/xray-core/common/protocol"
  13. "github.com/xtls/xray-core/common/serial"
  14. "github.com/xtls/xray-core/common/task"
  15. "github.com/xtls/xray-core/proxy/vmess/aead"
  16. )
  17. const (
  18. updateInterval = 10 * time.Second
  19. cacheDurationSec = 120
  20. )
  21. type user struct {
  22. user protocol.MemoryUser
  23. lastSec protocol.Timestamp
  24. }
  25. // TimedUserValidator is a user Validator based on time.
  26. type TimedUserValidator struct {
  27. sync.RWMutex
  28. users []*user
  29. userHash map[[16]byte]indexTimePair
  30. hasher protocol.IDHash
  31. baseTime protocol.Timestamp
  32. task *task.Periodic
  33. behaviorSeed uint64
  34. behaviorFused bool
  35. aeadDecoderHolder *aead.AuthIDDecoderHolder
  36. }
  37. type indexTimePair struct {
  38. user *user
  39. timeInc uint32
  40. taintedFuse *uint32
  41. }
  42. // NewTimedUserValidator creates a new TimedUserValidator.
  43. func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator {
  44. tuv := &TimedUserValidator{
  45. users: make([]*user, 0, 16),
  46. userHash: make(map[[16]byte]indexTimePair, 1024),
  47. hasher: hasher,
  48. baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
  49. aeadDecoderHolder: aead.NewAuthIDDecoderHolder(),
  50. }
  51. tuv.task = &task.Periodic{
  52. Interval: updateInterval,
  53. Execute: func() error {
  54. tuv.updateUserHash()
  55. return nil
  56. },
  57. }
  58. common.Must(tuv.task.Start())
  59. return tuv
  60. }
  61. func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) {
  62. var hashValue [16]byte
  63. genEndSec := nowSec + cacheDurationSec
  64. genHashForID := func(id *protocol.ID) {
  65. idHash := v.hasher(id.Bytes())
  66. genBeginSec := user.lastSec
  67. if genBeginSec < nowSec-cacheDurationSec {
  68. genBeginSec = nowSec - cacheDurationSec
  69. }
  70. for ts := genBeginSec; ts <= genEndSec; ts++ {
  71. common.Must2(serial.WriteUint64(idHash, uint64(ts)))
  72. idHash.Sum(hashValue[:0])
  73. idHash.Reset()
  74. v.userHash[hashValue] = indexTimePair{
  75. user: user,
  76. timeInc: uint32(ts - v.baseTime),
  77. taintedFuse: new(uint32),
  78. }
  79. }
  80. }
  81. account := user.user.Account.(*MemoryAccount)
  82. genHashForID(account.ID)
  83. for _, id := range account.AlterIDs {
  84. genHashForID(id)
  85. }
  86. user.lastSec = genEndSec
  87. }
  88. func (v *TimedUserValidator) removeExpiredHashes(expire uint32) {
  89. for key, pair := range v.userHash {
  90. if pair.timeInc < expire {
  91. delete(v.userHash, key)
  92. }
  93. }
  94. }
  95. func (v *TimedUserValidator) updateUserHash() {
  96. now := time.Now()
  97. nowSec := protocol.Timestamp(now.Unix())
  98. v.Lock()
  99. defer v.Unlock()
  100. for _, user := range v.users {
  101. v.generateNewHashes(nowSec, user)
  102. }
  103. expire := protocol.Timestamp(now.Unix() - cacheDurationSec)
  104. if expire > v.baseTime {
  105. v.removeExpiredHashes(uint32(expire - v.baseTime))
  106. }
  107. }
  108. func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
  109. v.Lock()
  110. defer v.Unlock()
  111. nowSec := time.Now().Unix()
  112. uu := &user{
  113. user: *u,
  114. lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
  115. }
  116. v.users = append(v.users, uu)
  117. v.generateNewHashes(protocol.Timestamp(nowSec), uu)
  118. account := uu.user.Account.(*MemoryAccount)
  119. if !v.behaviorFused {
  120. hashkdf := hmac.New(sha256.New, []byte("VMESSBSKDF"))
  121. hashkdf.Write(account.ID.Bytes())
  122. v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil))
  123. }
  124. var cmdkeyfl [16]byte
  125. copy(cmdkeyfl[:], account.ID.CmdKey())
  126. v.aeadDecoderHolder.AddUser(cmdkeyfl, u)
  127. return nil
  128. }
  129. func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) {
  130. v.RLock()
  131. defer v.RUnlock()
  132. v.behaviorFused = true
  133. var fixedSizeHash [16]byte
  134. copy(fixedSizeHash[:], userHash)
  135. pair, found := v.userHash[fixedSizeHash]
  136. if found {
  137. user := pair.user.user
  138. if atomic.LoadUint32(pair.taintedFuse) == 0 {
  139. return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil
  140. }
  141. return nil, 0, false, ErrTainted
  142. }
  143. return nil, 0, false, ErrNotFound
  144. }
  145. func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool, error) {
  146. v.RLock()
  147. defer v.RUnlock()
  148. var userHashFL [16]byte
  149. copy(userHashFL[:], userHash)
  150. userd, err := v.aeadDecoderHolder.Match(userHashFL)
  151. if err != nil {
  152. return nil, false, err
  153. }
  154. return userd.(*protocol.MemoryUser), true, err
  155. }
  156. func (v *TimedUserValidator) Remove(email string) bool {
  157. v.Lock()
  158. defer v.Unlock()
  159. email = strings.ToLower(email)
  160. idx := -1
  161. for i, u := range v.users {
  162. if strings.EqualFold(u.user.Email, email) {
  163. idx = i
  164. var cmdkeyfl [16]byte
  165. copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey())
  166. v.aeadDecoderHolder.RemoveUser(cmdkeyfl)
  167. break
  168. }
  169. }
  170. if idx == -1 {
  171. return false
  172. }
  173. ulen := len(v.users)
  174. v.users[idx] = v.users[ulen-1]
  175. v.users[ulen-1] = nil
  176. v.users = v.users[:ulen-1]
  177. return true
  178. }
  179. // Close implements common.Closable.
  180. func (v *TimedUserValidator) Close() error {
  181. return v.task.Close()
  182. }
  183. func (v *TimedUserValidator) GetBehaviorSeed() uint64 {
  184. v.Lock()
  185. defer v.Unlock()
  186. v.behaviorFused = true
  187. if v.behaviorSeed == 0 {
  188. v.behaviorSeed = dice.RollUint64()
  189. }
  190. return v.behaviorSeed
  191. }
  192. func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error {
  193. v.RLock()
  194. defer v.RUnlock()
  195. var userHashFL [16]byte
  196. copy(userHashFL[:], userHash)
  197. pair, found := v.userHash[userHashFL]
  198. if found {
  199. if atomic.CompareAndSwapUint32(pair.taintedFuse, 0, 1) {
  200. return nil
  201. }
  202. return ErrTainted
  203. }
  204. return ErrNotFound
  205. }
  206. var ErrNotFound = newError("Not Found")
  207. var ErrTainted = newError("ErrTainted")