ping: Add gVisor destination

This commit is contained in:
世界
2025-08-24 10:36:16 +08:00
parent 12c9fb6a5d
commit 86d96064d5
9 changed files with 332 additions and 221 deletions

113
ping/destination_gvisor.go Normal file
View File

@@ -0,0 +1,113 @@
//go:build with_gvisor
package ping
import (
"context"
"net/netip"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
var _ tun.DirectRouteDestination = (*GVisorDestination)(nil)
type GVisorDestination struct {
ctx context.Context
logger logger.ContextLogger
conn *gonet.TCPConn
rewriter *Rewriter
}
func ConnectGVisor(
ctx context.Context, logger logger.ContextLogger,
sourceAddress, destinationAddress netip.Addr,
routeContext tun.DirectRouteContext,
stack *stack.Stack,
bindAddress4, bindAddress6 netip.Addr,
) (*GVisorDestination, error) {
var (
bindAddress tcpip.Address
wq waiter.Queue
endpoint tcpip.Endpoint
gErr tcpip.Error
)
if !destinationAddress.Is6() {
if !bindAddress4.IsValid() {
return nil, E.New("missing IPv4 interface address")
}
bindAddress = tun.AddressFromAddr(bindAddress4)
endpoint, gErr = stack.NewRawEndpoint(header.ICMPv4ProtocolNumber, header.IPv4ProtocolNumber, &wq, true)
} else {
if !bindAddress6.IsValid() {
return nil, E.New("missing IPv6 interface address")
}
bindAddress = tun.AddressFromAddr(bindAddress6)
endpoint, gErr = stack.NewRawEndpoint(header.ICMPv6ProtocolNumber, header.IPv6ProtocolNumber, &wq, true)
}
if gErr != nil {
return nil, gonet.TranslateNetstackError(gErr)
}
gErr = endpoint.Bind(tcpip.FullAddress{
NIC: 1,
Addr: bindAddress,
})
if gErr != nil {
return nil, gonet.TranslateNetstackError(gErr)
}
gErr = endpoint.Connect(tcpip.FullAddress{
NIC: 1,
Addr: tun.AddressFromAddr(destinationAddress),
})
if gErr != nil {
return nil, gonet.TranslateNetstackError(gErr)
}
endpoint.SocketOptions().SetHeaderIncluded(true)
rewriter := NewRewriter(bindAddress4, bindAddress6)
rewriter.CreateSession(tun.DirectRouteSession{Source: sourceAddress, Destination: destinationAddress}, routeContext)
destination := &GVisorDestination{
ctx: ctx,
logger: logger,
conn: gonet.NewTCPConn(&wq, endpoint),
rewriter: rewriter,
}
go destination.loopRead()
return destination, nil
}
func (d *GVisorDestination) loopRead() {
for {
buffer := buf.NewPacket()
n, err := d.conn.Read(buffer.FreeBytes())
if err != nil {
buffer.Release()
if !E.IsClosed(err) {
d.logger.ErrorContext(d.ctx, E.Cause(err, "receive ICMP echo reply"))
}
return
}
buffer.Truncate(n)
_, err = d.rewriter.WriteBack(buffer.Bytes())
if err != nil {
d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP echo reply"))
}
buffer.Release()
}
}
func (d *GVisorDestination) WritePacket(packet *buf.Buffer) error {
d.rewriter.RewritePacket(packet.Bytes())
return common.Error(d.conn.Write(packet.Bytes()))
}
func (d *GVisorDestination) Close() error {
return d.conn.Close()
}

142
ping/rewriter.go Normal file
View File

@@ -0,0 +1,142 @@
package ping
import (
"net/netip"
"sync"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
)
type Rewriter struct {
access sync.RWMutex
sessions map[tun.DirectRouteSession]tun.DirectRouteContext
source4Address map[uint16]netip.Addr
source6Address map[uint16]netip.Addr
inet4Address netip.Addr
inet6Address netip.Addr
}
func NewRewriter(inet4Address netip.Addr, inet6Address netip.Addr) *Rewriter {
return &Rewriter{
sessions: make(map[tun.DirectRouteSession]tun.DirectRouteContext),
inet4Address: inet4Address,
inet6Address: inet6Address,
}
}
func (m *Rewriter) CreateSession(session tun.DirectRouteSession, context tun.DirectRouteContext) {
m.access.Lock()
m.sessions[session] = context
m.access.Unlock()
}
func (m *Rewriter) DeleteSession(session tun.DirectRouteSession) {
m.access.Lock()
delete(m.sessions, session)
m.access.Unlock()
}
func (m *Rewriter) RewritePacket(packet []byte) {
var ipHdr header.Network
var bindAddr netip.Addr
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr = header.IPv4(packet)
bindAddr = m.inet4Address
case header.IPv6Version:
ipHdr = header.IPv6(packet)
bindAddr = m.inet6Address
default:
return
}
sourceAddr := ipHdr.SourceAddr()
ipHdr.SetSourceAddr(bindAddr)
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(0)
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
switch ipHdr.TransportProtocol() {
case header.ICMPv4ProtocolNumber:
icmpHdr := header.ICMPv4(ipHdr.Payload())
m.access.Lock()
m.source4Address[icmpHdr.Ident()] = sourceAddr
m.access.Lock()
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
m.access.Lock()
m.source6Address[icmpHdr.Ident()] = sourceAddr
m.access.Lock()
}
}
func (m *Rewriter) WriteBack(packet []byte) (bool, error) {
var ipHdr header.Network
var routeSession tun.DirectRouteSession
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr = header.IPv4(packet)
routeSession.Destination = ipHdr.SourceAddr()
case header.IPv6Version:
ipHdr = header.IPv6(packet)
routeSession.Destination = ipHdr.SourceAddr()
default:
return false, nil
}
switch ipHdr.TransportProtocol() {
case header.ICMPv4ProtocolNumber:
icmpHdr := header.ICMPv4(ipHdr.Payload())
m.access.Lock()
ident := icmpHdr.Ident()
source, loaded := m.source4Address[ident]
if !loaded {
m.access.Unlock()
return false, nil
}
delete(m.source4Address, icmpHdr.Ident())
m.access.Lock()
routeSession.Source = source
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(ipHdr.Payload())
m.access.Lock()
ident := icmpHdr.Ident()
source, loaded := m.source6Address[ident]
if !loaded {
m.access.Unlock()
return false, nil
}
delete(m.source6Address, icmpHdr.Ident())
m.access.Lock()
routeSession.Source = source
default:
return false, nil
}
m.access.RLock()
context, loaded := m.sessions[routeSession]
m.access.RUnlock()
if !loaded {
return false, nil
}
ipHdr.SetDestinationAddr(routeSession.Source)
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(0)
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
switch ipHdr.TransportProtocol() {
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
}
return true, context.WritePacket(packet)
}

