mirror of
https://github.com/SagerNet/sing-tun.git
synced 2025-09-26 20:51:13 +08:00
ping: Fix unprivileged response on linux
This commit is contained in:
25
ping/ping.go
25
ping/ping.go
@@ -32,7 +32,7 @@ type Conn struct {
|
||||
}
|
||||
|
||||
func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
|
||||
conn, err := connect(privileged, controlFunc, destination)
|
||||
conn, err := connect0(ctx, privileged, controlFunc, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -45,6 +45,14 @@ func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func connect0(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "android") && !privileged {
|
||||
return newUnprivilegedConn(ctx, controlFunc, destination)
|
||||
} else {
|
||||
return connect(privileged, controlFunc, destination)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
||||
if c.destination.Is6() || (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged {
|
||||
var readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error)
|
||||
@@ -53,20 +61,22 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
||||
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
|
||||
var ipAddr *net.IPAddr
|
||||
n, oobn, _, ipAddr, err = conn.ReadMsgIP(b, oob)
|
||||
if ipAddr != nil {
|
||||
if err == nil {
|
||||
addr = M.AddrFromNet(ipAddr)
|
||||
}
|
||||
return
|
||||
}
|
||||
case *net.UDPConn:
|
||||
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
|
||||
var udpAddr *net.UDPAddr
|
||||
n, oobn, _, udpAddr, err = conn.ReadMsgUDP(b, oob)
|
||||
if udpAddr != nil {
|
||||
addr = M.AddrFromNet(udpAddr)
|
||||
var addrPort netip.AddrPort
|
||||
n, oobn, _, addrPort, err = conn.ReadMsgUDPAddrPort(b, oob)
|
||||
if err == nil {
|
||||
addr = addrPort.Addr()
|
||||
}
|
||||
return
|
||||
}
|
||||
case *UnprivilegedConn:
|
||||
readMsg = conn.ReadMsg
|
||||
default:
|
||||
return E.New("unsupported conn type: ", reflect.TypeOf(c.conn))
|
||||
}
|
||||
@@ -124,6 +134,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
||||
trafficClass = controlMessage.TrafficClass
|
||||
}
|
||||
icmpHdr := header.ICMPv6(buffer.Bytes())
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize],
|
||||
Src: addr.AsSlice(),
|
||||
@@ -151,12 +162,14 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
|
||||
c.logger.TraceContext(c.ctx, "read icmpv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
|
||||
} else {
|
||||
ipHdr := header.IPv6(buffer.Bytes())
|
||||
ipHdr.SetDestinationAddr(c.source.Load())
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddressSlice(),
|
||||
|
173
ping/socket_linux_unprivileged.go
Normal file
173
ping/socket_linux_unprivileged.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type UnprivilegedConn struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
controlFunc control.Func
|
||||
destination netip.Addr
|
||||
receiveChan chan *unprivilegedResponse
|
||||
readDeadline atomic.TypedValue[time.Time]
|
||||
writeDeadline atomic.TypedValue[time.Time]
|
||||
}
|
||||
|
||||
type unprivilegedResponse struct {
|
||||
Buffer *buf.Buffer
|
||||
Cmsg *buf.Buffer
|
||||
Addr netip.Addr
|
||||
}
|
||||
|
||||
func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
|
||||
conn, err := connect(false, controlFunc, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.Close()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &UnprivilegedConn{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
controlFunc: controlFunc,
|
||||
destination: destination,
|
||||
receiveChan: make(chan *unprivilegedResponse),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
|
||||
select {
|
||||
case packet := <-c.receiveChan:
|
||||
n = copy(b, packet.Buffer.Bytes())
|
||||
packet.Buffer.Release()
|
||||
packet.Cmsg.Release()
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr netip.Addr, err error) {
|
||||
select {
|
||||
case packet := <-c.receiveChan:
|
||||
n = copy(b, packet.Buffer.Bytes())
|
||||
oobn = copy(oob, packet.Cmsg.Bytes())
|
||||
addr = packet.Addr
|
||||
packet.Buffer.Release()
|
||||
packet.Cmsg.Release()
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return 0, 0, netip.Addr{}, os.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
|
||||
conn, err := connect(false, c.controlFunc, c.destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var identifier uint16
|
||||
if !c.destination.Is6() {
|
||||
icmpHdr := header.ICMPv4(b)
|
||||
identifier = icmpHdr.Ident()
|
||||
} else {
|
||||
icmpHdr := header.ICMPv6(b)
|
||||
identifier = icmpHdr.Ident()
|
||||
}
|
||||
if readDeadline := c.readDeadline.Load(); !readDeadline.IsZero() {
|
||||
conn.SetReadDeadline(readDeadline)
|
||||
}
|
||||
if writeDeadline := c.writeDeadline.Load(); !writeDeadline.IsZero() {
|
||||
conn.SetWriteDeadline(writeDeadline)
|
||||
}
|
||||
n, err = conn.Write(b)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
go c.fetchResponse(conn, identifier)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) fetchResponse(conn net.Conn, identifier uint16) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func() {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
case <-done:
|
||||
}
|
||||
conn.Close()
|
||||
}()
|
||||
buffer := buf.NewPacket()
|
||||
cmsgBuffer := buf.NewSize(1024)
|
||||
n, oobN, _, addr, err := conn.(*net.UDPConn).ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
cmsgBuffer.Release()
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
cmsgBuffer.Truncate(oobN)
|
||||
if !c.destination.Is6() {
|
||||
icmpHdr := header.ICMPv4(buffer.Bytes())
|
||||
icmpHdr.SetIdent(identifier)
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
|
||||
} else {
|
||||
icmpHdr := header.ICMPv6(buffer.Bytes())
|
||||
icmpHdr.SetIdent(identifier)
|
||||
// offload checksum here since we don't have source address here
|
||||
}
|
||||
select {
|
||||
case c.receiveChan <- &unprivilegedResponse{
|
||||
Buffer: buffer,
|
||||
Cmsg: cmsgBuffer,
|
||||
Addr: addr.Addr(),
|
||||
}:
|
||||
case <-c.ctx.Done():
|
||||
buffer.Release()
|
||||
cmsgBuffer.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) Close() error {
|
||||
c.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) LocalAddr() net.Addr {
|
||||
return M.Socksaddr{}
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) RemoteAddr() net.Addr {
|
||||
return M.SocksaddrFrom(c.destination, 0).UDPAddr()
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) SetDeadline(t time.Time) error {
|
||||
c.readDeadline.Store(t)
|
||||
c.writeDeadline.Store(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error {
|
||||
c.readDeadline.Store(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error {
|
||||
c.writeDeadline.Store(t)
|
||||
return nil
|
||||
}
|
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
@@ -77,7 +78,7 @@ func connect(privileged bool, controlFunc control.Func, destination netip.Addr)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "connect()")
|
||||
}
|
||||
|
||||
|
||||
conn, err := net.FileConn(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
Reference in New Issue
Block a user