Safer Event Callbacks

Resolves #218

Change Event Callback APIs to setter functions which take care of
locking so that users don't need to know about or remember
to do this.
This commit is contained in:
Michael MacDonald
2018-11-07 10:01:54 -05:00
parent d3984899d1
commit d5cf800ebb
15 changed files with 420 additions and 131 deletions

View File

@@ -32,14 +32,12 @@ func main() {
// Set the handler for ICE connection state // Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected // This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState ice.ConnectionState) {
fmt.Printf("ICE Connection State has changed: %s\n", connectionState.String()) fmt.Printf("ICE Connection State has changed: %s\n", connectionState.String())
} })
dataChannel.Lock()
// Register channel opening handling // Register channel opening handling
dataChannel.OnOpen = func() { dataChannel.OnOpen(func() {
fmt.Printf("Data channel '%s'-'%d' open. Random messages will now be sent to any connected DataChannels every 5 seconds\n", dataChannel.Label, dataChannel.ID) fmt.Printf("Data channel '%s'-'%d' open. Random messages will now be sent to any connected DataChannels every 5 seconds\n", dataChannel.Label, dataChannel.ID)
for { for {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
@@ -49,10 +47,10 @@ func main() {
err := dataChannel.Send(datachannel.PayloadString{Data: []byte(message)}) err := dataChannel.Send(datachannel.PayloadString{Data: []byte(message)})
util.Check(err) util.Check(err)
} }
} })
// Register the Onmessage to handle incoming messages // Register the OnMessage to handle incoming messages
dataChannel.Onmessage = func(payload datachannel.Payload) { dataChannel.OnMessage(func(payload datachannel.Payload) {
switch p := payload.(type) { switch p := payload.(type) {
case *datachannel.PayloadString: case *datachannel.PayloadString:
fmt.Printf("Message '%s' from DataChannel '%s' payload '%s'\n", p.PayloadType().String(), dataChannel.Label, string(p.Data)) fmt.Printf("Message '%s' from DataChannel '%s' payload '%s'\n", p.PayloadType().String(), dataChannel.Label, string(p.Data))
@@ -61,9 +59,7 @@ func main() {
default: default:
fmt.Printf("Message '%s' from DataChannel '%s' no payload \n", p.PayloadType().String(), dataChannel.Label) fmt.Printf("Message '%s' from DataChannel '%s' no payload \n", p.PayloadType().String(), dataChannel.Label)
} }
} })
dataChannel.Unlock()
// Create an offer to send to the browser // Create an offer to send to the browser
offer, err := peerConnection.CreateOffer(nil) offer, err := peerConnection.CreateOffer(nil)

View File

@@ -28,19 +28,16 @@ func main() {
// Set the handler for ICE connection state // Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected // This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState ice.ConnectionState) {
fmt.Printf("ICE Connection State has changed: %s\n", connectionState.String()) fmt.Printf("ICE Connection State has changed: %s\n", connectionState.String())
} })
// Register data channel creation handling // Register data channel creation handling
peerConnection.OnDataChannel = func(d *webrtc.RTCDataChannel) { peerConnection.OnDataChannel(func(d *webrtc.RTCDataChannel) {
fmt.Printf("New DataChannel %s %d\n", d.Label, d.ID) fmt.Printf("New DataChannel %s %d\n", d.Label, d.ID)
d.Lock()
defer d.Unlock()
// Register channel opening handling // Register channel opening handling
d.OnOpen = func() { d.OnOpen(func() {
fmt.Printf("Data channel '%s'-'%d' open. Random messages will now be sent to any connected DataChannels every 5 seconds\n", d.Label, d.ID) fmt.Printf("Data channel '%s'-'%d' open. Random messages will now be sent to any connected DataChannels every 5 seconds\n", d.Label, d.ID)
for { for {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
@@ -50,10 +47,10 @@ func main() {
err := d.Send(datachannel.PayloadString{Data: []byte(message)}) err := d.Send(datachannel.PayloadString{Data: []byte(message)})
util.Check(err) util.Check(err)
} }
} })
// Register message handling // Register message handling
d.Onmessage = func(payload datachannel.Payload) { d.OnMessage(func(payload datachannel.Payload) {
switch p := payload.(type) { switch p := payload.(type) {
case *datachannel.PayloadString: case *datachannel.PayloadString:
fmt.Printf("Message '%s' from DataChannel '%s' payload '%s'\n", p.PayloadType().String(), d.Label, string(p.Data)) fmt.Printf("Message '%s' from DataChannel '%s' payload '%s'\n", p.PayloadType().String(), d.Label, string(p.Data))
@@ -62,8 +59,8 @@ func main() {
default: default:
fmt.Printf("Message '%s' from DataChannel '%s' no payload \n", p.PayloadType().String(), d.Label) fmt.Printf("Message '%s' from DataChannel '%s' no payload \n", p.PayloadType().String(), d.Label)
} }
} })
} })
// Wait for the offer to be pasted // Wait for the offer to be pasted
sd := util.Decode(util.MustReadStdin()) sd := util.Decode(util.MustReadStdin())

View File

@@ -31,7 +31,7 @@ func main() {
// Set a handler for when a new remote track starts, this handler creates a gstreamer pipeline // Set a handler for when a new remote track starts, this handler creates a gstreamer pipeline
// for the given codec // for the given codec
peerConnection.OnTrack = func(track *webrtc.RTCTrack) { peerConnection.OnTrack(func(track *webrtc.RTCTrack) {
codec := track.Codec codec := track.Codec
fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType, codec.Name) fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType, codec.Name)
pipeline := gst.CreatePipeline(codec.Name) pipeline := gst.CreatePipeline(codec.Name)
@@ -40,13 +40,13 @@ func main() {
p := <-track.Packets p := <-track.Packets
pipeline.Push(p.Raw) pipeline.Push(p.Raw)
} }
} })
// Set the handler for ICE connection state // Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected // This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState ice.ConnectionState) {
fmt.Printf("Connection State has changed %s \n", connectionState.String()) fmt.Printf("Connection State has changed %s \n", connectionState.String())
} })
// Wait for the offer to be pasted // Wait for the offer to be pasted
sd := util.Decode(util.MustReadStdin()) sd := util.Decode(util.MustReadStdin())

