Add ping proxy support

This commit is contained in:
世界
2025-02-17 21:56:18 +08:00
parent 4a56d47035
commit 933bd2b2d5
16 changed files with 542 additions and 19 deletions

4
go.mod
View File

@@ -1,12 +1,12 @@
module github.com/sagernet/sing-tun
go 1.20
go 1.23.1
require (
github.com/go-ole/go-ole v1.3.0
github.com/google/btree v1.1.3
github.com/sagernet/fswatch v0.1.1
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.7.0-beta.1

8
go.sum
View File

@@ -1,4 +1,5 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
@@ -14,10 +15,11 @@ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs=
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb h1:pprQtDqNgqXkRsXn+0E8ikKOemzmum8bODjSfDene38=
github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb/go.mod h1:QkkPEJLw59/tfxgapHta14UL5qMUah5NXhO0Kw2Kan4=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
@@ -25,6 +27,7 @@ github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/l
github.com/sagernet/sing v0.7.0-beta.1 h1:2D44KzgeDZwD/R4Ts8jwSUHTRR238a1FpXDrl7l4tVw=
github.com/sagernet/sing v0.7.0-beta.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
@@ -41,3 +44,4 @@ golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -86,6 +86,8 @@ type Network interface {
// SourceAddress returns the value of the "source address" field.
SourceAddress() tcpip.Address
SourceAddr() netip.Addr
// DestinationAddress returns the value of the "destination address"
// field.
DestinationAddress() tcpip.Address
@@ -98,6 +100,8 @@ type Network interface {
// SetSourceAddress sets the value of the "source address" field.
SetSourceAddress(tcpip.Address)
SetSourceAddr(netip.Addr)
// SetDestinationAddress sets the value of the "destination address"
// field.
SetDestinationAddress(tcpip.Address)

45
route_mapping.go Normal file
View File

@@ -0,0 +1,45 @@
package tun
import (
"net/netip"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
)
type DirectRouteSession struct {
// IPVersion uint8
// Network uint8
Source netip.Addr
Destination netip.Addr
}
type RouteMapping struct {
status freelru.Cache[DirectRouteSession, DirectRouteDestination]
}
func NewRouteMapping(timeout time.Duration) *RouteMapping {
status := common.Must1(freelru.NewSharded[DirectRouteSession, DirectRouteDestination](1024, maphash.NewHasher[DirectRouteSession]().Hash32))
status.SetOnEvict(func(session DirectRouteSession, action DirectRouteDestination) {
action.Close()
})
status.SetLifetime(timeout)
return &RouteMapping{status}
}
func (m *RouteMapping) Lookup(session DirectRouteSession, constructor func() (DirectRouteDestination, error)) (DirectRouteDestination, error) {
var (
created DirectRouteDestination
err error
)
action, _, ok := m.status.GetAndRefreshOrAdd(session, func() (DirectRouteDestination, bool) {
created, err = constructor()
return created, err != nil
})
if !ok {
return created, err
}
return action, nil
}

109
route_nat.go Normal file
View File

@@ -0,0 +1,109 @@
package tun
import (
"net/netip"
"sync"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
)
type NatMapping struct {
access sync.RWMutex
sessions map[DirectRouteSession]DirectRouteContext
ipRewrite bool
}
func NewNatMapping(ipRewrite bool) *NatMapping {
return &NatMapping{
sessions: make(map[DirectRouteSession]DirectRouteContext),
ipRewrite: ipRewrite,
}
}
func (m *NatMapping) CreateSession(session DirectRouteSession, context DirectRouteContext) {
if m.ipRewrite {
session.Source = netip.Addr{}
}
m.access.Lock()
m.sessions[session] = context
m.access.Unlock()
}
func (m *NatMapping) DeleteSession(session DirectRouteSession) {
if m.ipRewrite {
session.Source = netip.Addr{}
}
m.access.Lock()
delete(m.sessions, session)
m.access.Unlock()
}
func (m *NatMapping) WritePacket(packet []byte) (bool, error) {
var routeSession DirectRouteSession
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr := header.IPv4(packet)
routeSession.Source = ipHdr.DestinationAddr()
routeSession.Destination = ipHdr.SourceAddr()
case header.IPv6Version:
ipHdr := header.IPv6(packet)
routeSession.Source = ipHdr.DestinationAddr()
routeSession.Destination = ipHdr.SourceAddr()
default:
return false, nil
}
m.access.RLock()
context, loaded := m.sessions[routeSession]
m.access.RUnlock()
if !loaded {
return false, nil
}
return true, context.WritePacket(packet)
}
type NatWriter struct {
inet4Address netip.Addr
inet6Address netip.Addr
}
func NewNatWriter(inet4Address netip.Addr, inet6Address netip.Addr) *NatWriter {
return &NatWriter{
inet4Address: inet4Address,
inet6Address: inet6Address,
}
}
func (w *NatWriter) RewritePacket(packet []byte) {
var ipHdr header.Network
var bindAddr netip.Addr
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr = header.IPv4(packet)
bindAddr = w.inet4Address
case header.IPv6Version:
ipHdr = header.IPv6(packet)
bindAddr = w.inet6Address
default:
return
}
ipHdr.SetSourceAddr(bindAddr)
switch ipHdr.TransportProtocol() {
case header.ICMPv4ProtocolNumber:
icmpHdr := header.ICMPv4(packet)
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(packet)
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddress(),
Dst: ipHdr.DestinationAddress(),
}))
}
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(0)
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
}

