diff --git a/stream/source.go b/stream/source.go index a98fc9f..3beeddd 100644 --- a/stream/source.go +++ b/stream/source.go @@ -80,6 +80,8 @@ type Source interface { // PostEvent 切换到主协程执行当前函数 PostEvent(cb func()) + ExecuteSyncEvent(cb func()) + // LastPacketTime 返回最近收流时间戳 LastPacketTime() time.Time @@ -484,16 +486,9 @@ func (s *PublishSource) AddSink(sink Sink) { } func (s *PublishSource) RemoveSink(sink Sink) { - group := sync.WaitGroup{} - group.Add(1) - - s.PostEvent(func() { + s.ExecuteSyncEvent(func() { s.doRemoveSink(sink) - - group.Done() }) - - group.Wait() } func (s *PublishSource) RemoveSinkWithID(id SinkID) { @@ -507,19 +502,13 @@ func (s *PublishSource) RemoveSinkWithID(id SinkID) { func (s *PublishSource) FindSink(id SinkID) Sink { var result Sink - group := sync.WaitGroup{} - group.Add(1) - - s.PostEvent(func() { + s.ExecuteSyncEvent(func() { sink, ok := s.sinks[id] if ok { result = sink } - - group.Done() }) - group.Wait() return result } @@ -663,16 +652,9 @@ func (s *PublishSource) Close() { } // 同步执行, 确保close后, 主协程已经退出, 不会再处理任何推拉流、查询等任何事情. - group := sync.WaitGroup{} - group.Add(1) - - s.PostEvent(func() { + s.ExecuteSyncEvent(func() { s.DoClose() - - group.Done() }) - - group.Wait() } // 解析完所有track后, 创建各种输出流 @@ -886,6 +868,18 @@ func (s *PublishSource) PostEvent(cb func()) { s.mainContextEvents <- cb } +func (s *PublishSource) ExecuteSyncEvent(cb func()) { + group := sync.WaitGroup{} + group.Add(1) + + s.PostEvent(func() { + cb() + group.Done() + }) + + group.Wait() +} + func (s *PublishSource) CreateTime() time.Time { return s.createTime } @@ -897,17 +891,12 @@ func (s *PublishSource) SetCreateTime(time time.Time) { func (s *PublishSource) Sinks() []Sink { var sinks []Sink - group := sync.WaitGroup{} - group.Add(1) - s.PostEvent(func() { + s.ExecuteSyncEvent(func() { for _, sink := range s.sinks { sinks = append(sinks, sink) } - - group.Done() }) - group.Wait() return sinks } @@ -924,5 +913,7 @@ func (s *PublishSource) GetStreamEndInfo() *StreamEndInfo { } func (s *PublishSource) ProbeTimeout() { - s.TransDemuxer.ProbeComplete() + if s.TransDemuxer != nil { + s.TransDemuxer.ProbeComplete() + } } diff --git a/stream/source_utils.go b/stream/source_utils.go index 529649a..4fd8663 100644 --- a/stream/source_utils.go +++ b/stream/source_utils.go @@ -328,7 +328,7 @@ func LoopEvent(source Source) { } var ok bool - source.PostEvent(func() { + source.ExecuteSyncEvent(func() { source.ProbeTimeout() ok = len(source.OriginTracks()) > 0 })