View File

@@ -31,9 +31,9 @@ func main() {
// Set the handler for ICE connection state // Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected // This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState ice.ConnectionState) {
fmt.Printf("Connection State has changed %s \n", connectionState.String()) fmt.Printf("Connection State has changed %s \n", connectionState.String())
} })
// Create a audio track // Create a audio track
opusTrack, err := peerConnection.NewRTCSampleTrack(webrtc.DefaultPayloadTypeOpus, "audio", "pion1") opusTrack, err := peerConnection.NewRTCSampleTrack(webrtc.DefaultPayloadTypeOpus, "audio", "pion1")

View File

@@ -31,9 +31,9 @@ func main() {
// Set the handler for ICE connection state // Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected // This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState ice.ConnectionState) {
fmt.Printf("Connection State has changed %s \n", connectionState.String()) fmt.Printf("Connection State has changed %s \n", connectionState.String())
} })
// Create a audio track // Create a audio track
opusTrack, err := peerConnection.NewRTCTrack(webrtc.DefaultPayloadTypeOpus, "audio", "pion1") opusTrack, err := peerConnection.NewRTCTrack(webrtc.DefaultPayloadTypeOpus, "audio", "pion1")

View File

@@ -52,11 +52,11 @@ func main() {
peerConnection, err := webrtc.New(config) peerConnection, err := webrtc.New(config)
util.Check(err) util.Check(err)
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState ice.ConnectionState) {
fmt.Printf("Connection State has changed %s \n", connectionState.String()) fmt.Printf("Connection State has changed %s \n", connectionState.String())
} })
peerConnection.OnTrack = func(track *webrtc.RTCTrack) { peerConnection.OnTrack(func(track *webrtc.RTCTrack) {
if track.Codec.Name == webrtc.Opus { if track.Codec.Name == webrtc.Opus {
return return
} }
@@ -68,7 +68,7 @@ func main() {
err = i.AddPacket(<-track.Packets) err = i.AddPacket(<-track.Packets)
util.Check(err) util.Check(err)
} }
} })
// Janus // Janus
gateway, err := janus.Connect("ws://localhost:8188/") gateway, err := janus.Connect("ws://localhost:8188/")

View File

