Files
pg/net/udp.go

184 lines
3.6 KiB
Go

package net
import (
"context"
"errors"
"log/slog"
"net"
"sync"
"sync/atomic"
"time"
)
var _ net.Conn = (*UDPConn)(nil)
type UDPConn struct {
removeConn func()
remoteAddr net.Addr
c net.PacketConn
closeOnce sync.Once
inbound chan []byte
closeChan chan struct{}
lastActiveTime atomic.Value
}
func (c *UDPConn) init() {
c.inbound = make(chan []byte, 512)
c.closeChan = make(chan struct{})
c.lastActiveTime.Store(time.Now())
ticker := time.NewTicker(6 * time.Second)
go func() { // create a timer to trace timeout udp conn, and close it
defer ticker.Stop()
for range ticker.C {
if time.Since(c.lastActiveTime.Load().(time.Time)) > 10*time.Second {
c.Close()
break
}
}
}()
}
func (c *UDPConn) Read(p []byte) (int, error) {
select {
case b := <-c.inbound:
c.lastActiveTime.Store(time.Now())
return copy(p, b), nil
case <-c.closeChan:
return 0, net.ErrClosed
}
}
func (c *UDPConn) Write(p []byte) (int, error) {
c.lastActiveTime.Store(time.Now())
return c.c.WriteTo(p, c.remoteAddr)
}
func (c *UDPConn) LocalAddr() net.Addr {
return c.c.LocalAddr()
}
func (c *UDPConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *UDPConn) Close() error {
c.closeOnce.Do(func() {
close(c.closeChan)
close(c.inbound)
c.removeConn()
slog.Log(context.Background(), -2, "UDPConn closed", "local_addr", c.LocalAddr(), "remote_addr", c.remoteAddr)
})
return nil
}
func (c *UDPConn) SetDeadline(t time.Time) error {
return errors.ErrUnsupported
}
func (c *UDPConn) SetReadDeadline(t time.Time) error {
return errors.ErrUnsupported
}
func (c *UDPConn) SetWriteDeadline(t time.Time) error {
return errors.ErrUnsupported
}
type UDPListener struct {
PacketConn net.PacketConn
buf []byte
initOnce sync.Once
closeOnce sync.Once
udpChan chan *UDPConn
connMap map[string]*UDPConn
connMapMu sync.RWMutex
}
func (l *UDPListener) init() {
l.initOnce.Do(func() {
l.buf = make([]byte, 65535)
l.udpChan = make(chan *UDPConn, 8)
l.connMap = make(map[string]*UDPConn)
go l.readUDP()
})
}
func (l *UDPListener) readUDP() {
read := func() error {
read:
n, peerAddr, err := l.PacketConn.ReadFrom(l.buf)
if err != nil {
return err
}
l.connMapMu.RLock()
conn, ok := l.connMap[peerAddr.String()]
l.connMapMu.RUnlock()
if ok {
conn.inbound <- append([]byte(nil), l.buf[:n]...)
goto read
}
l.connMapMu.Lock()
conn, ok = l.connMap[peerAddr.String()]
if ok {
l.connMapMu.Unlock()
conn.inbound <- append([]byte(nil), l.buf[:n]...)
goto read
}
defer l.connMapMu.Unlock()
conn = &UDPConn{remoteAddr: peerAddr, c: l.PacketConn, removeConn: func() {
l.connMapMu.Lock()
defer l.connMapMu.Unlock()
delete(l.connMap, peerAddr.String())
}}
conn.init()
l.connMap[peerAddr.String()] = conn
conn.inbound <- append([]byte(nil), l.buf[:n]...)
l.udpChan <- conn
return nil
}
for {
if err := read(); err != nil {
return
}
}
}
func (l *UDPListener) Accept() (net.Conn, error) {
return l.AcceptContext(context.Background())
}
func (l *UDPListener) AcceptContext(ctx context.Context) (net.Conn, error) {
l.init()
select {
case c := <-l.udpChan:
return c, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (l *UDPListener) Close() error {
if l.PacketConn == nil {
return nil
}
l.closeOnce.Do(func() {
l.PacketConn.Close()
l.connMapMu.Lock()
defer l.connMapMu.Unlock()
for _, c := range l.connMap {
go c.Close()
}
})
return nil
}
func (l *UDPListener) Addr() net.Addr {
l.init()
if l.PacketConn == nil {
return nil
}
return l.PacketConn.LocalAddr()
}