diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3afe968..be40dea 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,12 +26,14 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: ^1.23 + go-version: ^1.25.0 - name: Build run: | make test - build_go120: - name: Linux (Go 1.20) + go test -c -o ping_test ./ping + sudo ./ping_test -test.v + build_go124: + name: Linux (Go 1.24) runs-on: ubuntu-latest steps: - name: Checkout @@ -41,13 +43,15 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: ~1.20 + go-version: ~1.24 continue-on-error: true - name: Build run: | make test - build_go121: - name: Linux (Go 1.21) + go test -c -o ping_test ./ping + sudo ./ping_test -test.v + build_go123: + name: Linux (Go 1.23) runs-on: ubuntu-latest steps: - name: Checkout @@ -57,27 +61,13 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: ~1.21 - continue-on-error: true - - name: Build - run: | - make test - build_go122: - name: Linux (Go 1.22) - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version: ~1.22 + go-version: ~1.23 continue-on-error: true - name: Build run: | make test + go test -c -o ping_test ./ping + sudo ./ping_test -test.v build_windows: name: Windows runs-on: windows-latest @@ -94,6 +84,7 @@ jobs: - name: Build run: | make test + go test -v ./ping build_darwin: name: macOS runs-on: macos-latest @@ -109,4 +100,7 @@ jobs: continue-on-error: true - name: Build run: | - make test \ No newline at end of file + make test + go test -v ./ping + go test -c -o ping_test ./ping + sudo ./ping_test -test.v \ No newline at end of file diff --git a/Makefile b/Makefile index f7a8532..c752474 100644 --- a/Makefile +++ b/Makefile @@ -29,5 +29,5 @@ lint_install: test: go build -v . - go test -bench=. ./internal/checksum_test - #go test -v . + #go test -bench=. ./internal/checksum_test + go test -v . diff --git a/go.mod b/go.mod index 3f7f0a2..7d49739 100644 --- a/go.mod +++ b/go.mod @@ -9,20 +9,24 @@ require ( github.com/sagernet/gvisor v0.0.0-20250822052253-5558536cf237 github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/sagernet/nftables v0.3.0-beta.4 - github.com/sagernet/sing v0.7.0-beta.1 + github.com/sagernet/sing v0.7.6-0.20250823024003-88f1880f43af + github.com/stretchr/testify v1.9.0 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 - golang.org/x/net v0.26.0 - golang.org/x/sys v0.26.0 + golang.org/x/net v0.43.0 + golang.org/x/sys v0.35.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/time v0.7.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c964e1c..fceb588 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= -github.com/sagernet/sing v0.7.0-beta.1 h1:2D44KzgeDZwD/R4Ts8jwSUHTRR238a1FpXDrl7l4tVw= -github.com/sagernet/sing v0.7.0-beta.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.7.6-0.20250823024003-88f1880f43af h1:/1H30c/+j7Q9BBPuJuX6eHyzKpbGWrr7S/4DcdtNIfw= +github.com/sagernet/sing v0.7.6-0.20250823024003-88f1880f43af/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= @@ -34,14 +34,16 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/icmp.go b/icmp.go deleted file mode 100644 index a1db779..0000000 --- a/icmp.go +++ /dev/null @@ -1,29 +0,0 @@ -package tun - -import ( - "context" - "net" - "net/netip" - "os" - "runtime" - - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - - "golang.org/x/sys/unix" -) - -func NewICMPDestination(ctx context.Context, logger logger.Logger, dialer net.Dialer, network string, address netip.Addr, routeContext DirectRouteContext) (DirectRouteDestination, error) { - if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { - return NewUnprivilegedICMPDestination(ctx, logger, dialer, network, address, routeContext) - } else { - destination, err := NewPrivilegedICMPDestination(ctx, logger, dialer, network, address, routeContext) - if err != nil { - if E.IsMulti(err, os.ErrPermission, unix.EPERM) { - return NewUnprivilegedICMPDestination(ctx, logger, dialer, network, address, routeContext) - } - return nil, err - } - return destination, nil - } -} diff --git a/icmp_privileged.go b/icmp_privileged.go deleted file mode 100644 index aef7d18..0000000 --- a/icmp_privileged.go +++ /dev/null @@ -1,112 +0,0 @@ -package tun - -import ( - "context" - "net" - "net/netip" - "os" - - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" - "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -type PrivilegedICMPDestination struct { - ctx context.Context - cancel context.CancelCauseFunc - logger logger.Logger - routeContext DirectRouteContext - isIPv6 bool - localAddr atomic.TypedValue[netip.Addr] - rawConn net.Conn -} - -func NewPrivilegedICMPDestination(ctx context.Context, logger logger.Logger, dialer net.Dialer, network string, address netip.Addr, routeContext DirectRouteContext) (DirectRouteDestination, error) { - var dialNetwork string - switch network { - case N.NetworkICMPv4: - dialNetwork = "ip4:icmp" - case N.NetworkICMPv6: - dialNetwork = "ip6:icmp" - default: - return nil, E.New("unsupported network: ", network) - } - ctx, cancel := context.WithCancelCause(ctx) - rawConn, err := dialer.DialContext(ctx, dialNetwork, address.String()) - if err != nil { - cancel(err) - return nil, err - } - d := &PrivilegedICMPDestination{ - ctx: ctx, - cancel: cancel, - logger: logger, - routeContext: routeContext, - isIPv6: network == N.NetworkICMPv6, - rawConn: rawConn, - } - go d.loopRead() - return d, nil -} - -func (d *PrivilegedICMPDestination) loopRead() { - for { - buffer := buf.NewPacket() - _, err := buffer.ReadOnceFrom(d.rawConn) - if err != nil { - return - } - if !d.isIPv6 { - ipHdr := header.IPv4(buffer.Bytes()) - ipHdr.SetDestinationAddr(d.localAddr.Load()) - ipHdr.SetChecksum(0) - ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) - icmpHdr := header.ICMPv4(ipHdr.Payload()) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) - } else { - ipHdr := header.IPv6(buffer.Bytes()) - ipHdr.SetDestinationAddr(d.localAddr.Load()) - icmpHdr := header.ICMPv6(ipHdr.Payload()) - icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpHdr, - Src: ipHdr.SourceAddress(), - Dst: ipHdr.DestinationAddress(), - })) - } - err = d.routeContext.WritePacket(buffer.Bytes()) - if err != nil { - d.logger.Error(err) - } - } -} - -func (d *PrivilegedICMPDestination) WritePacket(packet *buf.Buffer) error { - if !d.isIPv6 { - ipHdr := header.IPv4(packet.Bytes()) - d.localAddr.Store(M.AddrFromIP(ipHdr.SourceAddressSlice())) - icmpHdr := header.ICMPv6(ipHdr.Payload()) - _, err := d.rawConn.Write(icmpHdr) - if err != nil { - return err - } - } else { - ipHdr := header.IPv6(packet.Bytes()) - d.localAddr.Store(M.AddrFromIP(ipHdr.SourceAddressSlice())) - icmpHdr := header.ICMPv6(ipHdr.Payload()) - _, err := d.rawConn.Write(icmpHdr) - if err != nil { - return err - } - } - return nil -} - -func (d *PrivilegedICMPDestination) Close() error { - d.cancel(os.ErrClosed) - return d.rawConn.Close() -} diff --git a/icmp_privileged_gvisor.go b/icmp_privileged_gvisor.go deleted file mode 100644 index 002207f..0000000 --- a/icmp_privileged_gvisor.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build with_gvisor - -package tun - -import ( - "net/netip" - - "github.com/sagernet/gvisor/pkg/tcpip/stack" -) - -func (d *PrivilegedICMPDestination) WritePacketBuffer(packetBuffer *stack.PacketBuffer) error { - ipHdr := packetBuffer.Network() - if !d.isIPv6 { - d.localAddr.Store(netip.AddrFrom4(ipHdr.SourceAddress().As4())) - } else { - d.localAddr.Store(netip.AddrFrom16(ipHdr.SourceAddress().As16())) - } - packetSlice := packetBuffer.TransportHeader().Slice() - packetSlice = append(packetSlice, packetBuffer.Data().AsRange().ToSlice()...) - _, err := d.rawConn.Write(packetSlice) - return err -} diff --git a/icmp_unprivileged.go b/icmp_unprivileged.go deleted file mode 100644 index d246771..0000000 --- a/icmp_unprivileged.go +++ /dev/null @@ -1,154 +0,0 @@ -package tun - -import ( - "context" - "net" - "net/netip" - "os" - "syscall" - "unsafe" - - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" - "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "golang.org/x/sys/unix" -) - -type UnprivilegedICMPDestination struct { - ctx context.Context - cancel context.CancelCauseFunc - logger logger.Logger - routeContext DirectRouteContext - isIPv6 bool - localAddr atomic.TypedValue[netip.Addr] - rawConn net.Conn - ipHdr bool -} - -func NewUnprivilegedICMPDestination(ctx context.Context, logger logger.Logger, dialer net.Dialer, network string, address netip.Addr, routeContext DirectRouteContext) (DirectRouteDestination, error) { - var ( - isIPv6 bool - fd int - ipHdr bool - err error - ) - var dialNetwork string - switch network { - case N.NetworkICMPv4: - dialNetwork = "ip4:icmp" - case N.NetworkICMPv6: - dialNetwork = "ip6:icmp" - isIPv6 = true - default: - return nil, E.New("unsupported network: ", network) - } - if !isIPv6 { - fd, err = unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_ICMP) - } else { - fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_ICMPV6) - } - if err != nil { - return nil, err - } - name, nameLen := bufio.ToSockaddr(M.SocksaddrFrom(address, 0).AddrPort()) - err = unixConnect(fd, name, nameLen) - if err != nil { - return nil, err - } - rawConn, err := net.FileConn(os.NewFile(uintptr(fd), "datagram-oriented icmp")) - if err != nil { - syscall.Close(fd) - return nil, err - } - if dialer.Control != nil { - var syscallConn syscall.RawConn - syscallConn, err = rawConn.(syscall.Conn).SyscallConn() - if err != nil { - return nil, err - } - err = dialer.Control(dialNetwork, address.String(), syscallConn) - if err != nil { - return nil, err - } - } - d := &UnprivilegedICMPDestination{ - ctx: ctx, - logger: logger, - routeContext: routeContext, - isIPv6: network == N.NetworkICMPv6, - rawConn: rawConn, - ipHdr: ipHdr, - } - go d.loopRead() - return d, nil -} - -//go:linkname unixConnect golang.org/x/sys/unix.connect -func unixConnect(fd int, addr unsafe.Pointer, addrlen uint32) error - -func (d *UnprivilegedICMPDestination) loopRead() { - for { - buffer := buf.NewPacket() - _, err := buffer.ReadOnceFrom(d.rawConn) - if err != nil { - return - } - if d.ipHdr { - if !d.isIPv6 { - ipHdr := header.IPv4(buffer.Bytes()) - ipHdr.SetDestinationAddr(d.localAddr.Load()) - ipHdr.SetChecksum(0) - ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) - icmpHdr := header.ICMPv4(ipHdr.Payload()) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) - } else { - ipHdr := header.IPv6(buffer.Bytes()) - ipHdr.SetDestinationAddr(d.localAddr.Load()) - icmpHdr := header.ICMPv6(ipHdr.Payload()) - icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpHdr, - Src: ipHdr.SourceAddress(), - Dst: ipHdr.DestinationAddress(), - })) - } - err = d.routeContext.WritePacket(buffer.Bytes()) - if err != nil { - d.logger.Error(err) - } - } else { - panic("impl no hdr version for windows and linux") - } - } -} - -func (d *UnprivilegedICMPDestination) WritePacket(packet *buf.Buffer) error { - if !d.isIPv6 { - ipHdr := header.IPv4(packet.Bytes()) - d.localAddr.Store(M.AddrFromIP(ipHdr.SourceAddressSlice())) - icmpHdr := header.ICMPv6(ipHdr.Payload()) - _, err := d.rawConn.Write(icmpHdr) - if err != nil { - return err - } - } else { - ipHdr := header.IPv6(packet.Bytes()) - d.localAddr.Store(M.AddrFromIP(ipHdr.SourceAddressSlice())) - icmpHdr := header.ICMPv6(ipHdr.Payload()) - _, err := d.rawConn.Write(icmpHdr) - if err != nil { - return err - } - } - return nil -} - -func (d *UnprivilegedICMPDestination) Close() error { - d.cancel(os.ErrClosed) - return d.rawConn.Close() -} diff --git a/icmp_unprivileged_gvisor.go b/icmp_unprivileged_gvisor.go deleted file mode 100644 index daba001..0000000 --- a/icmp_unprivileged_gvisor.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build with_gvisor - -package tun - -import ( - "net/netip" - - "github.com/sagernet/gvisor/pkg/tcpip/stack" -) - -func (d *UnprivilegedICMPDestination) WritePacketBuffer(packetBuffer *stack.PacketBuffer) error { - ipHdr := packetBuffer.Network() - if !d.isIPv6 { - d.localAddr.Store(netip.AddrFrom4(ipHdr.SourceAddress().As4())) - } else { - d.localAddr.Store(netip.AddrFrom16(ipHdr.SourceAddress().As16())) - } - packetSlice := packetBuffer.TransportHeader().Slice() - packetSlice = append(packetSlice, packetBuffer.Data().AsRange().ToSlice()...) - _, err := d.rawConn.Write(packetSlice) - return err -} diff --git a/internal/gtcpip/header/icmpv6.go b/internal/gtcpip/header/icmpv6.go index 970f743..520b403 100644 --- a/internal/gtcpip/header/icmpv6.go +++ b/internal/gtcpip/header/icmpv6.go @@ -276,8 +276,8 @@ func (b ICMPv6) Payload() []byte { // ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum. type ICMPv6ChecksumParams struct { Header ICMPv6 - Src tcpip.Address - Dst tcpip.Address + Src []byte + Dst []byte PayloadCsum uint16 PayloadLen int } @@ -287,7 +287,7 @@ type ICMPv6ChecksumParams struct { func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 { h := params.Header - xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src.AsSlice(), params.Dst.AsSlice(), uint16(len(h)+params.PayloadLen)) + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen)) xsum = checksum.Combine(xsum, params.PayloadCsum) // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. diff --git a/internal/gtcpip/header/interfaces.go b/internal/gtcpip/header/interfaces.go index c2f8cdf..fc13100 100644 --- a/internal/gtcpip/header/interfaces.go +++ b/internal/gtcpip/header/interfaces.go @@ -88,12 +88,16 @@ type Network interface { SourceAddr() netip.Addr + SourceAddressSlice() []byte + // DestinationAddress returns the value of the "destination address" // field. DestinationAddress() tcpip.Address DestinationAddr() netip.Addr + DestinationAddressSlice() []byte + // Checksum returns the value of the "checksum" field. Checksum() uint16 diff --git a/ping/cmsg_unix.go b/ping/cmsg_unix.go new file mode 100644 index 0000000..222cb85 --- /dev/null +++ b/ping/cmsg_unix.go @@ -0,0 +1,16 @@ +//go:build !windows + +package ping + +import ( + "golang.org/x/net/ipv6" +) + +func parseIPv6ControlMessage(cmsg []byte) (*ipv6.ControlMessage, error) { + var controlMessage ipv6.ControlMessage + err := controlMessage.Parse(cmsg) + if err != nil { + return nil, err + } + return &controlMessage, nil +} diff --git a/ping/cmsg_windows.go b/ping/cmsg_windows.go new file mode 100644 index 0000000..be5be9b --- /dev/null +++ b/ping/cmsg_windows.go @@ -0,0 +1,46 @@ +package ping + +import ( + "encoding/binary" + "fmt" + "unsafe" + + "golang.org/x/net/ipv6" + "golang.org/x/sys/windows" +) + +const ( + IPV6_HOPLIMIT = 21 + IPV6_TCLASS = 39 + IPV6_RECVTCLASS = 40 +) + +var ( + alignedSizeofCmsghdr = (sizeofCmsghdr + cmsgAlignTo - 1) & ^(cmsgAlignTo - 1) + sizeofCmsghdr = int(unsafe.Sizeof(windows.WSACMSGHDR{})) + cmsgAlignTo = int(unsafe.Sizeof(uintptr(0))) +) + +func cmsgAlign(n int) int { + return (n + cmsgAlignTo - 1) & ^(cmsgAlignTo - 1) +} + +func parseIPv6ControlMessage(cmsg []byte) (*ipv6.ControlMessage, error) { + var controlMessage ipv6.ControlMessage + for len(cmsg) >= sizeofCmsghdr { + cmsghdr := (*windows.WSACMSGHDR)(unsafe.Pointer(unsafe.SliceData(cmsg))) + msgLen := int(cmsghdr.Len) + msgSize := cmsgAlign(msgLen) + if msgLen < sizeofCmsghdr || msgSize > len(cmsg) { + return nil, fmt.Errorf("invalid control message length %d", cmsghdr.Len) + } + switch cmsghdr.Type { + case IPV6_TCLASS: + controlMessage.TrafficClass = int(binary.NativeEndian.Uint32(cmsg[alignedSizeofCmsghdr : alignedSizeofCmsghdr+4])) + case IPV6_HOPLIMIT: + controlMessage.HopLimit = int(binary.NativeEndian.Uint32(cmsg[alignedSizeofCmsghdr : alignedSizeofCmsghdr+4])) + } + cmsg = cmsg[msgSize:] + } + return &controlMessage, nil +} diff --git a/ping/destination.go b/ping/destination.go new file mode 100644 index 0000000..553c7f8 --- /dev/null +++ b/ping/destination.go @@ -0,0 +1,75 @@ +package ping + +import ( + "errors" + "net/netip" + "os" + "runtime" + + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" +) + +var _ tun.DirectRouteDestination = (*Destination)(nil) + +type Destination struct { + logger logger.Logger + routeContext tun.DirectRouteContext + conn *Conn +} + +func ConnectDestination(logger logger.Logger, controlFunc control.Func, address netip.Addr, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { + var ( + conn *Conn + err error + ) + switch runtime.GOOS { + case "darwin", "ios", "windows": + conn, err = Connect(false, controlFunc, address) + default: + conn, err = Connect(true, controlFunc, address) + if errors.Is(err, os.ErrPermission) { + conn, err = Connect(false, controlFunc, address) + } + } + if err != nil { + return nil, err + } + d := &Destination{ + logger: logger, + routeContext: routeContext, + conn: conn, + } + go d.loopRead() + return d, nil +} + +func (d *Destination) loopRead() { + for { + buffer := buf.NewPacket() + err := d.conn.ReadIP(buffer) + if err != nil { + buffer.Release() + if !E.IsClosed(err) { + d.logger.Error(E.Cause(err, "receive ICMP echo reply")) + } + return + } + err = d.routeContext.WritePacket(buffer.Bytes()) + if err != nil { + d.logger.Error(E.Cause(err, "write ICMP echo reply")) + } + buffer.Release() + } +} + +func (d *Destination) WritePacket(packet *buf.Buffer) error { + return d.conn.WriteIP(packet) +} + +func (d *Destination) Close() error { + return d.conn.Close() +} diff --git a/ping/ping.go b/ping/ping.go new file mode 100644 index 0000000..caad5f4 --- /dev/null +++ b/ping/ping.go @@ -0,0 +1,207 @@ +package ping + +import ( + "net" + "net/netip" + "reflect" + "runtime" + "time" + + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +type Conn struct { + privileged bool + conn net.Conn + destination netip.Addr + source atomic.TypedValue[netip.Addr] +} + +func Connect(privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) { + conn, err := connect(privileged, controlFunc, destination) + if err != nil { + return nil, err + } + return &Conn{ + privileged: privileged, + conn: conn, + destination: destination, + }, nil +} + +func (c *Conn) ReadIP(buffer *buf.Buffer) error { + if c.destination.Is6() || runtime.GOOS == "linux" && !c.privileged { + var readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) + switch conn := c.conn.(type) { + case *net.IPConn: + readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) { + var ipAddr *net.IPAddr + n, oobn, _, ipAddr, err = conn.ReadMsgIP(b, oob) + if ipAddr != nil { + addr = M.AddrFromNet(ipAddr) + } + return + } + case *net.UDPConn: + readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) { + var udpAddr *net.UDPAddr + n, oobn, _, udpAddr, err = conn.ReadMsgUDP(b, oob) + if udpAddr != nil { + addr = M.AddrFromNet(udpAddr) + } + return + } + default: + return E.New("unsupported conn type: ", reflect.TypeOf(c.conn)) + } + if !c.destination.Is6() { + oob := ipv4.NewControlMessage(ipv4.FlagTTL) + buffer.Advance(header.IPv4MinimumSize) + var ttl int + // tos int + n, oobn, addr, err := readMsg(buffer.FreeBytes(), oob) + if err != nil { + return err + } + if err != nil { + return err + } + buffer.Truncate(n) + if oobn > 0 { + var controlMessage ipv4.ControlMessage + err = controlMessage.Parse(oob[:oobn]) + if err != nil { + return err + } + ttl = controlMessage.TTL + } + ipHdr := header.IPv4(buffer.ExtendHeader(header.IPv4MinimumSize)) + ipHdr.Encode(&header.IPv4Fields{ + // TOS: uint8(tos), + SrcAddr: addr, + DstAddr: c.source.Load(), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: uint8(ttl), + TotalLength: uint16(buffer.Len()), + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + } else { + oob := make([]byte, 1024) + buffer.Advance(header.IPv6MinimumSize) + var ( + hopLimit int + trafficClass int + ) + n, oobn, addr, err := readMsg(buffer.FreeBytes(), oob) + if err != nil { + return err + } + buffer.Truncate(n) + if oobn > 0 { + var controlMessage *ipv6.ControlMessage + controlMessage, err = parseIPv6ControlMessage(oob[:oobn]) + if err != nil { + return err + } + hopLimit = controlMessage.HopLimit + trafficClass = controlMessage.TrafficClass + } + icmpHdr := header.ICMPv6(buffer.Bytes()) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize], + Src: addr.AsSlice(), + Dst: c.source.Load().AsSlice(), + })) + ipHdr := header.IPv6(buffer.ExtendHeader(header.IPv6MinimumSize)) + ipHdr.Encode(&header.IPv6Fields{ + TrafficClass: uint8(trafficClass), + PayloadLength: uint16(buffer.Len() - header.IPv6MinimumSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: uint8(hopLimit), + SrcAddr: addr, + DstAddr: c.source.Load(), + }) + } + } else { + _, err := buffer.ReadOnceFrom(c.conn) + if err != nil { + return err + } + if !c.destination.Is6() { + ipHdr := header.IPv4(buffer.Bytes()) + ipHdr.SetDestinationAddr(c.source.Load()) + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + icmpHdr := header.ICMPv4(ipHdr.Payload()) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) + } else { + ipHdr := header.IPv6(buffer.Bytes()) + ipHdr.SetDestinationAddr(c.source.Load()) + icmpHdr := header.ICMPv6(ipHdr.Payload()) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: ipHdr.SourceAddressSlice(), + Dst: ipHdr.DestinationAddressSlice(), + })) + } + } + return nil +} + +func (c *Conn) ReadICMP(buffer *buf.Buffer) error { + _, err := buffer.ReadOnceFrom(c.conn) + if err != nil { + return err + } + if c.destination.Is6() || runtime.GOOS == "linux" && !c.privileged { + return nil + } + if !c.destination.Is6() { + ipHdr := header.IPv4(buffer.Bytes()) + buffer.Advance(int(ipHdr.HeaderLength())) + } else { + ipHdr := header.IPv6(buffer.Bytes()) + buffer.Advance(buffer.Len() - int(ipHdr.PayloadLength())) + } + return nil +} + +func (c *Conn) WriteIP(buffer *buf.Buffer) error { + defer buffer.Release() + if !c.destination.Is6() { + ipHdr := header.IPv4(buffer.Bytes()) + c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice())) + return common.Error(c.conn.Write(ipHdr.Payload())) + } else { + ipHdr := header.IPv6(buffer.Bytes()) + c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice())) + return common.Error(c.conn.Write(ipHdr.Payload())) + } +} + +func (c *Conn) WriteICMP(buffer *buf.Buffer) error { + defer buffer.Release() + return common.Error(c.conn.Write(buffer.Bytes())) +} + +func (c *Conn) SetLocalAddr(addr netip.Addr) { + c.source.Store(addr) +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *Conn) Close() error { + return c.conn.Close() +} diff --git a/ping/ping_test.go b/ping/ping_test.go new file mode 100644 index 0000000..65a6da9 --- /dev/null +++ b/ping/ping_test.go @@ -0,0 +1,193 @@ +package ping_test + +import ( + "net/netip" + "os" + "runtime" + "testing" + "time" + + "github.com/sagernet/gvisor/pkg/rand" + "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/ping" + "github.com/sagernet/sing/common/buf" + + "github.com/stretchr/testify/require" +) + +func TestPing(t *testing.T) { + t.Parallel() + const addr4 = "127.0.0.1" + t.Run("ipv4", func(t *testing.T) { + t.Run("unprivileged", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + t.Run("read-icmp", func(t *testing.T) { + testPingIPv4ReadICMP(t, false, addr4) + }) + t.Run("read-ip", func(t *testing.T) { + testPingIPv4ReadIP(t, false, addr4) + }) + }) + t.Run("privileged", func(t *testing.T) { + if runtime.GOOS != "windows" && os.Getuid() != 0 { + t.SkipNow() + } + t.Run("read-icmp", func(t *testing.T) { + testPingIPv4ReadICMP(t, true, addr4) + }) + t.Run("read-ip", func(t *testing.T) { + testPingIPv4ReadIP(t, true, addr4) + }) + }) + }) + // const addr6 = "2606:4700:4700::1001" + const addr6 = "::1" + t.Run("ipv6", func(t *testing.T) { + t.Run("unprivileged", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + t.Run("read-icmp", func(t *testing.T) { + testPingIPv6ReadICMP(t, false, addr6) + }) + t.Run("read-ip", func(t *testing.T) { + testPingIPv6ReadIP(t, false, addr6) + }) + }) + t.Run("privileged", func(t *testing.T) { + if runtime.GOOS != "windows" && os.Getuid() != 0 { + t.SkipNow() + } + t.Run("read-icmp", func(t *testing.T) { + testPingIPv6ReadICMP(t, true, addr6) + }) + t.Run("read-ip", func(t *testing.T) { + testPingIPv6ReadIP(t, true, addr6) + }) + }) + }) +} + +func testPingIPv4ReadIP(t *testing.T, privileged bool, addr string) { + conn, err := ping.Connect(privileged, nil, netip.MustParseAddr(addr)) + if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" { + t.SkipNow() + } + require.NoError(t, err) + + request := make(header.ICMPv4, header.ICMPv4MinimumSize) + request.SetType(header.ICMPv4Echo) + request.SetIdent(uint16(rand.Uint32())) + request.SetChecksum(header.ICMPv4Checksum(request, 0)) + + err = conn.WriteICMP(buf.As(request)) + require.NoError(t, err) + + conn.SetLocalAddr(netip.MustParseAddr("127.0.0.1")) + require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second))) + + response := buf.NewPacket() + err = conn.ReadIP(response) + require.NoError(t, err) + if runtime.GOOS == "linux" && privileged { + response.Reset() + err = conn.ReadIP(response) + require.NoError(t, err) + } + ipHdr := header.IPv4(response.Bytes()) + require.NotZero(t, ipHdr.TTL()) + icmpHdr := header.ICMPv4(ipHdr.Payload()) + require.Equal(t, header.ICMPv4EchoReply, icmpHdr.Type()) +} + +func testPingIPv4ReadICMP(t *testing.T, privileged bool, addr string) { + conn, err := ping.Connect(privileged, nil, netip.MustParseAddr(addr)) + if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" { + t.SkipNow() + } + require.NoError(t, err) + + request := make(header.ICMPv4, header.ICMPv4MinimumSize) + request.SetType(header.ICMPv4Echo) + request.SetIdent(uint16(rand.Uint32())) + request.SetChecksum(header.ICMPv4Checksum(request, 0)) + + err = conn.WriteICMP(buf.As(request)) + require.NoError(t, err) + + require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second))) + + response := buf.NewPacket() + err = conn.ReadICMP(response) + require.NoError(t, err) + + if runtime.GOOS == "linux" && privileged { + response.Reset() + err = conn.ReadICMP(response) + require.NoError(t, err) + } + + icmpHdr := header.ICMPv4(response.Bytes()) + require.Equal(t, header.ICMPv4EchoReply, icmpHdr.Type()) +} + +func testPingIPv6ReadIP(t *testing.T, privileged bool, addr string) { + conn, err := ping.Connect(privileged, nil, netip.MustParseAddr(addr)) + if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" { + t.SkipNow() + } + require.NoError(t, err) + + request := make(header.ICMPv6, header.ICMPv6MinimumSize) + request.SetType(header.ICMPv6EchoRequest) + request.SetIdent(uint16(rand.Uint32())) + + err = conn.WriteICMP(buf.As(request)) + require.NoError(t, err) + + conn.SetLocalAddr(netip.MustParseAddr("::1")) + require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second))) + + response := buf.NewPacket() + err = conn.ReadIP(response) + require.NoError(t, err) + if runtime.GOOS == "darwin" || runtime.GOOS == "linux" && privileged { + response.Reset() + err = conn.ReadIP(response) + require.NoError(t, err) + } + ipHdr := header.IPv6(response.Bytes()) + require.NotZero(t, ipHdr.HopLimit()) + icmpHdr := header.ICMPv6(ipHdr.Payload()) + require.Equal(t, header.ICMPv6EchoReply, icmpHdr.Type()) +} + +func testPingIPv6ReadICMP(t *testing.T, privileged bool, addr string) { + conn, err := ping.Connect(privileged, nil, netip.MustParseAddr(addr)) + if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" { + t.SkipNow() + } + require.NoError(t, err) + + request := make(header.ICMPv6, header.ICMPv6MinimumSize) + request.SetType(header.ICMPv6EchoRequest) + request.SetIdent(uint16(rand.Uint32())) + + err = conn.WriteICMP(buf.As(request)) + require.NoError(t, err) + + require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second))) + + response := buf.NewPacket() + err = conn.ReadICMP(response) + require.NoError(t, err) + if runtime.GOOS == "darwin" || runtime.GOOS == "linux" && privileged { + response.Reset() + err = conn.ReadICMP(response) + require.NoError(t, err) + } + icmpHdr := header.ICMPv6(response.Bytes()) + require.Equal(t, header.ICMPv6EchoReply, icmpHdr.Type()) +} diff --git a/ping/socket_unix.go b/ping/socket_unix.go new file mode 100644 index 0000000..11fa7a9 --- /dev/null +++ b/ping/socket_unix.go @@ -0,0 +1,86 @@ +//go:build unix + +package ping + +import ( + "net" + "net/netip" + "os" + "runtime" + "syscall" + + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "golang.org/x/sys/unix" +) + +func connect(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) { + var ( + network string + fd int + err error + ) + if destination.Is4() { + network = "ip4:icmp" + if !privileged { + fd, err = unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_ICMP) + } else { + fd, err = unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_ICMP) + } + } else { + network = "ip6:icmp" + if !privileged { + fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_ICMPV6) + } else { + fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_ICMPV6) + } + } + if err != nil { + return nil, E.Cause(err, "socket()") + } + file := os.NewFile(uintptr(fd), "datagram-oriented icmp") + defer file.Close() + err = unix.Connect(fd, M.AddrPortToSockaddr(netip.AddrPortFrom(destination, 0))) + if err != nil { + return nil, E.Cause(err, "connect()") + } + + if destination.Is4() && runtime.GOOS == "linux" { + //err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTOS, 1) + //if err != nil { + // return nil, err + //} + err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTTL, 1) + if err != nil { + return nil, E.Cause(err, "setsockopt()") + } + } + if destination.Is6() { + err = unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_RECVHOPLIMIT, 1) + if err != nil { + return nil, E.Cause(err, "setsockopt()") + } + err = unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) + if err != nil { + return nil, E.Cause(err, "setsockopt()") + } + } + + conn, err := net.FileConn(file) + if err != nil { + return nil, err + } + if controlFunc != nil { + var syscallConn syscall.RawConn + syscallConn, err = conn.(syscall.Conn).SyscallConn() + if err != nil { + return nil, err + } + err = controlFunc(network, destination.String(), syscallConn) + if err != nil { + return nil, err + } + } + return conn, nil +} diff --git a/ping/socket_windows.go b/ping/socket_windows.go new file mode 100644 index 0000000..daafd18 --- /dev/null +++ b/ping/socket_windows.go @@ -0,0 +1,38 @@ +package ping + +import ( + "net" + "net/netip" + "syscall" + + "github.com/sagernet/sing/common/control" + + "golang.org/x/sys/windows" +) + +func connect(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) { + var dialer net.Dialer + dialer.Control = controlFunc + if destination.Is6() { + dialer.Control = control.Append(dialer.Control, func(network, address string, conn syscall.RawConn) error { + return control.Raw(conn, func(fd uintptr) error { + err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_HOPLIMIT, 1) + if err != nil { + return err + } + err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_RECVTCLASS, 1) + if err != nil { + return err + } + return nil + }) + }) + } + var network string + if destination.Is4() { + network = "ip4:icmp" + } else { + network = "ip6:ipv6-icmp" + } + return dialer.Dial(network, destination.String()) +} diff --git a/route_nat.go b/route_nat.go index 5997703..a4f33cc 100644 --- a/route_nat.go +++ b/route_nat.go @@ -98,8 +98,8 @@ func (w *NatWriter) RewritePacket(packet []byte) { icmpHdr.SetChecksum(0) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, - Src: ipHdr.SourceAddress(), - Dst: ipHdr.DestinationAddress(), + Src: ipHdr.SourceAddressSlice(), + Dst: ipHdr.DestinationAddressSlice(), })) } if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 { diff --git a/route_nat_non_gvisor.go b/route_nat_non_gvisor.go index 049b074..a0c6cea 100644 --- a/route_nat_non_gvisor.go +++ b/route_nat_non_gvisor.go @@ -7,6 +7,6 @@ import ( ) type DirectRouteDestination interface { - DirectRouteAction WritePacket(packet *buf.Buffer) error + Close() error } diff --git a/stack_system.go b/stack_system.go index 9075a78..64f366d 100644 --- a/stack_system.go +++ b/stack_system.go @@ -746,8 +746,8 @@ func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) (bool ipHdr.SetDestinationAddr(sourceAddress) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, - Src: ipHdr.SourceAddress(), - Dst: ipHdr.DestinationAddress(), + Src: ipHdr.SourceAddressSlice(), + Dst: ipHdr.DestinationAddressSlice(), })) return true, nil } @@ -782,8 +782,8 @@ func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) e icmpHdr.SetCode(code) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize], - Src: newIPHdr.SourceAddress(), - Dst: newIPHdr.DestinationAddress(), + Src: newIPHdr.SourceAddressSlice(), + Dst: newIPHdr.DestinationAddressSlice(), PayloadCsum: checksum.Checksum(payload, 0), PayloadLen: len(payload), }))