@@ -34,19 +34,16 @@ func main() {
// Set the handler for ICE connection state // Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected // This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState ice.ConnectionState) {
fmt.Printf("ICE Connection State has changed: %s\n", connectionState.String()) fmt.Printf("ICE Connection State has changed: %s\n", connectionState.String())
} })
// Register data channel creation handling // Register data channel creation handling
peerConnection.OnDataChannel = func(d *webrtc.RTCDataChannel) { peerConnection.OnDataChannel(func(d *webrtc.RTCDataChannel) {
fmt.Printf("New DataChannel %s %d\n", d.Label, d.ID) fmt.Printf("New DataChannel %s %d\n", d.Label, d.ID)
d.Lock()
defer d.Unlock()
// Register channel opening handling // Register channel opening handling
d.OnOpen = func() { d.OnOpen(func() {
fmt.Printf("Data channel '%s'-'%d' open. Random messages will now be sent to any connected DataChannels every 5 seconds\n", d.Label, d.ID) fmt.Printf("Data channel '%s'-'%d' open. Random messages will now be sent to any connected DataChannels every 5 seconds\n", d.Label, d.ID)
for range time.NewTicker(5 * time.Second).C { for range time.NewTicker(5 * time.Second).C {
@@ -56,10 +53,10 @@ func main() {
err := d.Send(datachannel.PayloadString{Data: []byte(message)}) err := d.Send(datachannel.PayloadString{Data: []byte(message)})
util.Check(err) util.Check(err)
} }
} })
// Register message handling // Register message handling
d.Onmessage = func(payload datachannel.Payload) { d.OnMessage(func(payload datachannel.Payload) {
switch p := payload.(type) { switch p := payload.(type) {
case *datachannel.PayloadString: case *datachannel.PayloadString:
fmt.Printf("Message '%s' from DataChannel '%s' payload '%s'\n", p.PayloadType().String(), d.Label, string(p.Data)) fmt.Printf("Message '%s' from DataChannel '%s' payload '%s'\n", p.PayloadType().String(), d.Label, string(p.Data))
@@ -68,8 +65,8 @@ func main() {
default: default:
fmt.Printf("Message '%s' from DataChannel '%s' no payload \n", p.PayloadType().String(), d.Label) fmt.Printf("Message '%s' from DataChannel '%s' no payload \n", p.PayloadType().String(), d.Label)
} }
} })
} })
// Exchange the offer/answer via HTTP // Exchange the offer/answer via HTTP
offerChan, answerChan := mustSignalViaHTTP(*addr) offerChan, answerChan := mustSignalViaHTTP(*addr)

View File

@@ -39,14 +39,12 @@ func main() {
// Set the handler for ICE connection state // Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected // This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState ice.ConnectionState) {
fmt.Printf("ICE Connection State has changed: %s\n", connectionState.String()) fmt.Printf("ICE Connection State has changed: %s\n", connectionState.String())
} })
dataChannel.Lock()
// Register channel opening handling // Register channel opening handling
dataChannel.OnOpen = func() { dataChannel.OnOpen(func() {
fmt.Printf("Data channel '%s'-'%d' open. Random messages will now be sent to any connected DataChannels every 5 seconds\n", dataChannel.Label, dataChannel.ID) fmt.Printf("Data channel '%s'-'%d' open. Random messages will now be sent to any connected DataChannels every 5 seconds\n", dataChannel.Label, dataChannel.ID)
for range time.NewTicker(5 * time.Second).C { for range time.NewTicker(5 * time.Second).C {
@@ -56,10 +54,10 @@ func main() {
err := dataChannel.Send(datachannel.PayloadString{Data: []byte(message)}) err := dataChannel.Send(datachannel.PayloadString{Data: []byte(message)})
util.Check(err) util.Check(err)
} }
} })
// Register the Onmessage to handle incoming messages // Register the OnMessage to handle incoming messages
dataChannel.Onmessage = func(payload datachannel.Payload) { dataChannel.OnMessage(func(payload datachannel.Payload) {
switch p := payload.(type) { switch p := payload.(type) {
case *datachannel.PayloadString: case *datachannel.PayloadString:
fmt.Printf("Message '%s' from DataChannel '%s' payload '%s'\n", p.PayloadType().String(), dataChannel.Label, string(p.Data)) fmt.Printf("Message '%s' from DataChannel '%s' payload '%s'\n", p.PayloadType().String(), dataChannel.Label, string(p.Data))
@@ -68,9 +66,7 @@ func main() {
default: default:
fmt.Printf("Message '%s' from DataChannel '%s' no payload \n", p.PayloadType().String(), dataChannel.Label) fmt.Printf("Message '%s' from DataChannel '%s' no payload \n", p.PayloadType().String(), dataChannel.Label)
} }
} })
dataChannel.Unlock()
// Create an offer to send to the browser // Create an offer to send to the browser
offer, err := peerConnection.CreateOffer(nil) offer, err := peerConnection.CreateOffer(nil)

View File

@@ -33,7 +33,7 @@ func main() {
// Set a handler for when a new remote track starts, this handler saves buffers to disk as // Set a handler for when a new remote track starts, this handler saves buffers to disk as
// an ivf file, since we could have multiple video tracks we provide a counter. // an ivf file, since we could have multiple video tracks we provide a counter.
// In your application this is where you would handle/process video // In your application this is where you would handle/process video
peerConnection.OnTrack = func(track *webrtc.RTCTrack) { peerConnection.OnTrack(func(track *webrtc.RTCTrack) {
if track.Codec.Name == webrtc.VP8 { if track.Codec.Name == webrtc.VP8 {
fmt.Println("Got VP8 track, saving to disk as output.ivf") fmt.Println("Got VP8 track, saving to disk as output.ivf")
i, err := ivfwriter.New("output.ivf") i, err := ivfwriter.New("output.ivf")
@@ -43,13 +43,13 @@ func main() {
util.Check(err) util.Check(err)
} }
} }
} })
// Set the handler for ICE connection state // Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected // This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState ice.ConnectionState) {
fmt.Printf("Connection State has changed %s \n", connectionState.String()) fmt.Printf("Connection State has changed %s \n", connectionState.String())
} })
// Wait for the offer to be pasted // Wait for the offer to be pasted
sd := util.Decode(util.MustReadStdin()) sd := util.Decode(util.MustReadStdin())

