diff --git a/hls/hls_sink.go b/hls/hls_sink.go index ff0abcc..dbed775 100644 --- a/hls/hls_sink.go +++ b/hls/hls_sink.go @@ -84,13 +84,13 @@ func (s *M3U8Sink) RefreshPlayTime() { } func (s *M3U8Sink) Close() { + s.BaseSink.Close() + stream.SinkManager.Remove(s.ID) + if s.playTimer != nil { s.playTimer.Stop() s.playTimer = nil } - - stream.SinkManager.Remove(s.ID) - s.BaseSink.Close() } func NewM3U8Sink(id stream.SinkID, sourceId string, cb func(m3u8 []byte), sessionId string) stream.Sink { diff --git a/rtc/rtc_sink.go b/rtc/rtc_sink.go index 11351b7..53df783 100644 --- a/rtc/rtc_sink.go +++ b/rtc/rtc_sink.go @@ -132,12 +132,12 @@ func (s *Sink) StartStreaming(transStream stream.TransStream) error { } func (s *Sink) Close() { + s.BaseSink.Close() + if s.peer != nil { s.peer.Close() s.peer = nil } - - s.BaseSink.Close() } func (s *Sink) Write(index int, data [][]byte, ts int64) error { diff --git a/rtmp/rtmp_sink.go b/rtmp/rtmp_sink.go index 9ae96e7..f1188e2 100644 --- a/rtmp/rtmp_sink.go +++ b/rtmp/rtmp_sink.go @@ -22,8 +22,8 @@ func (s *Sink) StopStreaming(stream stream.TransStream) { } func (s *Sink) Close() { - s.stack = nil s.BaseSink.Close() + s.stack = nil } func NewSink(id stream.SinkID, sourceId string, conn net.Conn, stack *librtmp.Stack) stream.Sink { diff --git a/stream/sink.go b/stream/sink.go index 0189be4..e539904 100644 --- a/stream/sink.go +++ b/stream/sink.go @@ -187,40 +187,28 @@ func (s *BaseSink) DesiredVideoCodecId() utils.AVCodecID { func (s *BaseSink) Close() { log.Sugar.Debugf("closing the %s sink. id: %s. current session state: %s", s.Protocol, SinkId2String(s.ID), s.State) - if SessionStateClosed == s.State { - return - } - - if s.Conn != nil { - s.Conn.Close() - s.Conn = nil - } - - // Sink未添加到任何队列, 不做处理 - if s.State < SessionStateWaiting { - return - } - - // 更新Sink状态 - var state SessionState - { - s.Lock() - defer s.UnLock() - if s.State == SessionStateClosed { - return - } - - state = s.State + s.Lock() + defer func() { s.State = SessionStateClosed - } + s.UnLock() - if state == SessionStateTransferring { + // 最后断开网络连接, 确保从source删除sink之前, 推流是安全的. + if s.Conn != nil { + s.Conn.Close() + s.Conn = nil + } + }() + + // 已经关闭或Sink未添加到任何队列, 不做处理 + if SessionStateClosed == s.State || s.State < SessionStateWaiting { + return + } else if s.State == SessionStateTransferring { // 从source中删除sink, 如果source为nil, 已经结束推流. if source := SourceManager.Find(s.SourceID); source != nil { source.RemoveSink(s) } - } else if state == SessionStateWaiting { - // 从等待队列中删除Sink + } else if s.State == SessionStateWaiting { + // 从等待队列中删除sink RemoveSinkFromWaitingQueue(s.SourceID, s.ID) go HookPlayDoneEvent(s) diff --git a/stream/source.go b/stream/source.go index 05a9c38..a32b775 100644 --- a/stream/source.go +++ b/stream/source.go @@ -45,7 +45,7 @@ type Source interface { // 匹配拉流期望的编码器, 创建TransStream或向已经存在TransStream添加Sink AddSink(sink Sink) - // RemoveSink 删除Sink + // RemoveSink 同步删除Sink RemoveSink(sink Sink) RemoveSinkWithID(id SinkID) @@ -441,18 +441,13 @@ func (s *PublishSource) doAddSink(sink Sink) bool { sink.SetTransStreamID(transStreamId) - err := sink.StartStreaming(transStream) - if err != nil { - log.Sugar.Errorf("添加sink失败,开始推流发生err: %s sink: %s source: %s ", err.Error(), SinkId2String(sink.GetID()), s.ID) - return false - } - { sink.Lock() defer sink.UnLock() if SessionStateClosed == sink.GetState() { log.Sugar.Warnf("添加sink失败, sink已经断开连接 %s", sink.String()) + return false } else { sink.SetState(SessionStateTransferring) } @@ -469,6 +464,12 @@ func (s *PublishSource) doAddSink(sink Sink) bool { log.Sugar.Infof("sink count: %d source: %s", s.sinkCount, s.ID) } + err := sink.StartStreaming(transStream) + if err != nil { + log.Sugar.Errorf("添加sink失败,开始推流发生err: %s sink: %s source: %s ", err.Error(), SinkId2String(sink.GetID()), s.ID) + return false + } + s.sinks[sink.GetID()] = sink s.TransStreamSinks[transStreamId][sink.GetID()] = sink @@ -510,9 +511,16 @@ func (s *PublishSource) AddSink(sink Sink) { } func (s *PublishSource) RemoveSink(sink Sink) { + group := sync.WaitGroup{} + group.Add(1) + s.PostEvent(func() { s.doRemoveSink(sink) + + group.Done() }) + + group.Wait() } func (s *PublishSource) RemoveSinkWithID(id SinkID) { diff --git a/stream/source_utils.go b/stream/source_utils.go index 9f5c096..6382e48 100644 --- a/stream/source_utils.go +++ b/stream/source_utils.go @@ -19,7 +19,6 @@ type SourceType byte type TransStreamProtocol uint32 // SessionState 推拉流Session的状态 -// 包含握手和Hook授权阶段 type SessionState uint32 const (