Improve checksum usages

This commit is contained in:
世界
2025-08-27 11:03:17 +08:00
parent f9bbb15bfb
commit adc106bcf6
9 changed files with 23 additions and 45 deletions

View File

@@ -479,7 +479,9 @@ func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
// CalculateChecksum calculates the checksum of the IPv4 header.
func (b IPv4) CalculateChecksum() uint16 {
return checksum.Checksum(b[:b.HeaderLength()], 0)
xsum0 := checksum.Checksum(b[:xsum], 0)
xsum0 = checksum.Checksum(b[xsum+2:b.HeaderLength()], xsum0)
return xsum0
}
// Encode encodes all the fields of the IPv4 header.

View File

@@ -351,7 +351,9 @@ func (b TCP) SetUrgentPointer(urgentPointer uint16) {
// and the checksum of the segment data.
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
xsum := checksum.Checksum(b[:TCPChecksumOffset], partialChecksum)
xsum = checksum.Checksum(b[TCPChecksumOffset+2:b.DataOffset()], xsum)
return xsum
}
// IsChecksumValid returns true iff the TCP header's checksum is valid.

View File

@@ -113,8 +113,10 @@ func (b UDP) SetLength(length uint16) {
// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
// Calculate the rest of the checksum.\
xsum := checksum.Checksum(b[:udpChecksum], partialChecksum)
xsum = checksum.Checksum(b[udpChecksum+2:], xsum)
return xsum
}
// IsChecksumValid returns true iff the UDP header's checksum is valid.

View File

@@ -9,7 +9,6 @@ import (
"sync/atomic"
"time"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
@@ -106,8 +105,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
if !c.isLinuxUnprivileged() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
}
ipHdr := header.IPv4(buffer.ExtendHeader(header.IPv4MinimumSize))
ipHdr.Encode(&header.IPv4Fields{
@@ -144,7 +142,6 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
if !c.isLinuxUnprivileged() {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: addr.AsSlice(),
@@ -178,14 +175,12 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
return E.New("invalid IPv4 header received")
}
ipHdr.SetDestinationAddr(c.source.Load())
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
icmpHdr := header.ICMPv4(ipHdr.Payload())
if !c.isLinuxUnprivileged() {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
} else {
ipHdr := header.IPv6(buffer.Bytes())
if !ipHdr.IsValid(buffer.Len()) {
@@ -196,7 +191,6 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
if !c.isLinuxUnprivileged() {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
@@ -219,12 +213,10 @@ func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
} else {
icmpHdr := header.ICMPv6(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: c.destination.AsSlice(),
@@ -242,8 +234,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
if !c.isLinuxUnprivileged() {
icmpHdr := header.ICMPv4(ipHdr.Payload())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
}
c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
return common.Error(c.conn.Write(ipHdr.Payload()))
@@ -252,7 +243,6 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
if !c.isLinuxUnprivileged() {
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
@@ -270,12 +260,10 @@ func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
if !c.destination.Is6() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
} else {
icmpHdr := header.ICMPv6(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: c.source.Load().AsSlice(),

View File

@@ -59,7 +59,6 @@ func (m *Rewriter) RewritePacket(packet []byte) {
sourceAddr := ipHdr.SourceAddr()
ipHdr.SetSourceAddr(bindAddr)
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(0)
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
switch ipHdr.TransportProtocol() {
@@ -71,7 +70,6 @@ func (m *Rewriter) RewritePacket(packet []byte) {
m.logger.TraceContext(m.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
@@ -133,7 +131,6 @@ func (m *Rewriter) WriteBack(packet []byte) (bool, error) {
}
ipHdr.SetDestinationAddr(routeSession.Source)
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(0)
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
switch ipHdr.TransportProtocol() {
@@ -142,7 +139,6 @@ func (m *Rewriter) WriteBack(packet []byte) (bool, error) {
m.logger.TraceContext(m.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),

View File

@@ -8,7 +8,6 @@ import (
"sync"
"time"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
@@ -130,8 +129,7 @@ func (c *UnprivilegedConn) fetchResponse(conn *net.UDPConn, identifier uint16) {
if !c.destination.Is6() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(identifier)
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
} else {
icmpHdr := header.ICMPv6(buffer.Bytes())
icmpHdr.SetIdent(identifier)

View File

@@ -90,7 +90,6 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa
sourceAddress := ipHdr.SourceAddress()
ipHdr.SetSourceAddress(ipHdr.DestinationAddress())
ipHdr.SetDestinationAddress(sourceAddress)
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], pkt.Data().Checksum()))
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
@@ -157,6 +156,8 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa
Header: icmpHdr,
Src: ipHdr.SourceAddress(),
Dst: ipHdr.DestinationAddress(),
PayloadCsum: pkt.Data().Checksum(),
PayloadLen: pkt.Data().Size(),
}))
outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber)
if gErr != nil {

View File

@@ -51,7 +51,6 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
ipHdr.SetDestinationAddressWithChecksumUpdate(ipHdr.SourceAddress())
ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress)
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
)))
@@ -65,7 +64,6 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
ipHdr.SetDestinationAddress(ipHdr.SourceAddress())
ipHdr.SetSourceAddress(inet6LoopbackAddress)
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
)))

View File

@@ -430,14 +430,12 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
}
}
if !s.txChecksumOffload {
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
tcpHdr.SetChecksum(0)
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return true, nil
}
@@ -478,7 +476,6 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
if !s.txChecksumOffload {
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
@@ -528,7 +525,6 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
}
}
if !s.txChecksumOffload {
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
@@ -682,8 +678,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) (bool
sourceAddress := ipHdr.SourceAddr()
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
ipHdr.SetDestinationAddr(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
ipHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return true, nil
}
@@ -717,7 +712,7 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e
icmpHdr := header.ICMPv4(newIPHdr.Payload())
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(code)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
@@ -831,14 +826,12 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize))
if !w.txChecksumOffload {
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
udpHdr.SetChecksum(0)
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
@@ -872,7 +865,6 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(udpLen)
if !w.txChecksumOffload {
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
@@ -900,7 +892,6 @@ func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error {
newPacket.Write(p)
ipHdr := header.IPv4(newPacket.Bytes())
ipHdr.SetDestinationAddr(w.source)
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)