桌面共享使用打洞或者服务器中继进行代理

This commit is contained in:
少年郎秃头呀
2024-05-07 01:06:08 +08:00
parent e9af825ac2
commit fe61072d0a
13 changed files with 390 additions and 560 deletions

View File

@@ -1,10 +1,11 @@
using common.libs.extends;
using System.Net.Sockets;
using System.Text.Json.Serialization;
namespace cmonitor.client.tunnel
{
public delegate Task TunnelReceivceCallback(Memory<byte> data, object state);
public delegate Task TunnelCloseCallback(object state);
public delegate Task TunnelReceivceCallback(ITunnelConnection connection,Memory<byte> data, object state);
public delegate Task TunnelCloseCallback(ITunnelConnection connection,object state);
public enum TunnelProtocolType : byte
{
@@ -57,6 +58,7 @@ namespace cmonitor.client.tunnel
public bool Connected => Socket != null && Socket.Connected;
[JsonIgnore]
public Socket Socket { get; init; }
@@ -101,7 +103,7 @@ namespace cmonitor.client.tunnel
int offset = e.Offset;
int length = e.BytesTransferred;
await receiveCallback(e.Buffer.AsMemory(offset, length), e.UserToken).ConfigureAwait(false);
await receiveCallback(this,e.Buffer.AsMemory(offset, length), e.UserToken).ConfigureAwait(false);
if (Socket.Available > 0)
{
@@ -110,7 +112,7 @@ namespace cmonitor.client.tunnel
length = Socket.Receive(e.Buffer);
if (length > 0)
{
await receiveCallback(e.Buffer.AsMemory(offset, length), e.UserToken).ConfigureAwait(false);
await receiveCallback(this,e.Buffer.AsMemory(offset, length), e.UserToken).ConfigureAwait(false);
}
else
{
@@ -143,7 +145,7 @@ namespace cmonitor.client.tunnel
}
private void CloseClientSocket(SocketAsyncEventArgs e)
{
closeCallback(e.UserToken);
closeCallback(this,e.UserToken);
e.Dispose();
Close();
}

View File

@@ -1,9 +1,8 @@
using cmonitor.client.api;
using cmonitor.client.tunnel;
using cmonitor.config;
using cmonitor.plugins.relay.transport;
using common.libs;
using common.libs.api;
using common.libs.extends;
using System.Text;
namespace cmonitor.plugins.relay
@@ -27,17 +26,16 @@ namespace cmonitor.plugins.relay
{
try
{
RelayTransportState state = await relayTransfer.ConnectAsync(param.Content, "test", config.Data.Client.Relay.SecretKey);
if (state != null)
ITunnelConnection connection = await relayTransfer.ConnectAsync(param.Content, "test", config.Data.Client.Relay.SecretKey);
if (connection != null)
{
var socket = state.Socket;
for (int i = 0; i < 10; i++)
{
Logger.Instance.Debug($"relay [test] send {i}");
socket.Send(Encoding.UTF8.GetBytes($"snltty.relay.{i}"));
await connection.SendAsync(Encoding.UTF8.GetBytes($"snltty.relay.{i}"));
await Task.Delay(10);
}
socket.SafeClose();
connection.Close();
}
}
catch (Exception ex)
@@ -48,23 +46,21 @@ namespace cmonitor.plugins.relay
}
private void RelayTest()
{
relayTransfer.OnConnected += (RelayTransportState state) =>
{
if (state.Info.TransactionId == "test")
relayTransfer.SetConnectCallback("test", (ITunnelConnection connection) =>
{
Task.Run(() =>
{
byte[] bytes = new byte[1024];
while (true)
connection.BeginReceive(async (ITunnelConnection connection, Memory<byte> data, object state) =>
{
int length = state.Socket.Receive(bytes);
if (length == 0) break;
Logger.Instance.Debug($"relay [{state.Info.TransactionId}] receive {Encoding.UTF8.GetString(bytes.AsSpan(0,length))}");
}
Logger.Instance.Debug($"relay [{connection.TransactionId}] receive {Encoding.UTF8.GetString(data.Span)}");
await Task.CompletedTask;
},
async (ITunnelConnection connection, object state) =>
{
await Task.CompletedTask;
}, null);
});
});
}
};
}
}

View File

@@ -1,12 +1,9 @@
using cmonitor.client;
using cmonitor.client.tunnel;
using cmonitor.config;
using cmonitor.plugins.relay.transport;
using cmonitor.server;
using common.libs;
using common.libs.extends;
using Microsoft.Extensions.DependencyInjection;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
namespace cmonitor.plugins.relay
@@ -18,7 +15,7 @@ namespace cmonitor.plugins.relay
private readonly Config config;
private readonly ServiceProvider serviceProvider;
public Action<RelayTransportState> OnConnected { get; set; } = (state) => { };
private Dictionary<string, Action<ITunnelConnection>> OnConnected { get; } = new Dictionary<string, Action<ITunnelConnection>>();
public RelayTransfer(Config config, ServiceProvider serviceProvider)
{
@@ -35,7 +32,19 @@ namespace cmonitor.plugins.relay
Logger.Instance.Warning($"load relay transport:{string.Join(",", transports.Select(c => c.Name))}");
}
public async Task<RelayTransportState> ConnectAsync(string remoteMachineName, string transactionId, string secretKey)
public void SetConnectCallback(string transactionId, Action<ITunnelConnection> callback)
{
if (OnConnected.TryGetValue(transactionId, out Action<ITunnelConnection> _callback) == false)
{
OnConnected[transactionId] = callback;
}
else
{
OnConnected[transactionId] += callback;
}
}
public async Task<ITunnelConnection> ConnectAsync(string remoteMachineName, string transactionId, string secretKey)
{
IEnumerable<ITransport> _transports = transports.OrderBy(c => c.Name);
foreach (RelayCompactInfo item in config.Data.Client.Relay.Servers.Where(c => c.Disabled == false))
@@ -58,10 +67,10 @@ namespace cmonitor.plugins.relay
TransactionId = transactionId,
TransportName = transport.Name
};
Socket socket = await transport.RelayAsync(relayInfo);
if (socket != null)
ITunnelConnection connection = await transport.RelayAsync(relayInfo);
if (connection != null)
{
return new RelayTransportState { Info = relayInfo, Socket = socket, Direction = RelayTransportDirection.Forward };
return connection;
}
}
catch (Exception ex)
@@ -76,10 +85,13 @@ namespace cmonitor.plugins.relay
ITransport _transports = transports.FirstOrDefault(c => c.Name == relayInfo.TransportName);
if (_transports != null)
{
Socket socket = await _transports.OnBeginAsync(relayInfo);
if (socket != null)
ITunnelConnection connection = await _transports.OnBeginAsync(relayInfo);
if (connection != null)
{
OnConnected(new RelayTransportState { Info = relayInfo, Socket = socket, Direction = RelayTransportDirection.Reverse });
if (OnConnected.TryGetValue(connection.TransactionId, out Action<ITunnelConnection> callback))
{
callback(connection);
}
return true;
}
}

View File