51
route_direct.go Normal file
View File

@@ -0,0 +1,51 @@
package tun
import (
"net/netip"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
)
type DirectRouteDestination interface {
WritePacket(packet *buf.Buffer) error
Close() error
}
type DirectRouteSession struct {
// IPVersion uint8
// Network uint8
Source netip.Addr
Destination netip.Addr
}
type DirectRouteMapping struct {
mapping freelru.Cache[DirectRouteSession, DirectRouteDestination]
}
func NewDirectRouteMapping(timeout time.Duration) *DirectRouteMapping {
mapping := common.Must1(freelru.NewSharded[DirectRouteSession, DirectRouteDestination](1024, maphash.NewHasher[DirectRouteSession]().Hash32))
mapping.SetOnEvict(func(session DirectRouteSession, action DirectRouteDestination) {
action.Close()
})
mapping.SetLifetime(timeout)
return &DirectRouteMapping{mapping}
}
func (m *DirectRouteMapping) Lookup(session DirectRouteSession, constructor func() (DirectRouteDestination, error)) (DirectRouteDestination, error) {
var (
created DirectRouteDestination
err error
)
action, _, ok := m.mapping.GetAndRefreshOrAdd(session, func() (DirectRouteDestination, bool) {
created, err = constructor()
return created, err == nil
})
if !ok {
return nil, err
}
return action, nil
}

View File

