123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- package mtproto
- import (
- "context"
- "crypto/rand"
- "crypto/sha256"
- "io"
- "sync"
- "github.com/xtls/xray-core/common"
- )
- const (
- HeaderSize = 64
- )
- type SessionContext struct {
- ConnectionType [4]byte
- DataCenterID uint16
- }
- func DefaultSessionContext() SessionContext {
- return SessionContext{
- ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef},
- DataCenterID: 0,
- }
- }
- type contextKey int32
- const (
- sessionContextKey contextKey = iota
- )
- func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context {
- return context.WithValue(ctx, sessionContextKey, c)
- }
- func SessionContextFromContext(ctx context.Context) SessionContext {
- if c := ctx.Value(sessionContextKey); c != nil {
- return c.(SessionContext)
- }
- return DefaultSessionContext()
- }
- type Authentication struct {
- Header [HeaderSize]byte
- DecodingKey [32]byte
- EncodingKey [32]byte
- DecodingNonce [16]byte
- EncodingNonce [16]byte
- }
- func (a *Authentication) DataCenterID() uint16 {
- x := ((int16(a.Header[61]) << 8) | int16(a.Header[60]))
- if x < 0 {
- x = -x
- }
- return uint16(x) - 1
- }
- func (a *Authentication) ConnectionType() [4]byte {
- var x [4]byte
- copy(x[:], a.Header[56:60])
- return x
- }
- func (a *Authentication) ApplySecret(b []byte) {
- a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...))
- a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...))
- }
- func generateRandomBytes(random []byte, connType [4]byte) {
- for {
- common.Must2(rand.Read(random))
- if random[0] == 0xef {
- continue
- }
- val := (uint32(random[3]) << 24) | (uint32(random[2]) << 16) | (uint32(random[1]) << 8) | uint32(random[0])
- if val == 0x44414548 || val == 0x54534f50 || val == 0x20544547 || val == 0x4954504f || val == 0xeeeeeeee {
- continue
- }
- if (uint32(random[7])<<24)|(uint32(random[6])<<16)|(uint32(random[5])<<8)|uint32(random[4]) == 0x00000000 {
- continue
- }
- copy(random[56:60], connType[:])
- return
- }
- }
- func NewAuthentication(sc SessionContext) *Authentication {
- auth := getAuthenticationObject()
- random := auth.Header[:]
- generateRandomBytes(random, sc.ConnectionType)
- copy(auth.EncodingKey[:], random[8:])
- copy(auth.EncodingNonce[:], random[8+32:])
- keyivInverse := Inverse(random[8 : 8+32+16])
- copy(auth.DecodingKey[:], keyivInverse)
- copy(auth.DecodingNonce[:], keyivInverse[32:])
- return auth
- }
- func ReadAuthentication(reader io.Reader) (*Authentication, error) {
- auth := getAuthenticationObject()
- if _, err := io.ReadFull(reader, auth.Header[:]); err != nil {
- putAuthenticationObject(auth)
- return nil, err
- }
- copy(auth.DecodingKey[:], auth.Header[8:])
- copy(auth.DecodingNonce[:], auth.Header[8+32:])
- keyivInverse := Inverse(auth.Header[8 : 8+32+16])
- copy(auth.EncodingKey[:], keyivInverse)
- copy(auth.EncodingNonce[:], keyivInverse[32:])
- return auth, nil
- }
- // Inverse returns a new byte array. It is a sequence of bytes when the input is read from end to beginning.Inverse
- // Visible for testing only.
- func Inverse(b []byte) []byte {
- lenb := len(b)
- b2 := make([]byte, lenb)
- for i, v := range b {
- b2[lenb-i-1] = v
- }
- return b2
- }
- var authPool = sync.Pool{
- New: func() interface{} {
- return new(Authentication)
- },
- }
- func getAuthenticationObject() *Authentication {
- return authPool.Get().(*Authentication)
- }
- func putAuthenticationObject(auth *Authentication) {
- authPool.Put(auth)
- }
|