diff --git a/cmonitor/client/tunnel/ITunnelConnection.cs b/cmonitor/client/tunnel/ITunnelConnection.cs index e37b7bb0..7779ae4d 100644 --- a/cmonitor/client/tunnel/ITunnelConnection.cs +++ b/cmonitor/client/tunnel/ITunnelConnection.cs @@ -1,10 +1,11 @@ using common.libs.extends; using System.Net.Sockets; +using System.Text.Json.Serialization; namespace cmonitor.client.tunnel { - public delegate Task TunnelReceivceCallback(Memory data, object state); - public delegate Task TunnelCloseCallback(object state); + public delegate Task TunnelReceivceCallback(ITunnelConnection connection,Memory data, object state); + public delegate Task TunnelCloseCallback(ITunnelConnection connection,object state); public enum TunnelProtocolType : byte { @@ -57,6 +58,7 @@ namespace cmonitor.client.tunnel public bool Connected => Socket != null && Socket.Connected; + [JsonIgnore] public Socket Socket { get; init; } @@ -101,7 +103,7 @@ namespace cmonitor.client.tunnel int offset = e.Offset; int length = e.BytesTransferred; - await receiveCallback(e.Buffer.AsMemory(offset, length), e.UserToken).ConfigureAwait(false); + await receiveCallback(this,e.Buffer.AsMemory(offset, length), e.UserToken).ConfigureAwait(false); if (Socket.Available > 0) { @@ -110,7 +112,7 @@ namespace cmonitor.client.tunnel length = Socket.Receive(e.Buffer); if (length > 0) { - await receiveCallback(e.Buffer.AsMemory(offset, length), e.UserToken).ConfigureAwait(false); + await receiveCallback(this,e.Buffer.AsMemory(offset, length), e.UserToken).ConfigureAwait(false); } else { @@ -143,7 +145,7 @@ namespace cmonitor.client.tunnel } private void CloseClientSocket(SocketAsyncEventArgs e) { - closeCallback(e.UserToken); + closeCallback(this,e.UserToken); e.Dispose(); Close(); } diff --git a/cmonitor/plugins/relay/RelayApiController.cs b/cmonitor/plugins/relay/RelayApiController.cs index cec61775..ff12155b 100644 --- a/cmonitor/plugins/relay/RelayApiController.cs +++ b/cmonitor/plugins/relay/RelayApiController.cs @@ -1,9 +1,8 @@ using cmonitor.client.api; +using cmonitor.client.tunnel; using cmonitor.config; -using cmonitor.plugins.relay.transport; using common.libs; using common.libs.api; -using common.libs.extends; using System.Text; namespace cmonitor.plugins.relay @@ -27,17 +26,16 @@ namespace cmonitor.plugins.relay { try { - RelayTransportState state = await relayTransfer.ConnectAsync(param.Content, "test", config.Data.Client.Relay.SecretKey); - if (state != null) + ITunnelConnection connection = await relayTransfer.ConnectAsync(param.Content, "test", config.Data.Client.Relay.SecretKey); + if (connection != null) { - var socket = state.Socket; for (int i = 0; i < 10; i++) { Logger.Instance.Debug($"relay [test] send {i}"); - socket.Send(Encoding.UTF8.GetBytes($"snltty.relay.{i}")); + await connection.SendAsync(Encoding.UTF8.GetBytes($"snltty.relay.{i}")); await Task.Delay(10); } - socket.SafeClose(); + connection.Close(); } } catch (Exception ex) @@ -48,23 +46,21 @@ namespace cmonitor.plugins.relay } private void RelayTest() { - relayTransfer.OnConnected += (RelayTransportState state) => + relayTransfer.SetConnectCallback("test", (ITunnelConnection connection) => { - if (state.Info.TransactionId == "test") + Task.Run(() => { - Task.Run(() => + connection.BeginReceive(async (ITunnelConnection connection, Memory data, object state) => { - byte[] bytes = new byte[1024]; - while (true) - { - int length = state.Socket.Receive(bytes); - if (length == 0) break; - - Logger.Instance.Debug($"relay [{state.Info.TransactionId}] receive {Encoding.UTF8.GetString(bytes.AsSpan(0,length))}"); - } - }); - } - }; + Logger.Instance.Debug($"relay [{connection.TransactionId}] receive {Encoding.UTF8.GetString(data.Span)}"); + await Task.CompletedTask; + }, + async (ITunnelConnection connection, object state) => + { + await Task.CompletedTask; + }, null); + }); + }); } } diff --git a/cmonitor/plugins/relay/RelayTransfer.cs b/cmonitor/plugins/relay/RelayTransfer.cs index 5fb55d1b..b3e29966 100644 --- a/cmonitor/plugins/relay/RelayTransfer.cs +++ b/cmonitor/plugins/relay/RelayTransfer.cs @@ -1,12 +1,9 @@ -using cmonitor.client; +using cmonitor.client.tunnel; using cmonitor.config; using cmonitor.plugins.relay.transport; -using cmonitor.server; using common.libs; -using common.libs.extends; using Microsoft.Extensions.DependencyInjection; using System.Net; -using System.Net.Sockets; using System.Reflection; namespace cmonitor.plugins.relay @@ -18,7 +15,7 @@ namespace cmonitor.plugins.relay private readonly Config config; private readonly ServiceProvider serviceProvider; - public Action OnConnected { get; set; } = (state) => { }; + private Dictionary> OnConnected { get; } = new Dictionary>(); public RelayTransfer(Config config, ServiceProvider serviceProvider) { @@ -35,7 +32,19 @@ namespace cmonitor.plugins.relay Logger.Instance.Warning($"load relay transport:{string.Join(",", transports.Select(c => c.Name))}"); } - public async Task ConnectAsync(string remoteMachineName, string transactionId, string secretKey) + public void SetConnectCallback(string transactionId, Action callback) + { + if (OnConnected.TryGetValue(transactionId, out Action _callback) == false) + { + OnConnected[transactionId] = callback; + } + else + { + OnConnected[transactionId] += callback; + } + } + + public async Task ConnectAsync(string remoteMachineName, string transactionId, string secretKey) { IEnumerable _transports = transports.OrderBy(c => c.Name); foreach (RelayCompactInfo item in config.Data.Client.Relay.Servers.Where(c => c.Disabled == false)) @@ -58,10 +67,10 @@ namespace cmonitor.plugins.relay TransactionId = transactionId, TransportName = transport.Name }; - Socket socket = await transport.RelayAsync(relayInfo); - if (socket != null) + ITunnelConnection connection = await transport.RelayAsync(relayInfo); + if (connection != null) { - return new RelayTransportState { Info = relayInfo, Socket = socket, Direction = RelayTransportDirection.Forward }; + return connection; } } catch (Exception ex) @@ -76,10 +85,13 @@ namespace cmonitor.plugins.relay ITransport _transports = transports.FirstOrDefault(c => c.Name == relayInfo.TransportName); if (_transports != null) { - Socket socket = await _transports.OnBeginAsync(relayInfo); - if (socket != null) + ITunnelConnection connection = await _transports.OnBeginAsync(relayInfo); + if (connection != null) { - OnConnected(new RelayTransportState { Info = relayInfo, Socket = socket, Direction = RelayTransportDirection.Reverse }); + if (OnConnected.TryGetValue(connection.TransactionId, out Action callback)) + { + callback(connection); + } return true; } } diff --git a/cmonitor/plugins/relay/transport/ITransport.cs b/cmonitor/plugins/relay/transport/ITransport.cs index 1894d8ff..8529856b 100644 --- a/cmonitor/plugins/relay/transport/ITransport.cs +++ b/cmonitor/plugins/relay/transport/ITransport.cs @@ -1,14 +1,16 @@ -using MemoryPack; +using cmonitor.client.tunnel; +using MemoryPack; using System.Net; -using System.Net.Sockets; namespace cmonitor.plugins.relay.transport { public interface ITransport { public string Name { get; } - public Task RelayAsync(RelayInfo relayInfo); - public Task OnBeginAsync(RelayInfo relayInfo); + public TunnelProtocolType ProtocolType { get; } + + public Task RelayAsync(RelayInfo relayInfo); + public Task OnBeginAsync(RelayInfo relayInfo); } [MemoryPackable] @@ -26,19 +28,4 @@ namespace cmonitor.plugins.relay.transport } - - public enum RelayTransportDirection : byte - { - Forward = 0, - Reverse = 1 - } - - public sealed class RelayTransportState - { - public RelayInfo Info { get; set; } - - public RelayTransportDirection Direction { get; set; } = RelayTransportDirection.Reverse; - - public Socket Socket { get; set; } - } } diff --git a/cmonitor/plugins/relay/transport/TransportSelfHost.cs b/cmonitor/plugins/relay/transport/TransportSelfHost.cs index de5eea85..8a53db6c 100644 --- a/cmonitor/plugins/relay/transport/TransportSelfHost.cs +++ b/cmonitor/plugins/relay/transport/TransportSelfHost.cs @@ -1,4 +1,5 @@ -using cmonitor.plugins.relay.messenger; +using cmonitor.client.tunnel; +using cmonitor.plugins.relay.messenger; using cmonitor.server; using common.libs; using common.libs.extends; @@ -11,6 +12,7 @@ namespace cmonitor.plugins.relay.transport public sealed class TransportSelfHost : ITransport { public string Name => "self"; + public TunnelProtocolType ProtocolType => TunnelProtocolType.Tcp; private readonly TcpServer tcpServer; private readonly MessengerSender messengerSender; @@ -22,9 +24,9 @@ namespace cmonitor.plugins.relay.transport this.messengerSender = messengerSender; } - public async Task RelayAsync(RelayInfo relayInfo) + public async Task RelayAsync(RelayInfo relayInfo) { - Socket socket = new Socket(relayInfo.Server.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + Socket socket = new Socket(relayInfo.Server.AddressFamily, SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp); socket.Reuse(true); socket.IPv6Only(relayInfo.Server.AddressFamily, false); await socket.ConnectAsync(relayInfo.Server).WaitAsync(TimeSpan.FromSeconds(5)); @@ -43,12 +45,21 @@ namespace cmonitor.plugins.relay.transport } await socket.SendAsync(relayFlagData); await Task.Delay(10); - return socket; + return new TunnelConnectionTcp + { + Direction = TunnelDirection.Forward, + ProtocolType = TunnelProtocolType.Tcp, + RemoteMachineName = relayInfo.RemoteMachineName, + Socket = socket, + TransactionId = relayInfo.TransactionId, + TransportName = Name, + Type = TunnelType.Relay + }; } - public async Task OnBeginAsync(RelayInfo relayInfo) + public async Task OnBeginAsync(RelayInfo relayInfo) { - Socket socket = new Socket(relayInfo.Server.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + Socket socket = new Socket(relayInfo.Server.AddressFamily, SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp); socket.Reuse(true); socket.IPv6Only(relayInfo.Server.AddressFamily, false); await socket.ConnectAsync(relayInfo.Server).WaitAsync(TimeSpan.FromSeconds(5)); @@ -67,7 +78,16 @@ namespace cmonitor.plugins.relay.transport } await socket.SendAsync(relayFlagData); await Task.Delay(10); - return socket; + return new TunnelConnectionTcp + { + Direction = TunnelDirection.Reverse, + ProtocolType = TunnelProtocolType.Tcp, + RemoteMachineName = relayInfo.RemoteMachineName, + Socket = socket, + TransactionId = relayInfo.TransactionId, + TransportName = Name, + Type = TunnelType.Relay + }; } } } diff --git a/cmonitor/plugins/tunnel/TunnelApiController.cs b/cmonitor/plugins/tunnel/TunnelApiController.cs index f885e9a4..bc0c43a5 100644 --- a/cmonitor/plugins/tunnel/TunnelApiController.cs +++ b/cmonitor/plugins/tunnel/TunnelApiController.cs @@ -1,5 +1,6 @@ using cmonitor.client; using cmonitor.client.api; +using cmonitor.client.tunnel; using cmonitor.config; using cmonitor.plugins.signin.messenger; using cmonitor.plugins.tunnel.server; @@ -38,17 +39,16 @@ namespace cmonitor.plugins.tunnel { try { - TunnelTransportState state = await tunnelTransfer.ConnectAsync(param.Content, "test"); - if (state != null) + ITunnelConnection connection = await tunnelTransfer.ConnectAsync(param.Content, "test"); + if (connection != null) { - var socket = state.ConnectedObject as Socket; for (int i = 0; i < 10; i++) { Logger.Instance.Debug($"tunnel [test] send {i}"); - socket.Send(BitConverter.GetBytes(i)); + await connection.SendAsync(BitConverter.GetBytes(i)); await Task.Delay(10); } - socket.SafeClose(); + connection.Close(); } } catch (Exception ex) @@ -59,16 +59,17 @@ namespace cmonitor.plugins.tunnel } private void TunnelTest() { - tunnelTransfer.OnConnected += (TunnelTransportState state) => + tunnelTransfer.SetConnectCallback("test", (ITunnelConnection connection) => { - if (state.TransactionId == "test" && state.TransportType == ProtocolType.Tcp) - { - tunnelBindServer.BindReceive(state.ConnectedObject as Socket, null, async (token, data) => - { - Logger.Instance.Debug($"tunnel [{state.TransactionId}] receive {BitConverter.ToInt32(data.Span)}"); - }); - } - }; + connection.BeginReceive(async (ITunnelConnection connection, Memory data, object state) => { + + Logger.Instance.Debug($"tunnel [{connection.TransactionId}] receive {BitConverter.ToInt32(data.Span)}"); + await Task.CompletedTask; + + }, async (ITunnelConnection connection, object state) => { + await Task.CompletedTask; + }, null); + }); } } diff --git a/cmonitor/plugins/tunnel/TunnelTransfer.cs b/cmonitor/plugins/tunnel/TunnelTransfer.cs index 95bd5e3e..de457332 100644 --- a/cmonitor/plugins/tunnel/TunnelTransfer.cs +++ b/cmonitor/plugins/tunnel/TunnelTransfer.cs @@ -1,4 +1,5 @@ using cmonitor.client; +using cmonitor.client.tunnel; using cmonitor.config; using cmonitor.plugins.tunnel.compact; using cmonitor.plugins.tunnel.messenger; @@ -8,8 +9,8 @@ using common.libs; using common.libs.extends; using MemoryPack; using Microsoft.Extensions.DependencyInjection; -using System.Net.Sockets; using System.Reflection; +using System.Transactions; namespace cmonitor.plugins.tunnel { @@ -23,7 +24,7 @@ namespace cmonitor.plugins.tunnel private readonly MessengerSender messengerSender; private readonly CompactTransfer compactTransfer; - public Action OnConnected { get; set; } = (state) => { }; + private Dictionary> OnConnected { get; } = new Dictionary>(); public TunnelTransfer(Config config, ServiceProvider serviceProvider, ClientSignInState clientSignInState, MessengerSender messengerSender, CompactTransfer compactTransfer) { @@ -53,45 +54,44 @@ namespace cmonitor.plugins.tunnel Logger.Instance.Warning($"load tunnel transport:{string.Join(",", transports.Select(c => c.Name))}"); } - public async Task ConnectAsync(string remoteMachineName, string transactionId) + public async Task ConnectAsync(string remoteMachineName, string transactionId) { - IEnumerable _transports = transports.OrderBy(c => c.Type); + IEnumerable _transports = transports.OrderBy(c => c.ProtocolType); foreach (ITransport transport in _transports) { //获取自己的外网ip - TunnelTransportExternalIPInfo localInfo = await GetLocalInfo(transport.Type); + TunnelTransportExternalIPInfo localInfo = await GetLocalInfo(transport.ProtocolType); if (localInfo == null) { continue; } //获取对方的外网ip - TunnelTransportExternalIPInfo remoteInfo = await GetRemoteInfo(remoteMachineName, transport.Type); + TunnelTransportExternalIPInfo remoteInfo = await GetRemoteInfo(remoteMachineName, transport.ProtocolType); if (remoteInfo == null) { continue; } TunnelTransportInfo tunnelTransportInfo = new TunnelTransportInfo { - Direction = TunnelTransportDirection.Forward, + Direction = TunnelDirection.Forward, TransactionId = transactionId, TransportName = transport.Name, - TransportType = transport.Type, + TransportType = transport.ProtocolType, Local = localInfo, Remote = remoteInfo, }; - TunnelTransportState state = await transport.ConnectAsync(tunnelTransportInfo); - if (state != null) + ITunnelConnection connection = await transport.ConnectAsync(tunnelTransportInfo); + if (connection != null) { - state.Direction = TunnelTransportDirection.Forward; - _OnConnected(state); - return state; + _OnConnected(connection); + return connection; } } return null; } public void OnBegin(TunnelTransportInfo tunnelTransportInfo) { - ITransport _transports = transports.FirstOrDefault(c => c.Name == tunnelTransportInfo.TransportName && c.Type == tunnelTransportInfo.TransportType); + ITransport _transports = transports.FirstOrDefault(c => c.Name == tunnelTransportInfo.TransportName && c.ProtocolType == tunnelTransportInfo.TransportType); if (_transports != null) { _transports.OnBegin(tunnelTransportInfo); @@ -99,7 +99,7 @@ namespace cmonitor.plugins.tunnel } public void OnFail(TunnelTransportInfo tunnelTransportInfo) { - ITransport _transports = transports.FirstOrDefault(c => c.Name == tunnelTransportInfo.TransportName && c.Type == tunnelTransportInfo.TransportType); + ITransport _transports = transports.FirstOrDefault(c => c.Name == tunnelTransportInfo.TransportName && c.ProtocolType == tunnelTransportInfo.TransportType); if (_transports != null) { _transports.OnFail(tunnelTransportInfo); @@ -111,7 +111,7 @@ namespace cmonitor.plugins.tunnel return await GetLocalInfo(request.TransportType); } - private async Task GetLocalInfo(ProtocolType transportType) + private async Task GetLocalInfo(TunnelProtocolType transportType) { TunnelCompactIPEndPoint[] ips = await compactTransfer.GetExternalIPAsync(transportType); if (ips != null && ips.Length > 0) @@ -126,7 +126,7 @@ namespace cmonitor.plugins.tunnel } return null; } - private async Task GetRemoteInfo(string remoteMachineName, ProtocolType transportType) + private async Task GetRemoteInfo(string remoteMachineName, TunnelProtocolType transportType) { MessageResponeInfo resp = await messengerSender.SendReply(new MessageRequestWrap { @@ -168,6 +168,19 @@ namespace cmonitor.plugins.tunnel } + public void SetConnectCallback(string transactionId, Action callback) + { + if (OnConnected.TryGetValue(transactionId, out Action _callback) == false) + { + OnConnected[transactionId] = callback; + } + else + { + OnConnected[transactionId] += callback; + } + } + + public Dictionary Connections { get; } = new Dictionary(); private int connectionsChangeFlag = 1; public bool ConnectionChanged => Interlocked.CompareExchange(ref connectionsChangeFlag, 0, 1) == 1; @@ -191,23 +204,27 @@ namespace cmonitor.plugins.tunnel info.Status = TunnelConnectStatus.Connecting; Interlocked.Exchange(ref connectionsChangeFlag, 1); } - private void _OnConnected(TunnelTransportState state) + private void _OnConnected(ITunnelConnection connection) { if (Logger.Instance.LoggerLevel <= LoggerTypes.DEBUG) { - Logger.Instance.Debug($"tunnel connect [{state.TransactionId}]->{state.RemoteMachineName} success"); + Logger.Instance.Debug($"tunnel connect [{connection.TransactionId}]->{connection.RemoteMachineName} success"); } - CheckDic(state.RemoteMachineName, out TunnelConnectInfo info); + CheckDic(connection.RemoteMachineName, out TunnelConnectInfo info); info.Status = TunnelConnectStatus.Connected; - info.State = state; + info.Connection = connection; Interlocked.Exchange(ref connectionsChangeFlag, 1); - OnConnected(state); + + if (OnConnected.TryGetValue(connection.TransactionId, out Action _callback) == false) + { + _callback(connection); + } } - private void OnDisConnected(TunnelTransportState state) + private void OnDisConnected(ITunnelConnection connection) { - CheckDic(state.RemoteMachineName, out TunnelConnectInfo info); + CheckDic(connection.RemoteMachineName, out TunnelConnectInfo info); info.Status = TunnelConnectStatus.None; - info.State = null; + info.Connection = null; Interlocked.Exchange(ref connectionsChangeFlag, 1); } private void OnConnectFail(string machineName) @@ -218,7 +235,7 @@ namespace cmonitor.plugins.tunnel } CheckDic(machineName, out TunnelConnectInfo info); info.Status = TunnelConnectStatus.None; - info.State = null; + info.Connection = null; Interlocked.Exchange(ref connectionsChangeFlag, 1); } private void CheckDic(string name, out TunnelConnectInfo info) @@ -233,7 +250,7 @@ namespace cmonitor.plugins.tunnel public sealed class TunnelConnectInfo { public TunnelConnectStatus Status { get; set; } - public TunnelTransportState State { get; set; } + public ITunnelConnection Connection { get; set; } } public enum TunnelConnectStatus { diff --git a/cmonitor/plugins/tunnel/compact/CompactTransfer.cs b/cmonitor/plugins/tunnel/compact/CompactTransfer.cs index 98946b88..9294ed54 100644 --- a/cmonitor/plugins/tunnel/compact/CompactTransfer.cs +++ b/cmonitor/plugins/tunnel/compact/CompactTransfer.cs @@ -1,4 +1,5 @@ -using cmonitor.config; +using cmonitor.client.tunnel; +using cmonitor.config; using common.libs; using Microsoft.Extensions.DependencyInjection; using System.Net; @@ -28,7 +29,7 @@ namespace cmonitor.plugins.tunnel.compact Logger.Instance.Warning($"load tunnel compacts:{string.Join(",", compacts.Select(c => c.Name))}"); } - public async Task GetExternalIPAsync(ProtocolType protocolType) + public async Task GetExternalIPAsync(TunnelProtocolType protocolType) { TunnelCompactIPEndPoint[] endpoints = new TunnelCompactIPEndPoint[config.Data.Client.Tunnel.Servers.Length]; @@ -42,12 +43,12 @@ namespace cmonitor.plugins.tunnel.compact try { IPEndPoint server = NetworkHelper.GetEndPoint(item.Host, 3478); - if (protocolType == ProtocolType.Tcp) + if (protocolType == TunnelProtocolType.Tcp) { TunnelCompactIPEndPoint externalIP = await compact.GetTcpExternalIPAsync(server); endpoints[i] = externalIP; } - else if (protocolType == ProtocolType.Udp) + else if (protocolType == TunnelProtocolType.Udp) { TunnelCompactIPEndPoint externalIP = await compact.GetUdpExternalIPAsync(server); endpoints[i] = externalIP; diff --git a/cmonitor/plugins/tunnel/server/TunnelBindServer.cs b/cmonitor/plugins/tunnel/server/TunnelBindServer.cs index 8552a46f..90cf2e81 100644 --- a/cmonitor/plugins/tunnel/server/TunnelBindServer.cs +++ b/cmonitor/plugins/tunnel/server/TunnelBindServer.cs @@ -12,7 +12,6 @@ namespace cmonitor.plugins.tunnel.server public Action OnTcpConnected { get; set; } = (state, socket) => { }; public Action OnUdpConnected { get; set; } = (state, udpClient) => { }; - public Action OnDisConnected { get; set; } = (state) => { }; private ConcurrentDictionary acceptBinds = new ConcurrentDictionary(); @@ -85,9 +84,6 @@ namespace cmonitor.plugins.tunnel.server case SocketAsyncOperation.Accept: ProcessAccept(e); break; - case SocketAsyncOperation.Receive: - ProcessReceive(e); - break; default: break; } @@ -119,87 +115,6 @@ namespace cmonitor.plugins.tunnel.server } } - public void BindReceive(Socket socket, object state, OnTunnelData dataCallback) - { - if (socket == null || socket.RemoteEndPoint == null) - { - return; - } - - socket.KeepAlive(); - AsyncUserToken userToken = new AsyncUserToken - { - SourceSocket = socket, - State = state, - OnData = dataCallback, - LocalPort = (socket.LocalEndPoint as IPEndPoint).Port, - }; - - SocketAsyncEventArgs readEventArgs = new SocketAsyncEventArgs - { - UserToken = userToken, - SocketFlags = SocketFlags.None, - }; - readEventArgs.SetBuffer(new byte[8 * 1024], 0, 8 * 1024); - readEventArgs.Completed += IO_Completed; - if (socket.ReceiveAsync(readEventArgs) == false) - { - ProcessReceive(readEventArgs); - } - } - private async void ProcessReceive(SocketAsyncEventArgs e) - { - try - { - AsyncUserToken token = (AsyncUserToken)e.UserToken; - - if (e.BytesTransferred > 0 && e.SocketError == SocketError.Success) - { - int offset = e.Offset; - int length = e.BytesTransferred; - - await token.OnData(token, e.Buffer.AsMemory(0, length)); - if (token.SourceSocket.Available > 0) - { - while (token.SourceSocket.Available > 0) - { - length = token.SourceSocket.Receive(e.Buffer); - if (length > 0) - { - await token.OnData(token, e.Buffer.AsMemory(0, length)); - } - else - { - CloseClientSocket(e); - return; - } - } - } - - if (token.SourceSocket.Connected == false) - { - CloseClientSocket(e); - return; - } - - if (token.SourceSocket.ReceiveAsync(e) == false) - { - ProcessReceive(e); - } - } - else - { - CloseClientSocket(e); - } - } - catch (Exception ex) - { - if (Logger.Instance.LoggerLevel <= LoggerTypes.DEBUG) - Logger.Instance.Error(ex); - - CloseClientSocket(e); - } - } private void CloseClientSocket(SocketAsyncEventArgs e) { if (e == null || e.UserToken == null) return; @@ -214,7 +129,6 @@ namespace cmonitor.plugins.tunnel.server { CloseClientSocket(saea1); } - OnDisConnected(token.State); } } @@ -226,9 +140,6 @@ namespace cmonitor.plugins.tunnel.server public Socket SourceSocket { get; set; } public SocketAsyncEventArgs Saea { get; set; } public object State { get; set; } - - public OnTunnelData OnData { get; set; } - public int LocalPort { get; set; } public void Clear() diff --git a/cmonitor/plugins/tunnel/transport/ITransport.cs b/cmonitor/plugins/tunnel/transport/ITransport.cs index ae0af544..4c4e2683 100644 --- a/cmonitor/plugins/tunnel/transport/ITransport.cs +++ b/cmonitor/plugins/tunnel/transport/ITransport.cs @@ -1,14 +1,13 @@ -using MemoryPack; +using cmonitor.client.tunnel; +using MemoryPack; using System.Net; -using System.Net.Sockets; -using System.Text.Json.Serialization; namespace cmonitor.plugins.tunnel.transport { public interface ITransport { public string Name { get; } - public ProtocolType Type { get; } + public TunnelProtocolType ProtocolType { get; } /// /// 发送连接信息 @@ -26,11 +25,11 @@ namespace cmonitor.plugins.tunnel.transport /// /// 收到连接 /// - public Action OnConnected { get; set; } + public Action OnConnected { get; set; } /// /// 断开连接 /// - public Action OnDisConnected { get; set; } + public Action OnDisConnected { get; set; } public Action OnConnectFail { get; set; } @@ -39,7 +38,7 @@ namespace cmonitor.plugins.tunnel.transport /// /// 你的名字 /// - public Task ConnectAsync(TunnelTransportInfo tunnelTransportInfo); + public Task ConnectAsync(TunnelTransportInfo tunnelTransportInfo); /// /// 收到开始连接 /// @@ -57,7 +56,7 @@ namespace cmonitor.plugins.tunnel.transport public sealed partial class TunnelTransportExternalIPRequestInfo { public string RemoteMachineName { get; set; } - public ProtocolType TransportType { get; set; } + public TunnelProtocolType TransportType { get; set; } } [MemoryPackable] @@ -83,43 +82,11 @@ namespace cmonitor.plugins.tunnel.transport public string TransactionId { get; set; } - public ProtocolType TransportType { get; set; } + public TunnelProtocolType TransportType { get; set; } public string TransportName { get; set; } - public TunnelTransportDirection Direction { get; set; } + public TunnelDirection Direction { get; set; } } - public enum TunnelTransportDirection : byte - { - Forward = 0, - Reverse = 1 - } - - public enum TunnelTransportType - { - Tcp = ProtocolType.Tcp, - Udp = ProtocolType.Udp, - } - - public sealed class TunnelTransportState - { - public string RemoteMachineName { get; set; } - public string TransactionId { get; set; } - public string TransportName { get; set; } - public ProtocolType TransportType { get; set; } - - public TunnelTransportDirection Direction { get; set; } = TunnelTransportDirection.Reverse; - - [JsonIgnore] - public object ConnectedObject { get; set; } - } - - - public interface ITunnelConnection - { - public TunnelTransportType TransportType { get; } - } - - } diff --git a/cmonitor/plugins/tunnel/transport/TransportTcpNutssb.cs b/cmonitor/plugins/tunnel/transport/TransportTcpNutssb.cs index d50e2239..8318841f 100644 --- a/cmonitor/plugins/tunnel/transport/TransportTcpNutssb.cs +++ b/cmonitor/plugins/tunnel/transport/TransportTcpNutssb.cs @@ -1,4 +1,5 @@ -using cmonitor.plugins.tunnel.server; +using cmonitor.client.tunnel; +using cmonitor.plugins.tunnel.server; using common.libs; using common.libs.extends; using System.Collections.Concurrent; @@ -10,14 +11,14 @@ namespace cmonitor.plugins.tunnel.transport public sealed class TransportTcpNutssb : ITransport { public string Name => "TcpNutssb"; - public ProtocolType Type => ProtocolType.Tcp; + public TunnelProtocolType ProtocolType => TunnelProtocolType.Tcp; public Func> OnSendConnectBegin { get; set; } = async (info) => { return await Task.FromResult(false); }; public Func OnSendConnectFail { get; set; } = async (info) => { await Task.CompletedTask; }; public Action OnConnectBegin { get; set; } = (info) => { }; public Action OnConnecting { get; set; } - public Action OnConnected { get; set; } = (state) => { }; - public Action OnDisConnected { get; set; } = (state) => { }; + public Action OnConnected { get; set; } = (state) => { }; + public Action OnDisConnected { get; set; } = (state) => { }; public Action OnConnectFail { get; set; } = (machineName) => { }; @@ -26,26 +27,25 @@ namespace cmonitor.plugins.tunnel.transport { this.tunnelBindServer = tunnelBindServer; tunnelBindServer.OnTcpConnected += OnTcpConnected; - tunnelBindServer.OnDisConnected += OnTcpDisConnected; } - public async Task ConnectAsync(TunnelTransportInfo tunnelTransportInfo) + public async Task ConnectAsync(TunnelTransportInfo tunnelTransportInfo) { OnConnecting(tunnelTransportInfo); //正向连接 - tunnelTransportInfo.Direction = TunnelTransportDirection.Forward; + tunnelTransportInfo.Direction = TunnelDirection.Forward; if (await OnSendConnectBegin(tunnelTransportInfo) == false) { OnConnectFail(tunnelTransportInfo.Remote.MachineName); return null; } - TunnelTransportState state = await ConnectForward(tunnelTransportInfo); - if (state != null) return state; + ITunnelConnection connection = await ConnectForward(tunnelTransportInfo); + if (connection != null) return connection; //反向连接 TunnelTransportInfo tunnelTransportInfo1 = tunnelTransportInfo.ToJsonFormat().DeJson(); - tunnelTransportInfo1.Direction = TunnelTransportDirection.Reverse; + tunnelTransportInfo1.Direction = TunnelDirection.Reverse; tunnelBindServer.Bind(tunnelTransportInfo1.Local.Local, tunnelTransportInfo1); BindAndTTL(tunnelTransportInfo1); if (await OnSendConnectBegin(tunnelTransportInfo1) == false) @@ -53,8 +53,8 @@ namespace cmonitor.plugins.tunnel.transport OnConnectFail(tunnelTransportInfo.Remote.MachineName); return null; } - state = await WaitReverse(tunnelTransportInfo1); - if (state != null) return state; + connection = await WaitReverse(tunnelTransportInfo1); + if (connection != null) return connection; //正向反向都失败 await OnSendConnectFail(tunnelTransportInfo); @@ -65,22 +65,22 @@ namespace cmonitor.plugins.tunnel.transport public void OnBegin(TunnelTransportInfo tunnelTransportInfo) { OnConnectBegin(tunnelTransportInfo); - if (tunnelTransportInfo.Direction == TunnelTransportDirection.Forward) + if (tunnelTransportInfo.Direction == TunnelDirection.Forward) { tunnelBindServer.Bind(tunnelTransportInfo.Local.Local, tunnelTransportInfo); } Task.Run(async () => { - if (tunnelTransportInfo.Direction == TunnelTransportDirection.Forward) + if (tunnelTransportInfo.Direction == TunnelDirection.Forward) { BindAndTTL(tunnelTransportInfo); } else { - TunnelTransportState state = await ConnectForward(tunnelTransportInfo); - if (state != null) + ITunnelConnection connection = await ConnectForward(tunnelTransportInfo); + if (connection != null) { - OnConnected(state); + OnConnected(connection); } else { @@ -96,7 +96,7 @@ namespace cmonitor.plugins.tunnel.transport tunnelBindServer.RemoveBind(tunnelTransportInfo.Local.Local.Port); } - private async Task ConnectForward(TunnelTransportInfo tunnelTransportInfo) + private async Task ConnectForward(TunnelTransportInfo tunnelTransportInfo) { await Task.Delay(20); //要连接哪些IP @@ -109,7 +109,7 @@ namespace cmonitor.plugins.tunnel.transport }; foreach (IPEndPoint ep in eps) { - Socket targetSocket = new(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + Socket targetSocket = new(ep.AddressFamily, SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp); targetSocket.IPv6Only(ep.Address.AddressFamily, false); targetSocket.ReuseBind(new IPEndPoint(ep.AddressFamily == AddressFamily.InterNetwork ? IPAddress.Any : IPAddress.IPv6Any, tunnelTransportInfo.Local.Local.Port)); IAsyncResult result = targetSocket.BeginConnect(ep, null, null); @@ -133,13 +133,16 @@ namespace cmonitor.plugins.tunnel.transport } targetSocket.EndConnect(result); - return new TunnelTransportState + return new TunnelConnectionTcp { - ConnectedObject = targetSocket, + Socket = targetSocket, TransactionId = tunnelTransportInfo.TransactionId, RemoteMachineName = tunnelTransportInfo.Remote.MachineName, TransportName = Name, - TransportType = Type + Direction = tunnelTransportInfo.Direction, + ProtocolType = ProtocolType, + Type = TunnelType.P2P, + Label = string.Empty }; } catch (Exception ex) @@ -167,7 +170,7 @@ namespace cmonitor.plugins.tunnel.transport //过滤掉不支持IPV6的情况 IEnumerable sockets = eps.Where(c => NotIPv6Support(c.Address) == false).Select(ip => { - Socket targetSocket = new(ip.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + Socket targetSocket = new(ip.AddressFamily, SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp); try { targetSocket.IPv6Only(ip.Address.AddressFamily, false); @@ -188,16 +191,16 @@ namespace cmonitor.plugins.tunnel.transport } - private ConcurrentDictionary> reverseDic = new ConcurrentDictionary>(); - private async Task WaitReverse(TunnelTransportInfo tunnelTransportInfo) + private ConcurrentDictionary> reverseDic = new ConcurrentDictionary>(); + private async Task WaitReverse(TunnelTransportInfo tunnelTransportInfo) { - TaskCompletionSource tcs = new TaskCompletionSource(); + TaskCompletionSource tcs = new TaskCompletionSource(); reverseDic.TryAdd(tunnelTransportInfo.Remote.MachineName, tcs); try { - TunnelTransportState state = await tcs.Task.WaitAsync(TimeSpan.FromMilliseconds(3000)); - return state; + ITunnelConnection connection = await tcs.Task.WaitAsync(TimeSpan.FromMilliseconds(3000)); + return connection; } catch (Exception) { @@ -214,15 +217,18 @@ namespace cmonitor.plugins.tunnel.transport { if (state is TunnelTransportInfo _state && _state.TransportName == Name) { - TunnelTransportState result = new TunnelTransportState + TunnelConnectionTcp result = new TunnelConnectionTcp { RemoteMachineName = _state.Remote.MachineName, - TransportType = ProtocolType.Tcp, - ConnectedObject = socket, + Direction = _state.Direction, + ProtocolType = TunnelProtocolType.Tcp, + Socket = socket, + Type = TunnelType.P2P, TransactionId = _state.TransactionId, TransportName = _state.TransportName, + Label = string.Empty, }; - if (reverseDic.TryRemove(_state.Remote.MachineName, out TaskCompletionSource tcs)) + if (reverseDic.TryRemove(_state.Remote.MachineName, out TaskCompletionSource tcs)) { tcs.SetResult(result); return; @@ -231,20 +237,6 @@ namespace cmonitor.plugins.tunnel.transport OnConnected(result); } } - private void OnTcpDisConnected(object state) - { - if (state is TunnelTransportInfo _state && _state.TransportName == Name) - { - TunnelTransportState result = new TunnelTransportState - { - RemoteMachineName = _state.Remote.MachineName, - TransportType = ProtocolType.Tcp, - TransactionId = _state.TransactionId, - TransportName = _state.TransportName, - }; - OnDisConnected(result); - } - } private bool NotIPv6Support(IPAddress ip) { diff --git a/cmonitor/plugins/viewer/proxy/ViewerProxy.cs b/cmonitor/plugins/viewer/proxy/ViewerProxy.cs index 26202e01..2b0e4b94 100644 --- a/cmonitor/plugins/viewer/proxy/ViewerProxy.cs +++ b/cmonitor/plugins/viewer/proxy/ViewerProxy.cs @@ -1,4 +1,5 @@ -using common.libs; +using cmonitor.client.tunnel; +using common.libs; using common.libs.extends; using System.Buffers; using System.Collections.Concurrent; @@ -9,10 +10,11 @@ namespace cmonitor.plugins.viewer.proxy { public class ViewerProxy { - private SocketAsyncEventArgs acceptEventArg; + private AsyncUserToken userToken; private Socket socket; private UdpClient udpClient; - private NumberSpace ns = new NumberSpace(); + private readonly NumberSpace ns = new NumberSpace(); + private readonly ConcurrentDictionary dic = new ConcurrentDictionary(); public IPEndPoint LocalEndpoint => socket?.LocalEndPoint as IPEndPoint ?? new IPEndPoint(IPAddress.Any, 0); @@ -32,14 +34,16 @@ namespace cmonitor.plugins.viewer.proxy socket.ReuseBind(localEndPoint); socket.Listen(int.MaxValue); - acceptEventArg = new SocketAsyncEventArgs + userToken = new AsyncUserToken { - UserToken = new AsyncUserToken - { - SourceSocket = socket - }, + Socket = socket + }; + SocketAsyncEventArgs acceptEventArg = new SocketAsyncEventArgs + { + UserToken = userToken, SocketFlags = SocketFlags.None, }; + userToken.Saea = acceptEventArg; acceptEventArg.Completed += IO_Completed; StartAccept(acceptEventArg); @@ -56,11 +60,11 @@ namespace cmonitor.plugins.viewer.proxy } } + private readonly AsyncUserUdpToken asyncUserUdpToken = new AsyncUserUdpToken { - Proxy = new ProxyInfo { Step = ProxyStep.Forward, Direction = ProxyDirection.UnPack, ConnectId = 0 } + Proxy = new ProxyInfo { Step = ProxyStep.Forward, ConnectId = 0 } }; - private async void ReceiveCallbackUdp(IAsyncResult result) { try @@ -83,15 +87,13 @@ namespace cmonitor.plugins.viewer.proxy await Task.CompletedTask; } - - private void StartAccept(SocketAsyncEventArgs acceptEventArg) { acceptEventArg.AcceptSocket = null; AsyncUserToken token = (AsyncUserToken)acceptEventArg.UserToken; try { - if (token.SourceSocket.AcceptAsync(acceptEventArg) == false) + if (token.Socket.AcceptAsync(acceptEventArg) == false) { ProcessAccept(acceptEventArg); } @@ -138,7 +140,7 @@ namespace cmonitor.plugins.viewer.proxy socket.KeepAlive(); AsyncUserToken userToken = new AsyncUserToken { - SourceSocket = socket, + Socket = socket, Proxy = new ProxyInfo { Data = Helper.EmptyArray, Step = ProxyStep.Request, ConnectId = ns.Increment() } }; @@ -147,6 +149,8 @@ namespace cmonitor.plugins.viewer.proxy UserToken = userToken, SocketFlags = SocketFlags.None, }; + userToken.Saea = readEventArgs; + readEventArgs.SetBuffer(new byte[8 * 1024], 0, 8 * 1024); readEventArgs.Completed += IO_Completed; if (socket.ReceiveAsync(readEventArgs) == false) @@ -162,46 +166,45 @@ namespace cmonitor.plugins.viewer.proxy } private async void ProcessReceive(SocketAsyncEventArgs e) { + AsyncUserToken token = (AsyncUserToken)e.UserToken; try { - AsyncUserToken token = (AsyncUserToken)e.UserToken; - if (e.BytesTransferred > 0 && e.SocketError == SocketError.Success) { int offset = e.Offset; int length = e.BytesTransferred; - await ReadPacket(e, token, e.Buffer.AsMemory(offset, length)); - if (token.SourceSocket.Available > 0) + await ReadPacket(token, e.Buffer.AsMemory(offset, length)); + if (token.Socket.Available > 0) { - while (token.SourceSocket.Available > 0) + while (token.Socket.Available > 0) { - length = token.SourceSocket.Receive(e.Buffer); + length = token.Socket.Receive(e.Buffer); if (length > 0) { - await ReadPacket(e, token, e.Buffer.AsMemory(0, length)); + await ReadPacket(token, e.Buffer.AsMemory(0, length)); } else { - CloseClientSocket(e); + CloseClientSocket(token); return; } } } - if (token.SourceSocket.Connected == false) + if (token.Socket.Connected == false) { - CloseClientSocket(e); + CloseClientSocket(token); return; } - if (token.SourceSocket.ReceiveAsync(e) == false) + if (token.Socket.ReceiveAsync(e) == false) { ProcessReceive(e); } } else { - CloseClientSocket(e); + CloseClientSocket(token); } } catch (Exception ex) @@ -209,55 +212,38 @@ namespace cmonitor.plugins.viewer.proxy if (Logger.Instance.LoggerLevel <= LoggerTypes.DEBUG) Logger.Instance.Error(ex); - CloseClientSocket(e); + CloseClientSocket(token); } } - - private async Task ReadPacket(SocketAsyncEventArgs e, AsyncUserToken token, Memory data) + private async Task ReadPacket(AsyncUserToken token, Memory data) { if (token.Proxy.Step == ProxyStep.Request) { await Connect(token); - if (token.TargetSocket != null) + if (token.Connection != null) { //发送连接请求包 - await SendToTarget(e, token).ConfigureAwait(false); + await SendToConnection(token).ConfigureAwait(false); token.Proxy.Step = ProxyStep.Forward; token.Proxy.TargetEP = null; //发送后续数据包 token.Proxy.Data = data; - await SendToTarget(e, token).ConfigureAwait(false); + await SendToConnection(token).ConfigureAwait(false); //绑定 - dic.TryAdd(new ConnectId(token.Proxy.ConnectId, token.TargetSocket.GetHashCode()), token.SourceSocket); + dic.TryAdd(new ConnectId(token.Proxy.ConnectId, token.Connection.GetHashCode()), token.Socket); } else { - CloseClientSocket(e); + CloseClientSocket(token); } } else { token.Proxy.Data = data; - await SendToTarget(e, token).ConfigureAwait(false); - } - } - private async Task SendToTarget(SocketAsyncEventArgs e, AsyncUserToken token) - { - byte[] connectData = token.Proxy.ToBytes(out int length); - try - { - await token.TargetSocket.SendAsync(connectData.AsMemory(0, length), SocketFlags.None); - } - catch (Exception) - { - CloseClientSocket(e); - } - finally - { - token.Proxy.Return(connectData); + await SendToConnection(token).ConfigureAwait(false); } } @@ -266,26 +252,115 @@ namespace cmonitor.plugins.viewer.proxy await Task.CompletedTask; } - protected bool BindReceiveTarget(Socket targetSocket, Socket sourceSocket) + private async Task SendToConnection(AsyncUserToken token) { + byte[] connectData = token.Proxy.ToBytes(out int length); try { - BindReceiveTarget(new AsyncUserToken - { - TargetSocket = targetSocket, - SourceSocket = sourceSocket, - Buffer = new ReceiveDataBuffer(), - Proxy = new ProxyInfo { Direction = ProxyDirection.UnPack } - }); - - return true; + await token.Connection.SendAsync(connectData.AsMemory(0, length)).ConfigureAwait(false); } - catch (Exception ex) + catch (Exception) { - Logger.Instance.Error(ex); + CloseClientSocket(token); + } + finally + { + token.Proxy.Return(connectData); } - return false; } + + protected void BindConnectionReceive(ITunnelConnection connection) + { + connection.BeginReceive(InputConnectionData, CloseConnection, new AsyncUserToken + { + Connection = connection, + Buffer = new ReceiveDataBuffer(), + Proxy = new ProxyInfo { } + }); + } + protected async Task InputConnectionData(ITunnelConnection connection, Memory memory, object userToken) + { + AsyncUserToken token = userToken as AsyncUserToken; + //是一个完整的包 + if (token.Buffer.Size == 0 && memory.Length > 4) + { + int packageLen = memory.ToInt32(); + if (packageLen == memory.Length - 4) + { + token.Proxy.DeBytes(memory.Slice(0, packageLen + 4)); + await ReadConnectionPack(token).ConfigureAwait(false); + return; + } + } + + //不是完整包 + token.Buffer.AddRange(memory); + do + { + int packageLen = token.Buffer.Data.ToInt32(); + if (packageLen > token.Buffer.Size - 4) + { + break; + } + token.Proxy.DeBytes(token.Buffer.Data.Slice(0, packageLen + 4)); + await ReadConnectionPack(token).ConfigureAwait(false); + + token.Buffer.RemoveRange(0, packageLen + 4); + } while (token.Buffer.Size > 4); + } + protected async Task CloseConnection(ITunnelConnection connection, object userToken) + { + CloseClientSocket(userToken as AsyncUserToken); + await Task.CompletedTask; + } + private async Task ReadConnectionPack(AsyncUserToken token) + { + if (token.Proxy.Step == ProxyStep.Request) + { + await ConnectBind(token).ConfigureAwait(false); + } + else + { + await SendToSocket(token).ConfigureAwait(false); + } + } + private async Task ConnectBind(AsyncUserToken token) + { + Socket socket = new Socket(token.Proxy.TargetEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + socket.KeepAlive(); + await socket.ConnectAsync(token.Proxy.TargetEP); + + dic.TryAdd(new ConnectId(token.Proxy.ConnectId, token.Connection.GetHashCode()), socket); + + BindReceiveTarget(new AsyncUserToken + { + Connection = token.Connection, + Socket = socket, + Proxy = new ProxyInfo + { + ConnectId = token.Proxy.ConnectId, + Step = ProxyStep.Forward + } + }); + } + private async Task SendToSocket(AsyncUserToken token) + { + ConnectId connectId = new ConnectId(token.Proxy.ConnectId, token.Connection.GetHashCode()); + if (dic.TryGetValue(connectId, out Socket source)) + { + try + { + await source.SendAsync(token.Proxy.Data); + } + catch (Exception) + { + CloseClientSocket(token); + + } + } + } + + private void IO_CompletedTarget(object sender, SocketAsyncEventArgs e) { switch (e.LastOperation) @@ -308,7 +383,7 @@ namespace cmonitor.plugins.viewer.proxy }; readEventArgs.SetBuffer(new byte[8 * 1024], 0, 8 * 1024); readEventArgs.Completed += IO_CompletedTarget; - if (userToken.TargetSocket.ReceiveAsync(readEventArgs) == false) + if (userToken.Socket.ReceiveAsync(readEventArgs) == false) { ProcessReceiveTarget(readEventArgs); } @@ -321,48 +396,49 @@ namespace cmonitor.plugins.viewer.proxy } private async void ProcessReceiveTarget(SocketAsyncEventArgs e) { + AsyncUserToken token = (AsyncUserToken)e.UserToken; try { - AsyncUserToken token = (AsyncUserToken)e.UserToken; - if (e.BytesTransferred > 0 && e.SocketError == SocketError.Success) { int offset = e.Offset; int length = e.BytesTransferred; - await ReadPacketTarget(e, token, e.Buffer.AsMemory(offset, length)).ConfigureAwait(false); + token.Proxy.Data = e.Buffer.AsMemory(offset, length); + await SendToConnection(token).ConfigureAwait(false); - if (token.TargetSocket.Available > 0) + if (token.Socket.Available > 0) { - while (token.TargetSocket.Available > 0) + while (token.Socket.Available > 0) { - length = token.TargetSocket.Receive(e.Buffer); + length = token.Socket.Receive(e.Buffer); if (length > 0) { - await ReadPacketTarget(e, token, e.Buffer.AsMemory(0, length)).ConfigureAwait(false); + token.Proxy.Data = e.Buffer.AsMemory(0, length); + await SendToConnection(token).ConfigureAwait(false); } else { - CloseClientSocket(e); + CloseClientSocket(token); return; } } } - if (token.TargetSocket.Connected == false) + if (token.Connection.Connected == false) { - CloseClientSocket(e); + CloseClientSocket(token); return; } - if (token.TargetSocket.ReceiveAsync(e) == false) + if (token.Socket.ReceiveAsync(e) == false) { ProcessReceiveTarget(e); } } else { - CloseClientSocket(e); + CloseClientSocket(token); } } catch (Exception ex) @@ -370,126 +446,19 @@ namespace cmonitor.plugins.viewer.proxy if (Logger.Instance.LoggerLevel <= LoggerTypes.DEBUG) Logger.Instance.Error(ex); - CloseClientSocket(e); + CloseClientSocket(token); } } - private readonly ConcurrentDictionary dic = new ConcurrentDictionary(); - private async Task ReadPacketTarget(SocketAsyncEventArgs e, AsyncUserToken token, Memory data) + + private void CloseClientSocket(AsyncUserToken token) { - //A 到 B - if (token.Proxy.Direction == ProxyDirection.UnPack) + if (token.Connection != null) { - //是一个完整的包 - if (token.Buffer.Size == 0 && data.Length > 4) + int code = token.Connection.GetHashCode(); + if (token.Connection.Connected == false) { - int packageLen = data.ToInt32(); - if (packageLen == data.Length - 4) - { - token.Proxy.DeBytes(data.Slice(0, packageLen + 4)); - await ReadPacketTarget(e, token).ConfigureAwait(false); - return; - } - } - - //不是完整包 - token.Buffer.AddRange(data); - do - { - int packageLen = token.Buffer.Data.ToInt32(); - if (packageLen > token.Buffer.Size - 4) - { - break; - } - token.Proxy.DeBytes(token.Buffer.Data.Slice(0, packageLen + 4)); - await ReadPacketTarget(e, token).ConfigureAwait(false); - - token.Buffer.RemoveRange(0, packageLen + 4); - } while (token.Buffer.Size > 4); - } - else - { - token.Proxy.Data = data; - await SendToSource(e, token).ConfigureAwait(false); - } - } - private async Task ReadPacketTarget(SocketAsyncEventArgs e, AsyncUserToken token) - { - if (token.Proxy.Step == ProxyStep.Request) - { - await ConnectBind(e, token).ConfigureAwait(false); - } - else - { - await SendToSource(e, token).ConfigureAwait(false); - } - } - private async Task ConnectBind(SocketAsyncEventArgs e, AsyncUserToken token) - { - Socket socket = new Socket(token.Proxy.TargetEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp); - socket.KeepAlive(); - await socket.ConnectAsync(token.Proxy.TargetEP); - - dic.TryAdd(new ConnectId(token.Proxy.ConnectId, token.TargetSocket.GetHashCode()), socket); - - BindReceiveTarget(new AsyncUserToken - { - TargetSocket = socket, - SourceSocket = token.TargetSocket, - Proxy = new ProxyInfo - { - Direction = ProxyDirection.Pack, - ConnectId = token.Proxy.ConnectId, - Step = ProxyStep.Forward - } - }); - } - private async Task SendToSource(SocketAsyncEventArgs e, AsyncUserToken token) - { - if (token.Proxy.Direction == ProxyDirection.UnPack) - { - ConnectId connectId = new ConnectId(token.Proxy.ConnectId, token.TargetSocket.GetHashCode()); - if (dic.TryGetValue(connectId, out Socket source)) - { - try - { - await source.SendAsync(token.Proxy.Data); - } - catch (Exception) - { - CloseClientSocket(e); - - } - } - } - else - { - byte[] connectData = token.Proxy.ToBytes(out int length); - try - { - await token.SourceSocket.SendAsync(connectData.AsMemory(0, length), SocketFlags.None); - } - catch (Exception) - { - CloseClientSocket(e); - } - finally - { - token.Proxy.Return(connectData); - } - } - } - - private void CloseClientSocket(SocketAsyncEventArgs e) - { - if (e == null) return; - AsyncUserToken token = e.UserToken as AsyncUserToken; - if (token.TargetSocket != null) - { - int code = token.TargetSocket.GetHashCode(); - if (token.TargetSocket.Connected == false) - { - foreach (ConnectId item in dic.Keys.Where(c => c.socket == code).ToList()) + foreach (ConnectId item in dic.Keys.Where(c => c.hashCode == code).ToList()) { dic.TryRemove(item, out _); } @@ -499,28 +468,36 @@ namespace cmonitor.plugins.viewer.proxy dic.TryRemove(new ConnectId(token.Proxy.ConnectId, code), out _); } } - if (token.SourceSocket != null) - { - token.Clear(); - e.Dispose(); - } + token.Clear(); } public void Stop() { - CloseClientSocket(acceptEventArg); + CloseClientSocket(userToken); udpClient?.Close(); } } + public enum ProxyStep : byte + { + Request = 1, + Forward = 2 + } + public record struct ConnectId + { + public ulong connectId; + public int hashCode; + + public ConnectId(ulong connectId, int hashCode) + { + this.connectId = connectId; + this.hashCode = hashCode; + } + } public sealed class ProxyInfo { public ulong ConnectId { get; set; } - public ProxyStep Step { get; set; } = ProxyStep.Request; - - public ProxyDirection Direction { get; set; } = ProxyDirection.Pack; - public IPEndPoint TargetEP { get; set; } public Memory Data { get; set; } @@ -595,7 +572,6 @@ namespace cmonitor.plugins.viewer.proxy } } - public sealed class AsyncUserUdpToken { public UdpClient SourceSocket { get; set; } @@ -606,53 +582,31 @@ namespace cmonitor.plugins.viewer.proxy { SourceSocket?.Close(); SourceSocket = null; - GC.Collect(); } } - public sealed class AsyncUserToken { - public Socket SourceSocket { get; set; } - public Socket TargetSocket { get; set; } + public Socket Socket { get; set; } + public ITunnelConnection Connection { get; set; } public ProxyInfo Proxy { get; set; } public ReceiveDataBuffer Buffer { get; set; } + public SocketAsyncEventArgs Saea { get; set; } + public void Clear() { - SourceSocket?.SafeClose(); - SourceSocket = null; + Socket?.SafeClose(); + Socket = null; Buffer?.Clear(); + Saea?.Dispose(); + GC.Collect(); } } - - public enum ProxyStep : byte - { - Request = 1, - Forward = 2 - } - - public enum ProxyDirection - { - Pack = 0, - UnPack = 1, - } - - public record struct ConnectId - { - public ulong connectId; - public int socket; - - public ConnectId(ulong connectId, int socket) - { - this.connectId = connectId; - this.socket = socket; - } - } } diff --git a/cmonitor/plugins/viewer/proxy/ViewerProxyClient.cs b/cmonitor/plugins/viewer/proxy/ViewerProxyClient.cs index e73486cb..49762343 100644 --- a/cmonitor/plugins/viewer/proxy/ViewerProxyClient.cs +++ b/cmonitor/plugins/viewer/proxy/ViewerProxyClient.cs @@ -1,11 +1,9 @@ using cmonitor.client.running; +using cmonitor.client.tunnel; using cmonitor.config; using cmonitor.plugins.relay; -using cmonitor.plugins.relay.transport; using cmonitor.plugins.tunnel; -using cmonitor.plugins.tunnel.transport; using common.libs; -using System.Net.Sockets; namespace cmonitor.plugins.viewer.proxy { @@ -16,7 +14,7 @@ namespace cmonitor.plugins.viewer.proxy private readonly RelayTransfer relayTransfer; private readonly Config config; - private Socket tunnelSocket; + private ITunnelConnection connection; public ViewerProxyClient(RunningConfig runningConfig, TunnelTransfer tunnelTransfer, RelayTransfer relayTransfer, Config config) { @@ -28,55 +26,27 @@ namespace cmonitor.plugins.viewer.proxy Start(0); Logger.Instance.Info($"start viewer proxy, port : {LocalEndpoint.Port}"); - Tunnel(); + tunnelTransfer.SetConnectCallback("viewer", BindConnectionReceive); + relayTransfer.SetConnectCallback("viewer", BindConnectionReceive); } protected override async Task Connect(AsyncUserToken token) { token.Proxy.TargetEP = runningConfig.Data.Viewer.ConnectEP; - if (tunnelSocket == null || tunnelSocket.Connected == false) + token.Connection = connection; + if (connection == null || connection.Connected == false) { - TunnelTransportState state = await tunnelTransfer.ConnectAsync(runningConfig.Data.Viewer.ServerMachine, "viewer"); - if (state != null) + connection = await tunnelTransfer.ConnectAsync(runningConfig.Data.Viewer.ServerMachine, "viewer"); + if (connection == null) { - if (state.TransportType == ProtocolType.Tcp) - { - tunnelSocket = state.ConnectedObject as Socket; - token.TargetSocket = tunnelSocket; - BindReceiveTarget(tunnelSocket, token.SourceSocket); - return; - } + connection = await relayTransfer.ConnectAsync(runningConfig.Data.Viewer.ServerMachine, "viewer", config.Data.Client.Relay.SecretKey); } - - RelayTransportState relayState = await relayTransfer.ConnectAsync(runningConfig.Data.Viewer.ServerMachine, "viewer", config.Data.Client.Relay.SecretKey); - if (relayState != null) + if (connection != null) { - tunnelSocket = relayState.Socket; - token.TargetSocket = tunnelSocket; - BindReceiveTarget(tunnelSocket, token.SourceSocket); - return; + BindConnectionReceive(connection); + token.Connection = connection; } - - tunnelSocket = null; } } - - private void Tunnel() - { - tunnelTransfer.OnConnected += (TunnelTransportState state) => - { - if (state != null && state.TransportType == ProtocolType.Tcp && state.TransactionId == "viewer" && state.Direction == TunnelTransportDirection.Reverse) - { - BindReceiveTarget(state.ConnectedObject as Socket, null); - } - }; - relayTransfer.OnConnected += (RelayTransportState state) => - { - if (state != null && state.Info.TransactionId == "viewer" && state.Direction == RelayTransportDirection.Reverse) - { - BindReceiveTarget(state.Socket, null); - } - }; - } } }