diff --git a/core/engine/conn.go b/core/engine/conn.go index 9dbdfec..ed963f3 100644 --- a/core/engine/conn.go +++ b/core/engine/conn.go @@ -52,7 +52,7 @@ func (e *Engine) addConnByID(id string) (PacketChan, error) { return conn, nil } - peerChan := make(chan Payload, ChanSize) + peerChan := make(PacketChan, ChanSize) e.routeTable.id.Store(id, peerChan) go func() { @@ -109,10 +109,11 @@ func (e *Engine) addConn(peerChan PacketChan, id string) { return } - buff := make([]byte, len(msg)) - copy(buff, msg) + payload := e.payloadPool.Get() + payload.Data = e.bufferPool.Get(len(msg)) + copy(payload.Data, msg) mr.ReleaseMsg(msg) - e.devWriter <- Payload{Data: buff} + e.devWriter <- payload } }() @@ -120,6 +121,8 @@ func (e *Engine) addConn(peerChan PacketChan, id string) { select { case payload := <-peerChan: err := mw.WriteMsg(payload.Data) + e.bufferPool.Put(payload.Data) + e.payloadPool.Put(payload) if err != nil { e.log.Errorf("Peer [%s] write msg error: %s", id, err) return diff --git a/core/engine/engine.go b/core/engine/engine.go index 2bd8c37..651ff1a 100644 --- a/core/engine/engine.go +++ b/core/engine/engine.go @@ -4,8 +4,10 @@ import ( "context" "net/netip" + pool "github.com/libp2p/go-buffer-pool" "github.com/libp2p/go-msgio" "github.com/wlynxg/NetHive/core/route" + "github.com/wlynxg/NetHive/pkgs/xpool" "github.com/wlynxg/NetHive/core/config" "github.com/wlynxg/NetHive/core/device" @@ -27,7 +29,7 @@ const ( VPNStreamProtocol = "/NetHive/vpn" ) -type PacketChan chan Payload +type PacketChan chan *Payload type Engine struct { log *mlog.Logger @@ -48,6 +50,9 @@ type Engine struct { devReader PacketChan errChan chan error + bufferPool *pool.BufferPool + payloadPool xpool.Pool[*Payload] + routeTable struct { m xsync.Map[string, netip.Prefix] id xsync.Map[string, PacketChan] @@ -69,6 +74,11 @@ func Run(ctx context.Context, cfg *config.Config) (*Engine, error) { e.devWriter = make(PacketChan, ChanSize) e.devReader = make(PacketChan, ChanSize) + e.bufferPool = &pool.BufferPool{} + e.payloadPool = xpool.New[*Payload](func() *Payload { + return &Payload{} + }) + pk, err := cfg.PrivateKey.PrivKey() if err != nil { return nil, err @@ -220,10 +230,11 @@ func (e *Engine) VPNHandler(stream network.Stream) { return } - buff := make([]byte, len(msg)) - copy(buff, msg) + payload := e.payloadPool.Get() + payload.Data = e.bufferPool.Get(len(msg)) + copy(payload.Data, msg) mr.ReleaseMsg(msg) - e.devWriter <- Payload{Data: buff} + e.devWriter <- payload } }() @@ -231,9 +242,11 @@ func (e *Engine) VPNHandler(stream network.Stream) { select { case payload := <-peerChan: err := mw.WriteMsg(payload.Data) + e.bufferPool.Put(payload.Data) + e.payloadPool.Put(payload) if err != nil { e.log.Errorf("Peer [%s] write msg error: %s", id, err) - return + continue } } } diff --git a/core/engine/routine.go b/core/engine/routine.go index bafc99a..70e873f 100644 --- a/core/engine/routine.go +++ b/core/engine/routine.go @@ -1,24 +1,27 @@ package engine import ( - "fmt" - "github.com/wlynxg/NetHive/core/protocol" "net/netip" + + "github.com/wlynxg/NetHive/core/protocol" ) // RoutineTUNReader loop to read packets from TUN func (e *Engine) RoutineTUNReader() { var ( - buff = make([]byte, BuffSize) + buff []byte err error n int ) for { + buff = e.bufferPool.Get(BuffSize) n, err = e.device.Read(buff) if err != nil { - e.errChan <- fmt.Errorf("[RoutineTUNReader]: %s", err) - return + e.bufferPool.Put(buff) + e.log.Warnf("[RoutineTUNReader]: %s", err) + continue } + ip, err := protocol.ParseIP(buff[:n]) if err != nil { e.log.Warnf("[RoutineTUNReader] drop packet, because %s", err) @@ -30,16 +33,16 @@ func (e *Engine) RoutineTUNReader() { continue } - payload := Payload{ - Src: ip.Src(), - Dst: ip.Dst(), - Data: make([]byte, n), - } - copy(payload.Data, buff[:n]) + payload := e.payloadPool.Get() + payload.Src = ip.Src() + payload.Dst = ip.Dst() + payload.Data = buff[:n] select { case e.devReader <- payload: default: e.log.Warnf("[RoutineTUNReader] drop packet: %s, because the sending queue is already full", payload.Dst) + e.bufferPool.Put(payload.Data) + e.payloadPool.Put(payload) } } } @@ -47,12 +50,15 @@ func (e *Engine) RoutineTUNReader() { // RoutineTUNWriter loop writing packets to TUN func (e *Engine) RoutineTUNWriter() { var ( - payload Payload + payload *Payload err error ) for payload = range e.devWriter { _, err = e.device.Write(payload.Data) + e.bufferPool.Put(payload.Data) + e.payloadPool.Put(payload) + if err != nil { e.log.Errorf("[RoutineTUNWriter]: %s", err) e.log.Errorf("[err packet]: %v", payload.Data) @@ -63,7 +69,7 @@ func (e *Engine) RoutineTUNWriter() { // RoutineRouteTableWriter loop sending the data packet to the corresponding channel according to the routing table func (e *Engine) RoutineRouteTableWriter() { var ( - payload Payload + payload *Payload ok bool conn PacketChan ) @@ -81,7 +87,6 @@ func (e *Engine) RoutineRouteTableWriter() { defer e.routeTable.addr.Delete(value.Addr()) e.addConn(conn, key) }() - } select { case conn <- payload: @@ -111,6 +116,8 @@ func (e *Engine) RoutineRouteTableWriter() { case conn <- payload: default: e.log.Warnf("[RoutineRouteTableWriter] drop packet: %s, because the sending queue is already full", payload.Dst) + e.bufferPool.Put(payload.Data) + e.payloadPool.Put(payload) } } } diff --git a/pkgs/xpool/xpool.go b/pkgs/xpool/xpool.go new file mode 100644 index 0000000..a978760 --- /dev/null +++ b/pkgs/xpool/xpool.go @@ -0,0 +1,21 @@ +package xpool + +import ( + "sync" +) + +type Pool[T any] struct { + pool sync.Pool +} + +func New[T any](fn func() T) Pool[T] { + return Pool[T]{ + pool: sync.Pool{New: func() interface{} { return fn() }}, + } +} +func (p *Pool[T]) Get() T { + return p.pool.Get().(T) +} +func (p *Pool[T]) Put(x T) { + p.pool.Put(x) +}