using System; using System.Buffers.Binary; using System.Net.Sockets; namespace linker.libs { public sealed class ChecksumHelper { /// /// 计算IP包的校验和,当校验和为0时才计算 /// /// 一个完整的IP包 /// 是否计算IP头校验和 /// 是否计算荷载协议校验和 public static unsafe void ChecksumWithZero(ReadOnlyMemory packet, bool ipHeader = true, bool payload = true) { ChecksumWithZero(packet.Span, ipHeader, payload); } /// /// 计算IP包的校验和,当校验和为0时才计算 /// /// 一个完整的IP包 /// 是否计算IP头校验和 /// 是否计算荷载协议校验和 public static unsafe void ChecksumWithZero(ReadOnlySpan packet, bool ipHeader = true, bool payload = true) { fixed (byte* ptr = packet) { ChecksumWithZero(ptr, ipHeader, payload); } } /// /// 计算IP包的校验和,当校验和为0时才计算 /// /// IP包指针 /// 是否计算IP头校验和 /// 是否计算荷载协议校验和 public static unsafe void ChecksumWithZero(byte* ptr, bool ipHeader = true, bool payload = true) { byte ipHeaderLength = (byte)((*ptr & 0b1111) * 4); byte* packetPtr = ptr + ipHeaderLength; ipHeader = ipHeader && *(ushort*)(ptr + 10) == 0; payload = payload && ((ProtocolType)(*(ptr + 9)) switch { ProtocolType.Icmp => *(ushort*)(packetPtr + 2) == 0, ProtocolType.Tcp => *(ushort*)(packetPtr + 16) == 0, ProtocolType.Udp => *(ushort*)(packetPtr + 6) == 0, _ => false, }); if (ipHeader || payload) Checksum(ptr, ipHeader, payload); } /// /// 计算IP包的校验和 /// /// 一个完整的IP包 /// 是否计算IP头校验和 /// 是否计算荷载协议校验和 public static unsafe void Checksum(ReadOnlyMemory packet, bool ipHeader = true, bool payload = true) { Checksum(packet.Span, ipHeader, payload); } /// /// 计算IP包的校验和 /// /// 一个完整的IP包 /// 是否计算IP头校验和 /// 是否计算荷载协议校验和 public static unsafe void Checksum(ReadOnlySpan packet, bool ipHeader = true, bool payload = true) { fixed (byte* ptr = packet) { Checksum(ptr, ipHeader, payload); } } /// /// 计算IP包的校验和 /// /// IP包指针 /// 是否计算IP头校验和 /// 是否计算荷载协议校验和 public static unsafe void Checksum(byte* ptr, bool ipHeader = true, bool payload = true) { byte ipHeaderLength = (byte)((*ptr & 0b1111) * 4); byte* packetPtr = ptr + ipHeaderLength; uint totalLength = BinaryPrimitives.ReverseEndianness(*(ushort*)(ptr + 2)); uint packetLength = totalLength - ipHeaderLength; if (ipHeader) { //重新计算IP头校验和 *(ushort*)(ptr + 10) = 0; *(ushort*)(ptr + 10) = Checksum((ushort*)ptr, ipHeaderLength); } if (payload) { ProtocolType protocol = (ProtocolType)(*(ptr + 9)); switch (protocol) { case ProtocolType.Tcp: { *(ushort*)(packetPtr + 16) = 0; ulong sum = PseudoHeaderSum(ptr, packetLength); *(ushort*)(packetPtr + 16) = Checksum((ushort*)(packetPtr), packetLength, sum); } break; case ProtocolType.Udp: { *(ushort*)(packetPtr + 6) = 0; ulong sum = PseudoHeaderSum(ptr, packetLength); *(ushort*)(packetPtr + 6) = Checksum((ushort*)(packetPtr), packetLength, sum); } break; case ProtocolType.Icmp: { *(ushort*)(packetPtr + 2) = 0; *(ushort*)(packetPtr + 2) = Checksum((ushort*)(packetPtr), packetLength); } break; } } } /// /// 计算校验和 /// /// 包头开始位置 /// 计算长度,不同协议不同长度,请自己斟酌 /// 伪头部和,默认0不需要伪头部和 /// private static unsafe ushort Checksum(ushort* addr, uint length, ulong pseudoHeaderSum = 0) { //每两个字节为一个数,之和 while (length > 1) { pseudoHeaderSum += (ushort)((*addr >> 8) + (*addr << 8)); addr++; length -= 2; } //奇数字节末尾补零 if (length > 0) pseudoHeaderSum += (ushort)((*addr) << 8); //溢出处理 while ((pseudoHeaderSum >> 16) != 0) pseudoHeaderSum = (pseudoHeaderSum & 0xffff) + (pseudoHeaderSum >> 16); //取反 return BinaryPrimitives.ReverseEndianness((ushort)(~pseudoHeaderSum)); } /// /// 计算伪头部和,如TCP/UDP校验和需要一个伪头部 /// /// IP包头开始 /// TCP/UDP长度 /// private static unsafe ulong PseudoHeaderSum(byte* addr, uint length) { uint sum = 0; //源IP+目的IP for (byte i = 12; i < 20; i += 2) sum += (uint)((*(addr + i) << 8) | *(addr + i + 1)); //协议 sum += *(addr + 9); //协议内容长度 sum += length; return sum; } } }