View File

@@ -63,7 +63,7 @@ func main() {
var outboundSamplesLock sync.RWMutex var outboundSamplesLock sync.RWMutex
// Set a handler for when a new remote track starts, this just distributes all our packets // Set a handler for when a new remote track starts, this just distributes all our packets
// to connected peers // to connected peers
peerConnection.OnTrack = func(track *webrtc.RTCTrack) { peerConnection.OnTrack(func(track *webrtc.RTCTrack) {
// Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval
// This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it // This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it
go func() { go func() {
@@ -91,7 +91,7 @@ func main() {
} }
outboundSamplesLock.RUnlock() outboundSamplesLock.RUnlock()
} }
} })
// Set the remote SessionDescription // Set the remote SessionDescription
check(peerConnection.SetRemoteDescription(webrtc.RTCSessionDescription{ check(peerConnection.SetRemoteDescription(webrtc.RTCSessionDescription{

View File

@@ -5,11 +5,6 @@ import (
"github.com/pions/webrtc/pkg/rtp" "github.com/pions/webrtc/pkg/rtp"
) )
// RTCSample contains media, and the amount of samples in it
//
// Deprecated: use RTCSample from github.com/pions/webrtc/pkg/media instead
type RTCSample = media.RTCSample
// RTCTrack represents a track that is communicated // RTCTrack represents a track that is communicated
type RTCTrack struct { type RTCTrack struct {
ID string ID string

View File

@@ -89,24 +89,67 @@ type RTCDataChannel struct {
// OnError func() // OnError func()
// OnClose func() // OnClose func()
// Onmessage designates an event handler which is invoked on a message onMessageHandler func(datachannel.Payload)
// arrival over the sctp transport from a remote peer. onOpenHandler func()
//
// Deprecated: use OnMessage instead.
Onmessage func(datachannel.Payload)
// OnMessage designates an event handler which is invoked on a message
// arrival over the sctp transport from a remote peer.
OnMessage func(datachannel.Payload)
// OnOpen designates an event handler which is invoked when
// the underlying data transport has been established (or re-established).
OnOpen func()
// Deprecated: Will be removed when networkManager is deprecated. // Deprecated: Will be removed when networkManager is deprecated.
rtcPeerConnection *RTCPeerConnection rtcPeerConnection *RTCPeerConnection
} }
// OnOpen sets an event handler which is invoked when
// the underlying data transport has been established (or re-established).
func (d *RTCDataChannel) OnOpen(f func()) {
d.Lock()
defer d.Unlock()
d.onOpenHandler = f
}
func (d *RTCDataChannel) onOpen() (done chan struct{}) {
d.RLock()
hdlr := d.onOpenHandler
d.RUnlock()
done = make(chan struct{})
if hdlr == nil {
close(done)
return
}
go func() {
hdlr()
close(done)
}()
return
}
// OnMessage sets an event handler which is invoked on a message
// arrival over the sctp transport from a remote peer.
func (d *RTCDataChannel) OnMessage(f func(p datachannel.Payload)) {
d.Lock()
defer d.Unlock()
d.onMessageHandler = f
}
func (d *RTCDataChannel) onMessage(p datachannel.Payload) {
d.RLock()
hdlr := d.onMessageHandler
d.RUnlock()
if hdlr == nil || p == nil {
return
}
hdlr(p)
}
// Onmessage sets an event handler which is invoked on a message
// arrival over the sctp transport from a remote peer.
//
// Deprecated: use OnMessage instead.
func (d *RTCDataChannel) Onmessage(f func(p datachannel.Payload)) {
d.OnMessage(f)
}
// func (d *RTCDataChannel) generateID() error { // func (d *RTCDataChannel) generateID() error {
// // TODO: base on DTLS role, currently static at "true". // // TODO: base on DTLS role, currently static at "true".
// client := true // client := true
@@ -141,12 +184,3 @@ func (d *RTCDataChannel) Send(p datachannel.Payload) error {
} }
return nil return nil
} }
func (d *RTCDataChannel) doOnOpen() {
d.RLock()
onOpen := d.OnOpen
d.RUnlock()
if onOpen != nil {
onOpen()
}
}

View File

@@ -1,7 +1,14 @@
package webrtc package webrtc
import ( import (
"crypto/rand"
"encoding/binary"
"math/big"
"testing" "testing"
"time"
"github.com/pions/webrtc/pkg/datachannel"
"github.com/stretchr/testify/assert"
) )
func TestGenerateDataChannelID(t *testing.T) { func TestGenerateDataChannelID(t *testing.T) {
@@ -33,3 +40,99 @@ func TestGenerateDataChannelID(t *testing.T) {
} }
} }
} }
func TestRTCDataChannel_EventHandlers(t *testing.T) {
dc := &RTCDataChannel{}
onOpenCalled := make(chan bool)
onMessageCalled := make(chan bool)
// Verify that the noop case works
assert.NotPanics(t, func() { dc.onOpen() })
assert.NotPanics(t, func() { dc.onMessage(nil) })
dc.OnOpen(func() {
onOpenCalled <- true
})
dc.OnMessage(func(p datachannel.Payload) {
go func() {
onMessageCalled <- true
}()
})
// Verify that the handlers deal with nil inputs
assert.NotPanics(t, func() { dc.onMessage(nil) })
// Verify that the set handlers are called
assert.NotPanics(t, func() { dc.onOpen() })
assert.NotPanics(t, func() { dc.onMessage(&datachannel.PayloadString{Data: []byte("o hai")}) })
allTrue := func(vals []bool) bool {
for _, val := range vals {
if !val {
return false
}
}
return true
}
assert.True(t, allTrue([]bool{
<-onOpenCalled,
<-onMessageCalled,
}))
}
func TestRTCDataChannel_MessagesAreOrdered(t *testing.T) {
dc := &RTCDataChannel{}
max := 512
out := make(chan int)
inner := func(p datachannel.Payload) {
// randomly sleep
// NB: The big.Int/crypto.Rand is overkill but makes the linter happy
randInt, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
if err != nil {
t.Fatalf("Failed to get random sleep duration: %s", err)
}
time.Sleep(time.Duration(randInt.Int64()) * time.Microsecond)
switch p := p.(type) {
case *datachannel.PayloadBinary:
s, _ := binary.Varint(p.Data)
out <- int(s)
}
}
dc.OnMessage(func(p datachannel.Payload) {
inner(p)
})
go func() {
for i := 1; i <= max; i++ {
buf := make([]byte, 8)
binary.PutVarint(buf, int64(i))
dc.onMessage(&datachannel.PayloadBinary{Data: buf})
// Change the registered handler a couple of times to make sure
// that everything continues to work, we don't lose messages, etc.
if i%2 == 0 {
hdlr := func(p datachannel.Payload) {
inner(p)
}
dc.OnMessage(hdlr)
}
}
}()
values := make([]int, 0, max)
for v := range out {
values = append(values, v)
if len(values) == max {
close(out)
}
}
expected := make([]int, max)
for i := 1; i <= max; i++ {
expected[i-1] = i
}
assert.EqualValues(t, expected, values)
}

