Files
sing-tun/ping/socket_linux_unprivileged.go
2025-08-27 11:07:26 +08:00

190 lines
4.4 KiB
Go

package ping
import (
"context"
"net"
"net/netip"
"os"
"sync"
"time"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/pipe"
)
type UnprivilegedConn struct {
ctx context.Context
cancel context.CancelFunc
controlFunc control.Func
destination netip.Addr
receiveChan chan *unprivilegedResponse
readDeadline pipe.Deadline
mappingAccess sync.Mutex
mapping map[uint16]net.Conn
}
type unprivilegedResponse struct {
Buffer *buf.Buffer
Cmsg *buf.Buffer
Addr netip.Addr
}
func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
conn, err := connect(false, controlFunc, destination)
if err != nil {
return nil, err
}
conn.Close()
ctx, cancel := context.WithCancel(ctx)
return &UnprivilegedConn{
ctx: ctx,
cancel: cancel,
controlFunc: controlFunc,
destination: destination,
receiveChan: make(chan *unprivilegedResponse),
readDeadline: pipe.MakeDeadline(),
mapping: make(map[uint16]net.Conn),
}, nil
}
func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
select {
case packet := <-c.receiveChan:
n = copy(b, packet.Buffer.Bytes())
packet.Buffer.Release()
packet.Cmsg.Release()
return
case <-c.readDeadline.Wait():
return 0, os.ErrDeadlineExceeded
case <-c.ctx.Done():
return 0, os.ErrClosed
}
}
func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr netip.Addr, err error) {
select {
case packet := <-c.receiveChan:
n = copy(b, packet.Buffer.Bytes())
oobn = copy(oob, packet.Cmsg.Bytes())
addr = packet.Addr
packet.Buffer.Release()
packet.Cmsg.Release()
return
case <-c.readDeadline.Wait():
return 0, 0, netip.Addr{}, os.ErrDeadlineExceeded
case <-c.ctx.Done():
return 0, 0, netip.Addr{}, os.ErrClosed
}
}
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
var identifier uint16
if !c.destination.Is6() {
icmpHdr := header.ICMPv4(b)
identifier = icmpHdr.Ident()
} else {
icmpHdr := header.ICMPv6(b)
identifier = icmpHdr.Ident()
}
c.mappingAccess.Lock()
if c.ctx.Err() != nil {
return 0, c.ctx.Err()
}
conn, loaded := c.mapping[identifier]
if !loaded {
conn, err = connect(false, c.controlFunc, c.destination)
if err != nil {
c.mappingAccess.Unlock()
return
}
go c.fetchResponse(conn.(*net.UDPConn), identifier)
c.mapping[identifier] = conn
}
c.mappingAccess.Unlock()
n, err = conn.Write(b)
if err != nil {
c.removeConn(conn.(*net.UDPConn), identifier)
}
return
}
func (c *UnprivilegedConn) fetchResponse(conn *net.UDPConn, identifier uint16) {
defer c.removeConn(conn, identifier)
for {
buffer := buf.NewPacket()
cmsgBuffer := buf.NewSize(1024)
n, oobN, _, addr, err := conn.ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
if err != nil {
buffer.Release()
cmsgBuffer.Release()
return
}
buffer.Truncate(n)
cmsgBuffer.Truncate(oobN)
if !c.destination.Is6() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(identifier)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
} else {
icmpHdr := header.ICMPv6(buffer.Bytes())
icmpHdr.SetIdent(identifier)
// offload checksum here since we don't have source address here
}
select {
case c.receiveChan <- &unprivilegedResponse{
Buffer: buffer,
Cmsg: cmsgBuffer,
Addr: addr.Addr(),
}:
case <-c.ctx.Done():
buffer.Release()
cmsgBuffer.Release()
return
}
}
}
func (c *UnprivilegedConn) removeConn(conn *net.UDPConn, identifier uint16) {
c.mappingAccess.Lock()
defer c.mappingAccess.Unlock()
_ = conn.Close()
delete(c.mapping, identifier)
}
func (c *UnprivilegedConn) Close() error {
c.mappingAccess.Lock()
defer c.mappingAccess.Unlock()
c.cancel()
for _, conn := range c.mapping {
_ = conn.Close()
}
common.ClearMap(c.mapping)
return nil
}
func (c *UnprivilegedConn) LocalAddr() net.Addr {
return M.Socksaddr{}
}
func (c *UnprivilegedConn) RemoteAddr() net.Addr {
return M.SocksaddrFrom(c.destination, 0).UDPAddr()
}
func (c *UnprivilegedConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
return nil
}
func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}