diff --git a/pkg/core/track.go b/pkg/core/track.go index b585fa98..d3f1467d 100644 --- a/pkg/core/track.go +++ b/pkg/core/track.go @@ -3,6 +3,7 @@ package core import ( "encoding/json" "errors" + "github.com/pion/rtp" ) @@ -70,9 +71,8 @@ type Sender struct { Packets int `json:"packets,omitempty"` Drops int `json:"drops,omitempty"` - buf chan *Packet - done chan struct{} - isClosed bool + buf chan *Packet + done chan struct{} } func NewSender(media *Media, codec *Codec) *Sender { @@ -99,11 +99,6 @@ func NewSender(media *Media, codec *Codec) *Sender { s.Input = func(packet *Packet) { // writing to nil chan - OK, writing to closed chan - panic s.mu.Lock() - if s.isClosed { - s.Drops++ - s.mu.Unlock() - return - } select { case s.buf <- packet: s.Bytes += len(packet.Payload) @@ -145,6 +140,7 @@ func (s *Sender) Start() { s.done = make(chan struct{}) go func() { + // for range on nil chan is OK for packet := range s.buf { s.Output(packet) } @@ -153,7 +149,7 @@ func (s *Sender) Start() { } func (s *Sender) Wait() { - if done := s.done; s.done != nil { + if done := s.done; done != nil { <-done } } @@ -171,10 +167,9 @@ func (s *Sender) State() string { func (s *Sender) Close() { // close buffer if exists s.mu.Lock() - if buf := s.buf; buf != nil && !s.isClosed { - s.isClosed = true - s.buf = nil - defer close(buf) + if s.buf != nil { + close(s.buf) // exit from for range loop + s.buf = nil // prevent writing to closed chan } s.mu.Unlock() diff --git a/pkg/core/track_test.go b/pkg/core/track_test.go new file mode 100644 index 00000000..cf877d49 --- /dev/null +++ b/pkg/core/track_test.go @@ -0,0 +1,53 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSenser(t *testing.T) { + recv := make(chan *Packet) // blocking receiver + + sender := NewSender(nil, &Codec{}) + sender.Output = func(packet *Packet) { + recv <- packet + } + require.Equal(t, "new", sender.State()) + + sender.Start() + require.Equal(t, "connected", sender.State()) + + sender.Input(&Packet{}) + sender.Input(&Packet{}) + + require.Equal(t, 2, sender.Packets) + require.Equal(t, 0, sender.Drops) + + // important to read one before close + // because goroutine in Start() can run with nil chan + // it's OK in real life, but bad for test + _, ok := <-recv + require.True(t, ok) + + sender.Close() + require.Equal(t, "closed", sender.State()) + + sender.Input(&Packet{}) + + require.Equal(t, 2, sender.Packets) + require.Equal(t, 1, sender.Drops) + + // read 2nd + _, ok = <-recv + require.True(t, ok) + + // read 3rd + select { + case <-recv: + ok = true + default: + ok = false + } + require.False(t, ok) +}