49
route_nat_gvisor.go Normal file
View File

@@ -0,0 +1,49 @@
//go:build with_gvisor
package tun
import (
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
stack "github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/buf"
)
type DirectRouteDestination interface {
WritePacket(packet *buf.Buffer) error
WritePacketBuffer(packetBuffer *stack.PacketBuffer) error
Close() error
}
func (w *NatWriter) RewritePacketBuffer(packetBuffer *stack.PacketBuffer) {
var bindAddr tcpip.Address
if packetBuffer.NetworkProtocolNumber == header.IPv4ProtocolNumber {
bindAddr = AddressFromAddr(w.inet4Address)
} else {
bindAddr = AddressFromAddr(w.inet6Address)
}
/*var ipHdr header.Network
switch packetBuffer.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
ipHdr = header.IPv4(packetBuffer.NetworkHeader().Slice())
case header.IPv6ProtocolNumber:
ipHdr = header.IPv6(packetBuffer.NetworkHeader().Slice())
default:
return
}*/
ipHdr := packetBuffer.Network()
oldAddr := ipHdr.SourceAddress()
if checksumHdr, needChecksum := ipHdr.(header.ChecksummableNetwork); needChecksum {
checksumHdr.SetSourceAddressWithChecksumUpdate(bindAddr)
} else {
ipHdr.SetSourceAddress(bindAddr)
}
switch packetBuffer.TransportProtocolNumber {
case header.TCPProtocolNumber:
tcpHdr := header.TCP(packetBuffer.TransportHeader().Slice())
tcpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true)
case header.UDPProtocolNumber:
udpHdr := header.UDP(packetBuffer.TransportHeader().Slice())
udpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true)
}
}

12
route_nat_non_gvisor.go Normal file
View File

@@ -0,0 +1,12 @@
//go:build !with_gvisor
package tun
import (
"github.com/sagernet/sing/common/buf"
)
type DirectRouteDestination interface {
DirectRouteAction
WritePacket(packet *buf.Buffer) error
}

View File

