123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- package xudp
- import (
- "context"
- "crypto/rand"
- "encoding/base64"
- "fmt"
- "io"
- "os"
- "strings"
- "github.com/xtls/xray-core/common/buf"
- "github.com/xtls/xray-core/common/net"
- "github.com/xtls/xray-core/common/protocol"
- "github.com/xtls/xray-core/common/session"
- "lukechampine.com/blake3"
- )
- var AddrParser = protocol.NewAddressParser(
- protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
- protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
- protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
- protocol.PortThenAddress(),
- )
- var (
- Show bool
- BaseKey []byte
- )
- const (
- EnvShow = "XRAY_XUDP_SHOW"
- EnvBaseKey = "XRAY_XUDP_BASEKEY"
- )
- func init() {
- if strings.ToLower(os.Getenv(EnvShow)) == "true" {
- Show = true
- }
- if raw, found := os.LookupEnv(EnvBaseKey); found {
- if BaseKey, _ = base64.RawURLEncoding.DecodeString(raw); len(BaseKey) == 32 {
- return
- }
- panic(EnvBaseKey + ": invalid value: " + raw)
- }
- rand.Read(BaseKey)
- }
- func GetGlobalID(ctx context.Context) (globalID [8]byte) {
- if cone := ctx.Value("cone"); cone == nil || !cone.(bool) { // cone is nil only in some unit tests
- return
- }
- if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
- (inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") {
- h := blake3.New(8, BaseKey)
- h.Write([]byte(inbound.Source.String()))
- copy(globalID[:], h.Sum(nil))
- if Show {
- fmt.Printf("XUDP inbound.Source.String(): %v\tglobalID: %v\n", inbound.Source.String(), globalID)
- }
- }
- return
- }
- func NewPacketWriter(writer buf.Writer, dest net.Destination, globalID [8]byte) *PacketWriter {
- return &PacketWriter{
- Writer: writer,
- Dest: dest,
- GlobalID: globalID,
- }
- }
- type PacketWriter struct {
- Writer buf.Writer
- Dest net.Destination
- GlobalID [8]byte
- }
- func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
- defer buf.ReleaseMulti(mb)
- mb2Write := make(buf.MultiBuffer, 0, len(mb))
- for _, b := range mb {
- length := b.Len()
- if length == 0 || length+666 > buf.Size {
- continue
- }
- eb := buf.New()
- eb.Write([]byte{0, 0, 0, 0}) // Meta data length; Mux Session ID
- if w.Dest.Network == net.Network_UDP {
- eb.WriteByte(1) // New
- eb.WriteByte(1) // Opt
- eb.WriteByte(2) // UDP
- AddrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port)
- if b.UDP != nil { // make sure it's user's proxy request
- eb.Write(w.GlobalID[:]) // no need to check whether it's empty
- }
- w.Dest.Network = net.Network_Unknown
- } else {
- eb.WriteByte(2) // Keep
- eb.WriteByte(1) // Opt
- if b.UDP != nil {
- eb.WriteByte(2) // UDP
- AddrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port)
- }
- }
- l := eb.Len() - 2
- eb.SetByte(0, byte(l>>8))
- eb.SetByte(1, byte(l))
- eb.WriteByte(byte(length >> 8))
- eb.WriteByte(byte(length))
- eb.Write(b.Bytes())
- mb2Write = append(mb2Write, eb)
- }
- if mb2Write.IsEmpty() {
- return nil
- }
- return w.Writer.WriteMultiBuffer(mb2Write)
- }
- func NewPacketReader(reader io.Reader) *PacketReader {
- return &PacketReader{
- Reader: reader,
- cache: make([]byte, 2),
- }
- }
- type PacketReader struct {
- Reader io.Reader
- cache []byte
- }
- func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
- for {
- if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
- return nil, err
- }
- l := int32(r.cache[0])<<8 | int32(r.cache[1])
- if l < 4 {
- return nil, io.EOF
- }
- b := buf.New()
- if _, err := b.ReadFullFrom(r.Reader, l); err != nil {
- b.Release()
- return nil, err
- }
- discard := false
- switch b.Byte(2) {
- case 2:
- if l > 4 && b.Byte(4) == 2 { // MUST check the flag first
- b.Advance(5)
- // b.Clear() will be called automatically if all data had been read.
- addr, port, err := AddrParser.ReadAddressPort(nil, b)
- if err != nil {
- b.Release()
- return nil, err
- }
- b.UDP = &net.Destination{
- Network: net.Network_UDP,
- Address: addr,
- Port: port,
- }
- }
- case 4:
- discard = true
- default:
- b.Release()
- return nil, io.EOF
- }
- b.Clear() // in case there is padding (empty bytes) attached
- if b.Byte(3) == 1 {
- if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
- b.Release()
- return nil, err
- }
- length := int32(r.cache[0])<<8 | int32(r.cache[1])
- if length > 0 {
- if _, err := b.ReadFullFrom(r.Reader, length); err != nil {
- b.Release()
- return nil, err
- }
- if !discard {
- return buf.MultiBuffer{b}, nil
- }
- }
- }
- b.Release()
- }
- }
|