diff --git a/rtmp/rtmp_server_test.go b/rtmp/rtmp_server_test.go index 7290e13..f3b7ab4 100644 --- a/rtmp/rtmp_server_test.go +++ b/rtmp/rtmp_server_test.go @@ -1,10 +1,24 @@ package rtmp import ( + "github.com/yangjiechina/avformat/utils" + "github.com/yangjiechina/live-server/stream" "net" "testing" ) +func CreateTransStream(protocol stream.Protocol, streams []utils.AVStream) stream.ITransStream { + if stream.ProtocolRtmp == protocol { + return &TransStream{} + } + + return nil +} + +func init() { + stream.TransStreamFactory = CreateTransStream +} + func TestServer(t *testing.T) { impl := serverImpl{} addr := "0.0.0.0:1935" diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index 9b269bc..6c1246f 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -16,6 +16,9 @@ type Session interface { func NewSession(conn net.Conn) Session { impl := &sessionImpl{} + impl.Protocol = stream.ProtocolRtmpStr + impl.RemoteAddr = conn.RemoteAddr().String() + stack := librtmp.NewStack(impl) impl.stack = stack impl.conn = conn @@ -29,13 +32,11 @@ type sessionImpl struct { //publisher/sink handle interface{} conn net.Conn - - streamId string } func (s *sessionImpl) OnPublish(app, stream_ string, response chan utils.HookState) { - s.streamId = app + "/" + stream_ - publisher := NewPublisher(s.streamId) + s.SessionImpl.Stream = app + "/" + stream_ + publisher := NewPublisher(s.SessionImpl.Stream) s.stack.SetOnPublishHandler(publisher) s.stack.SetOnTransDeMuxerHandler(publisher) //stream.SessionImpl统一处理, Source是否已经存在, Hook回调.... @@ -48,7 +49,7 @@ func (s *sessionImpl) OnPublish(app, stream_ string, response chan utils.HookSta } func (s *sessionImpl) OnPlay(app, stream_ string, response chan utils.HookState) { - s.streamId = app + "/" + stream_ + s.SessionImpl.Stream = app + "/" + stream_ sink := NewSink(stream.GenerateSinkId(s.conn), s.conn) s.SessionImpl.OnPlay(sink, nil, func() { diff --git a/stream/session.go b/stream/session.go index 22900f0..b14b4a4 100644 --- a/stream/session.go +++ b/stream/session.go @@ -18,21 +18,21 @@ type Session interface { type SessionImpl struct { hookImpl - stream string //stream id - protocol string //推拉流协议 - remoteAddr string //peer地址 + Stream string //stream id + Protocol string //推拉流协议 + RemoteAddr string //peer地址 } // AddInfoParams 为每个需要通知的时间添加必要的信息 func (s *SessionImpl) AddInfoParams(data map[string]interface{}) { - data["stream"] = s.stream - data["protocol"] = s.protocol - data["remoteAddr"] = s.remoteAddr + data["stream"] = s.Stream + data["protocol"] = s.Protocol + data["remoteAddr"] = s.RemoteAddr } func (s *SessionImpl) OnPublish(source_ ISource, pra map[string]interface{}, success func(), failure func(state utils.HookState)) { //streamId 已经被占用 - source := SourceManager.Find(s.stream) + source := SourceManager.Find(s.Stream) if source != nil { failure(utils.HookStateOccupy) return @@ -76,11 +76,11 @@ func (s *SessionImpl) OnPublishDone() { func (s *SessionImpl) OnPlay(sink ISink, pra map[string]interface{}, success func(), failure func(state utils.HookState)) { f := func() { - source := SourceManager.Find(s.stream) + source := SourceManager.Find(s.Stream) if source == nil { - AddSinkToWaitingQueue(s.stream, nil) + AddSinkToWaitingQueue(s.Stream, sink) } else { - source.AddSink(nil) + source.AddSink(sink) } } diff --git a/stream/sink.go b/stream/sink.go index 0af9e54..df8d12c 100644 --- a/stream/sink.go +++ b/stream/sink.go @@ -54,8 +54,14 @@ func GenerateSinkId(conn net.Conn) SinkId { return conn.RemoteAddr().String() } -func AddSinkToWaitingQueue(streamId string, sink ISink) { +var waitingSinks map[string]map[SinkId]ISink +func init() { + waitingSinks = make(map[string]map[SinkId]ISink, 1024) +} + +func AddSinkToWaitingQueue(streamId string, sink ISink) { + waitingSinks[streamId][sink.Id()] = sink } func RemoveSinkFromWaitingQueue(streamId, sinkId SinkId) ISink { @@ -63,14 +69,24 @@ func RemoveSinkFromWaitingQueue(streamId, sinkId SinkId) ISink { } func PopWaitingSinks(streamId string) []ISink { - return nil + source, ok := waitingSinks[streamId] + if !ok { + return nil + } + + sinks := make([]ISink, len(source)) + var index = 0 + for _, sink := range source { + sinks[index] = sink + } + return sinks } type SinkImpl struct { - Id_ SinkId - sourceId string - Protocol_ Protocol - enableVideo bool + Id_ SinkId + sourceId string + Protocol_ Protocol + disableVideo bool DesiredAudioCodecId_ utils.AVCodecID DesiredVideoCodecId_ utils.AVCodecID @@ -111,11 +127,11 @@ func (s *SinkImpl) SetState(state int) { } func (s *SinkImpl) EnableVideo() bool { - return s.enableVideo + return !s.disableVideo } func (s *SinkImpl) SetEnableVideo(enable bool) { - s.enableVideo = enable + s.disableVideo = !enable } func (s *SinkImpl) DesiredAudioCodecId() utils.AVCodecID { diff --git a/stream/source.go b/stream/source.go index 278de4f..aa5d522 100644 --- a/stream/source.go +++ b/stream/source.go @@ -23,6 +23,8 @@ const ( ProtocolRtsp = Protocol(3) ProtocolHls = Protocol(4) ProtocolRtc = Protocol(5) + + ProtocolRtmpStr = "rtmp" ) // SessionState 推拉流Session状态 @@ -173,10 +175,13 @@ func (s *SourceImpl) AddSink(sink ISink) bool { index++ } - transStreamId := GenerateTransStreamId(sink.Protocol(), streams[:]...) + transStreamId := GenerateTransStreamId(sink.Protocol(), streams[:index]...) transStream, ok := s.transStreams[transStreamId] - if ok { - transStream = TransStreamFactory(sink.Protocol(), streams[:]) + if !ok { + transStream = TransStreamFactory(sink.Protocol(), streams[:index]) + if s.transStreams == nil { + s.transStreams = make(map[TransStreamId]ITransStream, 10) + } s.transStreams[transStreamId] = transStream for i := 0; i < index; i++ { @@ -206,9 +211,9 @@ func (s *SourceImpl) OnDeMuxStream(stream utils.AVStream) { s.originStreams.Add(stream) s.allStreams.Add(stream) - if len(s.originStreams.All()) == 1 { - s.probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, s.writeHeader) - } + //if len(s.originStreams.All()) == 1 { + // s.probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, s.writeHeader) + //} //为每个Stream创建对于的Buffer if AppConfig.GOPCache > 0 { @@ -220,7 +225,9 @@ func (s *SourceImpl) OnDeMuxStream(stream utils.AVStream) { // 从DeMuxer解析完Stream后, 处理等待Sinks func (s *SourceImpl) writeHeader() { utils.Assert(!s.completed) - s.probeTimer.Stop() + if s.probeTimer != nil { + s.probeTimer.Stop() + } s.completed = true sinks := PopWaitingSinks(s.Id_)