Files
sing-tun/icmp_privileged.go
2025-08-22 14:21:21 +08:00

113 lines
3.0 KiB
Go

package tun
import (
"context"
"net"
"net/netip"
"os"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type PrivilegedICMPDestination struct {
ctx context.Context
cancel context.CancelCauseFunc
logger logger.Logger
routeContext DirectRouteContext
isIPv6 bool
localAddr atomic.TypedValue[netip.Addr]
rawConn net.Conn
}
func NewPrivilegedICMPDestination(ctx context.Context, logger logger.Logger, dialer net.Dialer, network string, address netip.Addr, routeContext DirectRouteContext) (DirectRouteDestination, error) {
var dialNetwork string
switch network {
case N.NetworkICMPv4:
dialNetwork = "ip4:icmp"
case N.NetworkICMPv6:
dialNetwork = "ip6:icmp"
default:
return nil, E.New("unsupported network: ", network)
}
ctx, cancel := context.WithCancelCause(ctx)
rawConn, err := dialer.DialContext(ctx, dialNetwork, address.String())
if err != nil {
cancel(err)
return nil, err
}
d := &PrivilegedICMPDestination{
ctx: ctx,
cancel: cancel,
logger: logger,
routeContext: routeContext,
isIPv6: network == N.NetworkICMPv6,
rawConn: rawConn,
}
go d.loopRead()
return d, nil
}
func (d *PrivilegedICMPDestination) loopRead() {
for {
buffer := buf.NewPacket()
_, err := buffer.ReadOnceFrom(d.rawConn)
if err != nil {
return
}
if !d.isIPv6 {
ipHdr := header.IPv4(buffer.Bytes())
ipHdr.SetDestinationAddr(d.localAddr.Load())
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
icmpHdr := header.ICMPv4(ipHdr.Payload())
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
} else {
ipHdr := header.IPv6(buffer.Bytes())
ipHdr.SetDestinationAddr(d.localAddr.Load())
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddress(),
Dst: ipHdr.DestinationAddress(),
}))
}
err = d.routeContext.WritePacket(buffer.Bytes())
if err != nil {
d.logger.Error(err)
}
}
}
func (d *PrivilegedICMPDestination) WritePacket(packet *buf.Buffer) error {
if !d.isIPv6 {
ipHdr := header.IPv4(packet.Bytes())
d.localAddr.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
icmpHdr := header.ICMPv6(ipHdr.Payload())
_, err := d.rawConn.Write(icmpHdr)
if err != nil {
return err
}
} else {
ipHdr := header.IPv6(packet.Bytes())
d.localAddr.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
icmpHdr := header.ICMPv6(ipHdr.Payload())
_, err := d.rawConn.Write(icmpHdr)
if err != nil {
return err
}
}
return nil
}
func (d *PrivilegedICMPDestination) Close() error {
d.cancel(os.ErrClosed)
return d.rawConn.Close()
}