浏览代码

废除同步的 STUN API ,Socks5 UDP 代理与代理接口基础实现

Student Main 5 年之前
父节点
当前提交
1f3fb6b705

+ 3 - 3
NatTypeTester/ViewModels/MainWindowViewModel.cs

@@ -145,7 +145,7 @@ namespace NatTypeTester.ViewModels
 
         private IObservable<Unit> TestClassicNatTypeImpl()
         {
-            return Observable.Start(() =>
+            return Observable.FromAsync(async () =>
             {
                 try
                 {
@@ -158,7 +158,7 @@ namespace NatTypeTester.ViewModels
                                 .Subscribe(t => ClassicNatType = $@"{t}");
                         client.PubChanged.ObserveOn(RxApp.MainThreadScheduler).Subscribe(t => PublicEnd = $@"{t}");
                         client.LocalChanged.ObserveOn(RxApp.MainThreadScheduler).Subscribe(t => LocalEnd = $@"{t}");
-                        client.Query();
+                        await client.Query3489Async();
                     }
                     else
                     {
@@ -169,7 +169,7 @@ namespace NatTypeTester.ViewModels
                 {
                     MessageBox.Show(ex.Message, nameof(NatTypeTester), MessageBoxButton.OK, MessageBoxImage.Error);
                 }
-            });
+            }).SubscribeOn(RxApp.TaskpoolScheduler);
         }
 
         private IObservable<Unit> DiscoveryNatTypeImpl()

+ 14 - 12
STUN/Client/StunClient3489.cs

@@ -1,6 +1,7 @@
 using STUN.Enums;
 using STUN.Interfaces;
 using STUN.Message;
+using STUN.Proxy;
 using STUN.StunResult;
 using STUN.Utils;
 using System;
@@ -10,6 +11,7 @@ using System.Net;
 using System.Net.Sockets;
 using System.Reactive.Linq;
 using System.Reactive.Subjects;
+using System.Threading.Tasks;
 
 namespace STUN.Client
 {
@@ -47,6 +49,8 @@ namespace STUN.Client
 
         public IPEndPoint RemoteEndPoint => Server == null ? null : new IPEndPoint(Server, Port);
 
+        protected IUdpProxy Proxy;
+
         public StunClient3489(string server, ushort port = 3478, IPEndPoint local = null, IDnsQuery dnsQuery = null)
         {
             Func<string, IPAddress> dnsQuery1;
@@ -81,7 +85,7 @@ namespace STUN.Client
             Timeout = TimeSpan.FromSeconds(1.6);
         }
 
-        public ClassicStunResult Query()
+        public async Task<ClassicStunResult> Query3489Async()
         {
             var res = new ClassicStunResult();
             _natTypeSubj.OnNext(res.NatType);
@@ -92,7 +96,7 @@ namespace STUN.Client
                 // test I
                 var test1 = new StunMessage5389 { StunMessageType = StunMessageType.BindingRequest, MagicCookie = 0 };
 
-                var (response1, remote1, local1) = Test(test1, RemoteEndPoint, RemoteEndPoint);
+                var (response1, remote1, local1) = await TestAsync(test1, RemoteEndPoint, RemoteEndPoint);
                 if (response1 == null)
                 {
                     res.NatType = NatType.UdpBlocked;
@@ -127,7 +131,7 @@ namespace STUN.Client
                 };
 
                 // test II
-                var (response2, remote2, _) = Test(test2, RemoteEndPoint, changedAddress1);
+                var (response2, remote2, _) = await TestAsync(test2, RemoteEndPoint, changedAddress1);
                 var mappedAddress2 = AttributeExtensions.GetMappedAddressAttribute(response2);
 
                 if (Equals(mappedAddress1.Address, local1) && mappedAddress1.Port == LocalEndPoint.Port)
@@ -156,7 +160,7 @@ namespace STUN.Client
 
                 // Test I(#2)
                 var test12 = new StunMessage5389 { StunMessageType = StunMessageType.BindingRequest, MagicCookie = 0 };
-                var (response12, _, _) = Test(test12, changedAddress1, changedAddress1);
+                var (response12, _, _) = await TestAsync(test12, changedAddress1, changedAddress1);
                 var mappedAddress12 = AttributeExtensions.GetMappedAddressAttribute(response12);
 
                 if (mappedAddress12 == null)
@@ -179,7 +183,7 @@ namespace STUN.Client
                     MagicCookie = 0,
                     Attributes = new[] { AttributeExtensions.BuildChangeRequest(false, true) }
                 };
-                var (response3, _, _) = Test(test3, changedAddress1, changedAddress1);
+                var (response3, _, _) = await TestAsync(test3, changedAddress1, changedAddress1);
                 var mappedAddress3 = AttributeExtensions.GetMappedAddressAttribute(response3);
                 if (mappedAddress3 != null)
                 {
@@ -198,10 +202,7 @@ namespace STUN.Client
             }
         }
 
