diff --git a/internal/gtcpip/header/ipv4.go b/internal/gtcpip/header/ipv4.go index 8253936..d1cbf7c 100644 --- a/internal/gtcpip/header/ipv4.go +++ b/internal/gtcpip/header/ipv4.go @@ -479,7 +479,9 @@ func (b IPv4) SetDestinationAddress(addr tcpip.Address) { // CalculateChecksum calculates the checksum of the IPv4 header. func (b IPv4) CalculateChecksum() uint16 { - return checksum.Checksum(b[:b.HeaderLength()], 0) + xsum0 := checksum.Checksum(b[:xsum], 0) + xsum0 = checksum.Checksum(b[xsum+2:b.HeaderLength()], xsum0) + return xsum0 } // Encode encodes all the fields of the IPv4 header. diff --git a/internal/gtcpip/header/tcp.go b/internal/gtcpip/header/tcp.go index 5855253..da5d3d8 100644 --- a/internal/gtcpip/header/tcp.go +++ b/internal/gtcpip/header/tcp.go @@ -351,7 +351,9 @@ func (b TCP) SetUrgentPointer(urgentPointer uint16) { // and the checksum of the segment data. func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. - return checksum.Checksum(b[:b.DataOffset()], partialChecksum) + xsum := checksum.Checksum(b[:TCPChecksumOffset], partialChecksum) + xsum = checksum.Checksum(b[TCPChecksumOffset+2:b.DataOffset()], xsum) + return xsum } // IsChecksumValid returns true iff the TCP header's checksum is valid. diff --git a/internal/gtcpip/header/udp.go b/internal/gtcpip/header/udp.go index 080a97f..eac9d63 100644 --- a/internal/gtcpip/header/udp.go +++ b/internal/gtcpip/header/udp.go @@ -113,8 +113,10 @@ func (b UDP) SetLength(length uint16) { // CalculateChecksum calculates the checksum of the UDP packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { - // Calculate the rest of the checksum. - return checksum.Checksum(b[:UDPMinimumSize], partialChecksum) + // Calculate the rest of the checksum.\ + xsum := checksum.Checksum(b[:udpChecksum], partialChecksum) + xsum = checksum.Checksum(b[udpChecksum+2:], xsum) + return xsum } // IsChecksumValid returns true iff the UDP header's checksum is valid. diff --git a/ping/ping.go b/ping/ping.go index 5140c2d..eab977b 100644 --- a/ping/ping.go +++ b/ping/ping.go @@ -9,7 +9,6 @@ import ( "sync/atomic" "time" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -106,8 +105,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error { if !c.isLinuxUnprivileged() { icmpHdr := header.ICMPv4(buffer.Bytes()) icmpHdr.SetIdent(^icmpHdr.Ident()) - icmpHdr.SetChecksum(0) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) } ipHdr := header.IPv4(buffer.ExtendHeader(header.IPv4MinimumSize)) ipHdr.Encode(&header.IPv4Fields{ @@ -144,7 +142,6 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error { if !c.isLinuxUnprivileged() { icmpHdr.SetIdent(^icmpHdr.Ident()) } - icmpHdr.SetChecksum(0) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, Src: addr.AsSlice(), @@ -178,14 +175,12 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error { return E.New("invalid IPv4 header received") } ipHdr.SetDestinationAddr(c.source.Load()) - ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) icmpHdr := header.ICMPv4(ipHdr.Payload()) if !c.isLinuxUnprivileged() { icmpHdr.SetIdent(^icmpHdr.Ident()) } - icmpHdr.SetChecksum(0) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) } else { ipHdr := header.IPv6(buffer.Bytes()) if !ipHdr.IsValid(buffer.Len()) { @@ -196,7 +191,6 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error { if !c.isLinuxUnprivileged() { icmpHdr.SetIdent(^icmpHdr.Ident()) } - icmpHdr.SetChecksum(0) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, Src: ipHdr.SourceAddressSlice(), @@ -219,12 +213,10 @@ func (c *Conn) ReadICMP(buffer *buf.Buffer) error { icmpHdr := header.ICMPv4(buffer.Bytes()) icmpHdr.SetIdent(^icmpHdr.Ident()) - icmpHdr.SetChecksum(0) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) } else { icmpHdr := header.ICMPv6(buffer.Bytes()) icmpHdr.SetIdent(^icmpHdr.Ident()) - icmpHdr.SetChecksum(0) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, Src: c.destination.AsSlice(), @@ -242,8 +234,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error { if !c.isLinuxUnprivileged() { icmpHdr := header.ICMPv4(ipHdr.Payload()) icmpHdr.SetIdent(^icmpHdr.Ident()) - icmpHdr.SetChecksum(0) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) } c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice())) return common.Error(c.conn.Write(ipHdr.Payload())) @@ -252,7 +243,6 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error { if !c.isLinuxUnprivileged() { icmpHdr := header.ICMPv6(ipHdr.Payload()) icmpHdr.SetIdent(^icmpHdr.Ident()) - icmpHdr.SetChecksum(0) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, Src: ipHdr.SourceAddressSlice(), @@ -270,12 +260,10 @@ func (c *Conn) WriteICMP(buffer *buf.Buffer) error { if !c.destination.Is6() { icmpHdr := header.ICMPv4(buffer.Bytes()) icmpHdr.SetIdent(^icmpHdr.Ident()) - icmpHdr.SetChecksum(0) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) } else { icmpHdr := header.ICMPv6(buffer.Bytes()) icmpHdr.SetIdent(^icmpHdr.Ident()) - icmpHdr.SetChecksum(0) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, Src: c.source.Load().AsSlice(), diff --git a/ping/rewriter.go b/ping/rewriter.go index 52666b7..4d0dcd8 100644 --- a/ping/rewriter.go +++ b/ping/rewriter.go @@ -59,7 +59,6 @@ func (m *Rewriter) RewritePacket(packet []byte) { sourceAddr := ipHdr.SourceAddr() ipHdr.SetSourceAddr(bindAddr) if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 { - ipHdr4.SetChecksum(0) ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum()) } switch ipHdr.TransportProtocol() { @@ -71,7 +70,6 @@ func (m *Rewriter) RewritePacket(packet []byte) { m.logger.TraceContext(m.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) case header.ICMPv6ProtocolNumber: icmpHdr := header.ICMPv6(ipHdr.Payload()) - icmpHdr.SetChecksum(0) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, Src: ipHdr.SourceAddressSlice(), @@ -133,7 +131,6 @@ func (m *Rewriter) WriteBack(packet []byte) (bool, error) { } ipHdr.SetDestinationAddr(routeSession.Source) if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 { - ipHdr4.SetChecksum(0) ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum()) } switch ipHdr.TransportProtocol() { @@ -142,7 +139,6 @@ func (m *Rewriter) WriteBack(packet []byte) (bool, error) { m.logger.TraceContext(m.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) case header.ICMPv6ProtocolNumber: icmpHdr := header.ICMPv6(ipHdr.Payload()) - icmpHdr.SetChecksum(0) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, Src: ipHdr.SourceAddressSlice(), diff --git a/ping/socket_linux_unprivileged.go b/ping/socket_linux_unprivileged.go index 4e17945..79fd682 100644 --- a/ping/socket_linux_unprivileged.go +++ b/ping/socket_linux_unprivileged.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -130,8 +129,7 @@ func (c *UnprivilegedConn) fetchResponse(conn *net.UDPConn, identifier uint16) { if !c.destination.Is6() { icmpHdr := header.ICMPv4(buffer.Bytes()) icmpHdr.SetIdent(identifier) - icmpHdr.SetChecksum(0) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) } else { icmpHdr := header.ICMPv6(buffer.Bytes()) icmpHdr.SetIdent(identifier) diff --git a/stack_gvisor_icmp.go b/stack_gvisor_icmp.go index 78177d7..160d3ec 100644 --- a/stack_gvisor_icmp.go +++ b/stack_gvisor_icmp.go @@ -90,7 +90,6 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa sourceAddress := ipHdr.SourceAddress() ipHdr.SetSourceAddress(ipHdr.DestinationAddress()) ipHdr.SetDestinationAddress(sourceAddress) - icmpHdr.SetChecksum(0) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], pkt.Data().Checksum())) ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) @@ -154,9 +153,11 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa ipHdr.SetSourceAddress(ipHdr.DestinationAddress()) ipHdr.SetDestinationAddress(sourceAddress) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpHdr, - Src: ipHdr.SourceAddress(), - Dst: ipHdr.DestinationAddress(), + Header: icmpHdr, + Src: ipHdr.SourceAddress(), + Dst: ipHdr.DestinationAddress(), + PayloadCsum: pkt.Data().Checksum(), + PayloadLen: pkt.Data().Size(), })) outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber) if gErr != nil { diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go index 1592799..024f4b4 100644 --- a/stack_gvisor_tcp.go +++ b/stack_gvisor_tcp.go @@ -51,7 +51,6 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac ipHdr.SetDestinationAddressWithChecksumUpdate(ipHdr.SourceAddress()) ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress) tcpHdr := header.TCP(pkt.TransportHeader().Slice()) - tcpHdr.SetChecksum(0) tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()), ))) @@ -65,7 +64,6 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac ipHdr.SetDestinationAddress(ipHdr.SourceAddress()) ipHdr.SetSourceAddress(inet6LoopbackAddress) tcpHdr := header.TCP(pkt.TransportHeader().Slice()) - tcpHdr.SetChecksum(0) tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()), ))) diff --git a/stack_system.go b/stack_system.go index c262c5d..c354bda 100644 --- a/stack_system.go +++ b/stack_system.go @@ -430,14 +430,12 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err } } if !s.txChecksumOffload { - tcpHdr.SetChecksum(0) tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), ))) } else { tcpHdr.SetChecksum(0) } - ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) return true, nil } @@ -478,7 +476,6 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro if !s.txChecksumOffload { tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize))) } - ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version) @@ -528,7 +525,6 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err } } if !s.txChecksumOffload { - tcpHdr.SetChecksum(0) tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), ))) @@ -682,8 +678,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) (bool sourceAddress := ipHdr.SourceAddr() ipHdr.SetSourceAddr(ipHdr.DestinationAddr()) ipHdr.SetDestinationAddr(sourceAddress) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) - ipHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) return true, nil } @@ -717,7 +712,7 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e icmpHdr := header.ICMPv4(newIPHdr.Payload()) icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(code) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0))) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) copy(icmpHdr.Payload(), payload) if PacketOffset > 0 { newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET @@ -831,14 +826,12 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize)) if !w.txChecksumOffload { - udpHdr.SetChecksum(0) udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), ))) } else { udpHdr.SetChecksum(0) } - ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version) @@ -872,7 +865,6 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(udpLen) if !w.txChecksumOffload { - udpHdr.SetChecksum(0) udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), ))) @@ -900,7 +892,6 @@ func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error { newPacket.Write(p) ipHdr := header.IPv4(newPacket.Bytes()) ipHdr.SetDestinationAddr(w.source) - ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)