diff --git a/ping/socket_linux_unprivileged.go b/ping/socket_linux_unprivileged.go index 78d1026..efaf11f 100644 --- a/ping/socket_linux_unprivileged.go +++ b/ping/socket_linux_unprivileged.go @@ -18,14 +18,14 @@ import ( ) type UnprivilegedConn struct { - ctx context.Context - cancel context.CancelFunc - controlFunc control.Func - destination netip.Addr - receiveChan chan *unprivilegedResponse - readDeadline pipe.Deadline - natMap map[uint16]net.Conn - natMapMutex sync.Mutex + ctx context.Context + cancel context.CancelFunc + controlFunc control.Func + destination netip.Addr + receiveChan chan *unprivilegedResponse + readDeadline pipe.Deadline + mappingAccess sync.Mutex + mapping map[uint16]net.Conn } type unprivilegedResponse struct { @@ -48,7 +48,7 @@ func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destinat destination: destination, receiveChan: make(chan *unprivilegedResponse), readDeadline: pipe.MakeDeadline(), - natMap: make(map[uint16]net.Conn), + mapping: make(map[uint16]net.Conn), }, nil } @@ -59,10 +59,10 @@ func (c *UnprivilegedConn) Read(b []byte) (n int, err error) { packet.Buffer.Release() packet.Cmsg.Release() return - case <-c.ctx.Done(): - return 0, os.ErrClosed case <-c.readDeadline.Wait(): return 0, os.ErrDeadlineExceeded + case <-c.ctx.Done(): + return 0, os.ErrClosed } } @@ -75,10 +75,10 @@ func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr neti packet.Buffer.Release() packet.Cmsg.Release() return - case <-c.ctx.Done(): - return 0, 0, netip.Addr{}, os.ErrClosed case <-c.readDeadline.Wait(): return 0, 0, netip.Addr{}, os.ErrDeadlineExceeded + case <-c.ctx.Done(): + return 0, 0, netip.Addr{}, os.ErrClosed } } @@ -92,26 +92,23 @@ func (c *UnprivilegedConn) Write(b []byte) (n int, err error) { identifier = icmpHdr.Ident() } - c.natMapMutex.Lock() - if err = c.ctx.Err(); err != nil { - c.natMapMutex.Unlock() - return 0, err + c.mappingAccess.Lock() + if c.ctx.Err() != nil { + return 0, c.ctx.Err() } - conn, ok := c.natMap[identifier] - if !ok { + conn, loaded := c.mapping[identifier] + if !loaded { conn, err = connect(false, c.controlFunc, c.destination) if err != nil { - c.natMapMutex.Unlock() - return 0, err + c.mappingAccess.Unlock() + return } go c.fetchResponse(conn.(*net.UDPConn), identifier) } - c.natMapMutex.Unlock() - + c.mappingAccess.Unlock() n, err = conn.Write(b) if err != nil { c.removeConn(conn.(*net.UDPConn), identifier) - return } return } @@ -154,22 +151,20 @@ func (c *UnprivilegedConn) fetchResponse(conn *net.UDPConn, identifier uint16) { } func (c *UnprivilegedConn) removeConn(conn *net.UDPConn, identifier uint16) { - c.natMapMutex.Lock() + c.mappingAccess.Lock() + defer c.mappingAccess.Unlock() _ = conn.Close() - if c.natMap[identifier] == conn { - delete(c.natMap, identifier) - } - c.natMapMutex.Unlock() + delete(c.mapping, identifier) } func (c *UnprivilegedConn) Close() error { - c.natMapMutex.Lock() + c.mappingAccess.Lock() + defer c.mappingAccess.Unlock() c.cancel() - for _, conn := range c.natMap { + for _, conn := range c.mapping { _ = conn.Close() } - common.ClearMap(c.natMap) - c.natMapMutex.Unlock() + common.ClearMap(c.mapping) return nil } @@ -182,7 +177,7 @@ func (c *UnprivilegedConn) RemoteAddr() net.Addr { } func (c *UnprivilegedConn) SetDeadline(t time.Time) error { - return c.SetReadDeadline(t) + return os.ErrInvalid } func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error { @@ -191,5 +186,5 @@ func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error { } func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error { - return nil + return os.ErrInvalid }