server: allow calling ServerSession.WritePacketRTCP() inside OnRecord and OnPlay (#99)

This commit is contained in:
aler9
2022-02-18 23:19:33 +01:00
parent cbc228acbf
commit 86fb4181c7
4 changed files with 109 additions and 41 deletions

View File

@@ -638,6 +638,10 @@ func TestServerPublish(t *testing.T) {
}, nil, nil }, nil, nil
}, },
onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) {
// send RTCP packets directly to the session.
// these are sent after the response, only if onRecord returns StatusOK.
ctx.Session.WritePacketRTCP(0, &testRTCPPacket)
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
@@ -766,6 +770,28 @@ func TestServerPublish(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
// server -> client (direct)
if transport == "udp" {
buf := make([]byte, 2048)
n, _, err := l2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else {
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err := f.Read(br)
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
}
// skip firewall opening
if transport == "udp" {
buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf)
require.NoError(t, err)
}
// client -> server // client -> server
if transport == "udp" { if transport == "udp" {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@@ -799,12 +825,7 @@ func TestServerPublish(t *testing.T) {
// server -> client (RTCP) // server -> client (RTCP)
if transport == "udp" { if transport == "udp" {
// skip firewall opening
buf := make([]byte, 2048) buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf)
require.NoError(t, err)
buf = make([]byte, 2048)
n, _, err := l2.ReadFrom(buf) n, _, err := l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n]) require.Equal(t, testRTCPPacketMarshaled, buf[:n])

View File

@@ -287,10 +287,19 @@ func TestServerRead(t *testing.T) {
}, stream, nil }, stream, nil
}, },
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
// send RTCP packets directly to the session.
// these are sent after the response, only if onPlay returns StatusOK.
if transport != "multicast" {
ctx.Session.WritePacketRTCP(0, &testRTCPPacket)
}
// the session is added to the stream only after onPlay returns
// with StatusOK; therefore we must wait before calling
// ServerStream.WritePacket*()
go func() { go func() {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
stream.WritePacketRTP(0, &testRTPPacket)
stream.WritePacketRTCP(0, &testRTCPPacket) stream.WritePacketRTCP(0, &testRTCPPacket)
stream.WritePacketRTP(0, &testRTPPacket)
}() }()
return &base.Response{ return &base.Response{
@@ -461,13 +470,6 @@ func TestServerRead(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
// server -> client
if transport == "udp" || transport == "multicast" {
buf := make([]byte, 2048)
n, _, err := l1.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTPPacketMarshaled, buf[:n])
// skip firewall opening // skip firewall opening
if transport == "udp" { if transport == "udp" {
buf := make([]byte, 2048) buf := make([]byte, 2048)
@@ -475,23 +477,63 @@ func TestServerRead(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
// server -> client (direct)
switch transport {
case "udp":
buf := make([]byte, 2048)
n, _, err := l2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
case "tcp", "tls":
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err := f.Read(br)
require.NoError(t, err)
switch f.Channel {
case 4:
require.Equal(t, testRTPPacketMarshaled, f.Payload)
case 5:
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
default:
t.Errorf("should not happen")
}
}
// server -> client (through stream)
if transport == "udp" || transport == "multicast" {
buf := make([]byte, 2048)
n, _, err := l1.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTPPacketMarshaled, buf[:n])
buf = make([]byte, 2048) buf = make([]byte, 2048)
n, _, err = l2.ReadFrom(buf) n, _, err = l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n]) require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else { } else {
var f base.InterleavedFrame var f base.InterleavedFrame
for i := 0; i < 2; i++ {
f.Payload = make([]byte, 2048) f.Payload = make([]byte, 2048)
err := f.Read(br) err := f.Read(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 4, f.Channel)
switch f.Channel {
case 4:
require.Equal(t, testRTPPacketMarshaled, f.Payload) require.Equal(t, testRTPPacketMarshaled, f.Payload)
f.Payload = make([]byte, 2048) case 5:
err = f.Read(br)
require.NoError(t, err)
require.Equal(t, 5, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload) require.Equal(t, testRTCPPacketMarshaled, f.Payload)
default:
t.Errorf("should not happen")
}
}
} }
// client -> server (RTCP) // client -> server (RTCP)

View File

@@ -826,6 +826,14 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, liberrors.ErrServerPathHasChanged{Prev: *ss.setuppedPath, Cur: path} }, liberrors.ErrServerPathHasChanged{Prev: *ss.setuppedPath, Cur: path}
} }
// allocate writeBuffer before calling OnPlay().
// in this way it's possible to call ServerSession.WritePacket*()
// inside the callback.
if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast {
ss.writeBuffer = ringbuffer.New(uint64(ss.s.ReadBufferCount))
}
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{ res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
Session: ss, Session: ss,
Conn: sc, Conn: sc,
@@ -835,7 +843,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}) })
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
if ss.State() == ServerSessionStatePrePlay { if ss.state != ServerSessionStatePlay {
ss.writeBuffer = nil ss.writeBuffer = nil
} }
return res, err return res, err
@@ -851,13 +859,12 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
case TransportUDP: case TransportUDP:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.writeBuffer = ringbuffer.New(uint64(ss.s.ReadBufferCount))
ss.writerRunning = true ss.writerRunning = true
ss.writerDone = make(chan struct{}) ss.writerDone = make(chan struct{})
go ss.runWriter() go ss.runWriter()
for trackID, track := range ss.setuppedTracks { for trackID, track := range ss.setuppedTracks {
// readers can send RTCP packets // readers can send RTCP packets only
sc.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, false) sc.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, false)
// open the firewall by sending packets to the counterpart // open the firewall by sending packets to the counterpart
@@ -876,10 +883,11 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP
err = errSwitchReadFunc err = errSwitchReadFunc
ss.writeBuffer = ringbuffer.New(uint64(ss.s.ReadBufferCount)) // runWriter() is called by ServerConn after the response has been sent
// runWriter() is called by conn after sending the response
} }
ss.setuppedStream.readerSetActive(ss)
// add RTP-Info // add RTP-Info
var trackIDs []int var trackIDs []int
for trackID := range ss.setuppedTracks { for trackID := range ss.setuppedTracks {
@@ -917,8 +925,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
res.Header["RTP-Info"] = ri.Write() res.Header["RTP-Info"] = ri.Write()
} }
ss.setuppedStream.readerSetActive(ss)
return res, err return res, err
case base.Record: case base.Record:
@@ -955,6 +961,14 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, liberrors.ErrServerPathHasChanged{Prev: *ss.setuppedPath, Cur: path} }, liberrors.ErrServerPathHasChanged{Prev: *ss.setuppedPath, Cur: path}
} }
// allocate writeBuffer before calling OnRecord().
// in this way it's possible to call ServerSession.WritePacket*()
// inside the callback.
// when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
ss.writeBuffer = ringbuffer.New(uint64(8))
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
Session: ss, Session: ss,
Conn: sc, Conn: sc,
@@ -964,6 +978,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}) })
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
ss.writeBuffer = nil
return res, err return res, err
} }
@@ -974,10 +989,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.udpReceiverReportTimer = time.NewTimer(ss.s.udpReceiverReportPeriod) ss.udpReceiverReportTimer = time.NewTimer(ss.s.udpReceiverReportPeriod)
// when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
ss.writeBuffer = ringbuffer.New(uint64(8))
ss.writerRunning = true ss.writerRunning = true
ss.writerDone = make(chan struct{}) ss.writerDone = make(chan struct{})
go ss.runWriter() go ss.runWriter()
@@ -1000,10 +1011,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP
err = errSwitchReadFunc err = errSwitchReadFunc
// when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
ss.writeBuffer = ringbuffer.New(uint64(8))
// runWriter() is called by conn after sending the response // runWriter() is called by conn after sending the response
} }

View File

@@ -224,10 +224,8 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet) {
atomic.StoreUint32(&track.lastSequenceNumber, atomic.StoreUint32(&track.lastSequenceNumber,
uint32(pkt.Header.SequenceNumber)) uint32(pkt.Header.SequenceNumber))
atomic.StoreUint32(&track.lastTimeRTP, pkt.Header.Timestamp) atomic.StoreUint32(&track.lastTimeRTP, pkt.Header.Timestamp)
atomic.StoreInt64(&track.lastTimeNTP, time.Now().Unix()) atomic.StoreInt64(&track.lastTimeNTP, time.Now().Unix())
atomic.StoreUint32(&track.lastSSRC, pkt.Header.SSRC) atomic.StoreUint32(&track.lastSSRC, pkt.Header.SSRC)
st.mutex.RLock() st.mutex.RLock()