ping: Reduce the cost of readMsg

This commit is contained in:
wwqgtxx
2025-08-27 08:45:12 +08:00
parent 4c43f4af12
commit f9bbb15bfb

View File

@@ -28,6 +28,7 @@ type Conn struct {
destination netip.Addr
source common.TypedValue[netip.Addr]
closed atomic.Bool
readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error)
}
func Connect(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
@@ -49,6 +50,32 @@ func (c *Conn) connect(controlFunc control.Func) (err error) {
} else {
c.conn, err = connect(c.privileged, controlFunc, c.destination)
}
if err != nil {
return err
}
if ipConn, isIPConn := common.Cast[*net.IPConn](c.conn); isIPConn {
c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
var ipAddr *net.IPAddr
n, oobn, _, ipAddr, err = ipConn.ReadMsgIP(b, oob)
if err == nil {
addr = M.AddrFromNet(ipAddr)
}
return
}
} else if udpConn, isUDPConn := common.Cast[*net.UDPConn](c.conn); isUDPConn {
c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
var addrPort netip.AddrPort
n, oobn, _, addrPort, err = udpConn.ReadMsgUDPAddrPort(b, oob)
if err == nil {
addr = addrPort.Addr()
}
return
}
} else if unprivilegedConn, isUnprivilegedConn := c.conn.(*UnprivilegedConn); isUnprivilegedConn {
c.readMsg = unprivilegedConn.ReadMsg
} else {
return E.New("unsupported conn type: ", reflect.TypeOf(c.conn))
}
return
}
@@ -58,39 +85,12 @@ func (c *Conn) isLinuxUnprivileged() bool {
func (c *Conn) ReadIP(buffer *buf.Buffer) error {
if c.destination.Is6() || c.isLinuxUnprivileged() {
var readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error)
if ipConn, isIPConn := common.Cast[*net.IPConn](c.conn); isIPConn {
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
var ipAddr *net.IPAddr
n, oobn, _, ipAddr, err = ipConn.ReadMsgIP(b, oob)
if err == nil {
addr = M.AddrFromNet(ipAddr)
}
return
}
} else if udpConn, isUDPConn := common.Cast[*net.UDPConn](c.conn); isUDPConn {
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
var addrPort netip.AddrPort
n, oobn, _, addrPort, err = udpConn.ReadMsgUDPAddrPort(b, oob)
if err == nil {
addr = addrPort.Addr()
}
return
}
} else if unprivilegedConn, isUnprivilegedConn := c.conn.(*UnprivilegedConn); isUnprivilegedConn {
readMsg = unprivilegedConn.ReadMsg
} else {
return E.New("unsupported conn type: ", reflect.TypeOf(c.conn))
}
if !c.destination.Is6() {
oob := ipv4.NewControlMessage(ipv4.FlagTTL)
buffer.Advance(header.IPv4MinimumSize)
var ttl int
// tos int
n, oobn, addr, err := readMsg(buffer.FreeBytes(), oob)
if err != nil {
return err
}
n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob)
if err != nil {
return err
}
@@ -126,7 +126,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
hopLimit int
trafficClass int
)
n, oobn, addr, err := readMsg(buffer.FreeBytes(), oob)
n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob)
if err != nil {
return err
}