diff --git a/internal/staticsources/hls/source_test.go b/internal/staticsources/hls/source_test.go index 98fb6dba..6e98a1c6 100644 --- a/internal/staticsources/hls/source_test.go +++ b/internal/staticsources/hls/source_test.go @@ -102,14 +102,20 @@ func TestSource(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() + reloadConf := make(chan *conf.Path) + go func() { so.Run(defs.StaticSourceRunParams{ //nolint:errcheck Context: ctx, ResolvedSource: "http://localhost:5780/stream.m3u8", Conf: &conf.Path{}, + ReloadConf: reloadConf, }) close(done) }() <-p.Unit + + // the source must be listening on ReloadConf + reloadConf <- nil } diff --git a/internal/staticsources/mpegts/source.go b/internal/staticsources/mpegts/source.go index 4ff448c7..348a7d5b 100644 --- a/internal/staticsources/mpegts/source.go +++ b/internal/staticsources/mpegts/source.go @@ -66,15 +66,19 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { readerErr <- s.runReader(nc) }() - select { - case err = <-readerErr: - nc.Close() - return err + for { + select { + case err = <-readerErr: + nc.Close() + return err - case <-params.Context.Done(): - nc.Close() - <-readerErr - return fmt.Errorf("terminated") + case <-params.ReloadConf: + + case <-params.Context.Done(): + nc.Close() + <-readerErr + return fmt.Errorf("terminated") + } } } diff --git a/internal/staticsources/mpegts/source_test.go b/internal/staticsources/mpegts/source_test.go index 5637fe94..60e10d6b 100644 --- a/internal/staticsources/mpegts/source_test.go +++ b/internal/staticsources/mpegts/source_test.go @@ -69,11 +69,14 @@ func TestSourceUDP(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() + reloadConf := make(chan *conf.Path) + go func() { so.Run(defs.StaticSourceRunParams{ //nolint:errcheck Context: ctx, ResolvedSource: src, Conf: &conf.Path{}, + ReloadConf: reloadConf, }) close(done) }() @@ -128,6 +131,9 @@ func TestSourceUDP(t *testing.T) { require.NoError(t, err) <-p.Unit + + // the source must be listening on ReloadConf + reloadConf <- nil }) } } diff --git a/internal/staticsources/rtmp/source_test.go b/internal/staticsources/rtmp/source_test.go index 845ce0ac..16f7f402 100644 --- a/internal/staticsources/rtmp/source_test.go +++ b/internal/staticsources/rtmp/source_test.go @@ -51,46 +51,6 @@ func TestSource(t *testing.T) { defer ln.Close() - go func() { - for { - nconn, err := ln.Accept() - require.NoError(t, err) - defer nconn.Close() - - conn := &rtmp.ServerConn{ - RW: nconn, - } - err = conn.Initialize() - require.NoError(t, err) - - if auth == "auth" { - err = conn.CheckCredentials("myuser", "mypass") - if err != nil { - continue - } - } - - err = conn.Accept() - require.NoError(t, err) - - w := &rtmp.Writer{ - Conn: conn, - VideoTrack: test.FormatH264, - AudioTrack: test.FormatMPEG4Audio, - } - err = w.Initialize() - require.NoError(t, err) - - err = w.WriteH264(2*time.Second, 2*time.Second, [][]byte{{5, 2, 3, 4}}) - require.NoError(t, err) - - err = w.WriteH264(3*time.Second, 3*time.Second, [][]byte{{5, 2, 3, 4}}) - require.NoError(t, err) - - break - } - }() - var source string if encryption == "plain" { @@ -121,6 +81,8 @@ func TestSource(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() + reloadConf := make(chan *conf.Path) + go func() { so.Run(defs.StaticSourceRunParams{ //nolint:errcheck Context: ctx, @@ -128,11 +90,53 @@ func TestSource(t *testing.T) { Conf: &conf.Path{ SourceFingerprint: "33949E05FFFB5FF3E8AA16F8213A6251B4D9363804BA53233C4DA9A46D6F2739", }, + ReloadConf: reloadConf, }) close(done) }() + for { + nconn, err := ln.Accept() + require.NoError(t, err) + defer nconn.Close() + + conn := &rtmp.ServerConn{ + RW: nconn, + } + err = conn.Initialize() + require.NoError(t, err) + + if auth == "auth" { + err = conn.CheckCredentials("myuser", "mypass") + if err != nil { + continue + } + } + + err = conn.Accept() + require.NoError(t, err) + + w := &rtmp.Writer{ + Conn: conn, + VideoTrack: test.FormatH264, + AudioTrack: test.FormatMPEG4Audio, + } + err = w.Initialize() + require.NoError(t, err) + + err = w.WriteH264(2*time.Second, 2*time.Second, [][]byte{{5, 2, 3, 4}}) + require.NoError(t, err) + + err = w.WriteH264(3*time.Second, 3*time.Second, [][]byte{{5, 2, 3, 4}}) + require.NoError(t, err) + + break + } + <-p.Unit + + // the source must be listening on ReloadConf + reloadConf <- nil }) } } diff --git a/internal/staticsources/rtp/source.go b/internal/staticsources/rtp/source.go index e91068bd..55a1f337 100644 --- a/internal/staticsources/rtp/source.go +++ b/internal/staticsources/rtp/source.go @@ -80,15 +80,19 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { readerErr <- s.runReader(&desc, nc) }() - select { - case err = <-readerErr: - nc.Close() - return err + for { + select { + case err = <-readerErr: + nc.Close() + return err - case <-params.Context.Done(): - nc.Close() - <-readerErr - return fmt.Errorf("terminated") + case <-params.ReloadConf: + + case <-params.Context.Done(): + nc.Close() + <-readerErr + return fmt.Errorf("terminated") + } } } diff --git a/internal/staticsources/rtp/source_test.go b/internal/staticsources/rtp/source_test.go index ab1bf6ed..8c955de6 100644 --- a/internal/staticsources/rtp/source_test.go +++ b/internal/staticsources/rtp/source_test.go @@ -69,6 +69,8 @@ func TestSourceUDP(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() + reloadConf := make(chan *conf.Path) + go func() { so.Run(defs.StaticSourceRunParams{ //nolint:errcheck Context: ctx, @@ -83,6 +85,7 @@ func TestSourceUDP(t *testing.T) { "a=rtpmap:96 H264/90000\n" + "a=fmtp:96 profile-level-id=42e01e;packetization-mode=1\n", }, + ReloadConf: reloadConf, }) close(done) }() @@ -139,6 +142,9 @@ func TestSourceUDP(t *testing.T) { } <-p.Unit + + // the source must be listening on ReloadConf + reloadConf <- nil }) } } diff --git a/internal/staticsources/rtsp/source_test.go b/internal/staticsources/rtsp/source_test.go index 2f9f2609..f461972a 100644 --- a/internal/staticsources/rtsp/source_test.go +++ b/internal/staticsources/rtsp/source_test.go @@ -169,16 +169,22 @@ func TestSource(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() + reloadConf := make(chan *conf.Path) + go func() { so.Run(defs.StaticSourceRunParams{ //nolint:errcheck Context: ctx, ResolvedSource: ur, Conf: cnf, + ReloadConf: reloadConf, }) close(done) }() <-p.Unit + + // the source must be listening on ReloadConf + reloadConf <- nil }) } } diff --git a/internal/staticsources/srt/source_test.go b/internal/staticsources/srt/source_test.go index da943a5d..cc7561fe 100644 --- a/internal/staticsources/srt/source_test.go +++ b/internal/staticsources/srt/source_test.go @@ -1,7 +1,6 @@ package srt import ( - "bufio" "context" "testing" "time" @@ -20,42 +19,8 @@ func TestSource(t *testing.T) { require.NoError(t, err) defer ln.Close() - go func() { - req, err2 := ln.Accept2() - require.NoError(t, err2) - - require.Equal(t, "sidname", req.StreamId()) - err2 = req.SetPassphrase("ttest1234567") - require.NoError(t, err2) - - conn, err2 := req.Accept() - require.NoError(t, err2) - defer conn.Close() - - track := &mpegts.Track{ - Codec: &mpegts.CodecH264{}, - } - - bw := bufio.NewWriter(conn) - w := &mpegts.Writer{W: bw, Tracks: []*mpegts.Track{track}} - err2 = w.Initialize() - require.NoError(t, err2) - - err2 = w.WriteH264(track, 0, 0, [][]byte{{ // IDR - 5, 1, - }}) - require.NoError(t, err2) - - err2 = bw.Flush() - require.NoError(t, err2) - - // wait for internal SRT queue to be written - time.Sleep(500 * time.Millisecond) - }() - p := &test.StaticSourceParent{} p.Initialize() - defer p.Close() so := &Source{ ReadTimeout: conf.Duration(10 * time.Second), @@ -68,14 +33,50 @@ func TestSource(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() + reloadConf := make(chan *conf.Path) + go func() { so.Run(defs.StaticSourceRunParams{ //nolint:errcheck Context: ctx, ResolvedSource: "srt://127.0.0.1:9002?streamid=sidname&passphrase=ttest1234567", Conf: &conf.Path{}, + ReloadConf: reloadConf, }) close(done) }() + req, err2 := ln.Accept2() + require.NoError(t, err2) + + require.Equal(t, "sidname", req.StreamId()) + err2 = req.SetPassphrase("ttest1234567") + require.NoError(t, err2) + + conn, err2 := req.Accept() + require.NoError(t, err2) + defer conn.Close() + + track := &mpegts.Track{Codec: &mpegts.CodecH264{}} + + w := &mpegts.Writer{W: conn, Tracks: []*mpegts.Track{track}} + err2 = w.Initialize() + require.NoError(t, err2) + + err2 = w.WriteH264(track, 0, 0, [][]byte{{ // IDR + 5, 1, + }}) + require.NoError(t, err2) + + err = w.WriteH264(track, 0, 0, [][]byte{{ // non-IDR + 5, 2, + }}) + require.NoError(t, err) + <-p.Unit + + // the source must be listening on ReloadConf + reloadConf <- nil + + // stop test reader before 2nd H264 packet is received to avoid a crash + p.Close() } diff --git a/internal/staticsources/webrtc/source.go b/internal/staticsources/webrtc/source.go index 191fb5c0..6f565e5c 100644 --- a/internal/staticsources/webrtc/source.go +++ b/internal/staticsources/webrtc/source.go @@ -2,6 +2,8 @@ package webrtc import ( + "context" + "fmt" "net/http" "net/url" "strings" @@ -65,12 +67,12 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { if err != nil { return err } - defer client.Close() //nolint:errcheck var stream *stream.Stream medias, err := webrtc.ToStream(client.PeerConnection(), &stream) if err != nil { + client.Close() //nolint:errcheck return err } @@ -79,6 +81,7 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { GenerateRTPPackets: true, }) if rres.Err != nil { + client.Close() //nolint:errcheck return rres.Err } @@ -88,7 +91,26 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { client.StartReading() - return client.Wait(params.Context) + readErr := make(chan error) + + go func() { + readErr <- client.Wait(context.Background()) + }() + + for { + select { + case err = <-readErr: + client.Close() //nolint:errcheck + return err + + case <-params.ReloadConf: + + case <-params.Context.Done(): + client.Close() //nolint:errcheck + <-readErr + return fmt.Errorf("terminated") + } + } } // APISourceDescribe implements StaticSource. diff --git a/internal/staticsources/webrtc/source_test.go b/internal/staticsources/webrtc/source_test.go index b0330cd7..e6e88a1d 100644 --- a/internal/staticsources/webrtc/source_test.go +++ b/internal/staticsources/webrtc/source_test.go @@ -138,14 +138,20 @@ func TestSource(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() + reloadConf := make(chan *conf.Path) + go func() { so.Run(defs.StaticSourceRunParams{ //nolint:errcheck Context: ctx, ResolvedSource: "whep://localhost:9003/my/resource", Conf: &conf.Path{}, + ReloadConf: reloadConf, }) close(done) }() <-p.Unit + + // the source must be listening on ReloadConf + reloadConf <- nil }