@@ -1,45 +0,0 @@
package tun
import (
"net/netip"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
)
type DirectRouteSession struct {
// IPVersion uint8
// Network uint8
Source netip.Addr
Destination netip.Addr
}
type RouteMapping struct {
status freelru.Cache[DirectRouteSession, DirectRouteDestination]
}
func NewRouteMapping(timeout time.Duration) *RouteMapping {
status := common.Must1(freelru.NewSharded[DirectRouteSession, DirectRouteDestination](1024, maphash.NewHasher[DirectRouteSession]().Hash32))
status.SetOnEvict(func(session DirectRouteSession, action DirectRouteDestination) {
action.Close()
})
status.SetLifetime(timeout)
return &RouteMapping{status}
}
func (m *RouteMapping) Lookup(session DirectRouteSession, constructor func() (DirectRouteDestination, error)) (DirectRouteDestination, error) {
var (
created DirectRouteDestination
err error
)
action, _, ok := m.status.GetAndRefreshOrAdd(session, func() (DirectRouteDestination, bool) {
created, err = constructor()
return created, err == nil
})
if !ok {
return nil, err
}
return action, nil
}

View File

@@ -1,115 +0,0 @@
package tun
import (
"net/netip"
"sync"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common/buf"
)
type DirectRouteDestination interface {
WritePacket(packet *buf.Buffer) error
Close() error
}
type NatMapping struct {
access sync.RWMutex
sessions map[DirectRouteSession]DirectRouteContext
ipRewrite bool
}
func NewNatMapping(ipRewrite bool) *NatMapping {
return &NatMapping{
sessions: make(map[DirectRouteSession]DirectRouteContext),
ipRewrite: ipRewrite,
}
}
func (m *NatMapping) CreateSession(session DirectRouteSession, context DirectRouteContext) {
if m.ipRewrite {
session.Source = netip.Addr{}
}
m.access.Lock()
m.sessions[session] = context
m.access.Unlock()
}
func (m *NatMapping) DeleteSession(session DirectRouteSession) {
if m.ipRewrite {
session.Source = netip.Addr{}
}
m.access.Lock()
delete(m.sessions, session)
m.access.Unlock()
}
func (m *NatMapping) WritePacket(packet []byte) (bool, error) {
var routeSession DirectRouteSession
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr := header.IPv4(packet)
routeSession.Source = ipHdr.DestinationAddr()
routeSession.Destination = ipHdr.SourceAddr()
case header.IPv6Version:
ipHdr := header.IPv6(packet)
routeSession.Source = ipHdr.DestinationAddr()
routeSession.Destination = ipHdr.SourceAddr()
default:
return false, nil
}
m.access.RLock()
context, loaded := m.sessions[routeSession]
m.access.RUnlock()
if !loaded {
return false, nil
}
return true, context.WritePacket(packet)
}
type NatWriter struct {
inet4Address netip.Addr
inet6Address netip.Addr
}
func NewNatWriter(inet4Address netip.Addr, inet6Address netip.Addr) *NatWriter {
return &NatWriter{
inet4Address: inet4Address,
inet6Address: inet6Address,
}
}
func (w *NatWriter) RewritePacket(packet []byte) {
var ipHdr header.Network
var bindAddr netip.Addr
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr = header.IPv4(packet)
bindAddr = w.inet4Address
case header.IPv6Version:
ipHdr = header.IPv6(packet)
bindAddr = w.inet6Address
default:
return
}
ipHdr.SetSourceAddr(bindAddr)
switch ipHdr.TransportProtocol() {
case header.ICMPv4ProtocolNumber:
icmpHdr := header.ICMPv4(packet)
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(packet)
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
}
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(0)
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
}

View File

