diff --git a/session/session.go b/session/session.go index 66a1c8a4..d727de7f 100644 --- a/session/session.go +++ b/session/session.go @@ -1,6 +1,7 @@ package session import ( + "context" "sync" "time" @@ -16,8 +17,7 @@ type session struct { logger log.Logger - sessionActivate sync.Mutex - active bool + active bool location string peer string @@ -33,8 +33,8 @@ type session struct { txBitrate *average.SlidingWindow txBytes uint64 - tickerStop chan struct{} - sessionClose sync.Once + tickerStop context.CancelFunc + running bool lock sync.Mutex topRxBitrate float64 @@ -44,6 +44,9 @@ type session struct { } func (s *session) Init(id, reference string, closeCallback func(*session), inactive, timeout time.Duration, logger log.Logger) { + s.lock.Lock() + defer s.lock.Unlock() + s.id = id s.reference = reference s.createdAt = time.Now() @@ -71,35 +74,40 @@ func (s *session) Init(id, reference string, closeCallback func(*session), inact s.callback = func(s *session) {} } - s.tickerStop = make(chan struct{}) - s.sessionClose = sync.Once{} + s.running = true pendingTimeout := inactive if timeout < pendingTimeout { pendingTimeout = timeout } - s.lock.Lock() - defer s.lock.Unlock() - s.stale = time.AfterFunc(pendingTimeout, func() { s.close() }) } func (s *session) close() { - s.sessionClose.Do(func() { - s.lock.Lock() - s.stale.Stop() - s.lock.Unlock() + s.lock.Lock() + defer s.lock.Unlock() - s.closedAt = time.Now() + if !s.running { + return + } - close(s.tickerStop) - s.rxBitrate.Stop() - s.txBitrate.Stop() - go s.callback(s) - }) + s.running = false + + s.stale.Stop() + + s.closedAt = time.Now() + + if s.tickerStop != nil { + s.tickerStop() + s.tickerStop = nil + } + + s.rxBitrate.Stop() + s.txBitrate.Stop() + go s.callback(s) } func (s *session) Register(location, peer string) { @@ -116,15 +124,18 @@ func (s *session) Register(location, peer string) { } func (s *session) Activate() bool { - s.sessionActivate.Lock() - defer s.sessionActivate.Unlock() + s.lock.Lock() + defer s.lock.Unlock() if s.active { return false } s.active = true - go s.ticker() + ctx, cancel := context.WithCancel(context.Background()) + s.tickerStop = cancel + + go s.ticker(ctx) return true } @@ -237,13 +248,13 @@ func (s *session) SetTopTxBitrate(bitrate float64) { s.topTxBitrate = bitrate } -func (s *session) ticker() { +func (s *session) ticker(ctx context.Context) { ticker := time.NewTicker(time.Second) defer ticker.Stop() for { select { - case <-s.tickerStop: + case <-ctx.Done(): return case <-ticker.C: s.lock.Lock()