Files
webrtc/internal/network/port-receive.go
2018-07-21 12:27:38 -07:00

188 lines
4.7 KiB
Go

package network
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"time"
"github.com/pions/pkg/stun"
"github.com/pions/webrtc/internal/dtls"
"github.com/pions/webrtc/internal/sctp"
"github.com/pions/webrtc/internal/srtp"
"github.com/pions/webrtc/pkg/ice"
"github.com/pions/webrtc/pkg/rtp"
)
type incomingPacket struct {
srcAddr *net.UDPAddr
buffer []byte
}
func (p *Port) handleSRTP(b BufferTransportGenerator, certPair *dtls.CertPair, buffer []byte) {
if len(buffer) > 4 {
var rtcpPacketType uint8
r := bytes.NewReader([]byte{buffer[1]})
if err := binary.Read(r, binary.BigEndian, &rtcpPacketType); err != nil {
fmt.Println("Failed to check packet for RTCP")
return
}
if rtcpPacketType >= 192 && rtcpPacketType <= 223 {
fmt.Println("Discarding RTCP packet TODO")
return
}
}
packet := &rtp.Packet{}
if err := packet.Unmarshal(buffer); err != nil {
fmt.Println("Failed to unmarshal RTP packet")
return
}
contextMapKey := p.ListeningAddr.String() + ":" + fmt.Sprint(packet.SSRC)
p.srtpContextsLock.Lock()
srtpContext, ok := p.srtpContexts[contextMapKey]
if !ok {
var err error
srtpContext, err = srtp.CreateContext([]byte(certPair.ServerWriteKey[0:16]), []byte(certPair.ServerWriteKey[16:]), certPair.Profile, packet.SSRC)
if err != nil {
fmt.Println("Failed to build SRTP context")
return
}
p.srtpContexts[contextMapKey] = srtpContext
}
p.srtpContextsLock.Unlock()
if ok := srtpContext.DecryptPacket(packet); !ok {
fmt.Println("Failed to decrypt packet")
return
}
bufferTransport := p.bufferTransports[packet.SSRC]
if bufferTransport == nil {
bufferTransport = b(packet.SSRC, packet.PayloadType)
if bufferTransport == nil {
return
}
p.bufferTransports[packet.SSRC] = bufferTransport
}
select {
case bufferTransport <- packet:
default:
}
}
func (p *Port) handleICE(in *incomingPacket, remoteKey []byte, iceTimer *time.Timer, iceNotifier ICENotifier) {
if m, err := stun.NewMessage(in.buffer); err == nil && m.Class == stun.ClassRequest && m.Method == stun.MethodBinding {
dstAddr := &stun.TransportAddr{IP: in.srcAddr.IP, Port: in.srcAddr.Port}
if err := stun.BuildAndSend(p.conn, dstAddr, stun.ClassSuccessResponse, stun.MethodBinding, m.TransactionID,
&stun.XorMappedAddress{
XorAddress: stun.XorAddress{
IP: dstAddr.IP,
Port: dstAddr.Port,
},
},
&stun.MessageIntegrity{
Key: remoteKey,
},
&stun.Fingerprint{},
); err != nil {
fmt.Println(err)
} else {
p.ICEState = ice.ConnectionStateCompleted
iceTimer.Reset(iceTimeout)
iceNotifier(p)
}
}
}
const iceTimeout = time.Second * 10
const receiveMTU = 8192
func (p *Port) networkLoop(remoteKey []byte, tlscfg *dtls.TLSCfg, b BufferTransportGenerator, iceNotifier ICENotifier) {
incomingPackets := make(chan *incomingPacket, 15)
go func() {
buffer := make([]byte, receiveMTU)
for {
n, _, srcAddr, err := p.conn.ReadFrom(buffer)
if err != nil {
close(incomingPackets)
break
}
bufferCopy := make([]byte, n)
copy(bufferCopy, buffer[:n])
select {
case incomingPackets <- &incomingPacket{buffer: bufferCopy, srcAddr: srcAddr.(*net.UDPAddr)}:
default:
}
}
}()
var certPair *dtls.CertPair
// Never timeout originally, only start timer after we get an ICE ping
iceTimer := time.NewTimer(time.Hour * 8760)
for {
select {
case <-iceTimer.C:
p.ICEState = ice.ConnectionStateFailed
iceNotifier(p)
case in, inValid := <-incomingPackets:
if !inValid {
// incomingPackets channel has closed, this port is finished processing
return
}
dtlsState := p.dtlsStates[in.srcAddr.String()]
if dtlsState != nil && len(in.buffer) > 0 && in.buffer[0] >= 20 && in.buffer[0] <= 64 {
decrypted := dtlsState.HandleDTLSPacket(in.buffer)
if len(decrypted) > 0 {
sctp.Parse(decrypted)
}
if certPair == nil {
certPair = dtlsState.GetCertPair()
if certPair != nil {
p.authedConnections = append(p.authedConnections, &authedConnection{
pair: certPair,
peer: in.srcAddr,
})
}
}
continue
}
if packetType, err := stun.GetPacketType(in.buffer); err == nil && packetType == stun.PacketTypeSTUN {
p.handleICE(in, remoteKey, iceTimer, iceNotifier)
} else if certPair == nil {
fmt.Println("SRTP packet, but unable to handle DTLS handshake has not completed")
} else {
p.handleSRTP(b, certPair, in.buffer)
}
if dtlsState == nil {
d, err := dtls.NewState(tlscfg, true, p.ListeningAddr.String(), in.srcAddr.String())
if err != nil {
fmt.Println(err)
continue
}
d.DoHandshake()
p.dtlsStates[in.srcAddr.String()] = d
}
}
}
dtls.RemoveListener(p.ListeningAddr.String())
for _, d := range p.dtlsStates {
d.Close()
}
}