@@ -1,42 +0,0 @@
//go:build with_gvisor
package tun
import (
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
)
func (w *NatWriter) RewritePacketBuffer(packetBuffer *stack.PacketBuffer) {
var bindAddr tcpip.Address
if packetBuffer.NetworkProtocolNumber == header.IPv4ProtocolNumber {
bindAddr = AddressFromAddr(w.inet4Address)
} else {
bindAddr = AddressFromAddr(w.inet6Address)
}
/*var ipHdr header.Network
switch packetBuffer.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
ipHdr = header.IPv4(packetBuffer.NetworkHeader().Slice())
case header.IPv6ProtocolNumber:
ipHdr = header.IPv6(packetBuffer.NetworkHeader().Slice())
default:
return
}*/
ipHdr := packetBuffer.Network()
oldAddr := ipHdr.SourceAddress()
if checksumHdr, needChecksum := ipHdr.(header.ChecksummableNetwork); needChecksum {
checksumHdr.SetSourceAddressWithChecksumUpdate(bindAddr)
} else {
ipHdr.SetSourceAddress(bindAddr)
}
switch packetBuffer.TransportProtocolNumber {
case header.TCPProtocolNumber:
tcpHdr := header.TCP(packetBuffer.TransportHeader().Slice())
tcpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true)
case header.UDPProtocolNumber:
udpHdr := header.UDP(packetBuffer.TransportHeader().Slice())
udpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true)
}
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/raw"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
E "github.com/sagernet/sing/common/exceptions"
@@ -86,13 +87,14 @@ func (t *GVisor) Start() error {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := NewGVisorStackWithOptions(linkEndpoint, nicOptions)
ipStack, err := NewGVisorStackWithOptions(linkEndpoint, nicOptions, false)
if err != nil {
return err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
icmpForwarder := NewICMPForwarder(t.ctx, ipStack, t.inet4Address, t.inet6Address, t.handler, t.udpTimeout)
icmpForwarder := NewICMPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout)
icmpForwarder.SetLocalAddresses(t.inet4Address, t.inet6Address)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
t.stack = ipStack
@@ -129,11 +131,11 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
}
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
return NewGVisorStackWithOptions(ep, stack.NICOptions{})
return NewGVisorStackWithOptions(ep, stack.NICOptions{}, false)
}
func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions, allowRawEndpoint bool) (*stack.Stack, error) {
stackOptions := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
@@ -144,7 +146,11 @@ func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions) (*s
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
}
if allowRawEndpoint {
stackOptions.RawFactory = new(raw.EndpointFactory)
}
ipStack := stack.New(stackOptions)
err := ipStack.CreateNICWithOptions(DefaultNIC, ep, opts)
if err != nil {
return nil, gonet.TranslateNetstackError(err)

View File

@@ -27,27 +27,28 @@ type ICMPForwarder struct {
inet4Address netip.Addr
inet6Address netip.Addr
handler Handler
directNat *RouteMapping
mapping *DirectRouteMapping
}
func NewICMPForwarder(
ctx context.Context,
stack *stack.Stack,
inet4Address netip.Addr,
inet6Address netip.Addr,
handler Handler,
timeout time.Duration,
) *ICMPForwarder {
return &ICMPForwarder{
ctx: ctx,
stack: stack,
inet4Address: inet4Address,
inet6Address: inet6Address,
handler: handler,
directNat: NewRouteMapping(timeout),
ctx: ctx,
stack: stack,
handler: handler,
mapping: NewDirectRouteMapping(timeout),
}
}
func (f *ICMPForwarder) SetLocalAddresses(inet4Address, inet6Address netip.Addr) {
f.inet4Address = inet4Address
f.inet6Address = inet6Address
}
func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
ipHdr := header.IPv4(pkt.NetworkHeader().Slice())
@@ -58,7 +59,7 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa
sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice())
destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice())
if destinationAddr != f.inet4Address {
action, err := f.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) {
action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) {
return f.handler.PrepareConnection(
N.NetworkICMPv4,
M.SocksaddrFrom(sourceAddr, 0),
@@ -116,7 +117,7 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa
sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice())
destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice())
if destinationAddr != f.inet6Address {
action, err := f.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) {
action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) {
return f.handler.PrepareConnection(
N.NetworkICMPv6,
M.SocksaddrFrom(sourceAddr, 0),

View File

@@ -45,7 +45,7 @@ type System struct {
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service
directNat *RouteMapping
directNat *DirectRouteMapping
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
@@ -160,7 +160,7 @@ func (s *System) start() error {
}
s.tcpNat = NewNat(s.ctx, s.udpTimeout)
s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout, false)
s.directNat = NewRouteMapping(s.udpTimeout)
s.directNat = NewDirectRouteMapping(s.udpTimeout)
return nil
}