From 2882bacdf254cc35a7b8215a09ce3b73dcb2ea81 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 30 Oct 2021 16:15:04 +0200 Subject: [PATCH] server: split WriteFrame into WritePacketRTP and WritePacketRTCP --- examples/server-tls/main.go | 4 +-- examples/server/main.go | 4 +-- server_publish_test.go | 2 +- server_read_test.go | 24 +++++++------- server_test.go | 8 ++--- serverconn.go | 12 +++---- serverhandler.go | 12 +++---- serversession.go | 62 +++++++++++++++++++++++-------------- serverstream.go | 42 ++++++++++++++++--------- serverudpl.go | 12 +++---- 10 files changed, 105 insertions(+), 77 deletions(-) diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index a565aa20..1166321b 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -131,7 +131,7 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) // if we are the publisher, route packet to readers if ctx.Session == sh.publisher { - sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTP, ctx.Payload) + sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload) } } @@ -142,7 +142,7 @@ func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx // if we are the publisher, route packet to readers if ctx.Session == sh.publisher { - sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTCP, ctx.Payload) + sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) } } diff --git a/examples/server/main.go b/examples/server/main.go index 3fb63cb1..5b8c249d 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -130,7 +130,7 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) // if we are the publisher, route packet to readers if ctx.Session == sh.publisher { - sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTP, ctx.Payload) + sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload) } } @@ -141,7 +141,7 @@ func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx // if we are the publisher, route packet to readers if ctx.Session == sh.publisher { - sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTCP, ctx.Payload) + sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) } } diff --git a/server_publish_test.go b/server_publish_test.go index 14cbc0c1..20b556a4 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -651,7 +651,7 @@ func TestServerPublish(t *testing.T) { onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { require.Equal(t, 0, ctx.TrackID) require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload) - ctx.Session.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C}) + ctx.Session.WritePacketRTCP(0, []byte{0x09, 0x0A, 0x0B, 0x0C}) }, }, } diff --git a/server_read_test.go b/server_read_test.go index 24a0ed00..f69fa08e 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -295,8 +295,8 @@ func TestServerRead(t *testing.T) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { go func() { time.Sleep(1 * time.Second) - stream.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) - stream.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08}) + stream.WritePacketRTP(0, []byte{0x01, 0x02, 0x03, 0x04}) + stream.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08}) }() return &base.Response{ @@ -673,7 +673,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { go func() { defer close(writerDone) - stream.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + stream.WritePacketRTP(0, []byte("\x00\x00\x00\x00")) t := time.NewTicker(50 * time.Millisecond) defer t.Stop() @@ -681,7 +681,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { for { select { case <-t.C: - stream.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + stream.WritePacketRTP(0, []byte("\x00\x00\x00\x00")) case <-writerTerminate: return } @@ -857,7 +857,7 @@ func TestServerReadPlayPausePlay(t *testing.T) { for { select { case <-t.C: - stream.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + stream.WritePacketRTP(0, []byte("\x00\x00\x00\x00")) case <-writerTerminate: return } @@ -973,7 +973,7 @@ func TestServerReadPlayPausePause(t *testing.T) { for { select { case <-t.C: - stream.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + stream.WritePacketRTP(0, []byte("\x00\x00\x00\x00")) case <-writerTerminate: return } @@ -1477,8 +1477,8 @@ func TestServerReadNonSetuppedPath(t *testing.T) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { go func() { time.Sleep(1 * time.Second) - stream.WriteFrame(1, base.StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) - stream.WriteFrame(0, base.StreamTypeRTP, []byte{0x05, 0x06, 0x07, 0x08}) + stream.WritePacketRTP(1, []byte{0x01, 0x02, 0x03, 0x04}) + stream.WritePacketRTP(0, []byte{0x05, 0x06, 0x07, 0x08}) }() return &base.Response{ @@ -1642,8 +1642,8 @@ func TestServerReadAdditionalInfos(t *testing.T) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { go func() { time.Sleep(1 * time.Second) - stream.WriteFrame(1, base.StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) - stream.WriteFrame(0, base.StreamTypeRTP, []byte{0x05, 0x06, 0x07, 0x08}) + stream.WritePacketRTP(1, []byte{0x01, 0x02, 0x03, 0x04}) + stream.WritePacketRTP(0, []byte{0x05, 0x06, 0x07, 0x08}) }() return &base.Response{ @@ -1669,7 +1669,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { Payload: []byte{0x01, 0x02, 0x03, 0x04}, }).Marshal() require.NoError(t, err) - stream.WriteFrame(0, StreamTypeRTP, buf) + stream.WritePacketRTP(0, buf) rtpInfo, ssrcs := getInfos() require.Equal(t, &headers.RTPInfo{ @@ -1705,7 +1705,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { Payload: []byte{0x01, 0x02, 0x03, 0x04}, }).Marshal() require.NoError(t, err) - stream.WriteFrame(1, StreamTypeRTP, buf) + stream.WritePacketRTP(1, buf) rtpInfo, ssrcs = getInfos() require.Equal(t, &headers.RTPInfo{ diff --git a/server_test.go b/server_test.go index f79e6459..36e078c4 100644 --- a/server_test.go +++ b/server_test.go @@ -48,8 +48,8 @@ type testServerHandler struct { onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error) onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error) onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error) - onPacketRTP func(*ServerHandlerOnPacketRTPCtx) - onPacketRTCP func(*ServerHandlerOnPacketRTCPCtx) + onPacketRTP func(*ServerHandlerOnPacketRTPCtx) + onPacketRTCP func(*ServerHandlerOnPacketRTCPCtx) onSetParameter func(*ServerHandlerOnSetParameterCtx) (*base.Response, error) onGetParameter func(*ServerHandlerOnGetParameterCtx) (*base.Response, error) } @@ -404,7 +404,7 @@ func TestServerHighLevelPublishRead(t *testing.T) { defer mutex.Unlock() if ctx.Session == publisher { - stream.WriteFrame(ctx.TrackID, StreamTypeRTP, ctx.Payload) + stream.WritePacketRTP(ctx.TrackID, ctx.Payload) } }, onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { @@ -412,7 +412,7 @@ func TestServerHighLevelPublishRead(t *testing.T) { defer mutex.Unlock() if ctx.Session == publisher { - stream.WriteFrame(ctx.TrackID, StreamTypeRTCP, ctx.Payload) + stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) } }, }, diff --git a/serverconn.go b/serverconn.go index 18cef584..be438584 100644 --- a/serverconn.go +++ b/serverconn.go @@ -162,17 +162,17 @@ func (sc *ServerConn) run() { if streamType == StreamTypeRTP { if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: sc.tcpSession, - TrackID: trackID, - Payload: frame.Payload, + Session: sc.tcpSession, + TrackID: trackID, + Payload: frame.Payload, }) } } else { if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ - Session: sc.tcpSession, - TrackID: trackID, - Payload: frame.Payload, + Session: sc.tcpSession, + TrackID: trackID, + Payload: frame.Payload, }) } } diff --git a/serverhandler.go b/serverhandler.go index 511bf8c1..74b23900 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -181,9 +181,9 @@ type ServerHandlerOnSetParameter interface { // ServerHandlerOnPacketRTPCtx is the context of a RTP packet. type ServerHandlerOnPacketRTPCtx struct { - Session *ServerSession - TrackID int - Payload []byte + Session *ServerSession + TrackID int + Payload []byte } // ServerHandlerOnPacketRTP can be implemented by a ServerHandler. @@ -193,9 +193,9 @@ type ServerHandlerOnPacketRTP interface { // ServerHandlerOnPacketRTCPCtx is the context of a RTCP packet. type ServerHandlerOnPacketRTCPCtx struct { - Session *ServerSession - TrackID int - Payload []byte + Session *ServerSession + TrackID int + Payload []byte } // ServerHandlerOnPacketRTCP can be implemented by a ServerHandler. diff --git a/serversession.go b/serversession.go index 042d8bd9..3af883c1 100644 --- a/serversession.go +++ b/serversession.go @@ -339,7 +339,7 @@ func (ss *ServerSession) run() { now := time.Now() for trackID, track := range ss.announcedTracks { r := track.rtcpReceiver.Report(now) - ss.WriteFrame(trackID, StreamTypeRTCP, r) + ss.WritePacketRTCP(trackID, r) } case <-ss.ctx.Done(): @@ -863,7 +863,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base sc.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, false) // open the firewall by sending packets to the counterpart - ss.WriteFrame(trackID, StreamTypeRTCP, + ss.WritePacketRTCP(trackID, []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) } @@ -905,7 +905,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base path, query := base.PathSplitQuery(pathAndQuery) - // allow to use WriteFrame() before response + // allow to use WritePacket*() before response if *ss.setuppedTransport == TransportTCP { ss.tcpConn = sc } @@ -938,9 +938,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, true) // open the firewall by sending packets to the counterpart - ss.WriteFrame(trackID, StreamTypeRTP, + ss.WritePacketRTP(trackID, []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) - ss.WriteFrame(trackID, StreamTypeRTCP, + ss.WritePacketRTCP(trackID, []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) } @@ -1065,8 +1065,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base }, liberrors.ErrServerUnhandledRequest{Req: req} } -// WriteFrame writes a frame to the session. -func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload []byte) { +// WritePacketRTP writes a RTP packet to the session. +func (ss *ServerSession) WritePacketRTP(trackID int, payload []byte) { if _, ok := ss.setuppedTracks[trackID]; !ok { return } @@ -1075,25 +1075,41 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload case TransportUDP: track := ss.setuppedTracks[trackID] - if streamType == StreamTypeRTP { - ss.s.udpRTPListener.write(payload, &net.UDPAddr{ - IP: ss.author.ip(), - Zone: ss.author.zone(), - Port: track.udpRTPPort, - }) - } else { - ss.s.udpRTCPListener.write(payload, &net.UDPAddr{ - IP: ss.author.ip(), - Zone: ss.author.zone(), - Port: track.udpRTCPPort, - }) - } + ss.s.udpRTPListener.write(payload, &net.UDPAddr{ + IP: ss.author.ip(), + Zone: ss.author.zone(), + Port: track.udpRTPPort, + }) case TransportTCP: channel := ss.setuppedTracks[trackID].tcpChannel - if streamType == base.StreamTypeRTCP { - channel++ - } + + ss.tcpConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{ + Channel: channel, + Payload: payload, + }) + } +} + +// WritePacketRTCP writes a RTCP packet to the session. +func (ss *ServerSession) WritePacketRTCP(trackID int, payload []byte) { + if _, ok := ss.setuppedTracks[trackID]; !ok { + return + } + + switch *ss.setuppedTransport { + case TransportUDP: + track := ss.setuppedTracks[trackID] + + ss.s.udpRTCPListener.write(payload, &net.UDPAddr{ + IP: ss.author.ip(), + Zone: ss.author.zone(), + Port: track.udpRTCPPort, + }) + + case TransportTCP: + channel := ss.setuppedTracks[trackID].tcpChannel + channel++ ss.tcpConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{ Channel: channel, diff --git a/serverstream.go b/serverstream.go index ad2e5702..98f9f17e 100644 --- a/serverstream.go +++ b/serverstream.go @@ -221,9 +221,9 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) { } } -// WriteFrame writes a frame to all the readers of the stream. -func (st *ServerStream) WriteFrame(trackID int, streamType StreamType, payload []byte) { - if streamType == StreamTypeRTP && len(payload) >= 8 { +// WritePacketRTP writes a RTP packet to all the readers of the stream. +func (st *ServerStream) WritePacketRTP(trackID int, payload []byte) { + if len(payload) >= 8 { track := st.trackInfos[trackID] sequenceNumber := binary.BigEndian.Uint16(payload[2:4]) @@ -242,21 +242,33 @@ func (st *ServerStream) WriteFrame(trackID int, streamType StreamType, payload [ // send unicast for r := range st.readersUnicast { - r.WriteFrame(trackID, streamType, payload) + r.WritePacketRTP(trackID, payload) } // send multicast if st.multicastListeners != nil { - if streamType == StreamTypeRTP { - st.multicastListeners[trackID].rtpListener.write(payload, &net.UDPAddr{ - IP: st.multicastListeners[trackID].rtpListener.ip(), - Port: st.multicastListeners[trackID].rtpListener.port(), - }) - } else { - st.multicastListeners[trackID].rtcpListener.write(payload, &net.UDPAddr{ - IP: st.multicastListeners[trackID].rtpListener.ip(), - Port: st.multicastListeners[trackID].rtcpListener.port(), - }) - } + st.multicastListeners[trackID].rtpListener.write(payload, &net.UDPAddr{ + IP: st.multicastListeners[trackID].rtpListener.ip(), + Port: st.multicastListeners[trackID].rtpListener.port(), + }) + } +} + +// WritePacketRTCP writes a RTCP packet to all the readers of the stream. +func (st *ServerStream) WritePacketRTCP(trackID int, payload []byte) { + st.mutex.RLock() + defer st.mutex.RUnlock() + + // send unicast + for r := range st.readersUnicast { + r.WritePacketRTCP(trackID, payload) + } + + // send multicast + if st.multicastListeners != nil { + st.multicastListeners[trackID].rtcpListener.write(payload, &net.UDPAddr{ + IP: st.multicastListeners[trackID].rtcpListener.ip(), + Port: st.multicastListeners[trackID].rtcpListener.port(), + }) } } diff --git a/serverudpl.go b/serverudpl.go index ca26b73f..d118afee 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -211,17 +211,17 @@ func (u *serverUDPListener) run() { if u.streamType == StreamTypeRTP { if h, ok := u.s.Handler.(ServerHandlerOnPacketRTP); ok { h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: clientData.ss, - TrackID: clientData.trackID, - Payload: buf[:n], + Session: clientData.ss, + TrackID: clientData.trackID, + Payload: buf[:n], }) } } else { if h, ok := u.s.Handler.(ServerHandlerOnPacketRTCP); ok { h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ - Session: clientData.ss, - TrackID: clientData.trackID, - Payload: buf[:n], + Session: clientData.ss, + TrackID: clientData.trackID, + Payload: buf[:n], }) } }