mirror of
https://github.com/SagerNet/sing-tun.git
synced 2025-09-26 20:51:13 +08:00
ping: Fix rewriter
This commit is contained in:
@@ -76,7 +76,7 @@ func ConnectGVisor(
|
||||
return nil, gonet.TranslateNetstackError(gErr)
|
||||
}
|
||||
endpoint.SocketOptions().SetHeaderIncluded(true)
|
||||
rewriter := NewRewriter(bindAddress4, bindAddress6)
|
||||
rewriter := NewRewriter(ctx, logger, bindAddress4, bindAddress6)
|
||||
rewriter.CreateSession(tun.DirectRouteSession{Source: sourceAddress, Destination: destinationAddress}, routeContext)
|
||||
destination := &GVisorDestination{
|
||||
ctx: ctx,
|
||||
|
@@ -1,27 +1,33 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
type Rewriter struct {
|
||||
access sync.RWMutex
|
||||
sessions map[tun.DirectRouteSession]tun.DirectRouteContext
|
||||
source4Address map[uint16]netip.Addr
|
||||
source6Address map[uint16]netip.Addr
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
access sync.RWMutex
|
||||
sessions map[tun.DirectRouteSession]tun.DirectRouteContext
|
||||
sourceAddress map[uint16]netip.Addr
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
}
|
||||
|
||||
func NewRewriter(inet4Address netip.Addr, inet6Address netip.Addr) *Rewriter {
|
||||
func NewRewriter(ctx context.Context, logger logger.ContextLogger, inet4Address netip.Addr, inet6Address netip.Addr) *Rewriter {
|
||||
return &Rewriter{
|
||||
sessions: make(map[tun.DirectRouteSession]tun.DirectRouteContext),
|
||||
inet4Address: inet4Address,
|
||||
inet6Address: inet6Address,
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
sessions: make(map[tun.DirectRouteSession]tun.DirectRouteContext),
|
||||
sourceAddress: make(map[uint16]netip.Addr),
|
||||
inet4Address: inet4Address,
|
||||
inet6Address: inet6Address,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,8 +66,9 @@ func (m *Rewriter) RewritePacket(packet []byte) {
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
m.access.Lock()
|
||||
m.source4Address[icmpHdr.Ident()] = sourceAddr
|
||||
m.access.Lock()
|
||||
m.sourceAddress[icmpHdr.Ident()] = sourceAddr
|
||||
m.access.Unlock()
|
||||
m.logger.TraceContext(m.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
icmpHdr.SetChecksum(0)
|
||||
@@ -71,8 +78,9 @@ func (m *Rewriter) RewritePacket(packet []byte) {
|
||||
Dst: ipHdr.DestinationAddressSlice(),
|
||||
}))
|
||||
m.access.Lock()
|
||||
m.source6Address[icmpHdr.Ident()] = sourceAddr
|
||||
m.access.Lock()
|
||||
m.sourceAddress[icmpHdr.Ident()] = sourceAddr
|
||||
m.access.Unlock()
|
||||
m.logger.TraceContext(m.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,25 +102,25 @@ func (m *Rewriter) WriteBack(packet []byte) (bool, error) {
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
m.access.Lock()
|
||||
ident := icmpHdr.Ident()
|
||||
source, loaded := m.source4Address[ident]
|
||||
source, loaded := m.sourceAddress[ident]
|
||||
if !loaded {
|
||||
m.access.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
delete(m.source4Address, icmpHdr.Ident())
|
||||
m.access.Lock()
|
||||
delete(m.sourceAddress, icmpHdr.Ident())
|
||||
m.access.Unlock()
|
||||
routeSession.Source = source
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
m.access.Lock()
|
||||
ident := icmpHdr.Ident()
|
||||
source, loaded := m.source6Address[ident]
|
||||
source, loaded := m.sourceAddress[ident]
|
||||
if !loaded {
|
||||
m.access.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
delete(m.source6Address, icmpHdr.Ident())
|
||||
m.access.Lock()
|
||||
delete(m.sourceAddress, icmpHdr.Ident())
|
||||
m.access.Unlock()
|
||||
routeSession.Source = source
|
||||
default:
|
||||
return false, nil
|
||||
@@ -129,6 +137,9 @@ func (m *Rewriter) WriteBack(packet []byte) (bool, error) {
|
||||
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
|
||||
}
|
||||
switch ipHdr.TransportProtocol() {
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
m.logger.TraceContext(m.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
icmpHdr.SetChecksum(0)
|
||||
@@ -137,6 +148,7 @@ func (m *Rewriter) WriteBack(packet []byte) (bool, error) {
|
||||
Src: ipHdr.SourceAddressSlice(),
|
||||
Dst: ipHdr.DestinationAddressSlice(),
|
||||
}))
|
||||
m.logger.TraceContext(m.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
}
|
||||
return true, context.WritePacket(packet)
|
||||
}
|
||||
|
Reference in New Issue
Block a user