@@ -77,6 +77,9 @@ func (t *GVisor) Start() error {
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
icmpForwarder := NewICMPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint
return nil

207
stack_gvisor_icmp.go Normal file
View File

@@ -0,0 +1,207 @@
//go:build with_gvisor
package tun
import (
"context"
"sync"
"time"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type ICMPForwarder struct {
ctx context.Context
stack *stack.Stack
handler Handler
directNat *RouteMapping
}
func NewICMPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *ICMPForwarder {
return &ICMPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
directNat: NewRouteMapping(timeout),
}
}
func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
ipHdr := header.IPv4(pkt.NetworkHeader().Slice())
icmpHdr := header.ICMPv4(pkt.TransportHeader().Slice())
if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 {
return false
}
sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice())
destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice())
action, err := f.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) {
return f.handler.PrepareConnection(
N.NetworkICMPv4,
M.SocksaddrFrom(sourceAddr, 0),
M.SocksaddrFrom(destinationAddr, 0),
&ICMPBackWriter{
stack: f.stack,
packet: pkt,
source: ipHdr.SourceAddress(),
sourceNetwork: header.IPv4ProtocolNumber,
},
)
})
if err != nil {
return true
}
if action != nil {
// TODO: handle error
pkt.IncRef()
_ = action.WritePacketBuffer(pkt)
return true
}
icmpHdr.SetType(header.ICMPv4EchoReply)
sourceAddress := ipHdr.SourceAddress()
ipHdr.SetSourceAddress(ipHdr.DestinationAddress())
ipHdr.SetDestinationAddress(sourceAddress)
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], pkt.Data().Checksum()))
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber)
if gErr != nil {
// TODO: log error
return true
}
route, gErr := f.stack.FindRoute(
DefaultNIC,
id.LocalAddress,
id.RemoteAddress,
header.IPv6ProtocolNumber,
false,
)
if gErr != nil {
// TODO: log error
return true
}
defer route.Release()
outgoingEP.(ipv4.ExportedEndpoint).WritePacketDirect(route, pkt)
return true
} else {
ipHdr := header.IPv6(pkt.NetworkHeader().Slice())
icmpHdr := header.ICMPv6(pkt.TransportHeader().Slice())
if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 {
return false
}
sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice())
destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice())
action, err := f.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) {
return f.handler.PrepareConnection(
N.NetworkICMPv6,
M.SocksaddrFrom(sourceAddr, 0),
M.SocksaddrFrom(destinationAddr, 0),
&ICMPBackWriter{
stack: f.stack,
packet: pkt,
source: ipHdr.SourceAddress(),
sourceNetwork: header.IPv6ProtocolNumber,
},
)
})
if err != nil {
return true
}
if action != nil {
// TODO: handle error
pkt.IncRef()
_ = action.WritePacketBuffer(pkt)
return true
}
icmpHdr.SetType(header.ICMPv6EchoReply)
sourceAddress := ipHdr.SourceAddress()
ipHdr.SetSourceAddress(ipHdr.DestinationAddress())
ipHdr.SetDestinationAddress(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddress(),
Dst: ipHdr.DestinationAddress(),
}))
outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber)
if gErr != nil {
// TODO: log error
return true
}
route, gErr := f.stack.FindRoute(
DefaultNIC,
id.LocalAddress,
id.RemoteAddress,
header.IPv6ProtocolNumber,
false,
)
if gErr != nil {
// TODO: log error
return true
}
defer route.Release()
outgoingEP.(ipv6.ExportedEndpoint).WritePacketDirect(route, pkt)
return true
}
}
type ICMPBackWriter struct {
access sync.Mutex
stack *stack.Stack
packet *stack.PacketBuffer
source tcpip.Address
sourceNetwork tcpip.NetworkProtocolNumber
}
func (w *ICMPBackWriter) WritePacket(p []byte) error {
if w.sourceNetwork == header.IPv4ProtocolNumber {
route, err := w.stack.FindRoute(
DefaultNIC,
header.IPv4(p).SourceAddress(),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return gonet.TranslateNetstackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(p),
})
defer packet.DecRef()
err = route.WriteHeaderIncludedPacket(packet)
if err != nil {
return gonet.TranslateNetstackError(err)
}
} else {
route, err := w.stack.FindRoute(
DefaultNIC,
header.IPv6(p).SourceAddress(),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return gonet.TranslateNetstackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(p),
})
defer packet.DecRef()
err = route.WriteHeaderIncludedPacket(packet)
if err != nil {
return gonet.TranslateNetstackError(err)
}
}
return nil
}

View File

@@ -79,7 +79,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
_, pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination, nil)
if pErr != nil {
r.Complete(!errors.Is(pErr, ErrDrop))
return

View File

@@ -58,7 +58,7 @@ func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
func rangeIterate(r stack.Range, fn func(*buffer.View))
func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination)
_, pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination, nil)
if pErr != nil {
if !errors.Is(pErr, ErrDrop) {
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer))

View File

@@ -237,7 +237,7 @@ func (m *Mixed) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
pkt.DecRef()
return
case header.ICMPv4ProtocolNumber:
err = m.processIPv4ICMP(ipHdr, ipHdr.Payload())
writeBack, err = m.processIPv4ICMP(ipHdr, ipHdr.Payload())
}
return
}
@@ -259,7 +259,7 @@ func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
m.endpoint.InjectInbound(tcpip.NetworkProtocolNumber(header.IPv6ProtocolNumber), pkt)
pkt.DecRef()
case header.ICMPv6ProtocolNumber:
err = m.processIPv6ICMP(ipHdr, ipHdr.Payload())
writeBack, err = m.processIPv6ICMP(ipHdr, ipHdr.Payload())
}
return
}

View File