-        /// <returns>
-        /// (StunMessage, Remote, Local)
-        /// </returns>
-        private (StunMessage5389, IPEndPoint, IPAddress) Test(StunMessage5389 sendMessage, IPEndPoint remote, IPEndPoint receive)
+        protected async Task<(StunMessage5389, IPEndPoint, IPAddress)> TestAsync(StunMessage5389 sendMessage, IPEndPoint remote, IPEndPoint receive)
         {
             try
             {
@@ -209,12 +210,13 @@ namespace STUN.Client
                 //var t = DateTime.Now;
 
                 // Simple retransmissions
-                //https://tools.ietf.org/html/rfc3489#section-9.3
-                //while (t + TimeSpan.FromSeconds(3) > DateTime.Now)
+                //https://tools.ietf.org/html/rfc5389#section-7.2.1
+                //while (t + TimeSpan.FromSeconds(6) > DateTime.Now)
                 {
                     try
                     {
-                        var (receive1, ipe, local) = UdpClient.UdpReceive(b1, remote, receive);
+                        // var (receive1, ipe, local) = await Proxy.RecieveAsync(b1, remote, receive);
+                        var (receive1, ipe, local) = await UdpClient.UdpReceiveAsync(b1, remote, receive);
 
                         var message = new StunMessage5389();
                         if (message.TryParse(receive1) &&

+ 0 - 36
STUN/Client/StunClient5389UDP.cs

@@ -259,42 +259,6 @@ namespace STUN.Client
             }
         }
 
-        private async Task<(StunMessage5389, IPEndPoint, IPAddress)> TestAsync(StunMessage5389 sendMessage, IPEndPoint remote, IPEndPoint receive)
-        {
-            try
-            {
-                var b1 = sendMessage.Bytes.ToArray();
-                //var t = DateTime.Now;
-
-                // Simple retransmissions
-                //https://tools.ietf.org/html/rfc5389#section-7.2.1
-                //while (t + TimeSpan.FromSeconds(6) > DateTime.Now)
-                {
-                    try
-                    {
-
-                        var (receive1, ipe, local) = await UdpClient.UdpReceiveAsync(b1, remote, receive);
-
-                        var message = new StunMessage5389();
-                        if (message.TryParse(receive1) &&
-                            message.ClassicTransactionId.IsEqual(sendMessage.ClassicTransactionId))
-                        {
-                            return (message, ipe, local);
-                        }
-                    }
-                    catch (Exception ex)
-                    {
-                        Debug.WriteLine(ex);
-                    }
-                }
-            }
-            catch (Exception ex)
-            {
-                Debug.WriteLine(ex);
-            }
-            return (null, null, null);
-        }
-
         public override void Dispose()
         {
             base.Dispose();

+ 8 - 0
STUN/Enums/ProxyType.cs

@@ -0,0 +1,8 @@
+namespace STUN.Enums
+{
+    public enum ProxyType
+    {
+        Plain,
+        Socks5,
+    }
+}

+ 12 - 0
STUN/Enums/TransportType.cs

@@ -0,0 +1,12 @@
+namespace STUN.Enums
+{
+    // Only UDP is supported
+
+    public enum TransportType
+    {
+        Udp,
+        Tcp,
+        Tls,
+        Dtls,
+    }
+}

+ 15 - 0
STUN/Proxy/IUdpProxy.cs

@@ -0,0 +1,15 @@
+using System;
+using System.Net;
+using System.Threading.Tasks;
+
+namespace STUN.Proxy
+{
+    public interface IUdpProxy
+    {
+        TimeSpan Timeout { get; set; }
+        IPEndPoint LocalEndPoint { get; }
+        Task ConnectAsync(IPEndPoint local, IPEndPoint remote);
+        Task<(byte[], IPEndPoint, IPAddress)> RecieveAsync(byte[] bytes, IPEndPoint remote, EndPoint receive);
+        Task DisconnectAsync();
+    }
+}

+ 59 - 0
STUN/Proxy/NoneUdpProxy.cs

@@ -0,0 +1,59 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace STUN.Proxy
+{
+    class NoneUdpProxy : IUdpProxy
+    {
+
+        public TimeSpan Timeout
+        {
+            get => TimeSpan.FromMilliseconds(UdpClient.Client.ReceiveTimeout);
+            set => UdpClient.Client.ReceiveTimeout = Convert.ToInt32(value.TotalMilliseconds);
+        }
+
+        public IPEndPoint LocalEndPoint { get => (IPEndPoint)UdpClient.Client.LocalEndPoint; }
+
+        protected UdpClient UdpClient;
+
+        public Task ConnectAsync(IPEndPoint local, IPEndPoint remote)
+        {
+            UdpClient = local == null ? new UdpClient() : new UdpClient(local);
+            return Task.CompletedTask;
+        }
+
+        public Task DisconnectAsync()
+        {
+            UdpClient.Close();
+            return Task.CompletedTask;
+        }
+
+        public async Task<(byte[], IPEndPoint, IPAddress)> RecieveAsync(byte[] bytes, IPEndPoint remote, EndPoint receive)
+        {
+            var localEndPoint = (IPEndPoint)UdpClient.Client.LocalEndPoint;
+
+            Debug.WriteLine($@"{localEndPoint} => {remote} {bytes.Length} 字节");
+
+            await UdpClient.SendAsync(bytes, bytes.Length, remote);
+
+            var res = new byte[ushort.MaxValue];
+            var flag = SocketFlags.None;
+
+            var length = UdpClient.Client.ReceiveMessageFrom(res, 0, res.Length, ref flag, ref receive, out var ipPacketInformation);
+
+            var local = ipPacketInformation.Address;
+
+            Debug.WriteLine($@"{(IPEndPoint)receive} => {local} {length} 字节");
+
+            return (res.Take(length).ToArray(),
+                    (IPEndPoint)receive
+                    , local);
+        }
+    }
+}

+ 219 - 0
STUN/Proxy/Socks5UdpProxy.cs

@@ -0,0 +1,219 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Globalization;
+using System.Net;
+using System.Net.NetworkInformation;
+using System.Net.Sockets;
+using System.Text;
+using System.Threading.Tasks;
+using STUN.Utils;
+
+namespace STUN.Proxy
+{
+    class Socks5UdpProxy : IUdpProxy
+    {
+        TcpClient assoc = new TcpClient();
+        IPEndPoint socksTcpEndPoint;
+
+        IPEndPoint assocEndPoint;
+
+        public TimeSpan Timeout
+        {
+            get => TimeSpan.FromMilliseconds(UdpClient.Client.ReceiveTimeout);
+            set => UdpClient.Client.ReceiveTimeout = Convert.ToInt32(value.TotalMilliseconds);
+        }
+
+        public IPEndPoint LocalEndPoint => throw new NotImplementedException();
+
+        UdpClient UdpClient;
+
+        string user;
+        string passwd;
+        public Socks5UdpProxy(IPEndPoint proxy)
+        {
+            socksTcpEndPoint = proxy;
+        }
+
+        public async Task ConnectAsync(IPEndPoint local, IPEndPoint remote)
+        {
+            byte[] buf = new byte[1024];
+
+            UdpClient = local == null ? new UdpClient() : new UdpClient(local);
+            await assoc.ConnectAsync(socksTcpEndPoint.Address, socksTcpEndPoint.Port);
+            try
+            {
+                var s = assoc.GetStream();
+                bool requestPasswordAuth = user != null;
+
+                #region Handshake
+                // we have no gssapi support
+                if (requestPasswordAuth)
+                {
+                    // 5 authlen auth[](0=none, 2=userpasswd)
+                    s.Write(new byte[] { 5, 2, 0, 2 }, 0, 4);
+                }
+                else
+                {
+                    s.Write(new byte[] { 5, 1, 0 }, 0, 3);
+                }
+                // 5 auth(ff=deny)
+                if (s.Read(buf, 0, 2) != 2) throw new ProtocolViolationException();
+                if (buf[0] != 5) throw new ProtocolViolationException();
+                #endregion
+
+                #region Auth
+                var auth = buf[1];
+                switch (auth)
+                {
+                    case 0:
+                        break;
+                    case 2:
+                        byte[] ubyte = Encoding.UTF8.GetBytes(user);
+                        byte[] pbyte = Encoding.UTF8.GetBytes(passwd);
+                        buf[0] = 1;
+                        buf[1] = (byte)ubyte.Length;
+                        Array.Copy(ubyte, 0, buf, 2, ubyte.Length);
+                        buf[ubyte.Length + 3] = (byte)pbyte.Length;
+                        Array.Copy(pbyte, 0, buf, ubyte.Length + 4, pbyte.Length);
+                        // 1 userlen user passlen pass
+                        s.Write(buf, 0, ubyte.Length + pbyte.Length + 4);
+                        // 1 state(0=ok)
+                        if (s.Read(buf, 0, 2) != 2) throw new ProtocolViolationException();
+                        if (buf[0] != 1) throw new ProtocolViolationException();
+                        if (buf[1] != 0) throw new UnauthorizedAccessException();
+                        break;
+                    case 0xff:
+                        throw new UnauthorizedAccessException();
+                    default:
+                        throw new ProtocolViolationException();
+                }
+                #endregion
+
+                #region UDP Assoc Send
+                buf[0] = 5;
+                buf[1] = 3;
+                buf[2] = 0;
+                int addrLen;
+                int port;
+                if (remote is IPEndPoint ir)
+                {
+                    byte[] abyte = ir.Address.GetAddressBytes();
+                    addrLen = abyte.Length;
+                    buf[3] = (byte)(abyte.Length == 4 ? 1 : 4);
+                    Array.Copy(abyte, 0, buf, 4, addrLen);
+                    port = ir.Port;
+                }
+                else
+                {
+                    throw new NotImplementedException();
+                }
+                buf[addrLen + 4] = (byte)(port / 256);
+                buf[addrLen + 5] = (byte)(port % 256);
+
+                // 5 cmd(3=udpassoc) 0 atyp(1=v4 3=dns 4=v5) addr port
+                s.Write(buf, 0, addrLen + 4);
+                #endregion
+
+                #region UDP Assoc Response
+                if (s.Read(buf, 0, 4) != 4) throw new ProtocolViolationException();
+                if (buf[0] != 5 || buf[2] != 0) throw new ProtocolViolationException();
+                if (buf[1] != 0) throw new UnauthorizedAccessException();
+
+                switch (buf[3])
+                {
+                    case 1:
+                        addrLen = 4;
+                        break;
+                    case 4:
+                        addrLen = 16;
+                        break;
+                    default:
+                        throw new NotSupportedException();
+                }
+
+                byte[] addr = new byte[addrLen];
+                if (s.Read(buf, 0, addrLen) != addrLen) throw new ProtocolViolationException();
+                IPAddress assocIP = new IPAddress(addr);
+                if (s.Read(buf, 0, 2) != 2) throw new ProtocolViolationException();
+                int assocPort = buf[0] * 256 + buf[1];
+                #endregion
+
+                assocEndPoint = new IPEndPoint(assocIP, assocPort);
+            }
+            catch (Exception e)
+            {
+                Debug.WriteLine(e);
+                assoc.Close();
+            }
+        }
+
+        public async Task<(byte[], IPEndPoint, IPAddress)> RecieveAsync(byte[] bytes, IPEndPoint remote, EndPoint receive)
+        {
+            TcpState state = assoc.GetState();
+            if (state != TcpState.Established) throw new Exception();
+
+            byte[] remoteBytes = GetEndPointByte(remote);
+            byte[] proxyBytes = new byte[bytes.Length + remoteBytes.Length + 3];
+            Array.Copy(remoteBytes, 0, proxyBytes, 3, proxyBytes.Length);
+            Array.Copy(bytes, 0, proxyBytes, remoteBytes.Length + 3, bytes.Length);
+
+            await UdpClient.SendAsync(proxyBytes, proxyBytes.Length, assocEndPoint);
+            var res = new byte[ushort.MaxValue];
+            var flag = SocketFlags.None;
+
+            var length = UdpClient.Client.ReceiveMessageFrom(res, 0, res.Length, ref flag, ref receive, out var ipPacketInformation);
+
+            if (res[0] != 0 || res[1] != 0 || res[2] != 0)
+            {
+                throw new Exception();
+            }
+
+            int addrLen;
+            switch (res[3])
+            {
+                case 1:
+                    addrLen = 4;
+                    break;
+                case 4:
+                    addrLen = 16;
+                    break;
+                default:
+                    throw new Exception();
+            }
+
+            byte[] ipbyte = new byte[addrLen];
+            Array.Copy(res, 4, ipbyte, 0, addrLen);
+
+            IPAddress ip = new IPAddress(ipbyte);
+            int port = res[addrLen + 4] * 256 + res[addrLen + 5];
+            byte[] ret = new byte[length - addrLen - 6];
+            Array.Copy(res, addrLen + 6, ret, 0, length - addrLen - 6);
+            return (
+                ret,
+                new IPEndPoint(ip, port),
+                ipPacketInformation.Address);
+        }
+
+        public Task DisconnectAsync()
+        {
+            try
+            {
+                assoc.Close();
+            }
+            catch { }
+            return Task.CompletedTask;
+        }
+
+        byte[] GetEndPointByte(IPEndPoint ep)
+        {
+            byte[] ipbyte = ep.Address.GetAddressBytes();
+            byte[] ret = new byte[ipbyte.Length + 3];
+            ret[0] = (byte)(ipbyte.Length == 1 ? 4 : 16);
+            Array.Copy(ipbyte, 0, ret, 1, ipbyte.Length);
+            ret[ipbyte.Length + 1] = (byte)(ep.Port % 256);
+            ret[ipbyte.Length + 2] = (byte)(ep.Port / 256);
+            return ret;
+        }
+    }
+}

+ 9 - 27
STUN/Utils/NetUtils.cs

@@ -4,6 +4,7 @@ using System;
 using System.Diagnostics;
 using System.Linq;
 using System.Net;
+using System.Net.NetworkInformation;
 using System.Net.Sockets;
 using System.Threading.Tasks;
 
@@ -42,33 +43,6 @@ namespace STUN.Utils
             return null;
         }
 
-        public static (string, string, string) NatTypeTestCore(string local, string server, ushort port)
-        {
-            try
-            {
-                if (string.IsNullOrWhiteSpace(server))
-                {
-                    Debug.WriteLine(@"[ERROR]: Please specify STUN server !");
-                    return (string.Empty, DefaultLocalEnd, string.Empty);
-                }
-
-                using var client = new StunClient3489(server, port, ParseEndpoint(local));
-
-                var result = client.Query();
-
-                return (
-                        result.NatType.ToString(),
-                        $@"{client.LocalEndPoint}",
-                        $@"{result.PublicEndPoint}"
-                );
-            }
-            catch (Exception ex)
-            {
-                Debug.WriteLine($@"[ERROR]: {ex}");
-                return (string.Empty, DefaultLocalEnd, string.Empty);
-            }
-        }
-
         public static async Task<StunResult5389> NatBehaviorDiscovery(string server, ushort port, IPEndPoint local)
         {
             using var client = new StunClient5389UDP(server, port, local);
@@ -118,5 +92,13 @@ namespace STUN.Utils
                     (IPEndPoint)receive
                     , local);
         }
+
+        public static TcpState GetState(this TcpClient tcpClient)
+        {
+            var foo = IPGlobalProperties.GetIPGlobalProperties()
+              .GetActiveTcpConnections()
+              .SingleOrDefault(x => x.LocalEndPoint.Equals(tcpClient.Client.LocalEndPoint));
+            return foo != null ? foo.State : TcpState.Unknown;
+        }
     }
 }