mirror of
https://github.com/sigcn/pg.git
synced 2025-09-27 01:05:51 +08:00
184 lines
3.6 KiB
Go
184 lines
3.6 KiB
Go
package net
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
var _ net.Conn = (*UDPConn)(nil)
|
|
|
|
type UDPConn struct {
|
|
removeConn func()
|
|
remoteAddr net.Addr
|
|
c net.PacketConn
|
|
|
|
closeOnce sync.Once
|
|
inbound chan []byte
|
|
closeChan chan struct{}
|
|
lastActiveTime atomic.Value
|
|
}
|
|
|
|
func (c *UDPConn) init() {
|
|
c.inbound = make(chan []byte, 512)
|
|
c.closeChan = make(chan struct{})
|
|
c.lastActiveTime.Store(time.Now())
|
|
ticker := time.NewTicker(6 * time.Second)
|
|
go func() { // create a timer to trace timeout udp conn, and close it
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
if time.Since(c.lastActiveTime.Load().(time.Time)) > 10*time.Second {
|
|
c.Close()
|
|
break
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (c *UDPConn) Read(p []byte) (int, error) {
|
|
select {
|
|
case b := <-c.inbound:
|
|
c.lastActiveTime.Store(time.Now())
|
|
return copy(p, b), nil
|
|
case <-c.closeChan:
|
|
return 0, net.ErrClosed
|
|
}
|
|
}
|
|
|
|
func (c *UDPConn) Write(p []byte) (int, error) {
|
|
c.lastActiveTime.Store(time.Now())
|
|
return c.c.WriteTo(p, c.remoteAddr)
|
|
}
|
|
|
|
func (c *UDPConn) LocalAddr() net.Addr {
|
|
return c.c.LocalAddr()
|
|
}
|
|
|
|
func (c *UDPConn) RemoteAddr() net.Addr {
|
|
return c.remoteAddr
|
|
}
|
|
|
|
func (c *UDPConn) Close() error {
|
|
c.closeOnce.Do(func() {
|
|
close(c.closeChan)
|
|
close(c.inbound)
|
|
c.removeConn()
|
|
slog.Log(context.Background(), -2, "UDPConn closed", "local_addr", c.LocalAddr(), "remote_addr", c.remoteAddr)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (c *UDPConn) SetDeadline(t time.Time) error {
|
|
return errors.ErrUnsupported
|
|
}
|
|
|
|
func (c *UDPConn) SetReadDeadline(t time.Time) error {
|
|
return errors.ErrUnsupported
|
|
}
|
|
|
|
func (c *UDPConn) SetWriteDeadline(t time.Time) error {
|
|
return errors.ErrUnsupported
|
|
}
|
|
|
|
type UDPListener struct {
|
|
PacketConn net.PacketConn
|
|
|
|
buf []byte
|
|
initOnce sync.Once
|
|
closeOnce sync.Once
|
|
udpChan chan *UDPConn
|
|
|
|
connMap map[string]*UDPConn
|
|
connMapMu sync.RWMutex
|
|
}
|
|
|
|
func (l *UDPListener) init() {
|
|
l.initOnce.Do(func() {
|
|
l.buf = make([]byte, 65535)
|
|
l.udpChan = make(chan *UDPConn, 8)
|
|
l.connMap = make(map[string]*UDPConn)
|
|
go l.readUDP()
|
|
})
|
|
}
|
|
|
|
func (l *UDPListener) readUDP() {
|
|
read := func() error {
|
|
read:
|
|
n, peerAddr, err := l.PacketConn.ReadFrom(l.buf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
l.connMapMu.RLock()
|
|
conn, ok := l.connMap[peerAddr.String()]
|
|
l.connMapMu.RUnlock()
|
|
if ok {
|
|
conn.inbound <- append([]byte(nil), l.buf[:n]...)
|
|
goto read
|
|
}
|
|
l.connMapMu.Lock()
|
|
conn, ok = l.connMap[peerAddr.String()]
|
|
if ok {
|
|
l.connMapMu.Unlock()
|
|
conn.inbound <- append([]byte(nil), l.buf[:n]...)
|
|
goto read
|
|
}
|
|
defer l.connMapMu.Unlock()
|
|
conn = &UDPConn{remoteAddr: peerAddr, c: l.PacketConn, removeConn: func() {
|
|
l.connMapMu.Lock()
|
|
defer l.connMapMu.Unlock()
|
|
delete(l.connMap, peerAddr.String())
|
|
}}
|
|
conn.init()
|
|
l.connMap[peerAddr.String()] = conn
|
|
conn.inbound <- append([]byte(nil), l.buf[:n]...)
|
|
l.udpChan <- conn
|
|
return nil
|
|
}
|
|
for {
|
|
if err := read(); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *UDPListener) Accept() (net.Conn, error) {
|
|
return l.AcceptContext(context.Background())
|
|
}
|
|
|
|
func (l *UDPListener) AcceptContext(ctx context.Context) (net.Conn, error) {
|
|
l.init()
|
|
select {
|
|
case c := <-l.udpChan:
|
|
return c, nil
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (l *UDPListener) Close() error {
|
|
if l.PacketConn == nil {
|
|
return nil
|
|
}
|
|
l.closeOnce.Do(func() {
|
|
l.PacketConn.Close()
|
|
l.connMapMu.Lock()
|
|
defer l.connMapMu.Unlock()
|
|
for _, c := range l.connMap {
|
|
go c.Close()
|
|
}
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (l *UDPListener) Addr() net.Addr {
|
|
l.init()
|
|
if l.PacketConn == nil {
|
|
return nil
|
|
}
|
|
return l.PacketConn.LocalAddr()
|
|
}
|