diff --git a/datachannel.go b/datachannel.go index 94e4ddf5..6e41fdc1 100644 --- a/datachannel.go +++ b/datachannel.go @@ -49,6 +49,8 @@ type DataChannel struct { onMessageHandler func(DataChannelMessage) openHandlerOnce sync.Once onOpenHandler func() + dialHandlerOnce sync.Once + onDialHandler func() onCloseHandler func() onBufferedAmountLow func() onErrorHandler func(error) @@ -175,6 +177,7 @@ func (d *DataChannel) open(sctpTransport *SCTPTransport) error { dc.OnBufferedAmountLow(d.onBufferedAmountLow) d.mu.Unlock() + d.onDial() d.handleOpen(dc, false, d.negotiated) return nil } @@ -228,6 +231,30 @@ func (d *DataChannel) onOpen() { } } +// OnDial sets an event handler which is invoked when the +// peer has been dialed, but before said peer has responsed +func (d *DataChannel) OnDial(f func()) { + d.mu.Lock() + d.dialHandlerOnce = sync.Once{} + d.onDialHandler = f + d.mu.Unlock() + + if d.ReadyState() == DataChannelStateOpen { + // If the data channel is already open, call the handler immediately. + go d.dialHandlerOnce.Do(f) + } +} + +func (d *DataChannel) onDial() { + d.mu.RLock() + handler := d.onDialHandler + d.mu.RUnlock() + + if handler != nil { + go d.dialHandlerOnce.Do(handler) + } +} + // OnClose sets an event handler which is invoked when // the underlying data transport has been closed. func (d *DataChannel) OnClose(f func()) { diff --git a/datachannel_go_test.go b/datachannel_go_test.go index 126e2eb5..a93ca288 100644 --- a/datachannel_go_test.go +++ b/datachannel_go_test.go @@ -33,12 +33,17 @@ func TestDataChannel_EventHandlers(t *testing.T) { api := NewAPI() dc := &DataChannel{api: api} + onDialCalled := make(chan struct{}) onOpenCalled := make(chan struct{}) onMessageCalled := make(chan struct{}) // Verify that the noop case works assert.NotPanics(t, func() { dc.onOpen() }) + dc.OnDial(func() { + close(onDialCalled) + }) + dc.OnOpen(func() { close(onOpenCalled) }) @@ -48,10 +53,12 @@ func TestDataChannel_EventHandlers(t *testing.T) { }) // Verify that the set handlers are called + assert.NotPanics(t, func() { dc.onDial() }) assert.NotPanics(t, func() { dc.onOpen() }) assert.NotPanics(t, func() { dc.onMessage(DataChannelMessage{Data: []byte("o hai")}) }) // Wait for all handlers to be called + <-onDialCalled <-onOpenCalled <-onMessageCalled } @@ -578,3 +585,74 @@ func TestDataChannel_NonStandardSessionDescription(t *testing.T) { <-onDataChannelCalled closePairNow(t, offerPC, answerPC) } + +func TestDataChannel_Dial(t *testing.T) { + t.Run("handler should be called once, by dialing peer only", func(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + + dialCalls := make(chan bool, 2) + wg := new(sync.WaitGroup) + wg.Add(2) + + offerPC, answerPC, err := newPair() + if err != nil { + t.Fatalf("Failed to create a PC pair for testing") + } + + answerPC.OnDataChannel(func(d *DataChannel) { + if d.Label() != expectedLabel { + return + } + + d.OnDial(func() { + // only dialing side should fire OnDial + t.Fatalf("answering side should not call on dial") + }) + + d.OnOpen(wg.Done) + }) + + d, err := offerPC.CreateDataChannel(expectedLabel, nil) + assert.NoError(t, err) + d.OnDial(func() { + dialCalls <- true + wg.Done() + }) + + assert.NoError(t, signalPair(offerPC, answerPC)) + + wg.Wait() + closePairNow(t, offerPC, answerPC) + + assert.Len(t, dialCalls, 1) + }) + + t.Run("handler should be called immediately if already dialed", func(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + + done := make(chan bool) + + offerPC, answerPC, err := newPair() + if err != nil { + t.Fatalf("Failed to create a PC pair for testing") + } + + d, err := offerPC.CreateDataChannel(expectedLabel, nil) + assert.NoError(t, err) + d.OnOpen(func() { + // when the offer DC has been opened, its guaranteed to have dialed since it has + // received a response to said dial. this test represents an unrealistic usage, + // but its the best way to guarantee we "missed" the dial event and still invoke + // the handler. + d.OnDial(func() { + done <- true + }) + }) + + assert.NoError(t, signalPair(offerPC, answerPC)) + + closePair(t, offerPC, answerPC, done) + }) +}