Files
go2rtc/pkg/srtp/session.go
2025-04-07 16:56:38 +03:00

161 lines
3.1 KiB
Go

package srtp
import (
"net"
"time"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/srtp/v3"
)
type Session struct {
Local *Endpoint
Remote *Endpoint
OnReadRTP func(packet *rtp.Packet)
Recv int // bytes recv
Send int // bytes send
conn net.PacketConn // local conn endpoint
PayloadType uint8
RTCPInterval time.Duration
senderRTCP rtcp.SenderReport
senderTime time.Time
}
type Endpoint struct {
Addr string
Port uint16
MasterKey []byte
MasterSalt []byte
SSRC uint32
addr net.Addr
srtp *srtp.Context
}
func (e *Endpoint) init() (err error) {
e.addr = &net.UDPAddr{IP: net.ParseIP(e.Addr), Port: int(e.Port)}
e.srtp, err = srtp.CreateContext(e.MasterKey, e.MasterSalt, profile(e.MasterKey))
return
}
func profile(key []byte) srtp.ProtectionProfile {
switch len(key) {
case 16:
return srtp.ProtectionProfileAes128CmHmacSha1_80
//case 32:
// return srtp.ProtectionProfileAes256CmHmacSha1_80
}
return 0
}
func (s *Session) init() error {
if err := s.Local.init(); err != nil {
return err
}
if err := s.Remote.init(); err != nil {
return err
}
s.senderRTCP.SSRC = s.Local.SSRC
s.senderTime = time.Now().Add(s.RTCPInterval)
return nil
}
func (s *Session) WriteRTP(packet *rtp.Packet) (int, error) {
if s.Local.srtp == nil {
return 0, nil // before init call
}
if now := time.Now(); now.After(s.senderTime) {
s.senderRTCP.NTPTime = uint64(now.UnixNano())
s.senderTime = now.Add(s.RTCPInterval)
_, _ = s.WriteRTCP(&s.senderRTCP)
}
clone := rtp.Packet{
Header: rtp.Header{
Version: 2,
Marker: packet.Marker,
PayloadType: s.PayloadType,
SequenceNumber: packet.SequenceNumber,
Timestamp: packet.Timestamp,
SSRC: s.Local.SSRC,
},
Payload: packet.Payload,
}
b, err := clone.Marshal()
if err != nil {
return 0, err
}
s.senderRTCP.PacketCount++
s.senderRTCP.RTPTime = clone.Timestamp
s.senderRTCP.OctetCount += uint32(len(clone.Payload))
if b, err = s.Local.srtp.EncryptRTP(nil, b, nil); err != nil {
return 0, err
}
return s.conn.WriteTo(b, s.Remote.addr)
}
func (s *Session) WriteRTCP(packet rtcp.Packet) (int, error) {
b, err := packet.Marshal()
if err != nil {
return 0, err
}
b, err = s.Local.srtp.EncryptRTCP(nil, b, nil)
if err != nil {
return 0, err
}
return s.conn.WriteTo(b, s.Remote.addr)
}
func (s *Session) ReadRTP(b []byte) {
packet := &rtp.Packet{}
b, err := s.Remote.srtp.DecryptRTP(nil, b, &packet.Header)
if err != nil {
return
}
if err = packet.Unmarshal(b); err != nil {
return
}
if s.OnReadRTP != nil {
s.OnReadRTP(packet)
}
}
func (s *Session) ReadRTCP(b []byte) {
header := rtcp.Header{}
b, err := s.Remote.srtp.DecryptRTCP(nil, b, &header)
if err != nil {
return
}
//packets, err := rtcp.Unmarshal(b)
//if err != nil {
// return
//}
//if report, ok := packets[0].(*rtcp.SenderReport); ok {
// log.Printf("[srtp] rtcp type=%d report=%v", header.Type, report)
//}
if header.Type != rtcp.TypeSenderReport {
return
}
receiverRTCP := rtcp.ReceiverReport{SSRC: s.Local.SSRC}
_, _ = s.WriteRTCP(&receiverRTCP)
}