diff --git a/ping/ping.go b/ping/ping.go index c9fdc52..5140c2d 100644 --- a/ping/ping.go +++ b/ping/ping.go @@ -28,6 +28,7 @@ type Conn struct { destination netip.Addr source common.TypedValue[netip.Addr] closed atomic.Bool + readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) } func Connect(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) { @@ -49,6 +50,32 @@ func (c *Conn) connect(controlFunc control.Func) (err error) { } else { c.conn, err = connect(c.privileged, controlFunc, c.destination) } + if err != nil { + return err + } + if ipConn, isIPConn := common.Cast[*net.IPConn](c.conn); isIPConn { + c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) { + var ipAddr *net.IPAddr + n, oobn, _, ipAddr, err = ipConn.ReadMsgIP(b, oob) + if err == nil { + addr = M.AddrFromNet(ipAddr) + } + return + } + } else if udpConn, isUDPConn := common.Cast[*net.UDPConn](c.conn); isUDPConn { + c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) { + var addrPort netip.AddrPort + n, oobn, _, addrPort, err = udpConn.ReadMsgUDPAddrPort(b, oob) + if err == nil { + addr = addrPort.Addr() + } + return + } + } else if unprivilegedConn, isUnprivilegedConn := c.conn.(*UnprivilegedConn); isUnprivilegedConn { + c.readMsg = unprivilegedConn.ReadMsg + } else { + return E.New("unsupported conn type: ", reflect.TypeOf(c.conn)) + } return } @@ -58,39 +85,12 @@ func (c *Conn) isLinuxUnprivileged() bool { func (c *Conn) ReadIP(buffer *buf.Buffer) error { if c.destination.Is6() || c.isLinuxUnprivileged() { - var readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) - if ipConn, isIPConn := common.Cast[*net.IPConn](c.conn); isIPConn { - readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) { - var ipAddr *net.IPAddr - n, oobn, _, ipAddr, err = ipConn.ReadMsgIP(b, oob) - if err == nil { - addr = M.AddrFromNet(ipAddr) - } - return - } - } else if udpConn, isUDPConn := common.Cast[*net.UDPConn](c.conn); isUDPConn { - readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) { - var addrPort netip.AddrPort - n, oobn, _, addrPort, err = udpConn.ReadMsgUDPAddrPort(b, oob) - if err == nil { - addr = addrPort.Addr() - } - return - } - } else if unprivilegedConn, isUnprivilegedConn := c.conn.(*UnprivilegedConn); isUnprivilegedConn { - readMsg = unprivilegedConn.ReadMsg - } else { - return E.New("unsupported conn type: ", reflect.TypeOf(c.conn)) - } if !c.destination.Is6() { oob := ipv4.NewControlMessage(ipv4.FlagTTL) buffer.Advance(header.IPv4MinimumSize) var ttl int // tos int - n, oobn, addr, err := readMsg(buffer.FreeBytes(), oob) - if err != nil { - return err - } + n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob) if err != nil { return err } @@ -126,7 +126,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error { hopLimit int trafficClass int ) - n, oobn, addr, err := readMsg(buffer.FreeBytes(), oob) + n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob) if err != nil { return err }