validator.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. package vmess
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha256"
  5. "hash/crc64"
  6. "strings"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/xtls/xray-core/common"
  11. "github.com/xtls/xray-core/common/dice"
  12. "github.com/xtls/xray-core/common/protocol"
  13. "github.com/xtls/xray-core/common/serial"
  14. "github.com/xtls/xray-core/common/task"
  15. "github.com/xtls/xray-core/proxy/vmess/aead"
  16. )
  17. const (
  18. updateInterval = 10 * time.Second
  19. cacheDurationSec = 120
  20. )
  21. type user struct {
  22. user protocol.MemoryUser
  23. lastSec protocol.Timestamp
  24. }
  25. // TimedUserValidator is a user Validator based on time.
  26. type TimedUserValidator struct {
  27. sync.RWMutex
  28. users []*user
  29. userHash map[[16]byte]indexTimePair
  30. hasher protocol.IDHash
  31. baseTime protocol.Timestamp
  32. task *task.Periodic
  33. behaviorSeed uint64
  34. behaviorFused bool
  35. aeadDecoderHolder *aead.AuthIDDecoderHolder
  36. legacyWarnShown bool
  37. }
  38. type indexTimePair struct {
  39. user *user
  40. timeInc uint32
  41. taintedFuse *uint32
  42. }
  43. // NewTimedUserValidator creates a new TimedUserValidator.
  44. func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator {
  45. tuv := &TimedUserValidator{
  46. users: make([]*user, 0, 16),
  47. userHash: make(map[[16]byte]indexTimePair, 1024),
  48. hasher: hasher,
  49. baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
  50. aeadDecoderHolder: aead.NewAuthIDDecoderHolder(),
  51. }
  52. tuv.task = &task.Periodic{
  53. Interval: updateInterval,
  54. Execute: func() error {
  55. tuv.updateUserHash()
  56. return nil
  57. },
  58. }
  59. common.Must(tuv.task.Start())
  60. return tuv
  61. }
  62. // visible for testing
  63. func (v *TimedUserValidator) GetBaseTime() protocol.Timestamp {
  64. return v.baseTime
  65. }
  66. func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) {
  67. var hashValue [16]byte
  68. genEndSec := nowSec + cacheDurationSec
  69. genHashForID := func(id *protocol.ID) {
  70. idHash := v.hasher(id.Bytes())
  71. genBeginSec := user.lastSec
  72. if genBeginSec < nowSec-cacheDurationSec {
  73. genBeginSec = nowSec - cacheDurationSec
  74. }
  75. for ts := genBeginSec; ts <= genEndSec; ts++ {
  76. common.Must2(serial.WriteUint64(idHash, uint64(ts)))
  77. idHash.Sum(hashValue[:0])
  78. idHash.Reset()
  79. v.userHash[hashValue] = indexTimePair{
  80. user: user,
  81. timeInc: uint32(ts - v.baseTime),
  82. taintedFuse: new(uint32),
  83. }
  84. }
  85. }
  86. account := user.user.Account.(*MemoryAccount)
  87. genHashForID(account.ID)
  88. for _, id := range account.AlterIDs {
  89. genHashForID(id)
  90. }
  91. user.lastSec = genEndSec
  92. }
  93. func (v *TimedUserValidator) removeExpiredHashes(expire uint32) {
  94. for key, pair := range v.userHash {
  95. if pair.timeInc < expire {
  96. delete(v.userHash, key)
  97. }
  98. }
  99. }
  100. func (v *TimedUserValidator) updateUserHash() {
  101. now := time.Now()
  102. nowSec := protocol.Timestamp(now.Unix())
  103. v.Lock()
  104. defer v.Unlock()
  105. for _, user := range v.users {
  106. v.generateNewHashes(nowSec, user)
  107. }
  108. expire := protocol.Timestamp(now.Unix() - cacheDurationSec)
  109. if expire > v.baseTime {
  110. v.removeExpiredHashes(uint32(expire - v.baseTime))
  111. }
  112. }
  113. func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
  114. v.Lock()
  115. defer v.Unlock()
  116. nowSec := time.Now().Unix()
  117. uu := &user{
  118. user: *u,
  119. lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
  120. }
  121. v.users = append(v.users, uu)
  122. v.generateNewHashes(protocol.Timestamp(nowSec), uu)
  123. account := uu.user.Account.(*MemoryAccount)
  124. if !v.behaviorFused {
  125. hashkdf := hmac.New(sha256.New, []byte("VMESSBSKDF"))
  126. hashkdf.Write(account.ID.Bytes())
  127. v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil))
  128. }
  129. var cmdkeyfl [16]byte
  130. copy(cmdkeyfl[:], account.ID.CmdKey())
  131. v.aeadDecoderHolder.AddUser(cmdkeyfl, u)
  132. return nil
  133. }
  134. func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) {
  135. v.RLock()
  136. defer v.RUnlock()
  137. v.behaviorFused = true
  138. var fixedSizeHash [16]byte
  139. copy(fixedSizeHash[:], userHash)
  140. pair, found := v.userHash[fixedSizeHash]
  141. if found {
  142. user := pair.user.user
  143. if atomic.LoadUint32(pair.taintedFuse) == 0 {
  144. return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil
  145. }
  146. return nil, 0, false, ErrTainted
  147. }
  148. return nil, 0, false, ErrNotFound
  149. }
  150. func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool, error) {
  151. v.RLock()
  152. defer v.RUnlock()
  153. var userHashFL [16]byte
  154. copy(userHashFL[:], userHash)
  155. userd, err := v.aeadDecoderHolder.Match(userHashFL)
  156. if err != nil {
  157. return nil, false, err
  158. }
  159. return userd.(*protocol.MemoryUser), true, err
  160. }
  161. func (v *TimedUserValidator) Remove(email string) bool {
  162. v.Lock()
  163. defer v.Unlock()
  164. email = strings.ToLower(email)
  165. idx := -1
  166. for i, u := range v.users {
  167. if strings.EqualFold(u.user.Email, email) {
  168. idx = i
  169. var cmdkeyfl [16]byte
  170. copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey())
  171. v.aeadDecoderHolder.RemoveUser(cmdkeyfl)
  172. break
  173. }
  174. }
  175. if idx == -1 {
  176. return false
  177. }
  178. ulen := len(v.users)
  179. v.users[idx] = v.users[ulen-1]
  180. v.users[ulen-1] = nil
  181. v.users = v.users[:ulen-1]
  182. return true
  183. }
  184. // Close implements common.Closable.
  185. func (v *TimedUserValidator) Close() error {
  186. return v.task.Close()
  187. }
  188. func (v *TimedUserValidator) GetBehaviorSeed() uint64 {
  189. v.Lock()
  190. defer v.Unlock()
  191. v.behaviorFused = true
  192. if v.behaviorSeed == 0 {
  193. v.behaviorSeed = dice.RollUint64()
  194. }
  195. return v.behaviorSeed
  196. }
  197. func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error {
  198. v.RLock()
  199. defer v.RUnlock()
  200. var userHashFL [16]byte
  201. copy(userHashFL[:], userHash)
  202. pair, found := v.userHash[userHashFL]
  203. if found {
  204. if atomic.CompareAndSwapUint32(pair.taintedFuse, 0, 1) {
  205. return nil
  206. }
  207. return ErrTainted
  208. }
  209. return ErrNotFound
  210. }
  211. /* ShouldShowLegacyWarn will return whether a Legacy Warning should be shown
  212. Not guaranteed to only return true once for every inbound, but it is okay.
  213. */
  214. func (v *TimedUserValidator) ShouldShowLegacyWarn() bool {
  215. if v.legacyWarnShown {
  216. return false
  217. }
  218. v.legacyWarnShown = true
  219. return true
  220. }
  221. var ErrNotFound = newError("Not Found")
  222. var ErrTainted = newError("ErrTainted")