From c8d56fdcfc7f828ae08ada235d4b851b982c47d7 Mon Sep 17 00:00:00 2001 From: csznet Date: Mon, 8 Apr 2024 21:19:49 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=B8=BBwg?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- forward/forward.go | 5 +---- main.go | 6 ++++-- utils/utils.go | 10 ++++++++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/forward/forward.go b/forward/forward.go index c9b4600..0bd6d3d 100644 --- a/forward/forward.go +++ b/forward/forward.go @@ -37,8 +37,7 @@ var bufPool = sync.Pool{ } // 开启转发,负责分发具体转发 -func Run(stats *ConnectionStats, wg *sync.WaitGroup) { - defer wg.Done() +func Run(stats *ConnectionStats) { defer releaseResources(stats) // 在函数返回时释放资源 var ctx, cancel = context.WithCancel(context.Background()) var innerWg sync.WaitGroup @@ -165,14 +164,12 @@ func (cs *ConnectionStats) handleTCPConnection(wg *sync.WaitGroup, clientConn ne } } }() - copyWG.Wait() } // UDP转发 func (cs *ConnectionStats) handleUDPConnection(wg *sync.WaitGroup, localConn *net.UDPConn, remoteAddr *net.UDPAddr, ctx context.Context) { defer wg.Done() - for { select { case <-ctx.Done(): diff --git a/main.go b/main.go index 2567d9f..20e2775 100644 --- a/main.go +++ b/main.go @@ -49,10 +49,12 @@ func main() { } // 设置 WaitGroup 计数为连接数 conf.Wg.Add(len(largeStats.Connections)) - // 并发执行多个转发 for _, stats := range largeStats.Connections { - go forward.Run(stats, &conf.Wg) + go func(s *forward.ConnectionStats) { + forward.Run(s) + conf.Wg.Done() + }(stats) } conf.Wg.Wait() defer close(conf.Ch) diff --git a/utils/utils.go b/utils/utils.go index 569d4c1..b254a6b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -28,7 +28,10 @@ func AddForward(newF conf.ConnectionStats) bool { TotalBytesLock: sync.Mutex{}, } conf.Wg.Add(1) - go forward.Run(stats, &conf.Wg) + go func() { + forward.Run(stats) + conf.Wg.Done() + }() return true } return false @@ -62,7 +65,10 @@ func ExStatus(f conf.ConnectionStats) bool { TotalBytesLock: sync.Mutex{}, } conf.Wg.Add(1) - go forward.Run(stats, &conf.Wg) + go func() { + forward.Run(stats) + conf.Wg.Done() + }() return true } else { conf.Ch <- f.LocalPort + f.Protocol