@@ -1,14 +1,16 @@
using MemoryPack;
using cmonitor.client.tunnel;
using MemoryPack;
using System.Net;
using System.Net.Sockets;
namespace cmonitor.plugins.relay.transport
{
public interface ITransport
{
public string Name { get; }
public Task<Socket> RelayAsync(RelayInfo relayInfo);
public Task<Socket> OnBeginAsync(RelayInfo relayInfo);
public TunnelProtocolType ProtocolType { get; }
public Task<ITunnelConnection> RelayAsync(RelayInfo relayInfo);
public Task<ITunnelConnection> OnBeginAsync(RelayInfo relayInfo);
}
[MemoryPackable]
@@ -26,19 +28,4 @@ namespace cmonitor.plugins.relay.transport
}
public enum RelayTransportDirection : byte
{
Forward = 0,
Reverse = 1
}
public sealed class RelayTransportState
{
public RelayInfo Info { get; set; }
public RelayTransportDirection Direction { get; set; } = RelayTransportDirection.Reverse;
public Socket Socket { get; set; }
}
}

View File

@@ -1,4 +1,5 @@
using cmonitor.plugins.relay.messenger;
using cmonitor.client.tunnel;
using cmonitor.plugins.relay.messenger;
using cmonitor.server;
using common.libs;
using common.libs.extends;
@@ -11,6 +12,7 @@ namespace cmonitor.plugins.relay.transport
public sealed class TransportSelfHost : ITransport
{
public string Name => "self";
public TunnelProtocolType ProtocolType => TunnelProtocolType.Tcp;
private readonly TcpServer tcpServer;
private readonly MessengerSender messengerSender;
@@ -22,9 +24,9 @@ namespace cmonitor.plugins.relay.transport
this.messengerSender = messengerSender;
}
public async Task<Socket> RelayAsync(RelayInfo relayInfo)
public async Task<ITunnelConnection> RelayAsync(RelayInfo relayInfo)
{
Socket socket = new Socket(relayInfo.Server.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
Socket socket = new Socket(relayInfo.Server.AddressFamily, SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp);
socket.Reuse(true);
socket.IPv6Only(relayInfo.Server.AddressFamily, false);
await socket.ConnectAsync(relayInfo.Server).WaitAsync(TimeSpan.FromSeconds(5));
@@ -43,12 +45,21 @@ namespace cmonitor.plugins.relay.transport
}
await socket.SendAsync(relayFlagData);
await Task.Delay(10);
return socket;
return new TunnelConnectionTcp
{
Direction = TunnelDirection.Forward,
ProtocolType = TunnelProtocolType.Tcp,
RemoteMachineName = relayInfo.RemoteMachineName,
Socket = socket,
TransactionId = relayInfo.TransactionId,
TransportName = Name,
Type = TunnelType.Relay
};
}
public async Task<Socket> OnBeginAsync(RelayInfo relayInfo)
public async Task<ITunnelConnection> OnBeginAsync(RelayInfo relayInfo)
{
Socket socket = new Socket(relayInfo.Server.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
Socket socket = new Socket(relayInfo.Server.AddressFamily, SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp);
socket.Reuse(true);
socket.IPv6Only(relayInfo.Server.AddressFamily, false);
await socket.ConnectAsync(relayInfo.Server).WaitAsync(TimeSpan.FromSeconds(5));
@@ -67,7 +78,16 @@ namespace cmonitor.plugins.relay.transport
}
await socket.SendAsync(relayFlagData);
await Task.Delay(10);
return socket;
return new TunnelConnectionTcp
{
Direction = TunnelDirection.Reverse,
ProtocolType = TunnelProtocolType.Tcp,
RemoteMachineName = relayInfo.RemoteMachineName,
Socket = socket,
TransactionId = relayInfo.TransactionId,
TransportName = Name,
Type = TunnelType.Relay
};
}
}
}

View File

@@ -1,5 +1,6 @@
using cmonitor.client;
using cmonitor.client.api;
using cmonitor.client.tunnel;
using cmonitor.config;
using cmonitor.plugins.signin.messenger;
using cmonitor.plugins.tunnel.server;
@@ -38,17 +39,16 @@ namespace cmonitor.plugins.tunnel
{
try
{
TunnelTransportState state = await tunnelTransfer.ConnectAsync(param.Content, "test");
if (state != null)
ITunnelConnection connection = await tunnelTransfer.ConnectAsync(param.Content, "test");
if (connection != null)
{
var socket = state.ConnectedObject as Socket;
for (int i = 0; i < 10; i++)
{
Logger.Instance.Debug($"tunnel [test] send {i}");
socket.Send(BitConverter.GetBytes(i));
await connection.SendAsync(BitConverter.GetBytes(i));
await Task.Delay(10);
}
socket.SafeClose();
connection.Close();
}
}
catch (Exception ex)
@@ -59,17 +59,18 @@ namespace cmonitor.plugins.tunnel
}
private void TunnelTest()
{
tunnelTransfer.OnConnected += (TunnelTransportState state) =>
tunnelTransfer.SetConnectCallback("test", (ITunnelConnection connection) =>
{
if (state.TransactionId == "test" && state.TransportType == ProtocolType.Tcp)
{
tunnelBindServer.BindReceive(state.ConnectedObject as Socket, null, async (token, data) =>
{
Logger.Instance.Debug($"tunnel [{state.TransactionId}] receive {BitConverter.ToInt32(data.Span)}");
connection.BeginReceive(async (ITunnelConnection connection, Memory<byte> data, object state) => {
Logger.Instance.Debug($"tunnel [{connection.TransactionId}] receive {BitConverter.ToInt32(data.Span)}");
await Task.CompletedTask;
}, async (ITunnelConnection connection, object state) => {
await Task.CompletedTask;
}, null);
});
}
};
}
}
}

View File

