mirror of
https://github.com/SagerNet/sing-tun.git
synced 2025-09-26 20:51:13 +08:00
293 lines
8.3 KiB
Go
293 lines
8.3 KiB
Go
package ping
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/netip"
|
|
"reflect"
|
|
"runtime"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
|
"github.com/sagernet/sing/common"
|
|
"github.com/sagernet/sing/common/buf"
|
|
"github.com/sagernet/sing/common/control"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
)
|
|
|
|
type Conn struct {
|
|
ctx context.Context
|
|
privileged bool
|
|
conn net.Conn
|
|
destination netip.Addr
|
|
source common.TypedValue[netip.Addr]
|
|
closed atomic.Bool
|
|
readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error)
|
|
}
|
|
|
|
func Connect(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
|
|
c := &Conn{
|
|
ctx: ctx,
|
|
privileged: privileged,
|
|
destination: destination,
|
|
}
|
|
err := c.connect(controlFunc)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
func (c *Conn) connect(controlFunc control.Func) (err error) {
|
|
if c.isLinuxUnprivileged() {
|
|
c.conn, err = newUnprivilegedConn(c.ctx, controlFunc, c.destination)
|
|
} else {
|
|
c.conn, err = connect(c.privileged, controlFunc, c.destination)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if ipConn, isIPConn := common.Cast[*net.IPConn](c.conn); isIPConn {
|
|
c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
|
|
var ipAddr *net.IPAddr
|
|
n, oobn, _, ipAddr, err = ipConn.ReadMsgIP(b, oob)
|
|
if err == nil {
|
|
addr = M.AddrFromNet(ipAddr)
|
|
}
|
|
return
|
|
}
|
|
} else if udpConn, isUDPConn := common.Cast[*net.UDPConn](c.conn); isUDPConn {
|
|
c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
|
|
var addrPort netip.AddrPort
|
|
n, oobn, _, addrPort, err = udpConn.ReadMsgUDPAddrPort(b, oob)
|
|
if err == nil {
|
|
addr = addrPort.Addr()
|
|
}
|
|
return
|
|
}
|
|
} else if unprivilegedConn, isUnprivilegedConn := c.conn.(*UnprivilegedConn); isUnprivilegedConn {
|
|
c.readMsg = unprivilegedConn.ReadMsg
|
|
} else {
|
|
return E.New("unsupported conn type: ", reflect.TypeOf(c.conn))
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *Conn) isLinuxUnprivileged() bool {
|
|
return (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged
|
|
}
|
|
|
|
func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
|
if c.destination.Is6() || c.isLinuxUnprivileged() {
|
|
if !c.destination.Is6() {
|
|
oob := ipv4.NewControlMessage(ipv4.FlagTTL)
|
|
buffer.Advance(header.IPv4MinimumSize)
|
|
var ttl int
|
|
// tos int
|
|
n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
buffer.Truncate(n)
|
|
if oobn > 0 {
|
|
var controlMessage ipv4.ControlMessage
|
|
err = controlMessage.Parse(oob[:oobn])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ttl = controlMessage.TTL
|
|
}
|
|
if !c.isLinuxUnprivileged() {
|
|
icmpHdr := header.ICMPv4(buffer.Bytes())
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
|
}
|
|
ipHdr := header.IPv4(buffer.ExtendHeader(header.IPv4MinimumSize))
|
|
ipHdr.Encode(&header.IPv4Fields{
|
|
// TOS: uint8(tos),
|
|
SrcAddr: addr,
|
|
DstAddr: c.source.Load(),
|
|
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
|
TTL: uint8(ttl),
|
|
TotalLength: uint16(buffer.Len()),
|
|
})
|
|
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
|
} else {
|
|
oob := make([]byte, 1024)
|
|
buffer.Advance(header.IPv6MinimumSize)
|
|
var (
|
|
hopLimit int
|
|
trafficClass int
|
|
)
|
|
n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
buffer.Truncate(n)
|
|
if oobn > 0 {
|
|
var controlMessage *ipv6.ControlMessage
|
|
controlMessage, err = parseIPv6ControlMessage(oob[:oobn])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
hopLimit = controlMessage.HopLimit
|
|
trafficClass = controlMessage.TrafficClass
|
|
}
|
|
icmpHdr := header.ICMPv6(buffer.Bytes())
|
|
if !c.isLinuxUnprivileged() {
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
}
|
|
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
|
Header: icmpHdr,
|
|
Src: addr.AsSlice(),
|
|
Dst: c.source.Load().AsSlice(),
|
|
}))
|
|
ipHdr := header.IPv6(buffer.ExtendHeader(header.IPv6MinimumSize))
|
|
ipHdr.Encode(&header.IPv6Fields{
|
|
TrafficClass: uint8(trafficClass),
|
|
PayloadLength: uint16(buffer.Len() - header.IPv6MinimumSize),
|
|
TransportProtocol: header.ICMPv6ProtocolNumber,
|
|
HopLimit: uint8(hopLimit),
|
|
SrcAddr: addr,
|
|
DstAddr: c.source.Load(),
|
|
})
|
|
}
|
|
} else {
|
|
_, err := buffer.ReadOnceFrom(c.conn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !c.destination.Is6() {
|
|
ipHdr := header.IPv4(buffer.Bytes())
|
|
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
|
|
// MacOS have different TotalLen and FragOff in ipv4 header from socket api:
|
|
// https://stackoverflow.com/questions/13829712/mac-changes-ip-total-length-field/15881825#15881825
|
|
// but in the tun api still same data format as other system
|
|
ipHdr.SetTotalLength(ipHdr.TotalLengthDarwinRaw())
|
|
ipHdr.SetFlagsFragmentOffset(ipHdr.FlagsDarwinRaw(), ipHdr.FragmentOffsetDarwinRaw())
|
|
}
|
|
if !ipHdr.IsValid(buffer.Len()) {
|
|
return E.New("invalid IPv4 header received")
|
|
}
|
|
ipHdr.SetDestinationAddr(c.source.Load())
|
|
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
|
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
|
if !c.isLinuxUnprivileged() {
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
}
|
|
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
|
} else {
|
|
ipHdr := header.IPv6(buffer.Bytes())
|
|
if !ipHdr.IsValid(buffer.Len()) {
|
|
return E.New("invalid IPv6 header received")
|
|
}
|
|
ipHdr.SetDestinationAddr(c.source.Load())
|
|
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
|
if !c.isLinuxUnprivileged() {
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
}
|
|
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
|
Header: icmpHdr,
|
|
Src: ipHdr.SourceAddressSlice(),
|
|
Dst: ipHdr.DestinationAddressSlice(),
|
|
}))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
|
|
_, err := buffer.ReadOnceFrom(c.conn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !c.isLinuxUnprivileged() {
|
|
if !c.destination.Is6() {
|
|
ipHdr := header.IPv4(buffer.Bytes())
|
|
buffer.Advance(int(ipHdr.HeaderLength()))
|
|
|
|
icmpHdr := header.ICMPv4(buffer.Bytes())
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
|
} else {
|
|
icmpHdr := header.ICMPv6(buffer.Bytes())
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
|
Header: icmpHdr,
|
|
Src: c.destination.AsSlice(),
|
|
Dst: c.source.Load().AsSlice(),
|
|
}))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) WriteIP(buffer *buf.Buffer) error {
|
|
defer buffer.Release()
|
|
if !c.destination.Is6() {
|
|
ipHdr := header.IPv4(buffer.Bytes())
|
|
if !c.isLinuxUnprivileged() {
|
|
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
|
}
|
|
c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
|
|
return common.Error(c.conn.Write(ipHdr.Payload()))
|
|
} else {
|
|
ipHdr := header.IPv6(buffer.Bytes())
|
|
if !c.isLinuxUnprivileged() {
|
|
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
|
Header: icmpHdr,
|
|
Src: ipHdr.SourceAddressSlice(),
|
|
Dst: ipHdr.DestinationAddressSlice(),
|
|
}))
|
|
}
|
|
c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
|
|
return common.Error(c.conn.Write(ipHdr.Payload()))
|
|
}
|
|
}
|
|
|
|
func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
|
|
defer buffer.Release()
|
|
if !c.isLinuxUnprivileged() {
|
|
if !c.destination.Is6() {
|
|
icmpHdr := header.ICMPv4(buffer.Bytes())
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
|
} else {
|
|
icmpHdr := header.ICMPv6(buffer.Bytes())
|
|
icmpHdr.SetIdent(^icmpHdr.Ident())
|
|
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
|
Header: icmpHdr,
|
|
Src: c.source.Load().AsSlice(),
|
|
Dst: c.destination.AsSlice(),
|
|
}))
|
|
}
|
|
}
|
|
return common.Error(c.conn.Write(buffer.Bytes()))
|
|
}
|
|
|
|
func (c *Conn) SetLocalAddr(addr netip.Addr) {
|
|
c.source.Store(addr)
|
|
}
|
|
|
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
|
return c.conn.SetReadDeadline(t)
|
|
}
|
|
|
|
func (c *Conn) Close() error {
|
|
defer c.closed.Store(true)
|
|
return c.conn.Close()
|
|
}
|
|
|
|
func (c *Conn) IsClosed() bool {
|
|
return c.closed.Load()
|
|
}
|