@@ -45,6 +45,7 @@ type System struct {
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service
directNat *RouteMapping
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
@@ -159,6 +160,7 @@ func (s *System) start() error {
}
s.tcpNat = NewNat(s.ctx, s.udpTimeout)
s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout, false)
s.directNat = NewRouteMapping(s.udpTimeout)
return nil
}
@@ -361,7 +363,10 @@ func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
writeBack = false
err = s.processIPv4UDP(ipHdr, ipHdr.Payload())
case header.ICMPv4ProtocolNumber:
err = s.processIPv4ICMP(ipHdr, ipHdr.Payload())
writeBack, err = s.processIPv4ICMP(ipHdr, ipHdr.Payload())
}
if err != nil {
writeBack = false
}
return
}
@@ -377,7 +382,10 @@ func (s *System) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
case header.UDPProtocolNumber:
err = s.processIPv6UDP(ipHdr, ipHdr.Payload())
case header.ICMPv6ProtocolNumber:
err = s.processIPv6ICMP(ipHdr, ipHdr.Payload())
writeBack, err = s.processIPv6ICMP(ipHdr, ipHdr.Payload())
}
if err != nil {
writeBack = false
}
return
}
@@ -601,7 +609,7 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error {
}
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination)
_, pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination, nil)
if pErr != nil {
if !errors.Is(pErr, ErrDrop) {
if source.IsIPv4() {
@@ -643,9 +651,25 @@ func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socks
return true, s.ctx, writer, nil
}
func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error {
func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) (bool, error) {
if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 {
return nil
return false, nil
}
sourceAddr := ipHdr.SourceAddr()
destinationAddr := ipHdr.DestinationAddr()
action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) {
return s.handler.PrepareConnection(
N.NetworkICMPv4,
M.SocksaddrFrom(sourceAddr, 0),
M.SocksaddrFrom(destinationAddr, 0),
&systemICMPDirectPacketWriter4{s.tun, s.frontHeadroom + PacketOffset, sourceAddr},
)
})
if err != nil {
return false, nil
}
if action != nil {
return false, action.WritePacket(buf.As(ipHdr).ToOwned())
}
icmpHdr.SetType(header.ICMPv4EchoReply)
sourceAddress := ipHdr.SourceAddr()
@@ -654,7 +678,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return nil
return true, nil
}
func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) error {
@@ -696,9 +720,25 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e
return common.Error(s.tun.Write(newPacket.Bytes()))
}
func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error {
func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) (bool, error) {
if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 {
return nil
return false, nil
}
sourceAddr := ipHdr.SourceAddr()
destinationAddr := ipHdr.DestinationAddr()
action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) {
return s.handler.PrepareConnection(
N.NetworkICMPv6,
M.SocksaddrFrom(sourceAddr, 0),
M.SocksaddrFrom(destinationAddr, 0),
&systemICMPDirectPacketWriter6{s.tun, s.frontHeadroom + PacketOffset, sourceAddr},
)
})
if err != nil {
return false, nil
}
if action != nil {
return false, action.WritePacket(buf.As(ipHdr).ToOwned())
}
icmpHdr.SetType(header.ICMPv6EchoReply)
sourceAddress := ipHdr.SourceAddr()
@@ -709,7 +749,7 @@ func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error
Src: ipHdr.SourceAddress(),
Dst: ipHdr.DestinationAddress(),
}))
return nil
return true, nil
}
func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) error {
@@ -834,3 +874,47 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}
type systemICMPDirectPacketWriter4 struct {
tun Tun
frontHeadroom int
source netip.Addr
}
func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error {
newPacket := buf.NewSize(w.frontHeadroom + len(p))
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(p)
ipHdr := header.IPv4(newPacket.Bytes())
ipHdr.SetDestinationAddr(w.source)
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}
type systemICMPDirectPacketWriter6 struct {
tun Tun
frontHeadroom int
source netip.Addr
}
func (w *systemICMPDirectPacketWriter6) WritePacket(p []byte) error {
newPacket := buf.NewSize(w.frontHeadroom + len(p))
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(p)
ipHdr := header.IPv6(newPacket.Bytes())
ipHdr.SetDestinationAddr(w.source)
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}

View File

@@ -78,7 +78,7 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handl
if loaded {
return port, nil
}
pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination))
_, pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), nil)
if pErr != nil {
return 0, pErr
}

View File

@@ -5,6 +5,7 @@ import (
"syscall"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
)
func PacketIPVersion(packet []byte) int {
@@ -13,6 +14,7 @@ func PacketIPVersion(packet []byte) int {
func PacketFillHeader(packet []byte, ipVersion int) {
if PacketOffset > 0 {
common.ClearArray(packet[:3])
switch ipVersion {
case header.IPv4Version:
packet[3] = syscall.AF_INET

6
tun.go
View File

@@ -18,11 +18,15 @@ import (
)
type Handler interface {
PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error
PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext DirectRouteContext) (DirectRouteDestination, error)
N.TCPConnectionHandlerEx
N.UDPConnectionHandlerEx
}
type DirectRouteContext interface {
WritePacket(packet []byte) error
}
type Tun interface {
io.ReadWriter
Name() (string, error)