@@ -1,4 +1,5 @@
using cmonitor.client;
using cmonitor.client.tunnel;
using cmonitor.config;
using cmonitor.plugins.tunnel.compact;
using cmonitor.plugins.tunnel.messenger;
@@ -8,8 +9,8 @@ using common.libs;
using common.libs.extends;
using MemoryPack;
using Microsoft.Extensions.DependencyInjection;
using System.Net.Sockets;
using System.Reflection;
using System.Transactions;
namespace cmonitor.plugins.tunnel
{
@@ -23,7 +24,7 @@ namespace cmonitor.plugins.tunnel
private readonly MessengerSender messengerSender;
private readonly CompactTransfer compactTransfer;
public Action<TunnelTransportState> OnConnected { get; set; } = (state) => { };
private Dictionary<string, Action<ITunnelConnection>> OnConnected { get; } = new Dictionary<string, Action<ITunnelConnection>>();
public TunnelTransfer(Config config, ServiceProvider serviceProvider, ClientSignInState clientSignInState, MessengerSender messengerSender, CompactTransfer compactTransfer)
{
@@ -53,45 +54,44 @@ namespace cmonitor.plugins.tunnel
Logger.Instance.Warning($"load tunnel transport:{string.Join(",", transports.Select(c => c.Name))}");
}
public async Task<TunnelTransportState> ConnectAsync(string remoteMachineName, string transactionId)
public async Task<ITunnelConnection> ConnectAsync(string remoteMachineName, string transactionId)
{
IEnumerable<ITransport> _transports = transports.OrderBy(c => c.Type);
IEnumerable<ITransport> _transports = transports.OrderBy(c => c.ProtocolType);
foreach (ITransport transport in _transports)
{
//获取自己的外网ip
TunnelTransportExternalIPInfo localInfo = await GetLocalInfo(transport.Type);
TunnelTransportExternalIPInfo localInfo = await GetLocalInfo(transport.ProtocolType);
if (localInfo == null)
{
continue;
}
//获取对方的外网ip
TunnelTransportExternalIPInfo remoteInfo = await GetRemoteInfo(remoteMachineName, transport.Type);
TunnelTransportExternalIPInfo remoteInfo = await GetRemoteInfo(remoteMachineName, transport.ProtocolType);
if (remoteInfo == null)
{
continue;
}
TunnelTransportInfo tunnelTransportInfo = new TunnelTransportInfo
{
Direction = TunnelTransportDirection.Forward,
Direction = TunnelDirection.Forward,
TransactionId = transactionId,
TransportName = transport.Name,
TransportType = transport.Type,
TransportType = transport.ProtocolType,
Local = localInfo,
Remote = remoteInfo,
};
TunnelTransportState state = await transport.ConnectAsync(tunnelTransportInfo);
if (state != null)
ITunnelConnection connection = await transport.ConnectAsync(tunnelTransportInfo);
if (connection != null)
{
state.Direction = TunnelTransportDirection.Forward;
_OnConnected(state);
return state;
_OnConnected(connection);
return connection;
}
}
return null;
}
public void OnBegin(TunnelTransportInfo tunnelTransportInfo)
{
ITransport _transports = transports.FirstOrDefault(c => c.Name == tunnelTransportInfo.TransportName && c.Type == tunnelTransportInfo.TransportType);
ITransport _transports = transports.FirstOrDefault(c => c.Name == tunnelTransportInfo.TransportName && c.ProtocolType == tunnelTransportInfo.TransportType);
if (_transports != null)
{
_transports.OnBegin(tunnelTransportInfo);
@@ -99,7 +99,7 @@ namespace cmonitor.plugins.tunnel
}
public void OnFail(TunnelTransportInfo tunnelTransportInfo)
{
ITransport _transports = transports.FirstOrDefault(c => c.Name == tunnelTransportInfo.TransportName && c.Type == tunnelTransportInfo.TransportType);
ITransport _transports = transports.FirstOrDefault(c => c.Name == tunnelTransportInfo.TransportName && c.ProtocolType == tunnelTransportInfo.TransportType);
if (_transports != null)
{
_transports.OnFail(tunnelTransportInfo);
@@ -111,7 +111,7 @@ namespace cmonitor.plugins.tunnel
return await GetLocalInfo(request.TransportType);
}
private async Task<TunnelTransportExternalIPInfo> GetLocalInfo(ProtocolType transportType)
private async Task<TunnelTransportExternalIPInfo> GetLocalInfo(TunnelProtocolType transportType)
{
TunnelCompactIPEndPoint[] ips = await compactTransfer.GetExternalIPAsync(transportType);
if (ips != null && ips.Length > 0)
@@ -126,7 +126,7 @@ namespace cmonitor.plugins.tunnel
}
return null;
}
private async Task<TunnelTransportExternalIPInfo> GetRemoteInfo(string remoteMachineName, ProtocolType transportType)
private async Task<TunnelTransportExternalIPInfo> GetRemoteInfo(string remoteMachineName, TunnelProtocolType transportType)
{
MessageResponeInfo resp = await messengerSender.SendReply(new MessageRequestWrap
{
@@ -168,6 +168,19 @@ namespace cmonitor.plugins.tunnel
}
public void SetConnectCallback(string transactionId, Action<ITunnelConnection> callback)
{
if (OnConnected.TryGetValue(transactionId, out Action<ITunnelConnection> _callback) == false)
{
OnConnected[transactionId] = callback;
}
else
{
OnConnected[transactionId] += callback;
}
}
public Dictionary<string, TunnelConnectInfo> Connections { get; } = new Dictionary<string, TunnelConnectInfo>();
private int connectionsChangeFlag = 1;
public bool ConnectionChanged => Interlocked.CompareExchange(ref connectionsChangeFlag, 0, 1) == 1;
@@ -191,23 +204,27 @@ namespace cmonitor.plugins.tunnel
info.Status = TunnelConnectStatus.Connecting;
Interlocked.Exchange(ref connectionsChangeFlag, 1);
}
private void _OnConnected(TunnelTransportState state)
private void _OnConnected(ITunnelConnection connection)
{
if (Logger.Instance.LoggerLevel <= LoggerTypes.DEBUG)
{
Logger.Instance.Debug($"tunnel connect [{state.TransactionId}]->{state.RemoteMachineName} success");
Logger.Instance.Debug($"tunnel connect [{connection.TransactionId}]->{connection.RemoteMachineName} success");
}
CheckDic(state.RemoteMachineName, out TunnelConnectInfo info);
CheckDic(connection.RemoteMachineName, out TunnelConnectInfo info);
info.Status = TunnelConnectStatus.Connected;
info.State = state;
info.Connection = connection;
Interlocked.Exchange(ref connectionsChangeFlag, 1);
OnConnected(state);
}
private void OnDisConnected(TunnelTransportState state)
if (OnConnected.TryGetValue(connection.TransactionId, out Action<ITunnelConnection> _callback) == false)
{
CheckDic(state.RemoteMachineName, out TunnelConnectInfo info);
_callback(connection);
}
}
private void OnDisConnected(ITunnelConnection connection)
{
CheckDic(connection.RemoteMachineName, out TunnelConnectInfo info);
info.Status = TunnelConnectStatus.None;
info.State = null;
info.Connection = null;
Interlocked.Exchange(ref connectionsChangeFlag, 1);
}
private void OnConnectFail(string machineName)
@@ -218,7 +235,7 @@ namespace cmonitor.plugins.tunnel
}
CheckDic(machineName, out TunnelConnectInfo info);
info.Status = TunnelConnectStatus.None;
info.State = null;
info.Connection = null;
Interlocked.Exchange(ref connectionsChangeFlag, 1);
}
private void CheckDic(string name, out TunnelConnectInfo info)
@@ -233,7 +250,7 @@ namespace cmonitor.plugins.tunnel
public sealed class TunnelConnectInfo
{
public TunnelConnectStatus Status { get; set; }
public TunnelTransportState State { get; set; }
public ITunnelConnection Connection { get; set; }
}
public enum TunnelConnectStatus
{

View File

@@ -1,4 +1,5 @@
using cmonitor.config;
using cmonitor.client.tunnel;
using cmonitor.config;
using common.libs;
using Microsoft.Extensions.DependencyInjection;
using System.Net;
@@ -28,7 +29,7 @@ namespace cmonitor.plugins.tunnel.compact
Logger.Instance.Warning($"load tunnel compacts:{string.Join(",", compacts.Select(c => c.Name))}");
}
public async Task<TunnelCompactIPEndPoint[]> GetExternalIPAsync(ProtocolType protocolType)
public async Task<TunnelCompactIPEndPoint[]> GetExternalIPAsync(TunnelProtocolType protocolType)
{
TunnelCompactIPEndPoint[] endpoints = new TunnelCompactIPEndPoint[config.Data.Client.Tunnel.Servers.Length];
@@ -42,12 +43,12 @@ namespace cmonitor.plugins.tunnel.compact
try
{
IPEndPoint server = NetworkHelper.GetEndPoint(item.Host, 3478);
if (protocolType == ProtocolType.Tcp)
if (protocolType == TunnelProtocolType.Tcp)
{
TunnelCompactIPEndPoint externalIP = await compact.GetTcpExternalIPAsync(server);
endpoints[i] = externalIP;
}
else if (protocolType == ProtocolType.Udp)
else if (protocolType == TunnelProtocolType.Udp)
{
TunnelCompactIPEndPoint externalIP = await compact.GetUdpExternalIPAsync(server);
endpoints[i] = externalIP;

View File

@@ -12,7 +12,6 @@ namespace cmonitor.plugins.tunnel.server
public Action<object, Socket> OnTcpConnected { get; set; } = (state, socket) => { };
public Action<object, UdpClient> OnUdpConnected { get; set; } = (state, udpClient) => { };
public Action<object> OnDisConnected { get; set; } = (state) => { };
private ConcurrentDictionary<int, SocketAsyncEventArgs> acceptBinds = new ConcurrentDictionary<int, SocketAsyncEventArgs>();
@@ -85,9 +84,6 @@ namespace cmonitor.plugins.tunnel.server
case SocketAsyncOperation.Accept:
ProcessAccept(e);
break;
case SocketAsyncOperation.Receive:
ProcessReceive(e);
break;
default:
break;
}
@@ -119,87 +115,6 @@ namespace cmonitor.plugins.tunnel.server
}
}
public void BindReceive(Socket socket, object state, OnTunnelData dataCallback)
{
if (socket == null || socket.RemoteEndPoint == null)
{
return;
}
socket.KeepAlive();
AsyncUserToken userToken = new AsyncUserToken
{
SourceSocket = socket,
State = state,
OnData = dataCallback,
LocalPort = (socket.LocalEndPoint as IPEndPoint).Port,
};
SocketAsyncEventArgs readEventArgs = new SocketAsyncEventArgs
{
UserToken = userToken,
SocketFlags = SocketFlags.None,
};
readEventArgs.SetBuffer(new byte[8 * 1024], 0, 8 * 1024);
readEventArgs.Completed += IO_Completed;
if (socket.ReceiveAsync(readEventArgs) == false)
{
ProcessReceive(readEventArgs);
}
}
private async void ProcessReceive(SocketAsyncEventArgs e)
{
try
{
AsyncUserToken token = (AsyncUserToken)e.UserToken;
if (e.BytesTransferred > 0 && e.SocketError == SocketError.Success)
{
int offset = e.Offset;
int length = e.BytesTransferred;
await token.OnData(token, e.Buffer.AsMemory(0, length));
if (token.SourceSocket.Available > 0)
{
while (token.SourceSocket.Available > 0)
{
length = token.SourceSocket.Receive(e.Buffer);
if (length > 0)
{
await token.OnData(token, e.Buffer.AsMemory(0, length));
}
else
{
CloseClientSocket(e);
return;
}
}
}
if (token.SourceSocket.Connected == false)
{
CloseClientSocket(e);
return;
}
if (token.SourceSocket.ReceiveAsync(e) == false)
{
ProcessReceive(e);
}
}
else
{
CloseClientSocket(e);
}
}
catch (Exception ex)
{
if (Logger.Instance.LoggerLevel <= LoggerTypes.DEBUG)
Logger.Instance.Error(ex);
CloseClientSocket(e);
}
}
private void CloseClientSocket(SocketAsyncEventArgs e)
{
if (e == null || e.UserToken == null) return;
@@ -214,7 +129,6 @@ namespace cmonitor.plugins.tunnel.server
{
CloseClientSocket(saea1);
}
OnDisConnected(token.State);
}
}
@@ -226,9 +140,6 @@ namespace cmonitor.plugins.tunnel.server
public Socket SourceSocket { get; set; }
public SocketAsyncEventArgs Saea { get; set; }
public object State { get; set; }
public OnTunnelData OnData { get; set; }
public int LocalPort { get; set; }
public void Clear()

