diff --git a/internal/network/port-receive.go b/internal/network/port-receive.go index 667e43fb..f768e92a 100644 --- a/internal/network/port-receive.go +++ b/internal/network/port-receive.go @@ -21,7 +21,11 @@ type incomingPacket struct { buffer []byte } -func (p *Port) handleSRTP(b BufferTransportGenerator, certPair *dtls.CertPair, buffer []byte) { +func (p *Port) handleSRTP(b BufferTransportGenerator, buffer []byte) { + if p.certPair == nil { + fmt.Printf("Got SRTP packet but no DTLS state to handle it %v \n", p.certPair) + return + } if len(buffer) > 4 { var rtcpPacketType uint8 @@ -48,7 +52,7 @@ func (p *Port) handleSRTP(b BufferTransportGenerator, certPair *dtls.CertPair, b srtpContext, ok := p.srtpContexts[contextMapKey] if !ok { var err error - srtpContext, err = srtp.CreateContext([]byte(certPair.ServerWriteKey[0:16]), []byte(certPair.ServerWriteKey[16:]), certPair.Profile, packet.SSRC) + srtpContext, err = srtp.CreateContext([]byte(p.certPair.ServerWriteKey[0:16]), []byte(p.certPair.ServerWriteKey[16:]), p.certPair.Profile, packet.SSRC) if err != nil { fmt.Println("Failed to build SRTP context") return @@ -115,33 +119,26 @@ func (p *Port) handleSCTP(raw []byte, a *sctp.Association) { } } -func (p *Port) handleDTLS(raw []byte, srcAddr *net.UDPAddr, certPair *dtls.CertPair) bool { - if len(raw) < 0 || (raw[0] < 19 || raw[0] > 65) { - return false - } - +func (p *Port) handleDTLS(raw []byte, srcAddr *net.UDPAddr) { dtlsState := p.dtlsStates[srcAddr.String()] association := p.sctpAssocations[srcAddr.String()] if dtlsState == nil || association == nil { fmt.Printf("Got DTLS packet but no DTLS/SCTP state to handle it %v %v \n", dtlsState, association) - return true } if decrypted := dtlsState.HandleDTLSPacket(raw); len(decrypted) > 0 { p.handleSCTP(decrypted, association) } - if certPair == nil { - certPair = dtlsState.GetCertPair() - if certPair != nil { + if certPair := dtlsState.GetCertPair(); certPair != nil && p.certPair == nil { + p.certPair = certPair + if p.certPair != nil { p.authedConnections = append(p.authedConnections, &authedConnection{ - pair: certPair, + pair: p.certPair, peer: srcAddr, }) } } - - return true } const iceTimeout = time.Second * 10 @@ -168,7 +165,6 @@ func (p *Port) networkLoop(remoteKey []byte, tlscfg *dtls.TLSCfg, b BufferTransp } }() - var certPair *dtls.CertPair // Never timeout originally, only start timer after we get an ICE ping iceTimer := time.NewTimer(time.Hour * 8760) for { @@ -189,16 +185,18 @@ func (p *Port) networkLoop(remoteKey []byte, tlscfg *dtls.TLSCfg, b BufferTransp return } - if p.handleDTLS(in.buffer, in.srcAddr, certPair) { + if len(in.buffer) == 0 { + fmt.Println("Inbound buffer is not long enough to demux") continue } - if packetType, err := stun.GetPacketType(in.buffer); err == nil && packetType == stun.PacketTypeSTUN { + // https://tools.ietf.org/html/rfc5764#page-14 + if 127 < in.buffer[0] && in.buffer[0] < 192 { + p.handleSRTP(b, in.buffer) + } else if 19 < in.buffer[0] && in.buffer[0] < 64 { + p.handleDTLS(in.buffer, in.srcAddr) + } else if in.buffer[0] < 2 { p.handleICE(in, remoteKey, iceTimer, iceNotifier) - } else if certPair == nil { - fmt.Println("SRTP packet, but unable to handle DTLS handshake has not completed") - } else { - p.handleSRTP(b, certPair, in.buffer) } if _, ok := p.dtlsStates[in.srcAddr.String()]; !ok { diff --git a/internal/network/port.go b/internal/network/port.go index 48debb28..450acef7 100644 --- a/internal/network/port.go +++ b/internal/network/port.go @@ -41,7 +41,8 @@ type Port struct { association *sctp.Association - conn *ipv4.PacketConn + conn *ipv4.PacketConn + certPair *dtls.CertPair } // NewPort creates a new Port