Files
pg/rdt/rdt.go
2024-09-27 11:20:17 +08:00

630 lines
13 KiB
Go

package rdt
import (
"encoding/binary"
"errors"
"io"
"log/slog"
"math"
"net"
"strings"
"sync"
"sync/atomic"
"time"
N "github.com/sigcn/pg/net"
)
const (
Established = 0
FIN_WAIT1 = 2
FIN_WAIT2 = 3
CLOSED = 5
)
type nck struct {
no uint32
missing []uint32
}
var _ net.Conn = (*rdtConn)(nil)
type rdtConn struct {
server bool
cfg Config
window uint32
frameSize int
c net.PacketConn
remoteAddr net.Addr
exit chan struct{}
inbound chan []byte
nck chan nck
nckQuery chan uint32
fin chan uint32
finack chan uint32
sendEvent chan struct{}
inboundBuf []byte
recvNO, sentNO, ackNO uint32
recvPool map[uint32][]byte
recvMutex sync.RWMutex
sendPool map[uint32][]byte
sendMutex sync.RWMutex
rs atomic.Uint32
state atomic.Int32 // 0 Established 2 FIN_WAIT 5 CLOSED
closeOnce sync.Once
wClosed *net.OpError
deadlineRead N.Deadline
}
// Read reads data from the connection.
// Read can be made to time out and return an error after a fixed
// time limit; see SetDeadline and SetReadDeadline.
func (c *rdtConn) Read(b []byte) (n int, err error) {
if c.inboundBuf == nil {
select {
case <-c.exit:
err = io.EOF
return
case _, ok := <-c.deadlineRead.Deadline():
if !ok {
return 0, io.EOF
}
err = N.ErrDeadline
return
case pkt, ok := <-c.inbound:
if !ok {
return 0, io.EOF
}
c.inboundBuf = pkt
}
}
n = copy(b, c.inboundBuf)
if n == len(c.inboundBuf) {
c.inboundBuf = nil
return
}
c.inboundBuf = c.inboundBuf[n:]
return
}
// Write writes data to the connection.
// Write can be made to time out and return an error after a fixed
// time limit; see SetDeadline and SetWriteDeadline.
func (c *rdtConn) Write(b []byte) (n int, err error) {
if c.state.Load() > 0 {
err = c.wClosed
return
}
for i := range int(math.Ceil(float64(len(b)) / float64(c.frameSize))) {
start := i * c.frameSize
length := c.frameSize
if len(b)-start < c.frameSize {
length = len(b) - start
}
c.sendMutex.Lock()
if len(c.sendPool) >= int(c.window) {
c.sendMutex.Unlock()
for {
if _, ok := <-c.sendEvent; !ok {
err = c.wClosed
return
}
c.sendMutex.Lock()
if len(c.sendPool) >= int(c.window) {
c.sendMutex.Unlock()
continue
}
break
}
}
no := c.sentNO + 1
pkt := c.buildFrame(0, no, uint16(length), b[start:start+length])
c.sendPool[no] = pkt
c.sentNO++
c.sendMutex.Unlock()
c.send(pkt)
}
n = len(b)
return
}
// Close closes the connection.
// Any blocked Read or Write operations will be unblocked and return errors.
func (c *rdtConn) Close() error {
c.closeOnce.Do(func() {
c.state.Store(FIN_WAIT1)
c.sendFIN()
c.state.Store(FIN_WAIT2)
go func() {
wait := time.NewTimer(15 * time.Second)
select {
case <-wait.C:
slog.Warn("FIN_WAIT2 timeout")
case <-c.fin:
}
c.state.Store(CLOSED)
close(c.fin)
close(c.nckQuery)
c.recvNO = 0
c.recvMutex.Lock()
clear(c.recvPool)
c.recvMutex.Unlock()
}()
close(c.exit)
close(c.inbound)
close(c.nck)
close(c.finack)
close(c.sendEvent)
c.deadlineRead.Close()
c.inboundBuf = nil
c.sentNO = 0
c.sendMutex.Lock()
clear(c.sendPool)
c.sendMutex.Unlock()
})
return nil
}
// LocalAddr returns the local network address, if known.
func (c *rdtConn) LocalAddr() net.Addr {
return c.c.LocalAddr()
}
// RemoteAddr returns the remote network address, if known.
func (c *rdtConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
//
// A deadline is an absolute time after which I/O operations
// fail instead of blocking. The deadline applies to all future
// and pending I/O, not just the immediately following call to
// Read or Write. After a deadline has been exceeded, the
// connection can be refreshed by setting a deadline in the future.
//
// If the deadline is exceeded a call to Read or Write or to other
// I/O methods will return an error that wraps os.ErrDeadlineExceeded.
// This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
// The error's Timeout method will return true, but note that there
// are other possible errors for which the Timeout method will
// return true even if the deadline has not been exceeded.
//
// An idle timeout can be implemented by repeatedly extending
// the deadline after successful Read or Write calls.
//
// A zero value for t means I/O operations will not time out.
func (c *rdtConn) SetDeadline(t time.Time) error {
err1 := c.SetReadDeadline(t)
err2 := c.SetWriteDeadline(t)
return errors.Join(err1, err2)
}
// SetReadDeadline sets the deadline for future Read calls
// and any currently-blocked Read call.
// A zero value for t means Read will not time out.
func (c *rdtConn) SetReadDeadline(t time.Time) error {
c.deadlineRead.SetDeadline(t)
return nil
}
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
func (c *rdtConn) SetWriteDeadline(t time.Time) error {
return errors.ErrUnsupported
}
func (c *rdtConn) resend(nck nck) {
missing := map[uint32]struct{}{}
for _, n := range nck.missing {
c.sendMutex.RLock()
pkt, ok := c.sendPool[n]
c.sendMutex.RUnlock()
if ok {
c.send(pkt)
c.rs.Add(1)
}
missing[n] = struct{}{}
}
c.sendMutex.Lock()
for k := range c.sendPool {
if _, ok := missing[k]; !ok && k <= nck.no {
delete(c.sendPool, k)
}
}
c.sendMutex.Unlock()
defer func() { recover() }()
c.sendEvent <- struct{}{}
}
func (c *rdtConn) send(pkt []byte) {
if c.server {
pkt[0] |= 0x80
}
no := binary.BigEndian.Uint32(pkt[1:5])
slog.Debug("RDTSend", "cmd", pkt[0], "no", no, "peer", c.remoteAddr, "len", len(pkt))
c.c.WriteTo(pkt, c.RemoteAddr())
}
func (c *rdtConn) buildFrame(cmd byte, no uint32, length uint16, data []byte) []byte {
pkt := []byte{cmd}
pkt = append(pkt, binary.BigEndian.AppendUint32(nil, no)...)
pkt = append(pkt, binary.BigEndian.AppendUint16(nil, length)...)
pkt = append(pkt, data...)
return pkt
}
func (c *rdtConn) recv(pkt []byte) {
defer func() {
if err := recover(); err != nil {
slog.Debug("Recv", "recover", err)
}
}()
no := binary.BigEndian.Uint32(pkt[1:5])
l := binary.BigEndian.Uint16(pkt[5:7])
slog.Debug("RDTRecv", "cmd", pkt[0], "no", no, "peer", c.remoteAddr, "len", len(pkt))
switch pkt[0] {
case 0: // DATA
c.recvData(no, pkt[7:l+7])
case 1: // QueryNCK
c.nckQuery <- no
case 2: // NCK
nck := nck{no: no}
for i := range l / 4 {
s := 7 + i*4
nck.missing = append(nck.missing, binary.BigEndian.Uint32(pkt[s:s+4]))
}
if len(nck.missing) == 0 && nck.no > c.ackNO {
c.ackNO = nck.no
}
c.nck <- nck
case 21: // FIN
if c.recvNO < no {
c.sendNCK(no)
return
}
c.send(c.buildFrame(22, no, 0, nil)) // send FINACK
c.fin <- no
c.Close()
case 22: // FINACK
c.finack <- no
}
}
func (c *rdtConn) recvData(no uint32, data []byte) {
if c.state.Load() == CLOSED {
slog.Warn("drop packet for closed conn")
return
}
if no%c.window == 0 {
c.sendNCK(min(c.recvNO+c.window, no))
}
if no <= c.recvNO {
return
}
if no == c.recvNO+1 {
c.inbound <- data
c.recvNO++
for {
c.recvMutex.RLock()
k, ok := c.recvPool[c.recvNO+1]
c.recvMutex.RUnlock()
if ok {
c.recvMutex.Lock()
delete(c.recvPool, c.recvNO)
c.recvMutex.Unlock()
c.inbound <- k
c.recvNO++
continue
}
break
}
return
}
if no-c.recvNO > c.window*2 {
c.sendNCK(c.recvNO + c.window)
return
}
c.recvMutex.Lock()
c.recvPool[no] = data
c.recvMutex.Unlock()
}
func (c *rdtConn) askNCK(no uint32) {
c.send(c.buildFrame(1, no, 0, nil))
}
func (c *rdtConn) sendNCK(no uint32) {
var missing uint16
var noData []byte
for i := c.recvNO + 1; i <= no; i++ {
c.recvMutex.RLock()
_, ok := c.recvPool[i]
c.recvMutex.RUnlock()
if !ok {
missing++
noData = append(noData, binary.BigEndian.AppendUint32(nil, uint32(i))...)
}
}
if missing > uint16(c.window) {
slog.Debug("NCKOverflow", "missing", missing, "ackcount", c.window, "no", no, "recvno", c.recvNO)
return
}
c.send(c.buildFrame(2, no, uint16(missing*4), noData))
}
func (c *rdtConn) sendFIN() error {
exit := make(chan struct{})
go func() {
defer func() {
if err := recover(); err != nil {
slog.Debug("SendFIN", "recover", err)
}
}()
for range 5 {
select {
case <-exit:
return
default:
}
c.send(c.buildFrame(21, c.sentNO, 0, nil))
time.Sleep(50 * time.Millisecond)
}
c.finack <- 0
}()
for {
finack, ok := <-c.finack
if !ok {
return io.ErrClosedPipe
}
if finack == 0 {
return errors.New("timeout")
}
if finack != c.sentNO {
continue
}
close(exit)
return nil
}
}
func (c *rdtConn) runCheckLoop() {
for {
select {
case <-c.exit:
return
default:
}
time.Sleep(c.cfg.Interval)
c.sendMutex.RLock()
count := len(c.sendPool)
c.sendMutex.RUnlock()
if count > 0 {
c.askNCK(min(c.sentNO, c.ackNO+c.window))
}
}
}
func (c *rdtConn) runNCKLoop() {
for {
select {
case <-c.exit:
return
case nck, ok := <-c.nck:
if !ok {
return
}
c.resend(nck)
}
}
}
func (c *rdtConn) runNCKQueryLoop() {
for {
select {
case <-c.exit:
return
case q, ok := <-c.nckQuery:
if !ok {
return
}
c.sendNCK(q)
}
}
}
func (c *rdtConn) startEventLoopGroup() {
go c.runNCKLoop()
go c.runNCKQueryLoop()
go c.runCheckLoop()
}
// RDTListener reliable data transmission listener
type RDTListener struct {
cfg Config
c net.PacketConn
accept chan *rdtConn
acceptConnMap map[string]*rdtConn
acceptConnMapMutex sync.RWMutex
openConnMap map[string]*rdtConn
openConnMapMutex sync.RWMutex
exitSig chan struct{}
}
// Accept accept a connection from addr (0RTT)
func (l *RDTListener) Accept() (net.Conn, error) {
err := &net.OpError{
Op: "accept",
Net: l.c.LocalAddr().Network(),
Err: errors.New("closed"),
}
select {
case c, ok := <-l.accept:
if !ok {
return nil, err
}
return c, nil
case <-l.exitSig:
return nil, err
}
}
func (l *RDTListener) Addr() net.Addr {
return l.c.LocalAddr()
}
func (l *RDTListener) Close() error {
close(l.exitSig)
l.acceptConnMapMutex.RLock()
for _, v := range l.acceptConnMap {
v.Close()
}
l.acceptConnMapMutex.RUnlock()
l.openConnMapMutex.RLock()
for _, v := range l.openConnMap {
v.Close()
}
l.openConnMapMutex.RUnlock()
return l.c.Close()
}
// OpenStream open a connection to addr (0RTT)
func (l *RDTListener) OpenStream(addr net.Addr) (net.Conn, error) {
c := l.newConn(addr)
l.openConnMapMutex.Lock()
defer l.openConnMapMutex.Unlock()
l.openConnMap[addr.String()] = c
c.startEventLoopGroup()
return c, nil
}
func (l *RDTListener) runPacketReadLoop() {
buf := make([]byte, 1500)
for {
select {
case <-l.exitSig:
return
default:
}
n, addr, err := l.c.ReadFrom(buf)
if err != nil {
if strings.Contains(err.Error(), "use of closed network connection") {
return
}
panic(err)
}
if n < 7 {
slog.Error("RDT received invalid packet")
continue
}
l.recvPacket(append([]byte(nil), buf[:n]...), addr)
}
}
func (l *RDTListener) recvPacket(pkt []byte, addr net.Addr) {
if 0x80&pkt[0] == 0x80 {
pkt[0] = pkt[0] << 1 >> 1
l.openConnMapMutex.RLock()
conn, ok := l.openConnMap[addr.String()]
l.openConnMapMutex.RUnlock()
if ok {
conn.recv(pkt)
return
}
}
l.acceptConnMapMutex.RLock()
conn, ok := l.acceptConnMap[addr.String()]
l.acceptConnMapMutex.RUnlock()
if ok && conn.state.Load() < CLOSED {
conn.recv(pkt)
return
}
l.acceptConnMapMutex.Lock()
defer l.acceptConnMapMutex.Unlock()
conn, ok = l.acceptConnMap[addr.String()]
if ok && conn.state.Load() < CLOSED {
conn.recv(pkt)
return
}
if pkt[0] != 0 {
return
}
conn = l.newConn(addr)
conn.server = true
l.acceptConnMap[addr.String()] = conn
l.accept <- conn
conn.startEventLoopGroup()
conn.recv(pkt)
}
func (l *RDTListener) newConn(remoteAddr net.Addr) *rdtConn {
return &rdtConn{
cfg: l.cfg,
window: uint32((l.cfg.MTU - 7) / 4),
frameSize: l.cfg.MTU - 7,
c: l.c,
remoteAddr: remoteAddr,
exit: make(chan struct{}),
inbound: make(chan []byte, 1024),
nck: make(chan nck, 256),
nckQuery: make(chan uint32, 256),
fin: make(chan uint32, 5),
finack: make(chan uint32, 5),
sendEvent: make(chan struct{}),
recvPool: map[uint32][]byte{},
sendPool: map[uint32][]byte{},
wClosed: &net.OpError{
Op: "write",
Net: l.c.LocalAddr().Network(),
Source: l.c.LocalAddr(),
Addr: remoteAddr,
Err: errors.New("closed"),
},
}
}
func Listen(conn net.PacketConn, opts ...Option) (*RDTListener, error) {
cfg := Config{
MTU: 1428,
Interval: 100 * time.Millisecond,
}
for _, opt := range opts {
if err := opt(&cfg); err != nil {
return nil, err
}
}
l := RDTListener{
cfg: cfg,
c: conn,
accept: make(chan *rdtConn, 512),
openConnMap: map[string]*rdtConn{},
acceptConnMap: map[string]*rdtConn{},
exitSig: make(chan struct{}),
}
if len(cfg.StatsServerListen) > 0 {
httpListener, err := net.Listen("tcp", cfg.StatsServerListen)
if err != nil {
return nil, err
}
go runStatsHTTPServer(httpListener, &l)
}
go l.runPacketReadLoop()
return &l, nil
}