View File

@@ -1,14 +1,13 @@
using MemoryPack;
using cmonitor.client.tunnel;
using MemoryPack;
using System.Net;
using System.Net.Sockets;
using System.Text.Json.Serialization;
namespace cmonitor.plugins.tunnel.transport
{
public interface ITransport
{
public string Name { get; }
public ProtocolType Type { get; }
public TunnelProtocolType ProtocolType { get; }
/// <summary>
/// 发送连接信息
@@ -26,11 +25,11 @@ namespace cmonitor.plugins.tunnel.transport
/// <summary>
/// 收到连接
/// </summary>
public Action<TunnelTransportState> OnConnected { get; set; }
public Action<ITunnelConnection> OnConnected { get; set; }
/// <summary>
/// 断开连接
/// </summary>
public Action<TunnelTransportState> OnDisConnected { get; set; }
public Action<ITunnelConnection> OnDisConnected { get; set; }
public Action<string> OnConnectFail { get; set; }
@@ -39,7 +38,7 @@ namespace cmonitor.plugins.tunnel.transport
/// </summary>
/// <param name="tunnelTransportInfo">你的名字</param>
/// <returns></returns>
public Task<TunnelTransportState> ConnectAsync(TunnelTransportInfo tunnelTransportInfo);
public Task<ITunnelConnection> ConnectAsync(TunnelTransportInfo tunnelTransportInfo);
/// <summary>
/// 收到开始连接
/// </summary>
@@ -57,7 +56,7 @@ namespace cmonitor.plugins.tunnel.transport
public sealed partial class TunnelTransportExternalIPRequestInfo
{
public string RemoteMachineName { get; set; }
public ProtocolType TransportType { get; set; }
public TunnelProtocolType TransportType { get; set; }
}
[MemoryPackable]
@@ -83,43 +82,11 @@ namespace cmonitor.plugins.tunnel.transport
public string TransactionId { get; set; }
public ProtocolType TransportType { get; set; }
public TunnelProtocolType TransportType { get; set; }
public string TransportName { get; set; }
public TunnelTransportDirection Direction { get; set; }
}
public enum TunnelTransportDirection : byte
{
Forward = 0,
Reverse = 1
}
public enum TunnelTransportType
{
Tcp = ProtocolType.Tcp,
Udp = ProtocolType.Udp,
}
public sealed class TunnelTransportState
{
public string RemoteMachineName { get; set; }
public string TransactionId { get; set; }
public string TransportName { get; set; }
public ProtocolType TransportType { get; set; }
public TunnelTransportDirection Direction { get; set; } = TunnelTransportDirection.Reverse;
[JsonIgnore]
public object ConnectedObject { get; set; }
public TunnelDirection Direction { get; set; }
}
public interface ITunnelConnection
{
public TunnelTransportType TransportType { get; }
}
}

View File