View File

@@ -100,20 +100,12 @@ type RTCPeerConnection struct {
// OnIceCandidateError func() // FIXME NOT-USED // OnIceCandidateError func() // FIXME NOT-USED
// OnSignalingStateChange func() // FIXME NOT-USED // OnSignalingStateChange func() // FIXME NOT-USED
// OnIceConnectionStateChange designates an event handler which is called
// when an ice connection state is changed.
OnICEConnectionStateChange func(ice.ConnectionState)
// OnIceGatheringStateChange func() // FIXME NOT-USED // OnIceGatheringStateChange func() // FIXME NOT-USED
// OnConnectionStateChange func() // FIXME NOT-USED // OnConnectionStateChange func() // FIXME NOT-USED
// OnTrack designates an event handler which is called when remote track onICEConnectionStateChangeHandler func(ice.ConnectionState)
// arrives from a remote peer. onTrackHandler func(*RTCTrack)
OnTrack func(*RTCTrack) onDataChannelHandler func(*RTCDataChannel)
// OnDataChannel designates an event handler which is invoked when a data
// channel message arrives from a remote peer.
OnDataChannel func(*RTCDataChannel)
// Deprecated: Internal mechanism which will be removed. // Deprecated: Internal mechanism which will be removed.
networkManager *network.Manager networkManager *network.Manager
@@ -233,6 +225,90 @@ func (pc *RTCPeerConnection) initConfiguration(configuration RTCConfiguration) e
return nil return nil
} }
// OnDataChannel sets an event handler which is invoked when a data
// channel message arrives from a remote peer.
func (pc *RTCPeerConnection) OnDataChannel(f func(*RTCDataChannel)) {
pc.Lock()
defer pc.Unlock()
pc.onDataChannelHandler = f
}
func (pc *RTCPeerConnection) onDataChannel(dc *RTCDataChannel) (done chan struct{}) {
pc.RLock()
hdlr := pc.onDataChannelHandler
pc.RUnlock()
done = make(chan struct{})
if hdlr == nil || dc == nil {
close(done)
return
}
// Run this synchronously to allow setup done in onDataChannelFn()
// to complete before datachannel event handlers might be called.
go func() {
hdlr(dc)
dc.onOpen() // TODO: move to ChannelAck handling
close(done)
}()
return
}
// OnTrack sets an event handler which is called when remote track
// arrives from a remote peer.
func (pc *RTCPeerConnection) OnTrack(f func(*RTCTrack)) {
pc.Lock()
defer pc.Unlock()
pc.onTrackHandler = f
}
func (pc *RTCPeerConnection) onTrack(t *RTCTrack) (done chan struct{}) {
pc.RLock()
hdlr := pc.onTrackHandler
pc.RUnlock()
done = make(chan struct{})
if hdlr == nil || t == nil {
close(done)
return
}
go func() {
hdlr(t)
close(done)
}()
return
}
// OnICEConnectionStateChange sets an event handler which is called
// when an ICE connection state is changed.
func (pc *RTCPeerConnection) OnICEConnectionStateChange(f func(ice.ConnectionState)) {
pc.Lock()
defer pc.Unlock()
pc.onICEConnectionStateChangeHandler = f
}
func (pc *RTCPeerConnection) onICEConnectionStateChange(cs ice.ConnectionState) (done chan struct{}) {
pc.RLock()
hdlr := pc.onICEConnectionStateChangeHandler
pc.RUnlock()
done = make(chan struct{})
if hdlr == nil {
close(done)
return
}
go func() {
hdlr(cs)
close(done)
}()
return
}
// SetConfiguration updates the configuration of this RTCPeerConnection object. // SetConfiguration updates the configuration of this RTCPeerConnection object.
func (pc *RTCPeerConnection) SetConfiguration(configuration RTCConfiguration) error { func (pc *RTCPeerConnection) SetConfiguration(configuration RTCConfiguration) error {
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration (step #2) // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration (step #2)
@@ -764,9 +840,12 @@ func (pc *RTCPeerConnection) Close() error {
/* Everything below is private */ /* Everything below is private */
func (pc *RTCPeerConnection) generateChannel(ssrc uint32, payloadType uint8) (buffers chan<- *rtp.Packet) { func (pc *RTCPeerConnection) generateChannel(ssrc uint32, payloadType uint8) (buffers chan<- *rtp.Packet) {
if pc.OnTrack == nil { pc.RLock()
if pc.onTrackHandler == nil {
pc.RUnlock()
return nil return nil
} }
pc.RUnlock()
sdpCodec, err := pc.CurrentLocalDescription.parsed.GetCodecForPayloadType(payloadType) sdpCodec, err := pc.CurrentLocalDescription.parsed.GetCodecForPayloadType(payloadType)
if err != nil { if err != nil {
@@ -794,54 +873,43 @@ func (pc *RTCPeerConnection) generateChannel(ssrc uint32, payloadType uint8) (bu
// TODO: Register the receiving Track // TODO: Register the receiving Track
go pc.OnTrack(track) pc.onTrack(track)
return bufferTransport return bufferTransport
} }
func (pc *RTCPeerConnection) iceStateChange(newState ice.ConnectionState) { func (pc *RTCPeerConnection) iceStateChange(newState ice.ConnectionState) {
pc.Lock() pc.Lock()
defer pc.Unlock()
if pc.OnICEConnectionStateChange != nil {
pc.OnICEConnectionStateChange(newState)
}
pc.IceConnectionState = newState pc.IceConnectionState = newState
pc.Unlock()
pc.onICEConnectionStateChange(newState)
} }
func (pc *RTCPeerConnection) dataChannelEventHandler(e network.DataChannelEvent) { func (pc *RTCPeerConnection) dataChannelEventHandler(e network.DataChannelEvent) {
pc.Lock()
defer pc.Unlock()
switch event := e.(type) { switch event := e.(type) {
case *network.DataChannelCreated: case *network.DataChannelCreated:
id := event.StreamIdentifier() id := event.StreamIdentifier()
newDataChannel := &RTCDataChannel{ID: &id, Label: event.Label, rtcPeerConnection: pc, ReadyState: RTCDataChannelStateOpen} newDataChannel := &RTCDataChannel{ID: &id, Label: event.Label, rtcPeerConnection: pc, ReadyState: RTCDataChannelStateOpen}
pc.Lock()
pc.dataChannels[e.StreamIdentifier()] = newDataChannel pc.dataChannels[e.StreamIdentifier()] = newDataChannel
if pc.OnDataChannel != nil { pc.Unlock()
go func() {
pc.OnDataChannel(newDataChannel) // This should actually be called when processing the SDP answer. // NB: We block here waiting for the callback to finish before
if newDataChannel.OnOpen != nil { // proceeding, in order to guarantee that all user setup of the channel
go newDataChannel.doOnOpen() // has completed before moving on to process more events.
} <-pc.onDataChannel(newDataChannel)
}()
} else {
fmt.Println("OnDataChannel is unset, discarding message")
}
case *network.DataChannelMessage: case *network.DataChannelMessage:
if datachannel, ok := pc.dataChannels[e.StreamIdentifier()]; ok { pc.RLock()
datachannel.RLock() if dc, ok := pc.dataChannels[e.StreamIdentifier()]; ok {
defer datachannel.RUnlock() pc.RUnlock()
dc.onMessage(event.Payload)
if datachannel.Onmessage != nil {
go datachannel.Onmessage(event.Payload)
} else {
fmt.Printf("Onmessage has not been set for Datachannel %s %d \n", datachannel.Label, e.StreamIdentifier())
}
} else { } else {
pc.RUnlock()
fmt.Printf("No datachannel found for streamIdentifier %d \n", e.StreamIdentifier()) fmt.Printf("No datachannel found for streamIdentifier %d \n", e.StreamIdentifier())
} }
case *network.DataChannelOpen: case *network.DataChannelOpen:
pc.RLock()
defer pc.RUnlock()
for _, dc := range pc.dataChannels { for _, dc := range pc.dataChannels {
dc.Lock() dc.Lock()
err := dc.sendOpenChannelMessage() err := dc.sendOpenChannelMessage()
@@ -853,7 +921,7 @@ func (pc *RTCPeerConnection) dataChannelEventHandler(e network.DataChannelEvent)
dc.ReadyState = RTCDataChannelStateOpen dc.ReadyState = RTCDataChannelStateOpen
dc.Unlock() dc.Unlock()
go dc.doOnOpen() // TODO: move to ChannelAck handling dc.onOpen() // TODO: move to ChannelAck handling
} }
default: default:
fmt.Printf("Unhandled DataChannelEvent %v \n", event) fmt.Printf("Unhandled DataChannelEvent %v \n", event)

View File

@@ -9,6 +9,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/pions/webrtc/internal/network"
"github.com/pions/webrtc/pkg/datachannel"
"github.com/pions/webrtc/pkg/ice"
"github.com/pions/webrtc/pkg/media" "github.com/pions/webrtc/pkg/media"
"github.com/pions/webrtc/pkg/rtp" "github.com/pions/webrtc/pkg/rtp"
@@ -342,3 +348,100 @@ func TestRTCPeerConnection_NewRTCSampleTrack(t *testing.T) {
track.Samples <- media.RTCSample{} track.Samples <- media.RTCSample{}
}) })
} }
func TestRTCPeerConnection_EventHandlers(t *testing.T) {
pc, err := New(RTCConfiguration{})
assert.Nil(t, err)
onTrackCalled := make(chan bool)
onICEConnectionStateChangeCalled := make(chan bool)
onDataChannelCalled := make(chan bool)
// Verify that the noop case works
assert.NotPanics(t, func() { pc.onTrack(nil) })
assert.NotPanics(t, func() { pc.onICEConnectionStateChange(ice.ConnectionStateNew) })
assert.NotPanics(t, func() { pc.onDataChannel(nil) })
pc.OnTrack(func(t *RTCTrack) {
onTrackCalled <- true
})
pc.OnICEConnectionStateChange(func(cs ice.ConnectionState) {
onICEConnectionStateChangeCalled <- true
})
pc.OnDataChannel(func(dc *RTCDataChannel) {
onDataChannelCalled <- true
})
// Verify that the handlers deal with nil inputs
assert.NotPanics(t, func() { pc.onTrack(nil) })
assert.NotPanics(t, func() { pc.onDataChannel(nil) })
// Verify that the set handlers are called
assert.NotPanics(t, func() { pc.onTrack(&RTCTrack{}) })
assert.NotPanics(t, func() { pc.onICEConnectionStateChange(ice.ConnectionStateNew) })
assert.NotPanics(t, func() { pc.onDataChannel(&RTCDataChannel{}) })
allTrue := func(vals []bool) bool {
for _, val := range vals {
if !val {
return false
}
}
return true
}
assert.True(t, allTrue([]bool{
<-onTrackCalled,
<-onICEConnectionStateChangeCalled,
<-onDataChannelCalled,
}))
}
func TestRTCPeerConnection_OnDataChannelSync(t *testing.T) {
// This is a special case, where we need to ensure that any DataChannel setup
// in the supplied handler completes before allowing the calling code to
// resume running.
//
// This test also validates that the locking in RTCPeerConnection.dataChannelEventHandler()
// correctly interacts with the locking in the event handlers.
pc, err := New(RTCConfiguration{})
assert.Nil(t, err)
onOpenCalled := make(chan bool)
onDataChannelCalled := make(chan bool)
onMessageCalled := make(chan bool)
pc.OnDataChannel(func(dc *RTCDataChannel) {
onDataChannelCalled <- true
dc.OnOpen(func() {
onOpenCalled <- true
})
dc.OnMessage(func(p datachannel.Payload) {
onMessageCalled <- true
})
})
go func() {
dcEvents := []network.DataChannelEvent{
// NB: This order seems odd, but it matches what's emitted
// by networkManager
&network.DataChannelOpen{},
&network.DataChannelCreated{},
&network.DataChannelMessage{Payload: &datachannel.PayloadString{Data: []byte("o hai")}},
}
for _, event := range dcEvents {
pc.dataChannelEventHandler(event)
}
}()
// NB: If RTCPeerConnection.dataChannelEventHandler() does not correctly wait for
// OnDataChannel() to complete, this will hang until timeout because the handlers aren't set
// before the events are processed.
assert.EqualValues(t,
[]bool{true, true, true},
[]bool{<-onDataChannelCalled, <-onOpenCalled, <-onMessageCalled},
)
}