mirror of
https://github.com/SagerNet/sing-tun.git
synced 2025-09-26 20:51:13 +08:00
ping: Fix test
This commit is contained in:
71
ping/ping.go
71
ping/ping.go
@@ -26,7 +26,6 @@ type Conn struct {
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
privileged bool
|
||||
bitwiseID bool
|
||||
conn net.Conn
|
||||
destination netip.Addr
|
||||
source atomic.TypedValue[netip.Addr]
|
||||
@@ -38,15 +37,10 @@ func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replaceID := true
|
||||
if _, ok := conn.(*UnprivilegedConn); ok {
|
||||
replaceID = false
|
||||
}
|
||||
return &Conn{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
privileged: privileged,
|
||||
bitwiseID: replaceID,
|
||||
conn: conn,
|
||||
destination: destination,
|
||||
}, nil
|
||||
@@ -108,7 +102,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
||||
}
|
||||
ttl = controlMessage.TTL
|
||||
}
|
||||
if c.bitwiseID {
|
||||
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
|
||||
icmpHdr := header.ICMPv4(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(0)
|
||||
@@ -147,7 +141,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
||||
trafficClass = controlMessage.TrafficClass
|
||||
}
|
||||
icmpHdr := header.ICMPv6(buffer.Bytes())
|
||||
if c.bitwiseID {
|
||||
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
}
|
||||
icmpHdr.SetChecksum(0)
|
||||
@@ -188,7 +182,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
if c.bitwiseID {
|
||||
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
}
|
||||
icmpHdr.SetChecksum(0)
|
||||
@@ -201,7 +195,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
||||
}
|
||||
ipHdr.SetDestinationAddr(c.source.Load())
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
if c.bitwiseID {
|
||||
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
}
|
||||
icmpHdr.SetChecksum(0)
|
||||
@@ -221,17 +215,25 @@ func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.destination.Is6() || (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged {
|
||||
return nil
|
||||
}
|
||||
if !c.destination.Is6() {
|
||||
ipHdr := header.IPv4(buffer.Bytes())
|
||||
buffer.Advance(int(ipHdr.HeaderLength()))
|
||||
c.logger.TraceContext(c.ctx, "read icmpv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
|
||||
} else {
|
||||
ipHdr := header.IPv6(buffer.Bytes())
|
||||
buffer.Advance(buffer.Len() - int(ipHdr.PayloadLength()))
|
||||
c.logger.TraceContext(c.ctx, "read icmpv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
|
||||
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
|
||||
if !c.destination.Is6() {
|
||||
ipHdr := header.IPv4(buffer.Bytes())
|
||||
buffer.Advance(int(ipHdr.HeaderLength()))
|
||||
|
||||
icmpHdr := header.ICMPv4(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
|
||||
} else {
|
||||
icmpHdr := header.ICMPv6(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: c.destination.AsSlice(),
|
||||
Dst: c.source.Load().AsSlice(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -240,7 +242,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
|
||||
defer buffer.Release()
|
||||
if !c.destination.Is6() {
|
||||
ipHdr := header.IPv4(buffer.Bytes())
|
||||
if c.bitwiseID {
|
||||
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(0)
|
||||
@@ -251,7 +253,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
|
||||
return common.Error(c.conn.Write(ipHdr.Payload()))
|
||||
} else {
|
||||
ipHdr := header.IPv6(buffer.Bytes())
|
||||
if c.bitwiseID {
|
||||
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(0)
|
||||
@@ -269,6 +271,29 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
|
||||
|
||||
func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
|
||||
defer buffer.Release()
|
||||
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
|
||||
if !c.destination.Is6() {
|
||||
icmpHdr := header.ICMPv4(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
|
||||
c.logger.TraceContext(c.ctx, "write icmpv4 echo request to ", c.destination)
|
||||
} else {
|
||||
icmpHdr := header.ICMPv6(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: c.source.Load().AsSlice(),
|
||||
Dst: c.destination.AsSlice(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
if !c.destination.Is6() {
|
||||
c.logger.TraceContext(c.ctx, "write icmpv4 echo request to ", c.destination)
|
||||
} else {
|
||||
c.logger.TraceContext(c.ctx, "write icmpv6 echo request to ", c.destination)
|
||||
}
|
||||
return common.Error(c.conn.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
|
@@ -84,7 +84,7 @@ func testPingIPv4ReadIP(t *testing.T, privileged bool, addr string) {
|
||||
request.SetIdent(uint16(rand.Uint32()))
|
||||
request.SetChecksum(header.ICMPv4Checksum(request, 0))
|
||||
|
||||
err = conn.WriteICMP(buf.As(request))
|
||||
err = conn.WriteICMP(buf.As(request).ToOwned())
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.SetLocalAddr(netip.MustParseAddr("127.0.0.1"))
|
||||
@@ -117,7 +117,7 @@ func testPingIPv4ReadICMP(t *testing.T, privileged bool, addr string) {
|
||||
request.SetIdent(uint16(rand.Uint32()))
|
||||
request.SetChecksum(header.ICMPv4Checksum(request, 0))
|
||||
|
||||
err = conn.WriteICMP(buf.As(request))
|
||||
err = conn.WriteICMP(buf.As(request).ToOwned())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
|
||||
@@ -148,7 +148,7 @@ func testPingIPv6ReadIP(t *testing.T, privileged bool, addr string) {
|
||||
request.SetType(header.ICMPv6EchoRequest)
|
||||
request.SetIdent(uint16(rand.Uint32()))
|
||||
|
||||
err = conn.WriteICMP(buf.As(request))
|
||||
err = conn.WriteICMP(buf.As(request).ToOwned())
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.SetLocalAddr(netip.MustParseAddr("::1"))
|
||||
@@ -180,7 +180,7 @@ func testPingIPv6ReadICMP(t *testing.T, privileged bool, addr string) {
|
||||
request.SetType(header.ICMPv6EchoRequest)
|
||||
request.SetIdent(uint16(rand.Uint32()))
|
||||
|
||||
err = conn.WriteICMP(buf.As(request))
|
||||
err = conn.WriteICMP(buf.As(request).ToOwned())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
|
||||
|
Reference in New Issue
Block a user