ping: Fix test

This commit is contained in:
世界
2025-08-25 10:48:44 +08:00
parent ce050baa58
commit 548f51cc9d
2 changed files with 52 additions and 27 deletions

View File

@@ -26,7 +26,6 @@ type Conn struct {
ctx context.Context
logger logger.ContextLogger
privileged bool
bitwiseID bool
conn net.Conn
destination netip.Addr
source atomic.TypedValue[netip.Addr]
@@ -38,15 +37,10 @@ func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool,
if err != nil {
return nil, err
}
replaceID := true
if _, ok := conn.(*UnprivilegedConn); ok {
replaceID = false
}
return &Conn{
ctx: ctx,
logger: logger,
privileged: privileged,
bitwiseID: replaceID,
conn: conn,
destination: destination,
}, nil
@@ -108,7 +102,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
}
ttl = controlMessage.TTL
}
if c.bitwiseID {
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
@@ -147,7 +141,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
trafficClass = controlMessage.TrafficClass
}
icmpHdr := header.ICMPv6(buffer.Bytes())
if c.bitwiseID {
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(0)
@@ -188,7 +182,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
icmpHdr := header.ICMPv4(ipHdr.Payload())
if c.bitwiseID {
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(0)
@@ -201,7 +195,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
}
ipHdr.SetDestinationAddr(c.source.Load())
icmpHdr := header.ICMPv6(ipHdr.Payload())
if c.bitwiseID {
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(0)
@@ -221,17 +215,25 @@ func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
if err != nil {
return err
}
if c.destination.Is6() || (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged {
return nil
}
if !c.destination.Is6() {
ipHdr := header.IPv4(buffer.Bytes())
buffer.Advance(int(ipHdr.HeaderLength()))
c.logger.TraceContext(c.ctx, "read icmpv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
} else {
ipHdr := header.IPv6(buffer.Bytes())
buffer.Advance(buffer.Len() - int(ipHdr.PayloadLength()))
c.logger.TraceContext(c.ctx, "read icmpv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
if !c.destination.Is6() {
ipHdr := header.IPv4(buffer.Bytes())
buffer.Advance(int(ipHdr.HeaderLength()))
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
} else {
icmpHdr := header.ICMPv6(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: c.destination.AsSlice(),
Dst: c.source.Load().AsSlice(),
}))
}
}
return nil
}
@@ -240,7 +242,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
defer buffer.Release()
if !c.destination.Is6() {
ipHdr := header.IPv4(buffer.Bytes())
if c.bitwiseID {
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
icmpHdr := header.ICMPv4(ipHdr.Payload())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
@@ -251,7 +253,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
return common.Error(c.conn.Write(ipHdr.Payload()))
} else {
ipHdr := header.IPv6(buffer.Bytes())
if c.bitwiseID {
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
@@ -269,6 +271,29 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
defer buffer.Release()
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
if !c.destination.Is6() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
c.logger.TraceContext(c.ctx, "write icmpv4 echo request to ", c.destination)
} else {
icmpHdr := header.ICMPv6(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: c.source.Load().AsSlice(),
Dst: c.destination.AsSlice(),
}))
}
}
if !c.destination.Is6() {
c.logger.TraceContext(c.ctx, "write icmpv4 echo request to ", c.destination)
} else {
c.logger.TraceContext(c.ctx, "write icmpv6 echo request to ", c.destination)
}
return common.Error(c.conn.Write(buffer.Bytes()))
}

View File

@@ -84,7 +84,7 @@ func testPingIPv4ReadIP(t *testing.T, privileged bool, addr string) {
request.SetIdent(uint16(rand.Uint32()))
request.SetChecksum(header.ICMPv4Checksum(request, 0))
err = conn.WriteICMP(buf.As(request))
err = conn.WriteICMP(buf.As(request).ToOwned())
require.NoError(t, err)
conn.SetLocalAddr(netip.MustParseAddr("127.0.0.1"))
@@ -117,7 +117,7 @@ func testPingIPv4ReadICMP(t *testing.T, privileged bool, addr string) {
request.SetIdent(uint16(rand.Uint32()))
request.SetChecksum(header.ICMPv4Checksum(request, 0))
err = conn.WriteICMP(buf.As(request))
err = conn.WriteICMP(buf.As(request).ToOwned())
require.NoError(t, err)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
@@ -148,7 +148,7 @@ func testPingIPv6ReadIP(t *testing.T, privileged bool, addr string) {
request.SetType(header.ICMPv6EchoRequest)
request.SetIdent(uint16(rand.Uint32()))
err = conn.WriteICMP(buf.As(request))
err = conn.WriteICMP(buf.As(request).ToOwned())
require.NoError(t, err)
conn.SetLocalAddr(netip.MustParseAddr("::1"))
@@ -180,7 +180,7 @@ func testPingIPv6ReadICMP(t *testing.T, privileged bool, addr string) {
request.SetType(header.ICMPv6EchoRequest)
request.SetIdent(uint16(rand.Uint32()))
err = conn.WriteICMP(buf.As(request))
err = conn.WriteICMP(buf.As(request).ToOwned())
require.NoError(t, err)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))