From 60774779fdd8dfa4bffe4a4a94cf88f3cc0479ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 11 Jun 2025 17:10:11 +0800 Subject: [PATCH] Add loopback address support --- redirect_nftables.go | 50 +++++- redirect_nftables_exprs.go | 86 ++++++++-- redirect_nftables_rules.go | 339 ++++++++++++++++++++++++++++--------- stack_gvisor.go | 34 ++-- stack_gvisor_tcp.go | 57 ++++++- stack_system.go | 120 ++++++++----- tun.go | 2 + 7 files changed, 514 insertions(+), 174 deletions(-) diff --git a/redirect_nftables.go b/redirect_nftables.go index 5369518..4a21fa8 100644 --- a/redirect_nftables.go +++ b/redirect_nftables.go @@ -46,6 +46,11 @@ func (r *autoRedirect) setupNFTables() error { return err } + err = r.nftablesCreateLoopbackAddressSets(nft, table) + if err != nil { + return err + } + skipOutput := len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo") if !skipOutput { chainOutput := nft.AddChain(&nftables.Chain{ @@ -61,8 +66,23 @@ func (r *autoRedirect) setupNFTables() error { return err } r.nftablesCreateUnreachable(nft, table, chainOutput) - r.nftablesCreateRedirect(nft, table, chainOutput) - + err = r.nftablesCreateRedirect(nft, table, chainOutput) + if err != nil { + return err + } + if len(r.tunOptions.Inet4LoopbackAddress) > 0 || len(r.tunOptions.Inet6LoopbackAddress) > 0 { + chainOutputRoute := nft.AddChain(&nftables.Chain{ + Name: "output_route", + Table: table, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeRoute, + }) + err = r.nftablesCreateLoopbackReroute(nft, table, chainOutputRoute) + if err != nil { + return err + } + } chainOutputUDP := nft.AddChain(&nftables.Chain{ Name: "output_udp_icmp", Table: table, @@ -77,7 +97,7 @@ func (r *autoRedirect) setupNFTables() error { r.nftablesCreateUnreachable(nft, table, chainOutputUDP) r.nftablesCreateMark(nft, table, chainOutputUDP) } else { - r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{ + err = r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{ Key: expr.MetaKeyOIFNAME, Register: 1, }, &expr.Cmp{ @@ -85,6 +105,9 @@ func (r *autoRedirect) setupNFTables() error { Register: 1, Data: nftablesIfname(r.tunOptions.Name), }) + if err != nil { + return err + } } } @@ -100,12 +123,25 @@ func (r *autoRedirect) setupNFTables() error { return err } r.nftablesCreateUnreachable(nft, table, chainPreRouting) - r.nftablesCreateRedirect(nft, table, chainPreRouting) + err = r.nftablesCreateRedirect(nft, table, chainPreRouting) + if err != nil { + return err + } if r.tunOptions.AutoRedirectMarkMode { r.nftablesCreateMark(nft, table, chainPreRouting) - } - - if r.tunOptions.AutoRedirectMarkMode { + if len(r.tunOptions.Inet4LoopbackAddress) > 0 || len(r.tunOptions.Inet6LoopbackAddress) > 0 { + chainPreRoutingFilter := nft.AddChain(&nftables.Chain{ + Name: "prerouting_filter", + Table: table, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 1), + Type: nftables.ChainTypeFilter, + }) + err = r.nftablesCreateLoopbackReroute(nft, table, chainPreRoutingFilter) + if err != nil { + return err + } + } chainPreRoutingUDP := nft.AddChain(&nftables.Chain{ Name: "prerouting_udp", Table: table, diff --git a/redirect_nftables_exprs.go b/redirect_nftables_exprs.go index 86a9868..b18f362 100644 --- a/redirect_nftables_exprs.go +++ b/redirect_nftables_exprs.go @@ -7,6 +7,7 @@ import ( "github.com/metacubex/nftables" "github.com/metacubex/nftables/expr" + "github.com/metacubex/sing/common" "go4.org/netipx" ) @@ -21,6 +22,20 @@ func nftablesCreateExcludeDestinationIPSet( nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, id uint32, name string, family nftables.TableFamily, invert bool, ) { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: append( + nftablesCreateDestinationIPSetExprs(id, name, family, invert), + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + ), + }) +} + +func nftablesCreateDestinationIPSetExprs(id uint32, name string, family nftables.TableFamily, invert bool) []expr.Any { exprs := []expr.Any{ &expr.Meta{ Key: expr.MetaKeyNFPROTO, @@ -53,22 +68,63 @@ func nftablesCreateExcludeDestinationIPSet( }, ) } - exprs = append(exprs, - &expr.Lookup{ - SourceRegister: 1, - SetID: id, - SetName: name, - Invert: invert, - }, - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictReturn, - }) - nft.AddRule(&nftables.Rule{ - Table: table, - Chain: chain, - Exprs: exprs, + exprs = append(exprs, &expr.Lookup{ + SourceRegister: 1, + SetID: id, + SetName: name, + Invert: invert, }) + return exprs +} + +func nftablesCreateIPConst( + nft *nftables.Conn, table *nftables.Table, id uint32, name string, family nftables.TableFamily, addressList []netip.Addr, +) (*nftables.Set, error) { + var keyType nftables.SetDatatype + if family == nftables.TableFamilyIPv4 { + keyType = nftables.TypeIPAddr + } else { + keyType = nftables.TypeIP6Addr + } + mySet := &nftables.Set{ + Table: table, + ID: id, + Name: name, + KeyType: keyType, + Constant: true, + } + if id == 0 { + mySet.Anonymous = true + } + setElements := common.Map(addressList, func(addr netip.Addr) nftables.SetElement { return nftables.SetElement{Key: addr.AsSlice()} }) + if id == 0 { + err := nft.AddSet(mySet, setElements) + if err != nil { + return nil, err + } + return mySet, nil + } else { + err := nft.AddSet(mySet, nil) + if err != nil { + return nil, err + } + } + for len(setElements) > 0 { + toAdd := setElements + if len(toAdd) > 1000 { + toAdd = toAdd[:1000] + } + setElements = setElements[len(toAdd):] + err := nft.SetAddElements(mySet, toAdd) + if err != nil { + return nil, err + } + err = nft.Flush() + if err != nil { + return nil, err + } + } + return mySet, nil } func nftablesCreateIPSet( diff --git a/redirect_nftables_rules.go b/redirect_nftables_rules.go index 90e0225..4f451a2 100644 --- a/redirect_nftables_rules.go +++ b/redirect_nftables_rules.go @@ -117,8 +117,61 @@ func (r *autoRedirect) nftablesCreateLocalAddressSets( return nil } +func (r *autoRedirect) nftablesCreateLoopbackAddressSets( + nft *nftables.Conn, table *nftables.Table, +) error { + if r.enableIPv4 && len(r.tunOptions.Inet4LoopbackAddress) > 0 { + _, err := nftablesCreateIPConst(nft, table, 7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, r.tunOptions.Inet4LoopbackAddress) + if err != nil { + return err + } + } + if r.enableIPv6 && len(r.tunOptions.Inet6LoopbackAddress) > 0 { + _, err := nftablesCreateIPConst(nft, table, 8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, r.tunOptions.Inet6LoopbackAddress) + if err != nil { + return err + } + } + return nil +} + func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error { if r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput { + if chain.Type == nftables.ChainTypeRoute { + ipProto := &nftables.Set{ + Table: table, + Anonymous: true, + Constant: true, + KeyType: nftables.TypeInetProto, + } + err := nft.AddSet(ipProto, []nftables.SetElement{ + {Key: []byte{unix.IPPROTO_UDP}}, + {Key: []byte{unix.IPPROTO_ICMP}}, + {Key: []byte{unix.IPPROTO_ICMPV6}}, + }) + if err != nil { + return err + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Lookup{ + SourceRegister: 1, + SetID: ipProto.ID, + SetName: ipProto.Name, + Invert: true, + }, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } nft.AddRule(&nftables.Rule{ Table: table, Chain: chain, @@ -161,6 +214,25 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft } } if chain.Hooknum == nftables.ChainHookPrerouting { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: nftablesIfname(r.tunOptions.Name), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) if len(r.tunOptions.IncludeInterface) > 0 { if len(r.tunOptions.IncludeInterface) > 1 { includeInterface := &nftables.Set{ @@ -436,44 +508,6 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft } } - if r.tunOptions.AutoRedirectMarkMode && - ((chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeRoute) || - (chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeFilter)) { - ipProto := &nftables.Set{ - Table: table, - Anonymous: true, - Constant: true, - KeyType: nftables.TypeInetProto, - } - err := nft.AddSet(ipProto, []nftables.SetElement{ - {Key: []byte{unix.IPPROTO_UDP}}, - {Key: []byte{unix.IPPROTO_ICMP}}, - {Key: []byte{unix.IPPROTO_ICMPV6}}, - }) - if err != nil { - return err - } - nft.AddRule(&nftables.Rule{ - Table: table, - Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{ - Key: expr.MetaKeyL4PROTO, - Register: 1, - }, - &expr.Lookup{ - SourceRegister: 1, - SetID: ipProto.ID, - SetName: ipProto.Name, - Invert: true, - }, - &expr.Verdict{ - Kind: expr.VerdictReturn, - }, - }, - }) - } - if r.enableIPv4 { nftablesCreateExcludeDestinationIPSet(nft, table, chain, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, false) } @@ -527,6 +561,9 @@ func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Ta SourceRegister: true, }, &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, }, }) } @@ -534,57 +571,193 @@ func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Ta func (r *autoRedirect) nftablesCreateRedirect( nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, exprs ...expr.Any, -) { - if r.enableIPv4 && !r.enableIPv6 { - exprs = append(exprs, - &expr.Meta{ - Key: expr.MetaKeyNFPROTO, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: []byte{uint8(nftables.TableFamilyIPv4)}, - }) - } else if !r.enableIPv4 && r.enableIPv6 { - exprs = append(exprs, - &expr.Meta{ - Key: expr.MetaKeyNFPROTO, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: []byte{uint8(nftables.TableFamilyIPv6)}, - }) +) error { + exprsRedirect := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Counter{}, + &expr.Immediate{ + Register: 1, + Data: binaryutil.BigEndian.PutUint16(r.redirectPort()), + }, + &expr.Redir{ + RegisterProtoMin: 1, + Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED, + }, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, } - nft.AddRule(&nftables.Rule{ - Table: table, - Chain: chain, - Exprs: append(exprs, - &expr.Meta{ - Key: expr.MetaKeyL4PROTO, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: []byte{unix.IPPROTO_TCP}, - }, - &expr.Counter{}, + if len(r.tunOptions.Inet4LoopbackAddress) == 0 && len(r.tunOptions.Inet6LoopbackAddress) == 0 { + if r.enableIPv4 && !r.enableIPv6 { + exprs = append(exprs, + &expr.Meta{ + Key: expr.MetaKeyNFPROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{uint8(nftables.TableFamilyIPv4)}, + }) + } else if !r.enableIPv4 && r.enableIPv6 { + exprs = append(exprs, + &expr.Meta{ + Key: expr.MetaKeyNFPROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{uint8(nftables.TableFamilyIPv6)}, + }) + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: append(exprs, exprsRedirect...), + }) + } else { + if r.enableIPv4 { + exprs4 := exprs + if len(r.tunOptions.Inet4LoopbackAddress) > 0 { + exprs4 = append(exprs4, nftablesCreateDestinationIPSetExprs(7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, true)...) + } else { + exprs4 = append(exprs4, &expr.Meta{ + Key: expr.MetaKeyNFPROTO, + Register: 1, + }, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{uint8(nftables.TableFamilyIPv4)}, + }) + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: append(exprs4, exprsRedirect...), + }) + } + if r.enableIPv6 { + exprs6 := exprs + if len(r.tunOptions.Inet6LoopbackAddress) > 0 { + exprs6 = append(exprs6, nftablesCreateDestinationIPSetExprs(8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, true)...) + } else { + exprs6 = append(exprs6, &expr.Meta{ + Key: expr.MetaKeyNFPROTO, + Register: 1, + }, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{uint8(nftables.TableFamilyIPv6)}, + }) + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: append(exprs6, exprsRedirect...), + }) + } + } + return nil +} + +func (r *autoRedirect) nftablesCreateLoopbackReroute( + nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, +) error { + exprs := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark), + }, + } + var exprs4 []expr.Any + if r.enableIPv4 && len(r.tunOptions.Inet4LoopbackAddress) > 0 { + exprs4 = append(exprs, nftablesCreateDestinationIPSetExprs(7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, false)...) + } + var exprs6 []expr.Any + if r.enableIPv6 && len(r.tunOptions.Inet6LoopbackAddress) > 0 { + exprs6 = append(exprs, nftablesCreateDestinationIPSetExprs(8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, false)...) + } + var exprsCreateMark []expr.Any + if chain.Hooknum == nftables.ChainHookPrerouting { + exprsCreateMark = []expr.Any{ &expr.Immediate{ Register: 1, - Data: binaryutil.BigEndian.PutUint16(r.redirectPort()), + Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark), }, - &expr.Redir{ - RegisterProtoMin: 1, - Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + SourceRegister: true, }, - &expr.Verdict{ - Kind: expr.VerdictReturn, + &expr.Counter{}, + } + } else { + exprsCreateMark = []expr.Any{ + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark), }, - ), - }) + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + SourceRegister: true, + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Ct{ + Key: expr.CtKeyMARK, + Register: 1, + SourceRegister: true, + }, + &expr.Counter{}, + } + } + if len(exprs4) > 0 { + exprs4 = append(exprs4, exprsCreateMark...) + } + if len(exprs6) > 0 { + exprs6 = append(exprs6, exprsCreateMark...) + } + if len(exprs4) > 0 { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: exprs4, + }) + } + if len(exprs6) > 0 { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: exprs6, + }) + } + return nil } func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily( diff --git a/stack_gvisor.go b/stack_gvisor.go index 2fe5867..57672f0 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -25,14 +25,16 @@ const WithGVisor = true const DefaultNIC tcpip.NICID = 1 type GVisor struct { - ctx context.Context - tun GVisorTun - udpTimeout int64 - broadcastAddr netip.Addr - handler Handler - logger logger.Logger - stack *stack.Stack - endpoint stack.LinkEndpoint + ctx context.Context + tun GVisorTun + inet4LoopbackAddress []netip.Addr + inet6LoopbackAddress []netip.Addr + udpTimeout int64 + broadcastAddr netip.Addr + handler Handler + logger logger.Logger + stack *stack.Stack + endpoint stack.LinkEndpoint } type GVisorTun interface { @@ -49,12 +51,14 @@ func NewGVisor( } gStack := &GVisor{ - ctx: options.Context, - tun: gTun, - udpTimeout: options.UDPTimeout, - broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), - handler: options.Handler, - logger: options.Logger, + ctx: options.Context, + tun: gTun, + inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress, + inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress, + udpTimeout: options.UDPTimeout, + broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), + handler: options.Handler, + logger: options.Logger, } return gStack, nil } @@ -69,7 +73,7 @@ func (t *GVisor) Start() error { if err != nil { return err } - ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket) + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler).HandlePacket) t.stack = ipStack diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go index a49780a..1f3165b 100644 --- a/stack_gvisor_tcp.go +++ b/stack_gvisor_tcp.go @@ -4,34 +4,77 @@ package tun import ( "context" + "net/netip" "time" "github.com/metacubex/gvisor/pkg/tcpip" "github.com/metacubex/gvisor/pkg/tcpip/adapters/gonet" + "github.com/metacubex/gvisor/pkg/tcpip/header" "github.com/metacubex/gvisor/pkg/tcpip/stack" "github.com/metacubex/gvisor/pkg/tcpip/transport/tcp" "github.com/metacubex/gvisor/pkg/waiter" + "github.com/metacubex/sing-tun/internal/gtcpip/checksum" + "github.com/metacubex/sing/common" + "github.com/metacubex/sing/common/bufio" M "github.com/metacubex/sing/common/metadata" ) type TCPForwarder struct { - ctx context.Context - stack *stack.Stack - handler Handler - forwarder *tcp.Forwarder + ctx context.Context + stack *stack.Stack + handler Handler + inet4LoopbackAddress []tcpip.Address + inet6LoopbackAddress []tcpip.Address + tun GVisorTun + forwarder *tcp.Forwarder } func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder { + return NewTCPForwarderWithLoopback(ctx, stack, handler, nil, nil, nil) +} + +func NewTCPForwarderWithLoopback(ctx context.Context, stack *stack.Stack, handler Handler, inet4LoopbackAddress []netip.Addr, inet6LoopbackAddress []netip.Addr, tun GVisorTun) *TCPForwarder { forwarder := &TCPForwarder{ - ctx: ctx, - stack: stack, - handler: handler, + ctx: ctx, + stack: stack, + handler: handler, + inet4LoopbackAddress: common.Map(inet4LoopbackAddress, AddressFromAddr), + inet6LoopbackAddress: common.Map(inet6LoopbackAddress, AddressFromAddr), + tun: tun, } forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward) return forwarder } func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + for _, inet4LoopbackAddress := range f.inet4LoopbackAddress { + if id.LocalAddress == inet4LoopbackAddress { + ipHdr := pkt.Network().(header.IPv4) + ipHdr.SetDestinationAddressWithChecksumUpdate(ipHdr.SourceAddress()) + ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress) + tcpHdr := header.TCP(pkt.TransportHeader().Slice()) + tcpHdr.SetChecksum(0) + tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( + header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()), + ))) + bufio.WriteVectorised(f.tun, pkt.AsSlices()) + return true + } + } + for _, inet6LoopbackAddress := range f.inet6LoopbackAddress { + if id.LocalAddress == inet6LoopbackAddress { + ipHdr := pkt.Network().(header.IPv6) + ipHdr.SetDestinationAddress(ipHdr.SourceAddress()) + ipHdr.SetSourceAddress(inet6LoopbackAddress) + tcpHdr := header.TCP(pkt.TransportHeader().Slice()) + tcpHdr.SetChecksum(0) + tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( + header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()), + ))) + bufio.WriteVectorised(f.tun, pkt.AsSlices()) + return true + } + } return f.forwarder.HandlePacket(id, pkt) } diff --git a/stack_system.go b/stack_system.go index 915472f..18ecafb 100644 --- a/stack_system.go +++ b/stack_system.go @@ -21,30 +21,32 @@ import ( var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25") type System struct { - ctx context.Context - tun Tun - tunName string - mtu int - handler Handler - logger logger.Logger - inet4Prefixes []netip.Prefix - inet6Prefixes []netip.Prefix - inet4ServerAddress netip.Addr - inet4Address netip.Addr - inet6ServerAddress netip.Addr - inet6Address netip.Addr - broadcastAddr netip.Addr - udpTimeout int64 - tcpListener net.Listener - tcpListener6 net.Listener - tcpPort uint16 - tcpPort6 uint16 - tcpNat *TCPNat - bindInterface bool - interfaceFinder control.InterfaceFinder - enforceBind bool - frontHeadroom int - txChecksumOffload bool + ctx context.Context + tun Tun + tunName string + mtu int + handler Handler + logger logger.Logger + inet4Prefixes []netip.Prefix + inet6Prefixes []netip.Prefix + inet4ServerAddress netip.Addr + inet4Address netip.Addr + inet6ServerAddress netip.Addr + inet6Address netip.Addr + broadcastAddr netip.Addr + inet4LoopbackAddress []netip.Addr + inet6LoopbackAddress []netip.Addr + udpTimeout int64 + tcpListener net.Listener + tcpListener6 net.Listener + tcpPort uint16 + tcpPort6 uint16 + tcpNat *TCPNat + bindInterface bool + interfaceFinder control.InterfaceFinder + enforceBind bool + frontHeadroom int + txChecksumOffload bool } type Session struct { @@ -56,19 +58,21 @@ type Session struct { func NewSystem(options StackOptions) (Stack, error) { stack := &System{ - ctx: options.Context, - tun: options.Tun, - tunName: options.TunOptions.Name, - mtu: int(options.TunOptions.MTU), - udpTimeout: options.UDPTimeout, - handler: options.Handler, - logger: options.Logger, - inet4Prefixes: options.TunOptions.Inet4Address, - inet6Prefixes: options.TunOptions.Inet6Address, - broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), - bindInterface: options.ForwarderBindInterface, - interfaceFinder: options.InterfaceFinder, - enforceBind: options.EnforceBindInterface, + ctx: options.Context, + tun: options.Tun, + tunName: options.TunOptions.Name, + mtu: int(options.TunOptions.MTU), + inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress, + inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress, + udpTimeout: options.UDPTimeout, + handler: options.Handler, + logger: options.Logger, + inet4Prefixes: options.TunOptions.Inet4Address, + inet6Prefixes: options.TunOptions.Inet6Address, + broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), + bindInterface: options.ForwarderBindInterface, + interfaceFinder: options.InterfaceFinder, + enforceBind: options.EnforceBindInterface, } if len(options.TunOptions.Inet4Address) > 0 { if !HasNextAddress(options.TunOptions.Inet4Address[0], 1) { @@ -371,11 +375,22 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err ipHdr.SetDestinationAddr(session.Source.Addr()) tcpHdr.SetDestinationPort(session.Source.Port()) } else { - natPort := s.tcpNat.Lookup(source, destination) - ipHdr.SetSourceAddr(s.inet4Address) - tcpHdr.SetSourcePort(natPort) - ipHdr.SetDestinationAddr(s.inet4ServerAddress) - tcpHdr.SetDestinationPort(s.tcpPort) + var loopback bool + for _, inet4LoopbackAddress := range s.inet4LoopbackAddress { + if destination.Addr() == inet4LoopbackAddress { + ipHdr.SetDestinationAddr(ipHdr.SourceAddr()) + ipHdr.SetSourceAddr(inet4LoopbackAddress) + loopback = true + break + } + } + if !loopback { + natPort := s.tcpNat.Lookup(source, destination) + ipHdr.SetSourceAddr(s.inet4Address) + tcpHdr.SetSourcePort(natPort) + ipHdr.SetDestinationAddr(s.inet4ServerAddress) + tcpHdr.SetDestinationPort(s.tcpPort) + } } if !s.txChecksumOffload { tcpHdr.SetChecksum(0) @@ -451,11 +466,22 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err ipHdr.SetDestinationAddr(session.Source.Addr()) tcpHdr.SetDestinationPort(session.Source.Port()) } else { - natPort := s.tcpNat.Lookup(source, destination) - ipHdr.SetSourceAddr(s.inet6Address) - tcpHdr.SetSourcePort(natPort) - ipHdr.SetDestinationAddr(s.inet6ServerAddress) - tcpHdr.SetDestinationPort(s.tcpPort6) + var loopback bool + for _, inet6LoopbackAddress := range s.inet6LoopbackAddress { + if destination.Addr() == inet6LoopbackAddress { + ipHdr.SetDestinationAddr(ipHdr.SourceAddr()) + ipHdr.SetSourceAddr(inet6LoopbackAddress) + loopback = true + break + } + } + if !loopback { + natPort := s.tcpNat.Lookup(source, destination) + ipHdr.SetSourceAddr(s.inet6Address) + tcpHdr.SetSourcePort(natPort) + ipHdr.SetDestinationAddr(s.inet6ServerAddress) + tcpHdr.SetDestinationPort(s.tcpPort6) + } } if !s.txChecksumOffload { tcpHdr.SetChecksum(0) diff --git a/tun.go b/tun.go index 2934421..3698ae7 100644 --- a/tun.go +++ b/tun.go @@ -68,6 +68,8 @@ type Options struct { AutoRedirectMarkMode bool AutoRedirectInputMark uint32 AutoRedirectOutputMark uint32 + Inet4LoopbackAddress []netip.Addr + Inet6LoopbackAddress []netip.Addr StrictRoute bool Inet4RouteAddress []netip.Prefix Inet6RouteAddress []netip.Prefix