using System.Net.Sockets;
using System.Buffers;
using linker.libs.extends;
using System.Collections.Concurrent;
using System.Net;
using linker.libs;
namespace linker.messenger.relay.server
{
///
/// 中继连接处理
///
public class RelayServerResolver: IResolver
{
public ResolverType Type => ResolverType.Relay;
private readonly RelayServerNodeTransfer relayServerNodeTransfer;
private readonly ISerializer serializer;
public RelayServerResolver(RelayServerNodeTransfer relayServerNodeTransfer, ISerializer serializer)
{
this.relayServerNodeTransfer = relayServerNodeTransfer;
this.serializer = serializer;
}
private readonly ConcurrentDictionary relayDic = new ConcurrentDictionary();
public virtual void AddReceive(string key, string from, string to, string groupid, ulong bytes)
{
}
public virtual void AddSendt(string key, string from, string to, string groupid, ulong bytes)
{
}
public virtual void AddReceive(string key, ulong bytes)
{
}
public virtual void AddSendt(string key, ulong bytes)
{
}
public async Task Resolve(Socket socket, IPEndPoint ep, Memory memory)
{
await Task.CompletedTask;
}
public async Task Resolve(Socket socket, Memory memory)
{
byte[] buffer = ArrayPool.Shared.Rent(1024);
try
{
int length = await socket.ReceiveAsync(buffer.AsMemory(), SocketFlags.None).ConfigureAwait(false);
RelayMessageInfo relayMessage = serializer.Deserialize(buffer.AsMemory(0, length).Span);
if (relayMessage.Type == RelayMessengerType.Ask && relayMessage.NodeId != RelayServerNodeInfo.MASTER_NODE_ID)
{
if (relayServerNodeTransfer.Validate() == false)
{
if (LoggerHelper.Instance.LoggerLevel <= LoggerTypes.DEBUG)
LoggerHelper.Instance.Error($"relay {relayMessage.Type} Validate false,flowid:{relayMessage.FlowId}");
await socket.SendAsync(new byte[] { 1 });
socket.SafeClose();
return;
}
}
//ask 是发起端来的,那key就是 发起端->目标端, answer的,目标和来源会交换,所以转换一下
string key = relayMessage.Type == RelayMessengerType.Ask ? $"{relayMessage.FromId}->{relayMessage.ToId}->{relayMessage.FlowId}" : $"{relayMessage.ToId}->{relayMessage.FromId}->{relayMessage.FlowId}";
//获取缓存
RelayCacheInfo relayCache = await relayServerNodeTransfer.TryGetRelayCache(key, relayMessage.NodeId);
if (relayCache == null)
{
if (LoggerHelper.Instance.LoggerLevel <= LoggerTypes.DEBUG)
LoggerHelper.Instance.Error($"relay {relayMessage.Type} get cache fail,flowid:{relayMessage.FlowId}");
socket.SafeClose();
return;
}
//流量统计
AddReceive(relayCache.FromId, relayCache.FromName, relayCache.ToName, relayCache.GroupId, (ulong)length);
try
{
switch (relayMessage.Type)
{
case RelayMessengerType.Ask:
{
//添加本地缓存
RelayWrapInfo relayWrap = new RelayWrapInfo { Socket = socket, Tcs = new TaskCompletionSource() };
relayWrap.Limit.SetLimit(relayServerNodeTransfer.GetBandwidthLimit());
relayDic.TryAdd(relayCache.FlowId, relayWrap);
await socket.SendAsync(new byte[] { 0 });
//等待对方连接
Socket targetSocket = await relayWrap.Tcs.Task.WaitAsync(TimeSpan.FromMilliseconds(15000));
_ = CopyToAsync(relayCache, relayWrap.Limit, socket, targetSocket);
}
break;
case RelayMessengerType.Answer:
{
//看发起端缓存
if (relayDic.TryRemove(relayCache.FlowId, out RelayWrapInfo relayWrap) == false || relayWrap.Socket == null)
{
if (LoggerHelper.Instance.LoggerLevel <= LoggerTypes.DEBUG)
LoggerHelper.Instance.Error($"relay {relayMessage.Type} get cache fail,flowid:{relayMessage.FlowId}");
socket.SafeClose();
return;
}
//告诉发起端我的socket
relayWrap.Tcs.SetResult(socket);
_ = CopyToAsync(relayCache, relayWrap.Limit, socket, relayWrap.Socket);
}
break;
default:
{
if (LoggerHelper.Instance.LoggerLevel <= LoggerTypes.DEBUG)
LoggerHelper.Instance.Error($"relay {relayMessage.Type} unknow type,flowid:{relayMessage.FlowId}");
socket.SafeClose();
}
break;
}
}
catch (Exception ex)
{
if (LoggerHelper.Instance.LoggerLevel <= LoggerTypes.DEBUG)
LoggerHelper.Instance.Error($"{ex},flowid:{relayMessage.FlowId}");
if (relayDic.TryRemove(relayCache.FlowId, out RelayWrapInfo remove))
{
remove.Socket?.SafeClose();
}
}
}
catch (Exception ex)
{
if (LoggerHelper.Instance.LoggerLevel <= LoggerTypes.DEBUG)
LoggerHelper.Instance.Error(ex);
socket.SafeClose();
}
finally
{
ArrayPool.Shared.Return(buffer);
}
}
private async Task CopyToAsync(RelayCacheInfo cache, RelaySpeedLimit limit, Socket source, Socket destination)
{
byte[] buffer = new byte[4 * 1024];
try
{
relayServerNodeTransfer.IncrementConnectionNum();
int bytesRead;
while ((bytesRead = await source.ReceiveAsync(buffer.AsMemory()).ConfigureAwait(false)) != 0)
{
//流量限制
if (relayServerNodeTransfer.AddBytes((ulong)bytesRead) == false)
{
source.SafeClose();
break;
}
//总速度
if (relayServerNodeTransfer.NeedLimit())
{
int length = bytesRead;
relayServerNodeTransfer.TryLimit(ref length);
while (length > 0)
{
await Task.Delay(30).ConfigureAwait(false);
relayServerNodeTransfer.TryLimit(ref length);
}
}
//单个速度
if (limit.NeedLimit())
{
int length = bytesRead;
limit.TryLimit(ref length);
while (length > 0)
{
await Task.Delay(30).ConfigureAwait(false);
limit.TryLimit(ref length);
}
}
AddReceive(cache.FromId, cache.FromName, cache.ToName, cache.GroupId, (ulong)bytesRead);
AddSendt(cache.FromId, cache.FromName, cache.ToName, cache.GroupId, (ulong)bytesRead);
await destination.SendAsync(buffer.AsMemory(0, bytesRead)).ConfigureAwait(false);
}
}
catch (Exception)
{
}
finally
{
relayServerNodeTransfer.DecrementConnectionNum();
}
}
}
public enum RelayMessengerType : byte
{
Ask = 0,
Answer = 1,
}
public class RelaySpeedLimit
{
private uint relayLimit = 0;
private double relayLimitToken = 0;
private double relayLimitBucket = 0;
private long relayLimitTicks = Environment.TickCount64;
public bool NeedLimit()
{
return relayLimit > 0;
}
public void SetLimit(uint bytes)
{
relayLimit = bytes;
relayLimitToken = relayLimit / 1000.0;
relayLimitBucket = relayLimit;
}
public bool TryLimit(ref int length)
{
if (relayLimit == 0) return false;
lock (this)
{
long _relayLimitTicks = Environment.TickCount64;
long relayLimitTicksTemp = _relayLimitTicks - relayLimitTicks;
relayLimitTicks = _relayLimitTicks;
relayLimitBucket += relayLimitTicksTemp * relayLimitToken;
if (relayLimitBucket > relayLimit) relayLimitBucket = relayLimit;
if (relayLimitBucket >= length)
{
relayLimitBucket -= length;
length = 0;
}
else
{
length -= (int)relayLimitBucket;
relayLimitBucket = 0;
}
}
return true;
}
}
public sealed partial class RelayCacheInfo
{
public ulong FlowId { get; set; }
public string FromId { get; set; }
public string FromName { get; set; }
public string ToId { get; set; }
public string ToName { get; set; }
public string GroupId { get; set; }
}
public sealed class RelayWrapInfo
{
public TaskCompletionSource Tcs { get; set; }
public Socket Socket { get; set; }
public RelaySpeedLimit Limit { get; set; } = new RelaySpeedLimit();
}
public sealed partial class RelayMessageInfo
{
public RelayMessengerType Type { get; set; }
public ulong FlowId { get; set; }
public string FromId { get; set; }
public string ToId { get; set; }
public string NodeId { get; set; }
}
}