Add ICE connection state change notification and timeouts

This commit is contained in:
Sean DuBois
2018-06-30 02:57:47 -07:00
parent 6cd9d069c1
commit 5bf9d5af34
10 changed files with 286 additions and 123 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/pions/webrtc" "github.com/pions/webrtc"
"github.com/pions/webrtc/examples/gstreamer-receive/gst" "github.com/pions/webrtc/examples/gstreamer-receive/gst"
"github.com/pions/webrtc/pkg/ice"
"github.com/pions/webrtc/pkg/rtp" "github.com/pions/webrtc/pkg/rtp"
) )
@@ -48,6 +49,12 @@ func startWebrtc(pipeline *gst.Pipeline) {
} }
} }
// Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) {
fmt.Printf("Connection State has changed %s \n", connectionState.String())
}
// Set the remote SessionDescription // Set the remote SessionDescription
if err := peerConnection.SetRemoteDescription(string(sd)); err != nil { if err := peerConnection.SetRemoteDescription(string(sd)); err != nil {
panic(err) panic(err)

View File

@@ -9,6 +9,7 @@ import (
"github.com/pions/webrtc" "github.com/pions/webrtc"
"github.com/pions/webrtc/examples/gstreamer-send/gst" "github.com/pions/webrtc/examples/gstreamer-send/gst"
"github.com/pions/webrtc/pkg/ice"
) )
func main() { func main() {
@@ -47,6 +48,12 @@ func main() {
panic(err) panic(err)
} }
// Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) {
fmt.Printf("Connection State has changed %s \n", connectionState.String())
}
// Get the LocalDescription and take it to base64 so we can paste in browser // Get the LocalDescription and take it to base64 so we can paste in browser
localDescriptionStr := peerConnection.LocalDescription.Marshal() localDescriptionStr := peerConnection.LocalDescription.Marshal()
fmt.Println(base64.StdEncoding.EncodeToString([]byte(localDescriptionStr))) fmt.Println(base64.StdEncoding.EncodeToString([]byte(localDescriptionStr)))

View File

@@ -8,6 +8,7 @@ import (
"sync/atomic" "sync/atomic"
"github.com/pions/webrtc" "github.com/pions/webrtc"
"github.com/pions/webrtc/pkg/ice"
"github.com/pions/webrtc/pkg/rtp" "github.com/pions/webrtc/pkg/rtp"
) )
@@ -49,6 +50,12 @@ func main() {
} }
} }
// Set the handler for ICE connection state
// This will notify you when the peer has connected/disconnected
peerConnection.OnICEConnectionStateChange = func(connectionState ice.ConnectionState) {
fmt.Printf("Connection State has changed %s \n", connectionState.String())
}
// Set the remote SessionDescription // Set the remote SessionDescription
if err := peerConnection.SetRemoteDescription(string(sd)); err != nil { if err := peerConnection.SetRemoteDescription(string(sd)); err != nil {
panic(err) panic(err)

View File

@@ -125,11 +125,11 @@ type CertPair struct {
} }
// HandleDTLSPacket checks if the packet is a DTLS packet, and if it is passes to the DTLS session // HandleDTLSPacket checks if the packet is a DTLS packet, and if it is passes to the DTLS session
func (d *State) HandleDTLSPacket(packet []byte, size int) (certPair *CertPair) { func (d *State) HandleDTLSPacket(packet []byte) (certPair *CertPair) {
packetRaw := C.CBytes(packet) packetRaw := C.CBytes(packet)
defer C.free(unsafe.Pointer(packetRaw)) defer C.free(unsafe.Pointer(packetRaw))
if ret := C.dtls_handle_incoming(d.dtlsSession, d.rawSrc, d.rawDst, packetRaw, C.int(size)); ret != nil { if ret := C.dtls_handle_incoming(d.dtlsSession, d.rawSrc, d.rawDst, packetRaw, C.int(len(packet))); ret != nil {
certPair = &CertPair{ certPair = &CertPair{
ClientWriteKey: []byte(C.GoStringN(&ret.client_write_key[0], ret.key_length)), ClientWriteKey: []byte(C.GoStringN(&ret.client_write_key[0], ret.key_length)),
ServerWriteKey: []byte(C.GoStringN(&ret.server_write_key[0], ret.key_length)), ServerWriteKey: []byte(C.GoStringN(&ret.server_write_key[0], ret.key_length)),

View File

@@ -1,42 +0,0 @@
package ice
import "net"
// HostInterfaces generates a slice of all the IPs associated with interfaces
func HostInterfaces() (ips []string) {
ifaces, err := net.Interfaces()
if err != nil {
return ips
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return ips
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
ips = append(ips, ip.String())
}
}
return ips
}

View File

@@ -1,7 +1,13 @@
package network package network
import "github.com/pions/webrtc/pkg/rtp" import (
"github.com/pions/webrtc/pkg/rtp"
)
// BufferTransportGenerator generates a new channel for the associated SSRC // BufferTransportGenerator generates a new channel for the associated SSRC
// This channel is used to send RTP packets to users of pion-WebRTC // This channel is used to send RTP packets to users of pion-WebRTC
type BufferTransportGenerator func(uint32) chan<- *rtp.Packet type BufferTransportGenerator func(uint32) chan<- *rtp.Packet
// ICENotifier
// Notify the RTCPeerConnection if ICE has changed state
type ICENotifier func(*Port)

View File

@@ -5,14 +5,21 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
"time"
"github.com/pions/pkg/stun" "github.com/pions/pkg/stun"
"github.com/pions/webrtc/internal/dtls" "github.com/pions/webrtc/internal/dtls"
"github.com/pions/webrtc/internal/srtp" "github.com/pions/webrtc/internal/srtp"
"github.com/pions/webrtc/pkg/ice"
"github.com/pions/webrtc/pkg/rtp" "github.com/pions/webrtc/pkg/rtp"
) )
func (p *Port) handleSRTP(srcString string, b BufferTransportGenerator, certPair *dtls.CertPair, buffer []byte, bufferSize int) { type incomingPacket struct {
srcAddr *net.UDPAddr
buffer []byte
}
func (p *Port) handleSRTP(b BufferTransportGenerator, certPair *dtls.CertPair, buffer []byte) {
if len(buffer) > 4 { if len(buffer) > 4 {
var rtcpPacketType uint8 var rtcpPacketType uint8
@@ -28,18 +35,13 @@ func (p *Port) handleSRTP(srcString string, b BufferTransportGenerator, certPair
} }
} }
// Make copy of packet
// buffer[:n] can't be modified outside of network loop
rawPacket := make([]byte, bufferSize)
copy(rawPacket, buffer[:bufferSize])
packet := &rtp.Packet{} packet := &rtp.Packet{}
if err := packet.Unmarshal(rawPacket); err != nil { if err := packet.Unmarshal(buffer); err != nil {
fmt.Println("Failed to unmarshal RTP packet") fmt.Println("Failed to unmarshal RTP packet")
return return
} }
contextMapKey := srcString + ":" + fmt.Sprint(packet.SSRC) contextMapKey := p.ListeningAddr.String() + ":" + fmt.Sprint(packet.SSRC)
p.srtpContextsLock.Lock() p.srtpContextsLock.Lock()
srtpContext, ok := p.srtpContexts[contextMapKey] srtpContext, ok := p.srtpContexts[contextMapKey]
if !ok { if !ok {
@@ -63,7 +65,6 @@ func (p *Port) handleSRTP(srcString string, b BufferTransportGenerator, certPair
if bufferTransport == nil { if bufferTransport == nil {
bufferTransport = b(packet.SSRC) bufferTransport = b(packet.SSRC)
if bufferTransport == nil { if bufferTransport == nil {
fmt.Println("Failed to generate buffer transport, onTrack should be defined")
return return
} }
p.bufferTransports[packet.SSRC] = bufferTransport p.bufferTransports[packet.SSRC] = bufferTransport
@@ -71,67 +72,99 @@ func (p *Port) handleSRTP(srcString string, b BufferTransportGenerator, certPair
bufferTransport <- packet bufferTransport <- packet
} }
func (p *Port) networkLoop(srcString string, remoteKey []byte, tlscfg *dtls.TLSCfg, b BufferTransportGenerator) { func (p *Port) handleICE(in *incomingPacket, remoteKey []byte, iceTimer *time.Timer, iceNotifier ICENotifier) {
const MTU = 8192 if m, err := stun.NewMessage(in.buffer); err == nil && m.Class == stun.ClassRequest && m.Method == stun.MethodBinding {
buffer := make([]byte, MTU) dstAddr := &stun.TransportAddr{IP: in.srcAddr.IP, Port: in.srcAddr.Port}
if err := stun.BuildAndSend(p.conn, dstAddr, stun.ClassSuccessResponse, stun.MethodBinding, m.TransactionID,
var certPair *dtls.CertPair &stun.XorMappedAddress{
for { XorAddress: stun.XorAddress{
n, _, rawDstAddr, err := p.conn.ReadFrom(buffer) IP: dstAddr.IP,
if err != nil { Port: dstAddr.Port,
fmt.Printf("Failed to read packet: %s \n", err.Error()) },
continue },
} &stun.MessageIntegrity{
Key: remoteKey,
d, haveHandshaked := p.dtlsStates[rawDstAddr.String()] },
if haveHandshaked && buffer[0] >= 20 && buffer[0] <= 64 { &stun.Fingerprint{},
tmpCertPair := d.HandleDTLSPacket(buffer, n) ); err != nil {
if tmpCertPair != nil { fmt.Println(err)
certPair = tmpCertPair
p.authedConnections = append(p.authedConnections, &authedConnection{
pair: certPair,
peer: rawDstAddr,
})
}
continue
}
if packetType, err := stun.GetPacketType(buffer[:n]); err == nil && packetType == stun.PacketTypeSTUN {
if m, err := stun.NewMessage(buffer[:n]); err == nil && m.Class == stun.ClassRequest && m.Method == stun.MethodBinding {
dstAddr := &stun.TransportAddr{IP: rawDstAddr.(*net.UDPAddr).IP, Port: rawDstAddr.(*net.UDPAddr).Port}
if err := stun.BuildAndSend(p.conn, dstAddr, stun.ClassSuccessResponse, stun.MethodBinding, m.TransactionID,
&stun.XorMappedAddress{
XorAddress: stun.XorAddress{
IP: dstAddr.IP,
Port: dstAddr.Port,
},
},
&stun.MessageIntegrity{
Key: remoteKey,
},
&stun.Fingerprint{},
); err != nil {
fmt.Println(err)
}
}
} else { } else {
if certPair == nil { p.ICEState = ice.Completed
fmt.Println("SRTP packet, but unable to handle DTLS handshake has not completed") iceTimer.Reset(iceTimeout)
continue iceNotifier(p)
} }
p.handleSRTP(srcString, b, certPair, buffer, n) }
} }
if !haveHandshaked { const iceTimeout = time.Second * 10
d, err := dtls.NewState(tlscfg, true, srcString, rawDstAddr.String()) const MTU = 8192
if err != nil {
fmt.Println(err) func (p *Port) networkLoop(remoteKey []byte, tlscfg *dtls.TLSCfg, b BufferTransportGenerator, iceNotifier ICENotifier) {
continue incomingPackets := make(chan *incomingPacket, 15)
} go func() {
buffer := make([]byte, MTU)
d.DoHandshake() for {
p.dtlsStates[rawDstAddr.String()] = d n, _, srcAddr, err := p.conn.ReadFrom(buffer)
if err != nil {
close(incomingPackets)
break
}
bufferCopy := make([]byte, n)
copy(bufferCopy, buffer[:n])
select {
case incomingPackets <- &incomingPacket{buffer: bufferCopy, srcAddr: srcAddr.(*net.UDPAddr)}:
default:
}
}
}()
var certPair *dtls.CertPair
iceTimer := time.NewTimer(iceTimeout)
for {
select {
case <-iceTimer.C:
p.ICEState = ice.Failed
iceNotifier(p)
case in, inValid := <-incomingPackets:
if !inValid {
// incomingPackets channel has closed, this port is finished processing
return
}
dtlsState := p.dtlsStates[in.srcAddr.String()]
if dtlsState != nil && in.buffer[0] >= 20 && in.buffer[0] <= 64 {
tmpCertPair := dtlsState.HandleDTLSPacket(in.buffer)
if tmpCertPair != nil {
certPair = tmpCertPair
p.authedConnections = append(p.authedConnections, &authedConnection{
pair: certPair,
peer: in.srcAddr,
})
}
continue
}
if packetType, err := stun.GetPacketType(in.buffer); err == nil && packetType == stun.PacketTypeSTUN {
p.handleICE(in, remoteKey, iceTimer, iceNotifier)
} else if certPair == nil {
fmt.Println("SRTP packet, but unable to handle DTLS handshake has not completed")
} else {
p.handleSRTP(b, certPair, in.buffer)
}
if dtlsState == nil {
d, err := dtls.NewState(tlscfg, true, p.ListeningAddr.String(), in.srcAddr.String())
if err != nil {
fmt.Println(err)
continue
}
d.DoHandshake()
p.dtlsStates[in.srcAddr.String()] = d
}
} }
} }
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/pions/pkg/stun" "github.com/pions/pkg/stun"
"github.com/pions/webrtc/internal/dtls" "github.com/pions/webrtc/internal/dtls"
"github.com/pions/webrtc/internal/srtp" "github.com/pions/webrtc/internal/srtp"
"github.com/pions/webrtc/pkg/ice"
"github.com/pions/webrtc/pkg/rtp" "github.com/pions/webrtc/pkg/rtp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@@ -20,6 +21,7 @@ type authedConnection struct {
// Port represents a UDP listener that handles incoming/outgoing traffic // Port represents a UDP listener that handles incoming/outgoing traffic
type Port struct { type Port struct {
ListeningAddr *stun.TransportAddr ListeningAddr *stun.TransportAddr
ICEState ice.ConnectionState
dtlsStates map[string]*dtls.State dtlsStates map[string]*dtls.State
@@ -39,7 +41,7 @@ type Port struct {
} }
// NewPort creates a new Port // NewPort creates a new Port
func NewPort(address string, remoteKey []byte, tlscfg *dtls.TLSCfg, b BufferTransportGenerator) (*Port, error) { func NewPort(address string, remoteKey []byte, tlscfg *dtls.TLSCfg, b BufferTransportGenerator, i ICENotifier) (*Port, error) {
listener, err := net.ListenPacket("udp4", address) listener, err := net.ListenPacket("udp4", address)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -65,10 +67,11 @@ func NewPort(address string, remoteKey []byte, tlscfg *dtls.TLSCfg, b BufferTran
srtpContexts: make(map[string]*srtp.Context), srtpContexts: make(map[string]*srtp.Context),
} }
go p.networkLoop(srcString, remoteKey, tlscfg, b) go p.networkLoop(remoteKey, tlscfg, b, i)
return p, nil return p, nil
} }
// Stop closes the listening port and cleans up any state // Stop closes the listening port and cleans up any state
func (p *Port) Stop() { func (p *Port) Close() error {
return p.conn.Close()
} }

90
pkg/ice/ice.go Normal file
View File

@@ -0,0 +1,90 @@
package ice
import "net"
// State is an enum showing the state of a ICE Connection
type ConnectionState int
// List of supported States
const (
// New ICE agent is gathering addresses
New = iota + 1
// Checking ICE agent has been given local and remote candidates, and is attempting to find a match
Checking
// Connected ICE agent has a pairing, but is still checking other pairs
Connected
// Completed ICE agent has finished
Completed
// Failed ICE agent never could sucessfully connect
Failed
// Failed ICE agent connected sucessfully, but has entered a failed state
Disconnected
// Closed ICE agent has finished and is no longer handling requests
Closed
)
func (c ConnectionState) String() string {
switch c {
case New:
return "New"
case Checking:
return "Checking"
case Connected:
return "Connected"
case Completed:
return "Completed"
case Failed:
return "Failed"
case Disconnected:
return "Disconnected"
case Closed:
return "Closed"
default:
return "Invalid"
}
}
// HostInterfaces generates a slice of all the IPs associated with interfaces
func HostInterfaces() (ips []string) {
ifaces, err := net.Interfaces()
if err != nil {
return ips
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return ips
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
ips = append(ips, ip.String())
}
}
return ips
}

View File

@@ -3,12 +3,13 @@ package webrtc
import ( import (
"fmt" "fmt"
"math/rand" "math/rand"
"sync"
"github.com/pions/webrtc/internal/dtls" "github.com/pions/webrtc/internal/dtls"
"github.com/pions/webrtc/internal/ice"
"github.com/pions/webrtc/internal/network" "github.com/pions/webrtc/internal/network"
"github.com/pions/webrtc/internal/sdp" "github.com/pions/webrtc/internal/sdp"
"github.com/pions/webrtc/internal/util" "github.com/pions/webrtc/internal/util"
"github.com/pions/webrtc/pkg/ice"
"github.com/pions/webrtc/pkg/rtp" "github.com/pions/webrtc/pkg/rtp"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -30,16 +31,18 @@ const (
// RTCPeerConnection represents a WebRTC connection between itself and a remote peer // RTCPeerConnection represents a WebRTC connection between itself and a remote peer
type RTCPeerConnection struct { type RTCPeerConnection struct {
Ontrack func(mediaType TrackType, buffers <-chan *rtp.Packet) Ontrack func(mediaType TrackType, buffers <-chan *rtp.Packet)
LocalDescription *sdp.SessionDescription LocalDescription *sdp.SessionDescription
OnICEConnectionStateChange func(iceConnectionState ice.ConnectionState)
tlscfg *dtls.TLSCfg tlscfg *dtls.TLSCfg
iceUsername string iceUsername string
icePassword string icePassword string
iceState ice.ConnectionState
// TODO mutex portsLock sync.RWMutex
ports []*network.Port ports []*network.Port
} }
// Public // Public
@@ -56,14 +59,18 @@ func (r *RTCPeerConnection) CreateOffer() error {
if r.tlscfg != nil { if r.tlscfg != nil {
return errors.Errorf("tlscfg is already defined, CreateOffer can only be called once") return errors.Errorf("tlscfg is already defined, CreateOffer can only be called once")
} }
r.tlscfg = dtls.NewTLSCfg() r.tlscfg = dtls.NewTLSCfg()
r.iceUsername = util.RandSeq(16) r.iceUsername = util.RandSeq(16)
r.icePassword = util.RandSeq(32) r.icePassword = util.RandSeq(32)
r.portsLock.Lock()
defer r.portsLock.Unlock()
candidates := []string{} candidates := []string{}
basePriority := uint16(rand.Uint32() & (1<<16 - 1)) basePriority := uint16(rand.Uint32() & (1<<16 - 1))
for id, c := range ice.HostInterfaces() { for id, c := range ice.HostInterfaces() {
port, err := network.NewPort(c+":0", []byte(r.icePassword), r.tlscfg, r.generateChannel) port, err := network.NewPort(c+":0", []byte(r.icePassword), r.tlscfg, r.generateChannel, r.iceStateChange)
if err != nil { if err != nil {
return err return err
} }
@@ -96,6 +103,21 @@ func (r *RTCPeerConnection) AddTrack(mediaType TrackType) (buffers chan<- []byte
return trackInput, nil return trackInput, nil
} }
// Close ends the RTCPeerConnection
func (r *RTCPeerConnection) Close() error {
r.portsLock.Lock()
defer r.portsLock.Unlock()
// Walk all ports remove and close them
for _, p := range r.ports {
if err := p.Close(); err != nil {
return err
}
}
r.ports = nil
return nil
}
// Private // Private
func (r *RTCPeerConnection) generateChannel(ssrc uint32) (buffers chan<- *rtp.Packet) { func (r *RTCPeerConnection) generateChannel(ssrc uint32) (buffers chan<- *rtp.Packet) {
if r.Ontrack == nil { if r.Ontrack == nil {
@@ -106,3 +128,33 @@ func (r *RTCPeerConnection) generateChannel(ssrc uint32) (buffers chan<- *rtp.Pa
go r.Ontrack(VP8, bufferTransport) // TODO look up media via SSRC in remote SD go r.Ontrack(VP8, bufferTransport) // TODO look up media via SSRC in remote SD
return bufferTransport return bufferTransport
} }
// Private
func (r *RTCPeerConnection) iceStateChange(p *network.Port) {
updateAndNotify := func(newState ice.ConnectionState) {
if r.OnICEConnectionStateChange != nil && r.iceState != newState {
r.OnICEConnectionStateChange(newState)
}
r.iceState = newState
}
if p.ICEState == ice.Failed {
if err := p.Close(); err != nil {
fmt.Println(errors.Wrap(err, "Failed to close Port when ICE went to failed"))
}
r.portsLock.Lock()
defer r.portsLock.Unlock()
for i := len(r.ports) - 1; i >= 0; i-- {
if r.ports[i] == p {
r.ports = append(r.ports[:i], r.ports[i+1:]...)
}
}
if len(r.ports) == 0 {
updateAndNotify(ice.Disconnected)
}
} else {
updateAndNotify(ice.Connected)
}
}