diff --git a/connmux/README.md b/connmux/README.md new file mode 100644 index 0000000..15b6911 --- /dev/null +++ b/connmux/README.md @@ -0,0 +1,41 @@ +# connmux +A connection multiplexing library + +### Example +#### client +``` +c, err := net.Dial("tcp", "192.168.3.99:7676") +if err != nil { + panic(err) +} + +session := connmux.Mux(c, connmux.DefaultGenerateSeq) +defer session.Close() +for { + muxC, err := session.Accept() + if err != nil { + panic(err) + } + go handleConn(muxC) +} +``` +#### server +``` +l, err := net.Listen("tcp", ":7676") +if err != nil { + panic(err) +} +c, err := l.Accept() +if err != nil { + panic(err) +} +session := connmux.Mux(c, nil) +defer session.Close() + +muxConn, err := session.OpenStream() +if err != nil { + panic(err) +} + +// ... +``` \ No newline at end of file diff --git a/connmux/connmux.go b/connmux/connmux.go new file mode 100644 index 0000000..4c12d04 --- /dev/null +++ b/connmux/connmux.go @@ -0,0 +1,280 @@ +package connmux + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log/slog" + "net" + "sync" + "sync/atomic" + "time" +) + +type GenerateSeq func() uint32 + +var ( + defaultSeq atomic.Uint32 + DefaultGenerateSeq = func() uint32 { + return defaultSeq.Add(1) + } +) + +type muxConn struct { + exit chan struct{} + inbound chan []byte + seq uint32 + s *MuxSession + + buf []byte +} + +func (c *muxConn) Read(b []byte) (n int, err error) { + select { + case <-c.exit: + return 0, io.ErrClosedPipe + default: + } + + if c.buf != nil { + n = copy(b, c.buf) + if n < len(c.buf) { + c.buf = c.buf[n:] + } else { + c.buf = nil + } + return + } + + wsb, ok := <-c.inbound + if !ok { + return 0, io.EOF + } + n = copy(b, wsb) + if n < len(wsb) { + c.buf = wsb[n:] + } + return +} + +func (c *muxConn) Write(p []byte) (int, error) { + select { + case <-c.exit: + return 0, io.ErrClosedPipe + default: + } + b := []byte{0, 0} + b = append(b, binary.BigEndian.AppendUint16(nil, uint16(len(p)))...) + b = append(b, binary.BigEndian.AppendUint32(nil, c.seq)...) + b = append(b, p...) + c.s.mut.Lock() + defer c.s.mut.Unlock() + n, err := c.s.c.Write(b) + if err != nil { + return max(0, n-8), err + } + return max(0, n-8), nil +} + +func (c *muxConn) Close() error { + b := []byte{0, 1} // FIN + b = append(b, binary.BigEndian.AppendUint16(nil, uint16(0))...) + b = append(b, binary.BigEndian.AppendUint32(nil, c.seq)...) + if _, err := c.s.c.Write(b); err != nil { + slog.Warn("MuxConnFIN", "err", err) + } + c.s.mut.Lock() + delete(c.s.dials, c.seq) + c.s.mut.Unlock() + c.close() + slog.Debug("ClientSideMuxConnClosed", "seq", c.seq) + return nil +} + +func (c *muxConn) close() { + close(c.exit) +} + +// LocalAddr returns the local network address, if known. +func (c *muxConn) LocalAddr() net.Addr { + if la, ok := c.s.c.(interface{ LocalAddr() net.Addr }); ok { + return la.LocalAddr() + } + return nil +} + +// RemoteAddr returns the remote network address, if known. +func (c *muxConn) RemoteAddr() net.Addr { + if la, ok := c.s.c.(interface{ RemoteAddr() net.Addr }); ok { + return la.RemoteAddr() + } + return nil +} + +func (c *muxConn) SetDeadline(t time.Time) error { + err1 := c.SetReadDeadline(t) + err2 := c.SetWriteDeadline(t) + return errors.Join(err1, err2) +} + +// SetReadDeadline sets the deadline for future Read calls +// and any currently-blocked Read call. +// A zero value for t means Read will not time out. +func (c *muxConn) SetReadDeadline(t time.Time) error { + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls +// and any currently-blocked Write call. +// Even if write times out, it may return n > 0, indicating that +// some of the data was successfully written. +// A zero value for t means Write will not time out. +func (c *muxConn) SetWriteDeadline(t time.Time) error { + return nil +} + +type MuxSession struct { + mut sync.Mutex + closeOnce sync.Once + closed atomic.Bool + exit chan struct{} + accept chan net.Conn + generateSeq GenerateSeq + c io.ReadWriteCloser + accepts map[uint32]*muxConn + dials map[uint32]*muxConn +} + +// Accept waits for and returns the next connection to the listener. +func (l *MuxSession) Accept() (net.Conn, error) { + select { + case <-l.exit: + return nil, io.ErrClosedPipe + case c, ok := <-l.accept: + if ok { + return c, nil + } + return nil, io.ErrClosedPipe + } +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *MuxSession) Close() error { + l.closeOnce.Do(func() { + close(l.exit) + close(l.accept) + l.closed.Store(true) + }) + return l.c.Close() +} + +func (l *MuxSession) Closed() bool { + return l.closed.Load() +} + +// Addr returns the listener's network address. +func (l *MuxSession) Addr() net.Addr { + if la, ok := l.c.(interface{ LocalAddr() net.Addr }); ok { + return la.LocalAddr() + } + return nil +} + +func (l *MuxSession) run() { + for { + select { + case <-l.exit: + return + default: + } + header := make([]byte, 8) + _, err := io.ReadFull(l.c, header) + if err != nil { + err = fmt.Errorf("read header: %w", err) + slog.Debug("MuxSessionClosed", "err", err) + l.Close() + return + } + if header[0] != 0 { + err = fmt.Errorf("unsupport connmux version %d", header[0]) + slog.Debug("MuxSessionClosed", "err", err) + l.Close() + return + } + + length := binary.BigEndian.Uint16(header[2:4]) + seq := binary.BigEndian.Uint32(header[4:8]) + cmd := header[1] + slog.Debug("ReadHeader", "header", header) + + data := make([]byte, length) + _, err = io.ReadFull(l.c, data) + if err != nil { + err = fmt.Errorf("read data: %w", err) + slog.Debug("MuxSessionClosed", "err", err) + l.Close() + return + } + if cmd == 0 { + if c, ok := l.dials[seq]; ok { + c.inbound <- data + continue + } + if c, ok := l.accepts[seq]; ok { + c.inbound <- data + continue + } + l.accepts[seq] = &muxConn{ + exit: make(chan struct{}), + inbound: make(chan []byte, 128), + seq: seq, + s: l, + } + l.accept <- l.accepts[seq] + l.accepts[seq].inbound <- data + continue + } + + if cmd == 1 { + if c, ok := l.accepts[seq]; ok { + c.close() + delete(l.accepts, seq) + slog.Debug("ServerSideMuxConnClosed", "seq", c.seq) + } + continue + } + err = fmt.Errorf("unsupport connmux cmd %d", cmd) + slog.Error("MuxSessionClosed", "err", err) + l.Close() + return + } +} + +func (d *MuxSession) OpenStream() (net.Conn, error) { + if d.generateSeq == nil { + return nil, errors.New("seq generator is nil") + } + c := &muxConn{ + exit: make(chan struct{}), + inbound: make(chan []byte, 128), + seq: d.generateSeq(), + s: d, + } + d.dials[c.seq] = c + return c, nil +} + +func Mux(conn io.ReadWriteCloser, generateSeq GenerateSeq) *MuxSession { + l := &MuxSession{ + exit: make(chan struct{}), + c: conn, + generateSeq: generateSeq, + accept: make(chan net.Conn), + accepts: make(map[uint32]*muxConn), + dials: make(map[uint32]*muxConn), + } + go l.run() + return l +} diff --git a/peermap/peermap.go b/peermap/peermap.go index 86a838d..39f5123 100644 --- a/peermap/peermap.go +++ b/peermap/peermap.go @@ -32,9 +32,10 @@ var ( ) type Peer struct { - conn *websocket.Conn - exitSig chan struct{} - peerMap *PeerMap + conn *websocket.Conn + exitSig chan struct{} + closeOnce sync.Once + peerMap *PeerMap networkSecret auth.JSONSecret networkContext *networkContext @@ -110,8 +111,10 @@ func (p *Peer) Write(b []byte) (n int, err error) { } func (p *Peer) Close() error { - close(p.exitSig) - close(p.connData) + p.closeOnce.Do(func() { + close(p.exitSig) + close(p.connData) + }) return p.close() } @@ -200,6 +203,10 @@ func (p *Peer) readMessageLoop() { } else if p.networkContext.ratelimiter != nil { p.networkContext.ratelimiter.WaitN(context.Background(), len(b)) } + if b[0] == peer.CONTROL_CONN { + p.connData <- b[1:] + continue + } tgtPeerID := peer.ID(b[2 : b[1]+2]) slog.Debug("PeerEvent", "op", b[0], "from", p.id, "to", tgtPeerID) tgtPeer, err := p.peerMap.getPeer(p.networkSecret.Network, tgtPeerID) @@ -207,20 +214,17 @@ func (p *Peer) readMessageLoop() { slog.Debug("FindPeer failed", "detail", err) continue } - switch b[0] { - case peer.CONTROL_LEAD_DISCO: + if b[0] == peer.CONTROL_LEAD_DISCO { p.leadDisco(tgtPeer) - case peer.CONTROL_CONN: - p.connData <- b[1:] - default: - data := b[b[1]+2:] - bb := make([]byte, 2+len(p.id)+len(data)) - bb[0] = b[0] - bb[1] = p.id.Len() - copy(bb[2:p.id.Len()+2], p.id.Bytes()) - copy(bb[p.id.Len()+2:], data) - _ = tgtPeer.write(bb) + continue } + data := b[b[1]+2:] + bb := make([]byte, 2+len(p.id)+len(data)) + bb[0] = b[0] + bb[1] = p.id.Len() + copy(bb[2:p.id.Len()+2], p.id.Bytes()) + copy(bb[p.id.Len()+2:], data) + _ = tgtPeer.write(bb) } }