auth.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package mtproto
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "crypto/sha256"
  6. "io"
  7. "sync"
  8. "github.com/xtls/xray-core/common"
  9. )
  10. const (
  11. HeaderSize = 64
  12. )
  13. type SessionContext struct {
  14. ConnectionType [4]byte
  15. DataCenterID uint16
  16. }
  17. func DefaultSessionContext() SessionContext {
  18. return SessionContext{
  19. ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef},
  20. DataCenterID: 0,
  21. }
  22. }
  23. type contextKey int32
  24. const (
  25. sessionContextKey contextKey = iota
  26. )
  27. func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context {
  28. return context.WithValue(ctx, sessionContextKey, c)
  29. }
  30. func SessionContextFromContext(ctx context.Context) SessionContext {
  31. if c := ctx.Value(sessionContextKey); c != nil {
  32. return c.(SessionContext)
  33. }
  34. return DefaultSessionContext()
  35. }
  36. type Authentication struct {
  37. Header [HeaderSize]byte
  38. DecodingKey [32]byte
  39. EncodingKey [32]byte
  40. DecodingNonce [16]byte
  41. EncodingNonce [16]byte
  42. }
  43. func (a *Authentication) DataCenterID() uint16 {
  44. x := ((int16(a.Header[61]) << 8) | int16(a.Header[60]))
  45. if x < 0 {
  46. x = -x
  47. }
  48. return uint16(x) - 1
  49. }
  50. func (a *Authentication) ConnectionType() [4]byte {
  51. var x [4]byte
  52. copy(x[:], a.Header[56:60])
  53. return x
  54. }
  55. func (a *Authentication) ApplySecret(b []byte) {
  56. a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...))
  57. a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...))
  58. }
  59. func generateRandomBytes(random []byte, connType [4]byte) {
  60. for {
  61. common.Must2(rand.Read(random))
  62. if random[0] == 0xef {
  63. continue
  64. }
  65. val := (uint32(random[3]) << 24) | (uint32(random[2]) << 16) | (uint32(random[1]) << 8) | uint32(random[0])
  66. if val == 0x44414548 || val == 0x54534f50 || val == 0x20544547 || val == 0x4954504f || val == 0xeeeeeeee {
  67. continue
  68. }
  69. if (uint32(random[7])<<24)|(uint32(random[6])<<16)|(uint32(random[5])<<8)|uint32(random[4]) == 0x00000000 {
  70. continue
  71. }
  72. copy(random[56:60], connType[:])
  73. return
  74. }
  75. }
  76. func NewAuthentication(sc SessionContext) *Authentication {
  77. auth := getAuthenticationObject()
  78. random := auth.Header[:]
  79. generateRandomBytes(random, sc.ConnectionType)
  80. copy(auth.EncodingKey[:], random[8:])
  81. copy(auth.EncodingNonce[:], random[8+32:])
  82. keyivInverse := Inverse(random[8 : 8+32+16])
  83. copy(auth.DecodingKey[:], keyivInverse)
  84. copy(auth.DecodingNonce[:], keyivInverse[32:])
  85. return auth
  86. }
  87. func ReadAuthentication(reader io.Reader) (*Authentication, error) {
  88. auth := getAuthenticationObject()
  89. if _, err := io.ReadFull(reader, auth.Header[:]); err != nil {
  90. putAuthenticationObject(auth)
  91. return nil, err
  92. }
  93. copy(auth.DecodingKey[:], auth.Header[8:])
  94. copy(auth.DecodingNonce[:], auth.Header[8+32:])
  95. keyivInverse := Inverse(auth.Header[8 : 8+32+16])
  96. copy(auth.EncodingKey[:], keyivInverse)
  97. copy(auth.EncodingNonce[:], keyivInverse[32:])
  98. return auth, nil
  99. }
  100. // Inverse returns a new byte array. It is a sequence of bytes when the input is read from end to beginning.Inverse
  101. // Visible for testing only.
  102. func Inverse(b []byte) []byte {
  103. lenb := len(b)
  104. b2 := make([]byte, lenb)
  105. for i, v := range b {
  106. b2[lenb-i-1] = v
  107. }
  108. return b2
  109. }
  110. var authPool = sync.Pool{
  111. New: func() interface{} {
  112. return new(Authentication)
  113. },
  114. }
  115. func getAuthenticationObject() *Authentication {
  116. return authPool.Get().(*Authentication)
  117. }
  118. func putAuthenticationObject(auth *Authentication) {
  119. authPool.Put(auth)
  120. }