From 933bd2b2d5fa394234b01af43dc6b7eac0c1d2f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 17 Feb 2025 21:56:18 +0800 Subject: [PATCH] Add ping proxy support --- go.mod | 4 +- go.sum | 8 +- internal/gtcpip/header/interfaces.go | 4 + route_mapping.go | 45 ++++++ route_nat.go | 109 ++++++++++++++ route_nat_gvisor.go | 49 +++++++ route_nat_non_gvisor.go | 12 ++ stack_gvisor.go | 3 + stack_gvisor_icmp.go | 207 +++++++++++++++++++++++++++ stack_gvisor_tcp.go | 2 +- stack_gvisor_udp.go | 2 +- stack_mixed.go | 4 +- stack_system.go | 102 +++++++++++-- stack_system_nat.go | 2 +- stack_system_packet.go | 2 + tun.go | 6 +- 16 files changed, 542 insertions(+), 19 deletions(-) create mode 100644 route_mapping.go create mode 100644 route_nat.go create mode 100644 route_nat_gvisor.go create mode 100644 route_nat_non_gvisor.go create mode 100644 stack_gvisor_icmp.go diff --git a/go.mod b/go.mod index 9e3eb8d..84d6651 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,12 @@ module github.com/sagernet/sing-tun -go 1.20 +go 1.23.1 require ( github.com/go-ole/go-ole v1.3.0 github.com/google/btree v1.1.3 github.com/sagernet/fswatch v0.1.1 - github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff + github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/sagernet/nftables v0.3.0-beta.4 github.com/sagernet/sing v0.7.0-beta.1 diff --git a/go.sum b/go.sum index fb42826..d561248 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,5 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= @@ -14,10 +15,11 @@ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs= github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o= -github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs= -github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw= +github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb h1:pprQtDqNgqXkRsXn+0E8ikKOemzmum8bODjSfDene38= +github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb/go.mod h1:QkkPEJLw59/tfxgapHta14UL5qMUah5NXhO0Kw2Kan4= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= @@ -25,6 +27,7 @@ github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/l github.com/sagernet/sing v0.7.0-beta.1 h1:2D44KzgeDZwD/R4Ts8jwSUHTRR238a1FpXDrl7l4tVw= github.com/sagernet/sing v0.7.0-beta.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= @@ -41,3 +44,4 @@ golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/gtcpip/header/interfaces.go b/internal/gtcpip/header/interfaces.go index b304532..c2f8cdf 100644 --- a/internal/gtcpip/header/interfaces.go +++ b/internal/gtcpip/header/interfaces.go @@ -86,6 +86,8 @@ type Network interface { // SourceAddress returns the value of the "source address" field. SourceAddress() tcpip.Address + SourceAddr() netip.Addr + // DestinationAddress returns the value of the "destination address" // field. DestinationAddress() tcpip.Address @@ -98,6 +100,8 @@ type Network interface { // SetSourceAddress sets the value of the "source address" field. SetSourceAddress(tcpip.Address) + SetSourceAddr(netip.Addr) + // SetDestinationAddress sets the value of the "destination address" // field. SetDestinationAddress(tcpip.Address) diff --git a/route_mapping.go b/route_mapping.go new file mode 100644 index 0000000..bd51212 --- /dev/null +++ b/route_mapping.go @@ -0,0 +1,45 @@ +package tun + +import ( + "net/netip" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" +) + +type DirectRouteSession struct { + // IPVersion uint8 + // Network uint8 + Source netip.Addr + Destination netip.Addr +} + +type RouteMapping struct { + status freelru.Cache[DirectRouteSession, DirectRouteDestination] +} + +func NewRouteMapping(timeout time.Duration) *RouteMapping { + status := common.Must1(freelru.NewSharded[DirectRouteSession, DirectRouteDestination](1024, maphash.NewHasher[DirectRouteSession]().Hash32)) + status.SetOnEvict(func(session DirectRouteSession, action DirectRouteDestination) { + action.Close() + }) + status.SetLifetime(timeout) + return &RouteMapping{status} +} + +func (m *RouteMapping) Lookup(session DirectRouteSession, constructor func() (DirectRouteDestination, error)) (DirectRouteDestination, error) { + var ( + created DirectRouteDestination + err error + ) + action, _, ok := m.status.GetAndRefreshOrAdd(session, func() (DirectRouteDestination, bool) { + created, err = constructor() + return created, err != nil + }) + if !ok { + return created, err + } + return action, nil +} diff --git a/route_nat.go b/route_nat.go new file mode 100644 index 0000000..5997703 --- /dev/null +++ b/route_nat.go @@ -0,0 +1,109 @@ +package tun + +import ( + "net/netip" + "sync" + + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/gtcpip/header" +) + +type NatMapping struct { + access sync.RWMutex + sessions map[DirectRouteSession]DirectRouteContext + ipRewrite bool +} + +func NewNatMapping(ipRewrite bool) *NatMapping { + return &NatMapping{ + sessions: make(map[DirectRouteSession]DirectRouteContext), + ipRewrite: ipRewrite, + } +} + +func (m *NatMapping) CreateSession(session DirectRouteSession, context DirectRouteContext) { + if m.ipRewrite { + session.Source = netip.Addr{} + } + m.access.Lock() + m.sessions[session] = context + m.access.Unlock() +} + +func (m *NatMapping) DeleteSession(session DirectRouteSession) { + if m.ipRewrite { + session.Source = netip.Addr{} + } + m.access.Lock() + delete(m.sessions, session) + m.access.Unlock() +} + +func (m *NatMapping) WritePacket(packet []byte) (bool, error) { + var routeSession DirectRouteSession + switch header.IPVersion(packet) { + case header.IPv4Version: + ipHdr := header.IPv4(packet) + routeSession.Source = ipHdr.DestinationAddr() + routeSession.Destination = ipHdr.SourceAddr() + case header.IPv6Version: + ipHdr := header.IPv6(packet) + routeSession.Source = ipHdr.DestinationAddr() + routeSession.Destination = ipHdr.SourceAddr() + default: + return false, nil + } + m.access.RLock() + context, loaded := m.sessions[routeSession] + m.access.RUnlock() + if !loaded { + return false, nil + } + return true, context.WritePacket(packet) +} + +type NatWriter struct { + inet4Address netip.Addr + inet6Address netip.Addr +} + +func NewNatWriter(inet4Address netip.Addr, inet6Address netip.Addr) *NatWriter { + return &NatWriter{ + inet4Address: inet4Address, + inet6Address: inet6Address, + } +} + +func (w *NatWriter) RewritePacket(packet []byte) { + var ipHdr header.Network + var bindAddr netip.Addr + switch header.IPVersion(packet) { + case header.IPv4Version: + ipHdr = header.IPv4(packet) + bindAddr = w.inet4Address + case header.IPv6Version: + ipHdr = header.IPv6(packet) + bindAddr = w.inet6Address + default: + return + } + ipHdr.SetSourceAddr(bindAddr) + switch ipHdr.TransportProtocol() { + case header.ICMPv4ProtocolNumber: + icmpHdr := header.ICMPv4(packet) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) + case header.ICMPv6ProtocolNumber: + icmpHdr := header.ICMPv6(packet) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: ipHdr.SourceAddress(), + Dst: ipHdr.DestinationAddress(), + })) + } + if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 { + ipHdr4.SetChecksum(0) + ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum()) + } +} diff --git a/route_nat_gvisor.go b/route_nat_gvisor.go new file mode 100644 index 0000000..be8febb --- /dev/null +++ b/route_nat_gvisor.go @@ -0,0 +1,49 @@ +//go:build with_gvisor + +package tun + +import ( + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/header" + stack "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/sing/common/buf" +) + +type DirectRouteDestination interface { + WritePacket(packet *buf.Buffer) error + WritePacketBuffer(packetBuffer *stack.PacketBuffer) error + Close() error +} + +func (w *NatWriter) RewritePacketBuffer(packetBuffer *stack.PacketBuffer) { + var bindAddr tcpip.Address + if packetBuffer.NetworkProtocolNumber == header.IPv4ProtocolNumber { + bindAddr = AddressFromAddr(w.inet4Address) + } else { + bindAddr = AddressFromAddr(w.inet6Address) + } + /*var ipHdr header.Network + switch packetBuffer.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + ipHdr = header.IPv4(packetBuffer.NetworkHeader().Slice()) + case header.IPv6ProtocolNumber: + ipHdr = header.IPv6(packetBuffer.NetworkHeader().Slice()) + default: + return + }*/ + ipHdr := packetBuffer.Network() + oldAddr := ipHdr.SourceAddress() + if checksumHdr, needChecksum := ipHdr.(header.ChecksummableNetwork); needChecksum { + checksumHdr.SetSourceAddressWithChecksumUpdate(bindAddr) + } else { + ipHdr.SetSourceAddress(bindAddr) + } + switch packetBuffer.TransportProtocolNumber { + case header.TCPProtocolNumber: + tcpHdr := header.TCP(packetBuffer.TransportHeader().Slice()) + tcpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true) + case header.UDPProtocolNumber: + udpHdr := header.UDP(packetBuffer.TransportHeader().Slice()) + udpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true) + } +} diff --git a/route_nat_non_gvisor.go b/route_nat_non_gvisor.go new file mode 100644 index 0000000..049b074 --- /dev/null +++ b/route_nat_non_gvisor.go @@ -0,0 +1,12 @@ +//go:build !with_gvisor + +package tun + +import ( + "github.com/sagernet/sing/common/buf" +) + +type DirectRouteDestination interface { + DirectRouteAction + WritePacket(packet *buf.Buffer) error +} diff --git a/stack_gvisor.go b/stack_gvisor.go index 0dc995b..f67b53f 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -77,6 +77,9 @@ func (t *GVisor) Start() error { } ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) + icmpForwarder := NewICMPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) t.stack = ipStack t.endpoint = linkEndpoint return nil diff --git a/stack_gvisor_icmp.go b/stack_gvisor_icmp.go new file mode 100644 index 0000000..d93c5e3 --- /dev/null +++ b/stack_gvisor_icmp.go @@ -0,0 +1,207 @@ +//go:build with_gvisor + +package tun + +import ( + "context" + "sync" + "time" + + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" + "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type ICMPForwarder struct { + ctx context.Context + stack *stack.Stack + handler Handler + directNat *RouteMapping +} + +func NewICMPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *ICMPForwarder { + return &ICMPForwarder{ + ctx: ctx, + stack: stack, + handler: handler, + directNat: NewRouteMapping(timeout), + } +} + +func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + ipHdr := header.IPv4(pkt.NetworkHeader().Slice()) + icmpHdr := header.ICMPv4(pkt.TransportHeader().Slice()) + if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 { + return false + } + sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice()) + destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice()) + action, err := f.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) { + return f.handler.PrepareConnection( + N.NetworkICMPv4, + M.SocksaddrFrom(sourceAddr, 0), + M.SocksaddrFrom(destinationAddr, 0), + &ICMPBackWriter{ + stack: f.stack, + packet: pkt, + source: ipHdr.SourceAddress(), + sourceNetwork: header.IPv4ProtocolNumber, + }, + ) + }) + if err != nil { + return true + } + if action != nil { + // TODO: handle error + pkt.IncRef() + _ = action.WritePacketBuffer(pkt) + return true + } + icmpHdr.SetType(header.ICMPv4EchoReply) + sourceAddress := ipHdr.SourceAddress() + ipHdr.SetSourceAddress(ipHdr.DestinationAddress()) + ipHdr.SetDestinationAddress(sourceAddress) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], pkt.Data().Checksum())) + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber) + if gErr != nil { + // TODO: log error + return true + } + route, gErr := f.stack.FindRoute( + DefaultNIC, + id.LocalAddress, + id.RemoteAddress, + header.IPv6ProtocolNumber, + false, + ) + if gErr != nil { + // TODO: log error + return true + } + defer route.Release() + outgoingEP.(ipv4.ExportedEndpoint).WritePacketDirect(route, pkt) + return true + } else { + ipHdr := header.IPv6(pkt.NetworkHeader().Slice()) + icmpHdr := header.ICMPv6(pkt.TransportHeader().Slice()) + if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 { + return false + } + sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice()) + destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice()) + action, err := f.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) { + return f.handler.PrepareConnection( + N.NetworkICMPv6, + M.SocksaddrFrom(sourceAddr, 0), + M.SocksaddrFrom(destinationAddr, 0), + &ICMPBackWriter{ + stack: f.stack, + packet: pkt, + source: ipHdr.SourceAddress(), + sourceNetwork: header.IPv6ProtocolNumber, + }, + ) + }) + if err != nil { + return true + } + if action != nil { + // TODO: handle error + pkt.IncRef() + _ = action.WritePacketBuffer(pkt) + return true + } + icmpHdr.SetType(header.ICMPv6EchoReply) + sourceAddress := ipHdr.SourceAddress() + ipHdr.SetSourceAddress(ipHdr.DestinationAddress()) + ipHdr.SetDestinationAddress(sourceAddress) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: ipHdr.SourceAddress(), + Dst: ipHdr.DestinationAddress(), + })) + outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber) + if gErr != nil { + // TODO: log error + return true + } + route, gErr := f.stack.FindRoute( + DefaultNIC, + id.LocalAddress, + id.RemoteAddress, + header.IPv6ProtocolNumber, + false, + ) + if gErr != nil { + // TODO: log error + return true + } + defer route.Release() + outgoingEP.(ipv6.ExportedEndpoint).WritePacketDirect(route, pkt) + return true + } +} + +type ICMPBackWriter struct { + access sync.Mutex + stack *stack.Stack + packet *stack.PacketBuffer + source tcpip.Address + sourceNetwork tcpip.NetworkProtocolNumber +} + +func (w *ICMPBackWriter) WritePacket(p []byte) error { + if w.sourceNetwork == header.IPv4ProtocolNumber { + route, err := w.stack.FindRoute( + DefaultNIC, + header.IPv4(p).SourceAddress(), + w.source, + w.sourceNetwork, + false, + ) + if err != nil { + return gonet.TranslateNetstackError(err) + } + defer route.Release() + packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(p), + }) + defer packet.DecRef() + err = route.WriteHeaderIncludedPacket(packet) + if err != nil { + return gonet.TranslateNetstackError(err) + } + } else { + route, err := w.stack.FindRoute( + DefaultNIC, + header.IPv6(p).SourceAddress(), + w.source, + w.sourceNetwork, + false, + ) + if err != nil { + return gonet.TranslateNetstackError(err) + } + defer route.Release() + packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(p), + }) + defer packet.DecRef() + err = route.WriteHeaderIncludedPacket(packet) + if err != nil { + return gonet.TranslateNetstackError(err) + } + } + return nil +} diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go index aad97cf..84bc3ff 100644 --- a/stack_gvisor_tcp.go +++ b/stack_gvisor_tcp.go @@ -79,7 +79,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) { source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) - pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination) + _, pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination, nil) if pErr != nil { r.Complete(!errors.Is(pErr, ErrDrop)) return diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 473eec4..db06b64 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -58,7 +58,7 @@ func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac func rangeIterate(r stack.Range, fn func(*buffer.View)) func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { - pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination) + _, pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination, nil) if pErr != nil { if !errors.Is(pErr, ErrDrop) { gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer)) diff --git a/stack_mixed.go b/stack_mixed.go index a48639d..083ae97 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -237,7 +237,7 @@ func (m *Mixed) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) { pkt.DecRef() return case header.ICMPv4ProtocolNumber: - err = m.processIPv4ICMP(ipHdr, ipHdr.Payload()) + writeBack, err = m.processIPv4ICMP(ipHdr, ipHdr.Payload()) } return } @@ -259,7 +259,7 @@ func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) { m.endpoint.InjectInbound(tcpip.NetworkProtocolNumber(header.IPv6ProtocolNumber), pkt) pkt.DecRef() case header.ICMPv6ProtocolNumber: - err = m.processIPv6ICMP(ipHdr, ipHdr.Payload()) + writeBack, err = m.processIPv6ICMP(ipHdr, ipHdr.Payload()) } return } diff --git a/stack_system.go b/stack_system.go index 825c5f2..9075a78 100644 --- a/stack_system.go +++ b/stack_system.go @@ -45,6 +45,7 @@ type System struct { tcpPort6 uint16 tcpNat *TCPNat udpNat *udpnat.Service + directNat *RouteMapping bindInterface bool interfaceFinder control.InterfaceFinder frontHeadroom int @@ -159,6 +160,7 @@ func (s *System) start() error { } s.tcpNat = NewNat(s.ctx, s.udpTimeout) s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout, false) + s.directNat = NewRouteMapping(s.udpTimeout) return nil } @@ -361,7 +363,10 @@ func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) { writeBack = false err = s.processIPv4UDP(ipHdr, ipHdr.Payload()) case header.ICMPv4ProtocolNumber: - err = s.processIPv4ICMP(ipHdr, ipHdr.Payload()) + writeBack, err = s.processIPv4ICMP(ipHdr, ipHdr.Payload()) + } + if err != nil { + writeBack = false } return } @@ -377,7 +382,10 @@ func (s *System) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) { case header.UDPProtocolNumber: err = s.processIPv6UDP(ipHdr, ipHdr.Payload()) case header.ICMPv6ProtocolNumber: - err = s.processIPv6ICMP(ipHdr, ipHdr.Payload()) + writeBack, err = s.processIPv6ICMP(ipHdr, ipHdr.Payload()) + } + if err != nil { + writeBack = false } return } @@ -601,7 +609,7 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error { } func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { - pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination) + _, pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination, nil) if pErr != nil { if !errors.Is(pErr, ErrDrop) { if source.IsIPv4() { @@ -643,9 +651,25 @@ func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socks return true, s.ctx, writer, nil } -func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error { +func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) (bool, error) { if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 { - return nil + return false, nil + } + sourceAddr := ipHdr.SourceAddr() + destinationAddr := ipHdr.DestinationAddr() + action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) { + return s.handler.PrepareConnection( + N.NetworkICMPv4, + M.SocksaddrFrom(sourceAddr, 0), + M.SocksaddrFrom(destinationAddr, 0), + &systemICMPDirectPacketWriter4{s.tun, s.frontHeadroom + PacketOffset, sourceAddr}, + ) + }) + if err != nil { + return false, nil + } + if action != nil { + return false, action.WritePacket(buf.As(ipHdr).ToOwned()) } icmpHdr.SetType(header.ICMPv4EchoReply) sourceAddress := ipHdr.SourceAddr() @@ -654,7 +678,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) - return nil + return true, nil } func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) error { @@ -696,9 +720,25 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e return common.Error(s.tun.Write(newPacket.Bytes())) } -func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error { +func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) (bool, error) { if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 { - return nil + return false, nil + } + sourceAddr := ipHdr.SourceAddr() + destinationAddr := ipHdr.DestinationAddr() + action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func() (DirectRouteDestination, error) { + return s.handler.PrepareConnection( + N.NetworkICMPv6, + M.SocksaddrFrom(sourceAddr, 0), + M.SocksaddrFrom(destinationAddr, 0), + &systemICMPDirectPacketWriter6{s.tun, s.frontHeadroom + PacketOffset, sourceAddr}, + ) + }) + if err != nil { + return false, nil + } + if action != nil { + return false, action.WritePacket(buf.As(ipHdr).ToOwned()) } icmpHdr.SetType(header.ICMPv6EchoReply) sourceAddress := ipHdr.SourceAddr() @@ -709,7 +749,7 @@ func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error Src: ipHdr.SourceAddress(), Dst: ipHdr.DestinationAddress(), })) - return nil + return true, nil } func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) error { @@ -834,3 +874,47 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S } return common.Error(w.tun.Write(newPacket.Bytes())) } + +type systemICMPDirectPacketWriter4 struct { + tun Tun + frontHeadroom int + source netip.Addr +} + +func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error { + newPacket := buf.NewSize(w.frontHeadroom + len(p)) + defer newPacket.Release() + newPacket.Resize(w.frontHeadroom, 0) + newPacket.Write(p) + ipHdr := header.IPv4(newPacket.Bytes()) + ipHdr.SetDestinationAddr(w.source) + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + if PacketOffset > 0 { + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version) + } else { + newPacket.Advance(-w.frontHeadroom) + } + return common.Error(w.tun.Write(newPacket.Bytes())) +} + +type systemICMPDirectPacketWriter6 struct { + tun Tun + frontHeadroom int + source netip.Addr +} + +func (w *systemICMPDirectPacketWriter6) WritePacket(p []byte) error { + newPacket := buf.NewSize(w.frontHeadroom + len(p)) + defer newPacket.Release() + newPacket.Resize(w.frontHeadroom, 0) + newPacket.Write(p) + ipHdr := header.IPv6(newPacket.Bytes()) + ipHdr.SetDestinationAddr(w.source) + if PacketOffset > 0 { + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version) + } else { + newPacket.Advance(-w.frontHeadroom) + } + return common.Error(w.tun.Write(newPacket.Bytes())) +} diff --git a/stack_system_nat.go b/stack_system_nat.go index 1d0216e..6b581bc 100644 --- a/stack_system_nat.go +++ b/stack_system_nat.go @@ -78,7 +78,7 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handl if loaded { return port, nil } - pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination)) + _, pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), nil) if pErr != nil { return 0, pErr } diff --git a/stack_system_packet.go b/stack_system_packet.go index b5060b0..34fe51e 100644 --- a/stack_system_packet.go +++ b/stack_system_packet.go @@ -5,6 +5,7 @@ import ( "syscall" "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing/common" ) func PacketIPVersion(packet []byte) int { @@ -13,6 +14,7 @@ func PacketIPVersion(packet []byte) int { func PacketFillHeader(packet []byte, ipVersion int) { if PacketOffset > 0 { + common.ClearArray(packet[:3]) switch ipVersion { case header.IPv4Version: packet[3] = syscall.AF_INET diff --git a/tun.go b/tun.go index 92eab64..09497f7 100644 --- a/tun.go +++ b/tun.go @@ -18,11 +18,15 @@ import ( ) type Handler interface { - PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error + PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext DirectRouteContext) (DirectRouteDestination, error) N.TCPConnectionHandlerEx N.UDPConnectionHandlerEx } +type DirectRouteContext interface { + WritePacket(packet []byte) error +} + type Tun interface { io.ReadWriter Name() (string, error)