@@ -1,4 +1,5 @@
using cmonitor.plugins.tunnel.server;
using cmonitor.client.tunnel;
using cmonitor.plugins.tunnel.server;
using common.libs;
using common.libs.extends;
using System.Collections.Concurrent;
@@ -10,14 +11,14 @@ namespace cmonitor.plugins.tunnel.transport
public sealed class TransportTcpNutssb : ITransport
{
public string Name => "TcpNutssb";
public ProtocolType Type => ProtocolType.Tcp;
public TunnelProtocolType ProtocolType => TunnelProtocolType.Tcp;
public Func<TunnelTransportInfo, Task<bool>> OnSendConnectBegin { get; set; } = async (info) => { return await Task.FromResult<bool>(false); };
public Func<TunnelTransportInfo, Task> OnSendConnectFail { get; set; } = async (info) => { await Task.CompletedTask; };
public Action<TunnelTransportInfo> OnConnectBegin { get; set; } = (info) => { };
public Action<TunnelTransportInfo> OnConnecting { get; set; }
public Action<TunnelTransportState> OnConnected { get; set; } = (state) => { };
public Action<TunnelTransportState> OnDisConnected { get; set; } = (state) => { };
public Action<ITunnelConnection> OnConnected { get; set; } = (state) => { };
public Action<ITunnelConnection> OnDisConnected { get; set; } = (state) => { };
public Action<string> OnConnectFail { get; set; } = (machineName) => { };
@@ -26,26 +27,25 @@ namespace cmonitor.plugins.tunnel.transport
{
this.tunnelBindServer = tunnelBindServer;
tunnelBindServer.OnTcpConnected += OnTcpConnected;
tunnelBindServer.OnDisConnected += OnTcpDisConnected;
}
public async Task<TunnelTransportState> ConnectAsync(TunnelTransportInfo tunnelTransportInfo)
public async Task<ITunnelConnection> ConnectAsync(TunnelTransportInfo tunnelTransportInfo)
{
OnConnecting(tunnelTransportInfo);
//正向连接
tunnelTransportInfo.Direction = TunnelTransportDirection.Forward;
tunnelTransportInfo.Direction = TunnelDirection.Forward;
if (await OnSendConnectBegin(tunnelTransportInfo) == false)
{
OnConnectFail(tunnelTransportInfo.Remote.MachineName);
return null;
}
TunnelTransportState state = await ConnectForward(tunnelTransportInfo);
if (state != null) return state;
ITunnelConnection connection = await ConnectForward(tunnelTransportInfo);
if (connection != null) return connection;
//反向连接
TunnelTransportInfo tunnelTransportInfo1 = tunnelTransportInfo.ToJsonFormat().DeJson<TunnelTransportInfo>();
tunnelTransportInfo1.Direction = TunnelTransportDirection.Reverse;
tunnelTransportInfo1.Direction = TunnelDirection.Reverse;
tunnelBindServer.Bind(tunnelTransportInfo1.Local.Local, tunnelTransportInfo1);
BindAndTTL(tunnelTransportInfo1);
if (await OnSendConnectBegin(tunnelTransportInfo1) == false)
@@ -53,8 +53,8 @@ namespace cmonitor.plugins.tunnel.transport
OnConnectFail(tunnelTransportInfo.Remote.MachineName);
return null;
}
state = await WaitReverse(tunnelTransportInfo1);
if (state != null) return state;
connection = await WaitReverse(tunnelTransportInfo1);
if (connection != null) return connection;
//正向反向都失败
await OnSendConnectFail(tunnelTransportInfo);
@@ -65,22 +65,22 @@ namespace cmonitor.plugins.tunnel.transport
public void OnBegin(TunnelTransportInfo tunnelTransportInfo)
{
OnConnectBegin(tunnelTransportInfo);
if (tunnelTransportInfo.Direction == TunnelTransportDirection.Forward)
if (tunnelTransportInfo.Direction == TunnelDirection.Forward)
{
tunnelBindServer.Bind(tunnelTransportInfo.Local.Local, tunnelTransportInfo);
}
Task.Run(async () =>
{
if (tunnelTransportInfo.Direction == TunnelTransportDirection.Forward)
if (tunnelTransportInfo.Direction == TunnelDirection.Forward)
{
BindAndTTL(tunnelTransportInfo);
}
else
{
TunnelTransportState state = await ConnectForward(tunnelTransportInfo);
if (state != null)
ITunnelConnection connection = await ConnectForward(tunnelTransportInfo);
if (connection != null)
{
OnConnected(state);
OnConnected(connection);
}
else
{
@@ -96,7 +96,7 @@ namespace cmonitor.plugins.tunnel.transport
tunnelBindServer.RemoveBind(tunnelTransportInfo.Local.Local.Port);
}
private async Task<TunnelTransportState> ConnectForward(TunnelTransportInfo tunnelTransportInfo)
private async Task<ITunnelConnection> ConnectForward(TunnelTransportInfo tunnelTransportInfo)
{
await Task.Delay(20);
//要连接哪些IP
@@ -109,7 +109,7 @@ namespace cmonitor.plugins.tunnel.transport
};
foreach (IPEndPoint ep in eps)
{
Socket targetSocket = new(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
Socket targetSocket = new(ep.AddressFamily, SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp);
targetSocket.IPv6Only(ep.Address.AddressFamily, false);
targetSocket.ReuseBind(new IPEndPoint(ep.AddressFamily == AddressFamily.InterNetwork ? IPAddress.Any : IPAddress.IPv6Any, tunnelTransportInfo.Local.Local.Port));
IAsyncResult result = targetSocket.BeginConnect(ep, null, null);
@@ -133,13 +133,16 @@ namespace cmonitor.plugins.tunnel.transport
}
targetSocket.EndConnect(result);
return new TunnelTransportState
return new TunnelConnectionTcp
{
ConnectedObject = targetSocket,
Socket = targetSocket,
TransactionId = tunnelTransportInfo.TransactionId,
RemoteMachineName = tunnelTransportInfo.Remote.MachineName,
TransportName = Name,
TransportType = Type
Direction = tunnelTransportInfo.Direction,
ProtocolType = ProtocolType,
Type = TunnelType.P2P,
Label = string.Empty
};
}
catch (Exception ex)
@@ -167,7 +170,7 @@ namespace cmonitor.plugins.tunnel.transport
//过滤掉不支持IPV6的情况
IEnumerable<Socket> sockets = eps.Where(c => NotIPv6Support(c.Address) == false).Select(ip =>
{
Socket targetSocket = new(ip.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
Socket targetSocket = new(ip.AddressFamily, SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp);
try
{
targetSocket.IPv6Only(ip.Address.AddressFamily, false);
@@ -188,16 +191,16 @@ namespace cmonitor.plugins.tunnel.transport
}
private ConcurrentDictionary<string, TaskCompletionSource<TunnelTransportState>> reverseDic = new ConcurrentDictionary<string, TaskCompletionSource<TunnelTransportState>>();
private async Task<TunnelTransportState> WaitReverse(TunnelTransportInfo tunnelTransportInfo)
private ConcurrentDictionary<string, TaskCompletionSource<ITunnelConnection>> reverseDic = new ConcurrentDictionary<string, TaskCompletionSource<ITunnelConnection>>();
private async Task<ITunnelConnection> WaitReverse(TunnelTransportInfo tunnelTransportInfo)
{
TaskCompletionSource<TunnelTransportState> tcs = new TaskCompletionSource<TunnelTransportState>();
TaskCompletionSource<ITunnelConnection> tcs = new TaskCompletionSource<ITunnelConnection>();
reverseDic.TryAdd(tunnelTransportInfo.Remote.MachineName, tcs);
try
{
TunnelTransportState state = await tcs.Task.WaitAsync(TimeSpan.FromMilliseconds(3000));
return state;
ITunnelConnection connection = await tcs.Task.WaitAsync(TimeSpan.FromMilliseconds(3000));
return connection;
}
catch (Exception)
{
@@ -214,15 +217,18 @@ namespace cmonitor.plugins.tunnel.transport
{
if (state is TunnelTransportInfo _state && _state.TransportName == Name)
{
TunnelTransportState result = new TunnelTransportState
TunnelConnectionTcp result = new TunnelConnectionTcp
{
RemoteMachineName = _state.Remote.MachineName,
TransportType = ProtocolType.Tcp,
ConnectedObject = socket,
Direction = _state.Direction,
ProtocolType = TunnelProtocolType.Tcp,
Socket = socket,
Type = TunnelType.P2P,
TransactionId = _state.TransactionId,
TransportName = _state.TransportName,
Label = string.Empty,
};
if (reverseDic.TryRemove(_state.Remote.MachineName, out TaskCompletionSource<TunnelTransportState> tcs))
if (reverseDic.TryRemove(_state.Remote.MachineName, out TaskCompletionSource<ITunnelConnection> tcs))
{
tcs.SetResult(result);
return;
@@ -231,20 +237,6 @@ namespace cmonitor.plugins.tunnel.transport
OnConnected(result);
}
}
private void OnTcpDisConnected(object state)
{
if (state is TunnelTransportInfo _state && _state.TransportName == Name)
{
TunnelTransportState result = new TunnelTransportState
{
RemoteMachineName = _state.Remote.MachineName,
TransportType = ProtocolType.Tcp,
TransactionId = _state.TransactionId,
TransportName = _state.TransportName,
};
OnDisConnected(result);
}
}
private bool NotIPv6Support(IPAddress ip)
{

View File

@@ -1,4 +1,5 @@
using common.libs;
using cmonitor.client.tunnel;
using common.libs;
using common.libs.extends;
using System.Buffers;
using System.Collections.Concurrent;
@@ -9,10 +10,11 @@ namespace cmonitor.plugins.viewer.proxy
{
public class ViewerProxy
{
private SocketAsyncEventArgs acceptEventArg;
private AsyncUserToken userToken;
private Socket socket;
private UdpClient udpClient;
private NumberSpace ns = new NumberSpace();
private readonly NumberSpace ns = new NumberSpace();
private readonly ConcurrentDictionary<ConnectId, Socket> dic = new ConcurrentDictionary<ConnectId, Socket>();
public IPEndPoint LocalEndpoint => socket?.LocalEndPoint as IPEndPoint ?? new IPEndPoint(IPAddress.Any, 0);
@@ -32,14 +34,16 @@ namespace cmonitor.plugins.viewer.proxy
socket.ReuseBind(localEndPoint);
socket.Listen(int.MaxValue);
acceptEventArg = new SocketAsyncEventArgs
userToken = new AsyncUserToken
{
UserToken = new AsyncUserToken
Socket = socket
};
SocketAsyncEventArgs acceptEventArg = new SocketAsyncEventArgs
{
SourceSocket = socket
},
UserToken = userToken,
SocketFlags = SocketFlags.None,
};
userToken.Saea = acceptEventArg;
acceptEventArg.Completed += IO_Completed;
StartAccept(acceptEventArg);
@@ -56,11 +60,11 @@ namespace cmonitor.plugins.viewer.proxy
}
}
private readonly AsyncUserUdpToken asyncUserUdpToken = new AsyncUserUdpToken
{
Proxy = new ProxyInfo { Step = ProxyStep.Forward, Direction = ProxyDirection.UnPack, ConnectId = 0 }
Proxy = new ProxyInfo { Step = ProxyStep.Forward, ConnectId = 0 }
};
private async void ReceiveCallbackUdp(IAsyncResult result)
{
try
@@ -83,15 +87,13 @@ namespace cmonitor.plugins.viewer.proxy
await Task.CompletedTask;
}
private void StartAccept(SocketAsyncEventArgs acceptEventArg)
{
acceptEventArg.AcceptSocket = null;
AsyncUserToken token = (AsyncUserToken)acceptEventArg.UserToken;
try
{
if (token.SourceSocket.AcceptAsync(acceptEventArg) == false)
if (token.Socket.AcceptAsync(acceptEventArg) == false)
{
ProcessAccept(acceptEventArg);
}
@@ -138,7 +140,7 @@ namespace cmonitor.plugins.viewer.proxy
socket.KeepAlive();
AsyncUserToken userToken = new AsyncUserToken
{
SourceSocket = socket,
Socket = socket,
Proxy = new ProxyInfo { Data = Helper.EmptyArray, Step = ProxyStep.Request, ConnectId = ns.Increment() }
};
@@ -147,6 +149,8 @@ namespace cmonitor.plugins.viewer.proxy
UserToken = userToken,
SocketFlags = SocketFlags.None,
};
userToken.Saea = readEventArgs;
readEventArgs.SetBuffer(new byte[8 * 1024], 0, 8 * 1024);
readEventArgs.Completed += IO_Completed;
if (socket.ReceiveAsync(readEventArgs) == false)
@@ -162,46 +166,45 @@ namespace cmonitor.plugins.viewer.proxy
}
private async void ProcessReceive(SocketAsyncEventArgs e)
{
AsyncUserToken token = (AsyncUserToken)e.UserToken;
try
{
AsyncUserToken token = (AsyncUserToken)e.UserToken;
if (e.BytesTransferred > 0 && e.SocketError == SocketError.Success)
{
int offset = e.Offset;
int length = e.BytesTransferred;
await ReadPacket(e, token, e.Buffer.AsMemory(offset, length));
if (token.SourceSocket.Available > 0)
await ReadPacket(token, e.Buffer.AsMemory(offset, length));
if (token.Socket.Available > 0)
{
while (token.SourceSocket.Available > 0)
while (token.Socket.Available > 0)
{
length = token.SourceSocket.Receive(e.Buffer);
length = token.Socket.Receive(e.Buffer);
if (length > 0)
{
await ReadPacket(e, token, e.Buffer.AsMemory(0, length));
await ReadPacket(token, e.Buffer.AsMemory(0, length));
}
else
{
CloseClientSocket(e);
CloseClientSocket(token);
return;
}
}
}
if (token.SourceSocket.Connected == false)
if (token.Socket.Connected == false)
{
CloseClientSocket(e);
CloseClientSocket(token);
return;
}
if (token.SourceSocket.ReceiveAsync(e) == false)
if (token.Socket.ReceiveAsync(e) == false)
{
ProcessReceive(e);
}
}
else
{
CloseClientSocket(e);
CloseClientSocket(token);
}
}
catch (Exception ex)
@@ -209,55 +212,38 @@ namespace cmonitor.plugins.viewer.proxy
if (Logger.Instance.LoggerLevel <= LoggerTypes.DEBUG)
Logger.Instance.Error(ex);
CloseClientSocket(e);
CloseClientSocket(token);
}
}
private async Task ReadPacket(SocketAsyncEventArgs e, AsyncUserToken token, Memory<byte> data)
private async Task ReadPacket(AsyncUserToken token, Memory<byte> data)
{
if (token.Proxy.Step == ProxyStep.Request)
{
await Connect(token);
if (token.TargetSocket != null)
if (token.Connection != null)
{
//发送连接请求包
await SendToTarget(e, token).ConfigureAwait(false);
await SendToConnection(token).ConfigureAwait(false);
token.Proxy.Step = ProxyStep.Forward;
token.Proxy.TargetEP = null;
//发送后续数据包
token.Proxy.Data = data;
await SendToTarget(e, token).ConfigureAwait(false);
await SendToConnection(token).ConfigureAwait(false);
//绑定
dic.TryAdd(new ConnectId(token.Proxy.ConnectId, token.TargetSocket.GetHashCode()), token.SourceSocket);
dic.TryAdd(new ConnectId(token.Proxy.ConnectId, token.Connection.GetHashCode()), token.Socket);
}
else
{
CloseClientSocket(e);
CloseClientSocket(token);
}
}
else
{
token.Proxy.Data = data;
await SendToTarget(e, token).ConfigureAwait(false);
}
}
private async Task SendToTarget(SocketAsyncEventArgs e, AsyncUserToken token)
{
byte[] connectData = token.Proxy.ToBytes(out int length);
try
{
await token.TargetSocket.SendAsync(connectData.AsMemory(0, length), SocketFlags.None);
}
catch (Exception)
{
CloseClientSocket(e);
}
finally
{
token.Proxy.Return(connectData);
await SendToConnection(token).ConfigureAwait(false);
}
}
@@ -266,26 +252,115 @@ namespace cmonitor.plugins.viewer.proxy
await Task.CompletedTask;
}
protected bool BindReceiveTarget(Socket targetSocket, Socket sourceSocket)
private async Task SendToConnection(AsyncUserToken token)
{
byte[] connectData = token.Proxy.ToBytes(out int length);
try
{
await token.Connection.SendAsync(connectData.AsMemory(0, length)).ConfigureAwait(false);
}
catch (Exception)
{
CloseClientSocket(token);
}
finally
{
token.Proxy.Return(connectData);
}
}
protected void BindConnectionReceive(ITunnelConnection connection)
{
connection.BeginReceive(InputConnectionData, CloseConnection, new AsyncUserToken
{
Connection = connection,
Buffer = new ReceiveDataBuffer(),
Proxy = new ProxyInfo { }
});
}
protected async Task InputConnectionData(ITunnelConnection connection, Memory<byte> memory, object userToken)
{
AsyncUserToken token = userToken as AsyncUserToken;
//是一个完整的包
if (token.Buffer.Size == 0 && memory.Length > 4)
{
int packageLen = memory.ToInt32();
if (packageLen == memory.Length - 4)
{
token.Proxy.DeBytes(memory.Slice(0, packageLen + 4));
await ReadConnectionPack(token).ConfigureAwait(false);
return;
}
}
//不是完整包
token.Buffer.AddRange(memory);
do
{
int packageLen = token.Buffer.Data.ToInt32();
if (packageLen > token.Buffer.Size - 4)
{
break;
}
token.Proxy.DeBytes(token.Buffer.Data.Slice(0, packageLen + 4));
await ReadConnectionPack(token).ConfigureAwait(false);
token.Buffer.RemoveRange(0, packageLen + 4);
} while (token.Buffer.Size > 4);
}
protected async Task CloseConnection(ITunnelConnection connection, object userToken)
{
CloseClientSocket(userToken as AsyncUserToken);
await Task.CompletedTask;
}
private async Task ReadConnectionPack(AsyncUserToken token)
{
if (token.Proxy.Step == ProxyStep.Request)
{
await ConnectBind(token).ConfigureAwait(false);
}
else
{
await SendToSocket(token).ConfigureAwait(false);
}
}
private async Task ConnectBind(AsyncUserToken token)
{
Socket socket = new Socket(token.Proxy.TargetEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
socket.KeepAlive();
await socket.ConnectAsync(token.Proxy.TargetEP);
dic.TryAdd(new ConnectId(token.Proxy.ConnectId, token.Connection.GetHashCode()), socket);
BindReceiveTarget(new AsyncUserToken
{
Connection = token.Connection,
Socket = socket,
Proxy = new ProxyInfo
{
ConnectId = token.Proxy.ConnectId,
Step = ProxyStep.Forward
}
});
}
private async Task SendToSocket(AsyncUserToken token)
{
ConnectId connectId = new ConnectId(token.Proxy.ConnectId, token.Connection.GetHashCode());
if (dic.TryGetValue(connectId, out Socket source))
{
try
{
BindReceiveTarget(new AsyncUserToken
await source.SendAsync(token.Proxy.Data);
}
catch (Exception)
{
TargetSocket = targetSocket,
SourceSocket = sourceSocket,
Buffer = new ReceiveDataBuffer(),
Proxy = new ProxyInfo { Direction = ProxyDirection.UnPack }
});
CloseClientSocket(token);
return true;
}
catch (Exception ex)
{
Logger.Instance.Error(ex);
}
return false;
}
private void IO_CompletedTarget(object sender, SocketAsyncEventArgs e)
{
switch (e.LastOperation)
@@ -308,7 +383,7 @@ namespace cmonitor.plugins.viewer.proxy
};
readEventArgs.SetBuffer(new byte[8 * 1024], 0, 8 * 1024);
readEventArgs.Completed += IO_CompletedTarget;
if (userToken.TargetSocket.ReceiveAsync(readEventArgs) == false)
if (userToken.Socket.ReceiveAsync(readEventArgs) == false)
{
ProcessReceiveTarget(readEventArgs);
}
@@ -321,48 +396,49 @@ namespace cmonitor.plugins.viewer.proxy
}
private async void ProcessReceiveTarget(SocketAsyncEventArgs e)
{
AsyncUserToken token = (AsyncUserToken)e.UserToken;
try
{
AsyncUserToken token = (AsyncUserToken)e.UserToken;
if (e.BytesTransferred > 0 && e.SocketError == SocketError.Success)
{
int offset = e.Offset;
int length = e.BytesTransferred;
await ReadPacketTarget(e, token, e.Buffer.AsMemory(offset, length)).ConfigureAwait(false);
token.Proxy.Data = e.Buffer.AsMemory(offset, length);
await SendToConnection(token).ConfigureAwait(false);
if (token.TargetSocket.Available > 0)
if (token.Socket.Available > 0)
{
while (token.TargetSocket.Available > 0)
while (token.Socket.Available > 0)
{
length = token.TargetSocket.Receive(e.Buffer);
length = token.Socket.Receive(e.Buffer);
if (length > 0)
{
await ReadPacketTarget(e, token, e.Buffer.AsMemory(0, length)).ConfigureAwait(false);
token.Proxy.Data = e.Buffer.AsMemory(0, length);
await SendToConnection(token).ConfigureAwait(false);
}
else
{
CloseClientSocket(e);
CloseClientSocket(token);
return;
}
}
}
if (token.TargetSocket.Connected == false)
if (token.Connection.Connected == false)
{
CloseClientSocket(e);
CloseClientSocket(token);
return;
}
if (token.TargetSocket.ReceiveAsync(e) == false)
if (token.Socket.ReceiveAsync(e) == false)
{
ProcessReceiveTarget(e);
}
}
else
{
CloseClientSocket(e);
CloseClientSocket(token);
}
}
catch (Exception ex)
@@ -370,126 +446,19 @@ namespace cmonitor.plugins.viewer.proxy
if (Logger.Instance.LoggerLevel <= LoggerTypes.DEBUG)
Logger.Instance.Error(ex);
CloseClientSocket(e);
CloseClientSocket(token);
}
}
private readonly ConcurrentDictionary<ConnectId, Socket> dic = new ConcurrentDictionary<ConnectId, Socket>();
private async Task ReadPacketTarget(SocketAsyncEventArgs e, AsyncUserToken token, Memory<byte> data)
{
//A 到 B
if (token.Proxy.Direction == ProxyDirection.UnPack)
{
//是一个完整的包
if (token.Buffer.Size == 0 && data.Length > 4)
{
int packageLen = data.ToInt32();
if (packageLen == data.Length - 4)
{
token.Proxy.DeBytes(data.Slice(0, packageLen + 4));
await ReadPacketTarget(e, token).ConfigureAwait(false);
return;
}
}
//不是完整包
token.Buffer.AddRange(data);
do
private void CloseClientSocket(AsyncUserToken token)
{
int packageLen = token.Buffer.Data.ToInt32();
if (packageLen > token.Buffer.Size - 4)
if (token.Connection != null)
{
break;
}
token.Proxy.DeBytes(token.Buffer.Data.Slice(0, packageLen + 4));
await ReadPacketTarget(e, token).ConfigureAwait(false);
token.Buffer.RemoveRange(0, packageLen + 4);
} while (token.Buffer.Size > 4);
}
else
int code = token.Connection.GetHashCode();
if (token.Connection.Connected == false)
{
token.Proxy.Data = data;
await SendToSource(e, token).ConfigureAwait(false);
}
}
private async Task ReadPacketTarget(SocketAsyncEventArgs e, AsyncUserToken token)
{
if (token.Proxy.Step == ProxyStep.Request)
{
await ConnectBind(e, token).ConfigureAwait(false);
}
else
{
await SendToSource(e, token).ConfigureAwait(false);
}
}
private async Task ConnectBind(SocketAsyncEventArgs e, AsyncUserToken token)
{
Socket socket = new Socket(token.Proxy.TargetEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
socket.KeepAlive();
await socket.ConnectAsync(token.Proxy.TargetEP);
dic.TryAdd(new ConnectId(token.Proxy.ConnectId, token.TargetSocket.GetHashCode()), socket);
BindReceiveTarget(new AsyncUserToken
{
TargetSocket = socket,
SourceSocket = token.TargetSocket,
Proxy = new ProxyInfo
{
Direction = ProxyDirection.Pack,
ConnectId = token.Proxy.ConnectId,
Step = ProxyStep.Forward
}
});
}
private async Task SendToSource(SocketAsyncEventArgs e, AsyncUserToken token)
{
if (token.Proxy.Direction == ProxyDirection.UnPack)
{
ConnectId connectId = new ConnectId(token.Proxy.ConnectId, token.TargetSocket.GetHashCode());
if (dic.TryGetValue(connectId, out Socket source))
{
try
{
await source.SendAsync(token.Proxy.Data);
}
catch (Exception)
{
CloseClientSocket(e);
}
}
}
else
{
byte[] connectData = token.Proxy.ToBytes(out int length);
try
{
await token.SourceSocket.SendAsync(connectData.AsMemory(0, length), SocketFlags.None);
}
catch (Exception)
{
CloseClientSocket(e);
}
finally
{
token.Proxy.Return(connectData);
}
}
}
private void CloseClientSocket(SocketAsyncEventArgs e)
{
if (e == null) return;
AsyncUserToken token = e.UserToken as AsyncUserToken;
if (token.TargetSocket != null)
{
int code = token.TargetSocket.GetHashCode();
if (token.TargetSocket.Connected == false)
{
foreach (ConnectId item in dic.Keys.Where(c => c.socket == code).ToList())
foreach (ConnectId item in dic.Keys.Where(c => c.hashCode == code).ToList())
{
dic.TryRemove(item, out _);
}
@@ -499,28 +468,36 @@ namespace cmonitor.plugins.viewer.proxy
dic.TryRemove(new ConnectId(token.Proxy.ConnectId, code), out _);
}
}
if (token.SourceSocket != null)
{
token.Clear();
e.Dispose();
}
}
public void Stop()
{
CloseClientSocket(acceptEventArg);
CloseClientSocket(userToken);
udpClient?.Close();
}
}
public enum ProxyStep : byte
{
Request = 1,
Forward = 2
}
public record struct ConnectId
{
public ulong connectId;
public int hashCode;
public ConnectId(ulong connectId, int hashCode)
{
this.connectId = connectId;
this.hashCode = hashCode;
}
}
public sealed class ProxyInfo
{
public ulong ConnectId { get; set; }
public ProxyStep Step { get; set; } = ProxyStep.Request;
public ProxyDirection Direction { get; set; } = ProxyDirection.Pack;
public IPEndPoint TargetEP { get; set; }
public Memory<byte> Data { get; set; }
@@ -595,7 +572,6 @@ namespace cmonitor.plugins.viewer.proxy
}
}
public sealed class AsyncUserUdpToken
{
public UdpClient SourceSocket { get; set; }
@@ -606,53 +582,31 @@ namespace cmonitor.plugins.viewer.proxy
{
SourceSocket?.Close();
SourceSocket = null;
GC.Collect();
}
}
public sealed class AsyncUserToken
{
public Socket SourceSocket { get; set; }
public Socket TargetSocket { get; set; }
public Socket Socket { get; set; }
public ITunnelConnection Connection { get; set; }
public ProxyInfo Proxy { get; set; }
public ReceiveDataBuffer Buffer { get; set; }
public SocketAsyncEventArgs Saea { get; set; }
public void Clear()
{
SourceSocket?.SafeClose();
SourceSocket = null;
Socket?.SafeClose();
Socket = null;
Buffer?.Clear();
Saea?.Dispose();
GC.Collect();
}
}
public enum ProxyStep : byte
{
Request = 1,
Forward = 2
}
public enum ProxyDirection
{
Pack = 0,
UnPack = 1,
}
public record struct ConnectId
{
public ulong connectId;
public int socket;
public ConnectId(ulong connectId, int socket)
{
this.connectId = connectId;
this.socket = socket;
}
}
}

View File

@@ -1,11 +1,9 @@
using cmonitor.client.running;
using cmonitor.client.tunnel;
using cmonitor.config;
using cmonitor.plugins.relay;
using cmonitor.plugins.relay.transport;
using cmonitor.plugins.tunnel;
using cmonitor.plugins.tunnel.transport;
using common.libs;
using System.Net.Sockets;
namespace cmonitor.plugins.viewer.proxy
{
@@ -16,7 +14,7 @@ namespace cmonitor.plugins.viewer.proxy
private readonly RelayTransfer relayTransfer;
private readonly Config config;
private Socket tunnelSocket;
private ITunnelConnection connection;
public ViewerProxyClient(RunningConfig runningConfig, TunnelTransfer tunnelTransfer, RelayTransfer relayTransfer, Config config)
{
@@ -28,55 +26,27 @@ namespace cmonitor.plugins.viewer.proxy
Start(0);
Logger.Instance.Info($"start viewer proxy, port : {LocalEndpoint.Port}");
Tunnel();
tunnelTransfer.SetConnectCallback("viewer", BindConnectionReceive);
relayTransfer.SetConnectCallback("viewer", BindConnectionReceive);
}
protected override async Task Connect(AsyncUserToken token)
{
token.Proxy.TargetEP = runningConfig.Data.Viewer.ConnectEP;
if (tunnelSocket == null || tunnelSocket.Connected == false)
token.Connection = connection;
if (connection == null || connection.Connected == false)
{
TunnelTransportState state = await tunnelTransfer.ConnectAsync(runningConfig.Data.Viewer.ServerMachine, "viewer");
if (state != null)
connection = await tunnelTransfer.ConnectAsync(runningConfig.Data.Viewer.ServerMachine, "viewer");
if (connection == null)
{
if (state.TransportType == ProtocolType.Tcp)
connection = await relayTransfer.ConnectAsync(runningConfig.Data.Viewer.ServerMachine, "viewer", config.Data.Client.Relay.SecretKey);
}
if (connection != null)
{
tunnelSocket = state.ConnectedObject as Socket;
token.TargetSocket = tunnelSocket;
BindReceiveTarget(tunnelSocket, token.SourceSocket);
return;
}
}
RelayTransportState relayState = await relayTransfer.ConnectAsync(runningConfig.Data.Viewer.ServerMachine, "viewer", config.Data.Client.Relay.SecretKey);
if (relayState != null)
{
tunnelSocket = relayState.Socket;
token.TargetSocket = tunnelSocket;
BindReceiveTarget(tunnelSocket, token.SourceSocket);
return;
}
tunnelSocket = null;
}
}
private void Tunnel()
{
tunnelTransfer.OnConnected += (TunnelTransportState state) =>
{
if (state != null && state.TransportType == ProtocolType.Tcp && state.TransactionId == "viewer" && state.Direction == TunnelTransportDirection.Reverse)
{
BindReceiveTarget(state.ConnectedObject as Socket, null);
}
};
relayTransfer.OnConnected += (RelayTransportState state) =>
{
if (state != null && state.Info.TransactionId == "viewer" && state.Direction == RelayTransportDirection.Reverse)
{
BindReceiveTarget(state.Socket, null);
}
};
BindConnectionReceive(connection);
token.Connection = connection;
}
}
}
}
}