using linker.libs;
using linker.libs.extends;
using linker.libs.timer;
using System.Buffers;
using System.Buffers.Binary;
using System.Collections.Concurrent;
using System.Net;
using System.Net.NetworkInformation;
using System.Net.Sockets;
using System.Runtime.InteropServices;
namespace linker.nat
{
///
/// 64位,放x64的WinDivert.dll和WinDivert64.sys
/// 32位,放x86的WinDivert.dll和WinDivert64.sys,WinDivert.sys
/// 应用层简单SNAT
/// 1,收到【客户端A】的数据包,10.18.18.23(客户端A的虚拟网卡IP)->192.168.56.6(局域网IP)
/// 2,改为 192.168.56.2(本机IP)->192.168.56.6(局域网IP)
/// 3,回来是 192.168.56.6(局域网IP)->192.168.56.2(本机IP)
/// 4,改为 192.168.56.6(局域网IP)->10.18.18.23(客户端A的虚拟网卡IP)
/// 5,回到客户端A,就完成了NAT
///
public sealed partial class LinkerSrcNat
{
public bool Running => winDivert != null;
///
/// 驱动
///
private WinDivert winDivert;
private uint srcIp;
private NetworkIPv4Addr srcAddr;
///
/// 用来注入数据包
///
private WinDivertAddress addr = new WinDivertAddress
{
Layer = WinDivert.Layer.Network,
Outbound = true,
IPv6 = false
};
private CancellationTokenSource cts;
///
/// 五元组NAT映射表
///
private readonly ConcurrentDictionary<(uint src, ushort srcPort, uint dst, ushort dstPort, ProtocolType pro), NatMapInfo> natMap = new();
///
/// 分配端口表
///
private readonly ConcurrentDictionary<(uint src, ushort port), ushort> source2portMap = new();
///
/// 网络接口
///
private readonly LinkerSrcNatInterfaceHelper interfaceHelper = new();
public LinkerSrcNat()
{
}
///
/// 启动
///
/// 启动参数
/// false启动失败的时候会有报错信息
///
public bool Setup(SetupInfo info, ref string error)
{
if (OperatingSystem.IsWindows() == false || (RuntimeInformation.ProcessArchitecture != Architecture.X86 && RuntimeInformation.ProcessArchitecture != Architecture.X64))
{
error = "snat only win x64 and win x86";
return false;
}
if (info.Src == null)
{
error = "src is null,snat fail";
return false;
}
if (info.Dsts == null || info.Dsts.Length == 0)
{
return false;
}
error = string.Empty;
try
{
CommandHelper.Windows(string.Empty, ["reg add \"HKLM\\SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\" /v IPEnableRouter /t REG_DWORD /d 1 /f"]);
Shutdown();
srcIp = NetworkHelper.ToValue(info.Src);
srcAddr = IPv4Addr.Parse(info.Src.ToString());
interfaceHelper.Setup();
string filters = BuildFilter(info.Dsts);
winDivert = new WinDivert(filters, WinDivert.Layer.Network, 0, 0);
cts = new CancellationTokenSource();
Recv(cts);
ClearTask(cts);
return true;
}
catch (Exception ex)
{
error = ex.Message;
}
return false;
}
///
/// 过滤条件,只过滤一定的数据包
///
///
private static string BuildFilter(AddrInfo[] dsts)
{
IEnumerable ipRanges = dsts.Select(c => $"(ip.SrcAddr >= {c.NetworkIP} and ip.SrcAddr <= {c.BroadcastIP})");
string filter = $"inbound and ({string.Join(" or ", ipRanges)})";
//Console.WriteLine($"filter:{filter}");
return filter;
}
///
/// 开始接收数据包
///
private void Recv(CancellationTokenSource cts)
{
TimerHelper.Async(() =>
{
using IMemoryOwner buffer = MemoryPool.Shared.Rent(10*WinDivert.MTUMax);
Memory abuf = new Memory(new WinDivertAddress[10]);
uint recvLen = 0, addrLen = 0;
while (cts.IsCancellationRequested == false)
{
try
{
(recvLen, addrLen) = winDivert.RecvEx(buffer.Memory.Span, abuf.Span);
Memory recv = buffer.Memory[..(int)recvLen];
Memory addr = abuf[..(int)addrLen];
foreach (var (i, p) in new WinDivertIndexedPacketParser(recv))
{
Recv(p, ref addr.Span[i]);
}
winDivert.SendEx(recv.Span, addr.Span);
}
catch (Exception)
{
break;
}
}
Shutdown();
});
}
///
/// 还原数据包
///
///
///
///
private unsafe bool Recv(WinDivertParseResult p, ref WinDivertAddress addr)
{
fixed (byte* ptr = p.Packet.Span)
{
byte ipHeaderLength = (byte)((*ptr & 0b1111) * 4);
ProtocolType proto = (ProtocolType)p.IPv4Hdr->Protocol;
bool result = (ProtocolType)p.IPv4Hdr->Protocol switch
{
ProtocolType.Icmp => RecvIcmp(p, ptr),
ProtocolType.Tcp => RecvTcp(p, ptr),
ProtocolType.Udp => RecvUdp(p, ptr),
_ => false,
};
if (result)
{
WinDivert.CalcChecksums(p.Packet.Span, ref addr, 0);
}
return result;
}
}
///
/// 注入TCP/IP,让它直接走正确的网卡,路由到目的地
///
/// 单个完整TCP/IP包
/// true注入成功,false失败了,你可以继续写入你的虚拟网卡
public unsafe bool Inject(ReadOnlyMemory packet)
{
if (winDivert == null) return false;
IPV4Packet ipv4 = new IPV4Packet(packet.Span);
//不是 ipv4,是虚拟网卡ip,是广播,不nat
if (ipv4.Version != 4 || ipv4.DstAddr == srcIp || ipv4.DstAddrSpan.IsCast())
{
return false;
}
fixed (byte* ptr = packet.Span)
{
foreach (var (i, p) in new WinDivertIndexedPacketParser(packet))
{
NetworkIPv4Addr interfaceAddr = interfaceHelper.GetInterfaceAddr(ipv4.DstAddr);
if (interfaceAddr.Raw != 0)
{
bool result = (ProtocolType)p.IPv4Hdr->Protocol switch
{
ProtocolType.Icmp => InjectIcmp(p, ptr, interfaceAddr),
ProtocolType.Tcp => InjectTcp(p, ptr, interfaceAddr),
ProtocolType.Udp => InjectUdp(p, ptr, interfaceAddr),
_ => false,
};
if (result == false) return false;
//Console.WriteLine($"snat inject :{p.IPv4Hdr->SrcAddr}->{p.IPv4Hdr->DstAddr} 替换为 {interfaceAddr}->{p.IPv4Hdr->DstAddr}");
//改写源地址为网卡地址
p.IPv4Hdr->SrcAddr = interfaceAddr;
}
WinDivert.CalcChecksums(p.Packet.Span, ref addr, 0);
winDivert.SendEx(p.Packet.Span, new ReadOnlySpan(ref addr));
}
}
return true;
}
///
/// 注入ICMP
///
///
///
///
private unsafe bool InjectIcmp(WinDivertParseResult p, byte* ptr, NetworkIPv4Addr interfaceAddr)
{
//只操作response 和 request
if (p.ICMPv4Hdr->Type != 0 && p.ICMPv4Hdr->Type != 8) return false;
IPV4Packet ipv4 = new IPV4Packet(ptr);
if (ipv4.IsFragment) return false;
//原标识符,两个字节
byte* ptr0 = ipv4.IcmpIdentifier0;
byte* ptr1 = ipv4.IcmpIdentifier1;
//用源地址的第三个,第四个字节作为新的标识符
byte identifier0 = ipv4.SrcAddrSpan[2];
byte identifier1 = ipv4.SrcAddrSpan[3];
//保存,源地址。标识符0,目的地址,标识符1,ICMP
//取值,目的地址,标识符0,源地址,标识符1,ICMP
//因为回来的数据包,地址交换了
ValueTuple key = (interfaceAddr.Raw, identifier0, p.IPv4Hdr->DstAddr.Raw, identifier1, ProtocolType.Icmp);
if (natMap.TryGetValue(key, out NatMapInfo natMapInfo) == false)
{
natMapInfo = new NatMapInfo();
natMap.TryAdd(key, natMapInfo);
}
natMapInfo.SrcAddr = p.IPv4Hdr->SrcAddr;
natMapInfo.Identifier0 = *ptr0;
natMapInfo.Identifier1 = *ptr1;
natMapInfo.LastTime = Environment.TickCount64;
natMapInfo.Timeout = 15 * 1000;
//Console.WriteLine($"snat inject icmp:{*ptr0}->{identifier0},{*ptr1}->{identifier1}");
//改写为新的标识符
*ptr0 = identifier0;
*ptr1 = identifier1;
return true;
}
///
/// 还原ICMP
///
///
///
///
private unsafe bool RecvIcmp(WinDivertParseResult p, byte* ptr)
{
//只操作response 和 request
if (p.ICMPv4Hdr->Type != 0 && p.ICMPv4Hdr->Type != 8) return false;
IPV4Packet ipv4 = new IPV4Packet(ptr);
//标识符,两个字节
byte* ptr0 = ipv4.IcmpIdentifier0;
byte* ptr1 = ipv4.IcmpIdentifier1;
ValueTuple key = (p.IPv4Hdr->DstAddr.Raw, *ptr0, p.IPv4Hdr->SrcAddr.Raw, *ptr1, ProtocolType.Icmp);
if (natMap.TryGetValue(key, out NatMapInfo natMapInfo))
{
//Console.WriteLine($"snat recv icmp:{*ptr0}->{natMapInfo.Identifier0},{*ptr1}->{natMapInfo.Identifier1}");
//改回原来的标识符
*ptr0 = natMapInfo.Identifier0;
*ptr1 = natMapInfo.Identifier1;
//Console.WriteLine($"icmp recv:{p.IPv4Hdr->SrcAddr}->{p.IPv4Hdr->DstAddr} 替换为 {p.IPv4Hdr->SrcAddr}->{natMapInfo.SrcAddr}");
p.IPv4Hdr->DstAddr = natMapInfo.SrcAddr;
return true;
}
return false;
}
///
/// 注入TCP
///
///
///
///
private unsafe bool InjectTcp(WinDivertParseResult p, byte* ptr, NetworkIPv4Addr interfaceAddr)
{
IPV4Packet ipv4 = new IPV4Packet(ptr);
//新端口
ValueTuple portKey = (p.IPv4Hdr->SrcAddr.Raw, p.TCPHdr->SrcPort);
if (source2portMap.TryGetValue(portKey, out ushort newPort) == false)
{
//只在syn时建立
if (ipv4.TcpFlagSyn == false || ipv4.TcpFlagAck) return false;
newPort = ApplyNewPort();
source2portMap.TryAdd(portKey, newPort);
}
//添加映射
ValueTuple key = (interfaceAddr.Raw, newPort, p.IPv4Hdr->DstAddr.Raw, p.TCPHdr->DstPort, ProtocolType.Tcp);
if (natMap.TryGetValue(key, out NatMapInfo natMapInfo) == false)
{
natMapInfo = new NatMapInfo
{
SrcAddr = p.IPv4Hdr->SrcAddr,
SrcPort = p.TCPHdr->SrcPort,
LastTime = Environment.TickCount64,
Timeout = 2 * 60 * 60 * 1000 //tcp 2小时
};
natMap.TryAdd(key, natMapInfo);
}
natMapInfo.LastTime = Environment.TickCount64;
//fin+ack 或者 rst 就清除
if (ipv4.TcpFlagFin) natMapInfo.Fin0 = ipv4.TcpFlagFin;
if (ipv4.TcpFlagRst) natMapInfo.Rst = ipv4.TcpFlagRst;
if (natMapInfo.Fin0 && ipv4.TcpFlagAck) natMapInfo.FinAck = ipv4.TcpFlagAck;
p.TCPHdr->SrcPort = newPort;
return true;
}
///
/// 还原TCP
///
///
///
///
private unsafe bool RecvTcp(WinDivertParseResult p, byte* ptr)
{
IPV4Packet ipv4 = new IPV4Packet(ptr);
ValueTuple key = (p.IPv4Hdr->DstAddr.Raw, p.TCPHdr->DstPort, p.IPv4Hdr->SrcAddr.Raw, p.TCPHdr->SrcPort, ProtocolType.Tcp);
if (natMap.TryGetValue(key, out NatMapInfo natMapInfo))
{
natMapInfo.LastTime = Environment.TickCount64;
//fin+ack 或者 rst 就清除
if (ipv4.TcpFlagFin) natMapInfo.Fin1 = ipv4.TcpFlagFin;
if (ipv4.TcpFlagRst) natMapInfo.Rst = ipv4.TcpFlagRst;
if (natMapInfo.Fin1 && ipv4.TcpFlagAck) natMapInfo.FinAck = ipv4.TcpFlagAck;
p.IPv4Hdr->DstAddr = natMapInfo.SrcAddr;
p.TCPHdr->DstPort = natMapInfo.SrcPort;
return true;
}
return false;
}
///
/// 注入UDP
///
///
///
///
private unsafe bool InjectUdp(WinDivertParseResult p, byte* ptr, NetworkIPv4Addr interfaceAddr)
{
//新端口
ValueTuple portKey = (p.IPv4Hdr->SrcAddr.Raw, p.UDPHdr->SrcPort);
if (source2portMap.TryGetValue(portKey, out ushort newPort) == false)
{
newPort = ApplyNewPort();
source2portMap.TryAdd(portKey, newPort);
}
//映射
ValueTuple key = (interfaceAddr.Raw, newPort, p.IPv4Hdr->DstAddr.Raw, p.UDPHdr->DstPort, ProtocolType.Tcp);
if (natMap.TryGetValue(key, out NatMapInfo natMapInfo) == false)
{
natMapInfo = new NatMapInfo
{
SrcAddr = p.IPv4Hdr->SrcAddr,
SrcPort = p.UDPHdr->SrcPort,
LastTime = Environment.TickCount64,
Timeout = 30 * 60 * 1000 //UDP 30分钟
};
natMap.TryAdd(key, natMapInfo);
}
natMapInfo.LastTime = Environment.TickCount64;
p.UDPHdr->SrcPort = newPort;
return true;
}
///
/// 还原UDP
///
///
///
///
private unsafe bool RecvUdp(WinDivertParseResult p, byte* ptr)
{
ValueTuple key = (p.IPv4Hdr->DstAddr.Raw, p.UDPHdr->DstPort, p.IPv4Hdr->SrcAddr.Raw, p.UDPHdr->SrcPort, ProtocolType.Tcp);
if (natMap.TryGetValue(key, out NatMapInfo natMapInfo))
{
natMapInfo.LastTime = Environment.TickCount64;
p.IPv4Hdr->DstAddr = natMapInfo.SrcAddr;
p.UDPHdr->DstPort = natMapInfo.SrcPort;
return true;
}
return false;
}
///
/// 申请一个新的端口
///
///
private static ushort ApplyNewPort()
{
using Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
socket.Bind(new IPEndPoint(IPAddress.Any, 0));
return (ushort)(socket.LocalEndPoint as IPEndPoint).Port;
}
///
/// 关闭
///
public void Shutdown()
{
cts?.Cancel();
winDivert?.Dispose();
winDivert = null;
natMap.Clear();
source2portMap.Clear();
interfaceHelper.Shutdown();
}
private void ClearTask(CancellationTokenSource cts)
{
TimerHelper.SetIntervalLong(() =>
{
long now = Environment.TickCount64;
foreach (var item in natMap.Where(c => now - c.Value.LastTime > c.Value.Timeout || c.Value.FinAck || c.Value.Rst).Select(c => c.Key).ToList())
{
if (natMap.TryRemove(item, out NatMapInfo natMapInfo))
{
source2portMap.TryRemove((natMapInfo.SrcAddr.Raw, natMapInfo.SrcPort), out _);
}
}
return cts.IsCancellationRequested == false;
}, 5000);
}
public sealed class AddrInfo
{
public AddrInfo(IPAddress ip, byte prefixLength)
{
IP = ip;
PrefixLength = prefixLength;
PrefixValue = NetworkHelper.ToPrefixValue(PrefixLength);
NetworkValue = NetworkHelper.ToNetworkValue(IP, PrefixLength);
BroadcastValue = NetworkHelper.ToBroadcastValue(IP, PrefixLength);
Addr = IPv4Addr.Parse(IP.ToString());
NetworkAddr = IPv4Addr.Parse(NetworkHelper.ToIP(NetworkValue).ToString());
NetworkIP = NetworkHelper.ToIP(NetworkValue);
BroadcastIP = NetworkHelper.ToIP(BroadcastValue);
}
public IPAddress IP { get; }
public byte PrefixLength { get; }
public NetworkIPv4Addr Addr { get; private set; }
public NetworkIPv4Addr NetworkAddr { get; private set; }
public uint PrefixValue { get; private set; }
public uint NetworkValue { get; private set; }
public uint BroadcastValue { get; private set; }
public IPAddress NetworkIP { get; private set; }
public IPAddress BroadcastIP { get; private set; }
}
public sealed class SetupInfo
{
///
/// 虚拟网卡IP
///
public IPAddress Src { get; init; }
///
/// 需要NAT的IP
///
public AddrInfo[] Dsts { get; init; }
}
///
/// NAT映射记录
///
sealed class NatMapInfo
{
//IP头
public NetworkIPv4Addr SrcAddr { get; set; }
//TCP/UDP
public NetworkUInt16 SrcPort { get; set; }
//ICMP
public byte Identifier0 { get; set; }
public byte Identifier1 { get; set; }
//TCP
public bool Fin0 { get; set; }
public bool Fin1 { get; set; }
public bool FinAck { get; set; }
public bool Rst { get; set; }
public long LastTime { get; set; } = Environment.TickCount64;
public int Timeout { get; set; } = 1 * 60 * 60;
}
///
/// IPV4 包
///
public unsafe struct IPV4Packet
{
byte* ptr;
///
/// 协议版本
///
public byte Version => (byte)((*ptr >> 4) & 0b1111);
public ProtocolType Protocol => (ProtocolType)(*(ptr + 9));
///
/// 源地址
///
public uint SrcAddr => BinaryPrimitives.ReverseEndianness(*(uint*)(ptr + 12));
///
/// 源端口
///
public ushort SrcPort => BinaryPrimitives.ReverseEndianness(*(ushort*)(ptr + IPHeadLength));
///
/// 目的地址
///
public uint DstAddr => BinaryPrimitives.ReverseEndianness(*(uint*)(ptr + 16));
///
/// 目标端口
///
public ushort DstPort => BinaryPrimitives.ReverseEndianness(*(ushort*)(ptr + IPHeadLength + 2));
///
/// 源地址
///
public ReadOnlySpan SrcAddrSpan => new Span((ptr + 12), 4);
///
/// 目的地址
///
public ReadOnlySpan DstAddrSpan => new Span((ptr + 16), 4);
///
/// IP头长度
///
public int IPHeadLength => (*ptr & 0b1111) * 4;
///
/// IP Flag
///
public byte Flag => (byte)(*(ptr + 6) >> 5);
///
/// 不分片
///
public bool DontFragment => (Flag & 0x02) == 2;
///
/// 更多分片
///
public bool MoreFragment => (Flag & 0x01) == 1;
///
/// 分片偏移量
///
public ushort Offset => (ushort)(BinaryPrimitives.ReverseEndianness(*(ushort*)(ptr + 6)) & 0x1fff);
///
/// 是否分片
///
public bool IsFragment => MoreFragment || Offset > 0;
///
/// ICMP标志第一个字节
///
public byte* IcmpIdentifier0 => ptr + IPHeadLength + 4;
///
/// ICMP标志第二个字节
///
public byte* IcmpIdentifier1 => ptr + IPHeadLength + 5;
///
/// TCP Flag
///
public byte TcpFlag => *(ptr + IPHeadLength + 13);
public bool TcpFlagFin => (TcpFlag & 0b000001) != 0;
public bool TcpFlagSyn => (TcpFlag & 0b000010) != 0;
public bool TcpFlagRst => (TcpFlag & 0b000100) != 0;
public bool TcpFlagPsh => (TcpFlag & 0b001000) != 0;
public bool TcpFlagAck => (TcpFlag & 0b010000) != 0;
public bool TcpFlagUrg => (TcpFlag & 0b100000) != 0;
public IPV4Packet(byte* ptr)
{
this.ptr = ptr;
}
public IPV4Packet(ReadOnlySpan span)
{
fixed (byte* ptr = span)
{
this.ptr = ptr;
}
}
}
}
public sealed class LinkerSrcNatInterfaceHelper
{
private uint[] interfaceMasks = [];
private readonly ConcurrentDictionary network2ipMap = new();
private readonly ConcurrentDictionary ip2ipMap = new();
public void Setup()
{
Shutdown();
List<(IPAddress Address, IPAddress IPv4Mask)> interfaces = NetworkInterface.GetAllNetworkInterfaces()
.Where(c => c.OperationalStatus == OperationalStatus.Up && c.NetworkInterfaceType != NetworkInterfaceType.Loopback && c.NetworkInterfaceType != NetworkInterfaceType.Tunnel)
.Select(c => c.GetIPProperties()).SelectMany(c => c.UnicastAddresses.Where(c => c.Address.AddressFamily == AddressFamily.InterNetwork).Select(c => (c.Address, c.IPv4Mask))).ToList();
interfaceMasks = interfaces.Select(c => NetworkHelper.ToValue(c.IPv4Mask)).Distinct().ToArray();
foreach ((IPAddress Address, IPAddress IPv4Mask) in interfaces)
{
uint value = NetworkHelper.ToValue(Address);
uint network = NetworkHelper.ToNetworkValue(value, NetworkHelper.ToValue(IPv4Mask));
network2ipMap.TryAdd(network, (value, IPv4Addr.Parse(Address.ToString())));
}
}
public NetworkIPv4Addr GetInterfaceAddr(uint dstAddr)
{
if (ip2ipMap.TryGetValue(dstAddr, out NetworkIPv4Addr interfaceAddr))
{
return interfaceAddr;
}
for (int i = 0; i < interfaceMasks.Length; i++)
{
//找到匹配的网卡
if (network2ipMap.TryGetValue(interfaceMasks[i] & dstAddr, out (uint, NetworkIPv4Addr) info))
{
//目标ip与网卡ip相同,无需注入
if (info.Item1 == dstAddr)
{
ip2ipMap.TryAdd(dstAddr, default);
return default;
}
ip2ipMap.TryAdd(dstAddr, info.Item2);
return info.Item2;
}
}
return default;
}
public void Shutdown()
{
ip2ipMap.Clear();
network2ipMap.Clear();
interfaceMasks = [];
}
}
}