diff --git a/client.go b/client.go index 79b3fd35..9b7e399e 100644 --- a/client.go +++ b/client.go @@ -1608,9 +1608,7 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp // open the firewall by sending packets to the counterpart. for _, cct := range c.tracks { - byts, _ := (&rtp.Packet{ - Header: rtp.Header{Version: 2}, - }).Marshal() + byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() cct.udpRTPListener.write(byts) byts, _ = (&rtcp.ReceiverReport{}).Marshal() diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index 622f325d..c28ad1fc 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -131,7 +131,7 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) // if we are the publisher, route the RTP packet to readers if ctx.Session == sh.publisher { - sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload) + sh.stream.WritePacketRTP(ctx.TrackID, ctx.Packet) } } @@ -142,7 +142,7 @@ func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTCPCt // if we are the publisher, route the RTCP packet to readers if ctx.Session == sh.publisher { - sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) + sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Packet) } } diff --git a/examples/server/main.go b/examples/server/main.go index 4b1d7960..55be7c0b 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -130,7 +130,7 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) // if we are the publisher, route the RTP packet to readers if ctx.Session == sh.publisher { - sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload) + sh.stream.WritePacketRTP(ctx.TrackID, ctx.Packet) } } @@ -141,7 +141,7 @@ func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTCPCt // if we are the publisher, route the RTCP packet to readers if ctx.Session == sh.publisher { - sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) + sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Packet) } } diff --git a/server_publish_test.go b/server_publish_test.go index 51df064f..cd888b1d 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -644,12 +644,12 @@ func TestServerPublish(t *testing.T) { }, onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) { require.Equal(t, 0, ctx.TrackID) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) + require.Equal(t, &testRTPPacket, ctx.Packet) }, onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { require.Equal(t, 0, ctx.TrackID) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload) - ctx.Session.WritePacketRTCP(0, []byte{0x09, 0x0A, 0x0B, 0x0C}) + require.Equal(t, &testRTCPPacket, ctx.Packet) + ctx.Session.WritePacketRTCP(0, &testRTCPPacket) }, }, RTSPAddress: "localhost:8554", @@ -770,28 +770,28 @@ func TestServerPublish(t *testing.T) { if transport == "udp" { time.Sleep(1 * time.Second) - l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ + l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[0], }) time.Sleep(500 * time.Millisecond) - l2.WriteTo([]byte{0x05, 0x06, 0x07, 0x08}, &net.UDPAddr{ + l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[1], }) } else { base.InterleavedFrame{ Channel: 0, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, + Payload: testRTPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) base.InterleavedFrame{ Channel: 1, - Payload: []byte{0x05, 0x06, 0x07, 0x08}, + Payload: testRTCPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -807,14 +807,14 @@ func TestServerPublish(t *testing.T) { buf = make([]byte, 2048) n, _, err := l2.ReadFrom(buf) require.NoError(t, err) - require.Equal(t, []byte{0x09, 0x0A, 0x0B, 0x0C}, buf[:n]) + require.Equal(t, testRTCPPacketMarshaled, buf[:n]) } else { var f base.InterleavedFrame f.Payload = make([]byte, 2048) err := f.Read(br) require.NoError(t, err) require.Equal(t, 1, f.Channel) - require.Equal(t, []byte{0x09, 0x0A, 0x0B, 0x0C}, f.Payload) + require.Equal(t, testRTCPPacketMarshaled, f.Payload) } res, err = writeReqReadRes(conn, br, base.Request{ @@ -837,7 +837,15 @@ func TestServerPublish(t *testing.T) { } func TestServerPublishNonStandardFrameSize(t *testing.T) { - payload := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5) + packet := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 97, + CSRC: []uint32{}, + }, + Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5), + } + packetMarshaled, _ := packet.Marshal() frameReceived := make(chan struct{}) s := &Server{ @@ -861,7 +869,7 @@ func TestServerPublishNonStandardFrameSize(t *testing.T) { }, onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) { require.Equal(t, 0, ctx.TrackID) - require.Equal(t, payload, ctx.Payload) + require.Equal(t, &packet, ctx.Packet) close(frameReceived) }, }, @@ -937,7 +945,7 @@ func TestServerPublishNonStandardFrameSize(t *testing.T) { base.InterleavedFrame{ Channel: 0, - Payload: payload, + Payload: packetMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) diff --git a/server_read_test.go b/server_read_test.go index 597bc7e8..e4f5b2a4 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -289,8 +289,8 @@ func TestServerRead(t *testing.T) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { go func() { time.Sleep(1 * time.Second) - stream.WritePacketRTP(0, []byte{0x01, 0x02, 0x03, 0x04}) - stream.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08}) + stream.WritePacketRTP(0, &testRTPPacket) + stream.WritePacketRTCP(0, &testRTCPPacket) }() return &base.Response{ @@ -304,7 +304,7 @@ func TestServerRead(t *testing.T) { } require.Equal(t, 0, ctx.TrackID) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) + require.Equal(t, &testRTCPPacket, ctx.Packet) close(framesReceived) }, onGetParameter: func(ctx *ServerHandlerOnGetParameterCtx) (*base.Response, error) { @@ -466,7 +466,7 @@ func TestServerRead(t *testing.T) { buf := make([]byte, 2048) n, _, err := l1.ReadFrom(buf) require.NoError(t, err) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf[:n]) + require.Equal(t, testRTPPacketMarshaled, buf[:n]) // skip firewall opening if transport == "udp" { @@ -478,33 +478,33 @@ func TestServerRead(t *testing.T) { buf = make([]byte, 2048) n, _, err = l2.ReadFrom(buf) require.NoError(t, err) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n]) + require.Equal(t, testRTCPPacketMarshaled, buf[:n]) } else { var f base.InterleavedFrame f.Payload = make([]byte, 2048) err := f.Read(br) require.NoError(t, err) require.Equal(t, 4, f.Channel) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, f.Payload) + require.Equal(t, testRTPPacketMarshaled, f.Payload) f.Payload = make([]byte, 2048) err = f.Read(br) require.NoError(t, err) require.Equal(t, 5, f.Channel) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, f.Payload) + require.Equal(t, testRTCPPacketMarshaled, f.Payload) } // client -> server (RTCP) switch transport { case "udp": - l2.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ + l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[1], }) <-framesReceived case "multicast": - l2.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ + l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ IP: *th.Destination, Port: th.Ports[1], }) @@ -513,7 +513,7 @@ func TestServerRead(t *testing.T) { default: base.InterleavedFrame{ Channel: 5, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, + Payload: testRTCPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -614,13 +614,21 @@ func TestServerReadVLCMulticast(t *testing.T) { } func TestServerReadNonStandardFrameSize(t *testing.T) { + packet := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 97, + CSRC: []uint32{}, + }, + Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5), + } + packetMarshaled, _ := packet.Marshal() + track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) require.NoError(t, err) stream := NewServerStream(Tracks{track}) - payload := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5) - s := &Server{ Handler: &testServerHandler{ onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { @@ -631,7 +639,7 @@ func TestServerReadNonStandardFrameSize(t *testing.T) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { go func() { time.Sleep(1 * time.Second) - stream.WritePacketRTP(0, payload) + stream.WritePacketRTP(0, &packet) }() return &base.Response{ @@ -694,7 +702,7 @@ func TestServerReadNonStandardFrameSize(t *testing.T) { err = f.Read(br) require.NoError(t, err) require.Equal(t, 0, f.Channel) - require.Equal(t, payload, f.Payload) + require.Equal(t, packetMarshaled, f.Payload) } func TestServerReadTCPResponseBeforeFrames(t *testing.T) { @@ -722,7 +730,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { go func() { defer close(writerDone) - stream.WritePacketRTP(0, []byte("\x00\x00\x00\x00")) + stream.WritePacketRTP(0, &testRTPPacket) t := time.NewTicker(50 * time.Millisecond) defer t.Stop() @@ -730,7 +738,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { for { select { case <-t.C: - stream.WritePacketRTP(0, []byte("\x00\x00\x00\x00")) + stream.WritePacketRTP(0, &testRTPPacket) case <-writerTerminate: return } @@ -913,7 +921,7 @@ func TestServerReadPlayPausePlay(t *testing.T) { for { select { case <-t.C: - stream.WritePacketRTP(0, []byte("\x00\x00\x00\x00")) + stream.WritePacketRTP(0, &testRTPPacket) case <-writerTerminate: return } @@ -1033,7 +1041,7 @@ func TestServerReadPlayPausePause(t *testing.T) { for { select { case <-t.C: - stream.WritePacketRTP(0, []byte("\x00\x00\x00\x00")) + stream.WritePacketRTP(0, &testRTPPacket) case <-writerTerminate: return } @@ -1457,8 +1465,8 @@ func TestServerReadPartialTracks(t *testing.T) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { go func() { time.Sleep(1 * time.Second) - stream.WritePacketRTP(0, []byte{0x01, 0x02, 0x03, 0x04}) - stream.WritePacketRTP(1, []byte{0x05, 0x06, 0x07, 0x08}) + stream.WritePacketRTP(0, &testRTPPacket) + stream.WritePacketRTP(1, &testRTPPacket) }() return &base.Response{ @@ -1522,7 +1530,7 @@ func TestServerReadPartialTracks(t *testing.T) { err = f.Read(br) require.NoError(t, err) require.Equal(t, 4, f.Channel) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, f.Payload) + require.Equal(t, testRTPPacketMarshaled, f.Payload) } func TestServerReadAdditionalInfos(t *testing.T) { @@ -1630,8 +1638,8 @@ func TestServerReadAdditionalInfos(t *testing.T) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { go func() { time.Sleep(1 * time.Second) - stream.WritePacketRTP(1, []byte{0x01, 0x02, 0x03, 0x04}) - stream.WritePacketRTP(0, []byte{0x05, 0x06, 0x07, 0x08}) + stream.WritePacketRTP(1, &testRTPPacket) + stream.WritePacketRTP(0, &testRTPPacket) }() return &base.Response{ @@ -1646,7 +1654,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { require.NoError(t, err) defer s.Close() - buf, err := (&rtp.Packet{ + stream.WritePacketRTP(0, &rtp.Packet{ Header: rtp.Header{ Version: 2, PayloadType: 96, @@ -1655,9 +1663,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { SSRC: 96342362, }, Payload: []byte{0x01, 0x02, 0x03, 0x04}, - }).Marshal() - require.NoError(t, err) - stream.WritePacketRTP(0, buf) + }) rtpInfo, ssrcs := getInfos() require.Equal(t, &headers.RTPInfo{ @@ -1682,7 +1688,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { nil, }, ssrcs) - buf, err = (&rtp.Packet{ + stream.WritePacketRTP(1, &rtp.Packet{ Header: rtp.Header{ Version: 2, PayloadType: 96, @@ -1691,9 +1697,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { SSRC: 536474323, }, Payload: []byte{0x01, 0x02, 0x03, 0x04}, - }).Marshal() - require.NoError(t, err) - stream.WritePacketRTP(1, buf) + }) rtpInfo, ssrcs = getInfos() require.Equal(t, &headers.RTPInfo{ diff --git a/server_test.go b/server_test.go index fe0e46fb..5905d9a3 100644 --- a/server_test.go +++ b/server_test.go @@ -410,7 +410,7 @@ func TestServerHighLevelPublishRead(t *testing.T) { defer mutex.Unlock() if ctx.Session == publisher { - stream.WritePacketRTP(ctx.TrackID, ctx.Payload) + stream.WritePacketRTP(ctx.TrackID, ctx.Packet) } }, onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { @@ -418,7 +418,7 @@ func TestServerHighLevelPublishRead(t *testing.T) { defer mutex.Unlock() if ctx.Session == publisher { - stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) + stream.WritePacketRTCP(ctx.TrackID, ctx.Packet) } }, }, diff --git a/serverconn.go b/serverconn.go index ffedc8bb..8b587815 100644 --- a/serverconn.go +++ b/serverconn.go @@ -10,6 +10,9 @@ import ( "strings" "time" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/multibuffer" @@ -238,32 +241,52 @@ func (sc *ServerConn) run() { func (sc *ServerConn) tcpProcessPlay(trackID int, isRTP bool, payload []byte) { if !isRTP { + packets, err := rtcp.Unmarshal(payload) + if err != nil { + return + } + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { - h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ - Session: sc.tcpSession, - TrackID: trackID, - Payload: payload, - }) + for _, pkt := range packets { + h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ + Session: sc.tcpSession, + TrackID: trackID, + Packet: pkt, + }) + } } } } func (sc *ServerConn) tcpProcessRecord(trackID int, isRTP bool, payload []byte) { if isRTP { + var pkt rtp.Packet + err := pkt.Unmarshal(payload) + if err != nil { + return + } + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ Session: sc.tcpSession, TrackID: trackID, - Payload: payload, + Packet: &pkt, }) } } else { + packets, err := rtcp.Unmarshal(payload) + if err != nil { + return + } + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { - h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ - Session: sc.tcpSession, - TrackID: trackID, - Payload: payload, - }) + for _, pkt := range packets { + h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ + Session: sc.tcpSession, + TrackID: trackID, + Packet: pkt, + }) + } } } } diff --git a/serverhandler.go b/serverhandler.go index 74b23900..c69b37a3 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -1,6 +1,9 @@ package gortsplib import ( + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/aler9/gortsplib/pkg/base" ) @@ -183,7 +186,7 @@ type ServerHandlerOnSetParameter interface { type ServerHandlerOnPacketRTPCtx struct { Session *ServerSession TrackID int - Payload []byte + Packet *rtp.Packet } // ServerHandlerOnPacketRTP can be implemented by a ServerHandler. @@ -195,7 +198,7 @@ type ServerHandlerOnPacketRTP interface { type ServerHandlerOnPacketRTCPCtx struct { Session *ServerSession TrackID int - Payload []byte + Packet rtcp.Packet } // ServerHandlerOnPacketRTCP can be implemented by a ServerHandler. diff --git a/servermulticasthandler.go b/servermulticasthandler.go index 409f55e3..5d4b015a 100644 --- a/servermulticasthandler.go +++ b/servermulticasthandler.go @@ -77,14 +77,14 @@ func (h *serverMulticastHandler) runWriter() { } } -func (h *serverMulticastHandler) writeRTP(payload []byte) { +func (h *serverMulticastHandler) writePacketRTP(payload []byte) { h.writeBuffer.Push(trackTypePayload{ isRTP: true, payload: payload, }) } -func (h *serverMulticastHandler) writeRTCP(payload []byte) { +func (h *serverMulticastHandler) writePacketRTCP(payload []byte) { h.writeBuffer.Push(trackTypePayload{ isRTP: false, payload: payload, diff --git a/serversession.go b/serversession.go index e27aba78..b24e1e94 100644 --- a/serversession.go +++ b/serversession.go @@ -11,6 +11,9 @@ import ( "sync/atomic" "time" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/liberrors" @@ -355,8 +358,7 @@ func (ss *ServerSession) run() { for trackID, track := range ss.announcedTracks { rr := track.rtcpReceiver.Report(now) if rr != nil { - byts, _ := rr.Marshal() - ss.WritePacketRTCP(trackID, byts) + ss.WritePacketRTCP(trackID, rr) } } @@ -762,7 +764,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base th.Delivery = &de v := uint(127) th.TTL = &v - d := stream.multicastHandlers[trackID].ip() + d := stream.serverMulticastHandlers[trackID].ip() th.Destination = &d th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort} @@ -870,8 +872,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base sc.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, false) // open the firewall by sending packets to the counterpart - ss.WritePacketRTCP(trackID, - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) + byts, _ := (&rtcp.ReceiverReport{}).Marshal() + ss.s.udpRTCPListener.write(byts, + ss.setuppedTracks[trackID].udpRTCPAddr) } case TransportUDPMulticast: @@ -1000,10 +1003,10 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, true) // open the firewall by sending packets to the counterpart - ss.WritePacketRTP(trackID, - []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) - ss.WritePacketRTCP(trackID, - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) + byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() + ss.s.udpRTPListener.write(byts, ss.setuppedTracks[trackID].udpRTPAddr) + byts, _ = (&rtcp.ReceiverReport{}).Marshal() + ss.s.udpRTCPListener.write(byts, ss.setuppedTracks[trackID].udpRTCPAddr) } case TransportUDPMulticast: @@ -1201,8 +1204,7 @@ func (ss *ServerSession) runWriter() { } } -// WritePacketRTP writes a RTP packet to the session. -func (ss *ServerSession) WritePacketRTP(trackID int, payload []byte) { +func (ss *ServerSession) writePacketRTP(trackID int, byts []byte) { if _, ok := ss.setuppedTracks[trackID]; !ok { return } @@ -1210,12 +1212,21 @@ func (ss *ServerSession) WritePacketRTP(trackID int, payload []byte) { ss.writeBuffer.Push(trackTypePayload{ trackID: trackID, isRTP: true, - payload: payload, + payload: byts, }) } -// WritePacketRTCP writes a RTCP packet to the session. -func (ss *ServerSession) WritePacketRTCP(trackID int, payload []byte) { +// WritePacketRTP writes a RTP packet to the session. +func (ss *ServerSession) WritePacketRTP(trackID int, pkt *rtp.Packet) { + byts, err := pkt.Marshal() + if err != nil { + return + } + + ss.writePacketRTP(trackID, byts) +} + +func (ss *ServerSession) writePacketRTCP(trackID int, byts []byte) { if _, ok := ss.setuppedTracks[trackID]; !ok { return } @@ -1223,6 +1234,16 @@ func (ss *ServerSession) WritePacketRTCP(trackID int, payload []byte) { ss.writeBuffer.Push(trackTypePayload{ trackID: trackID, isRTP: false, - payload: payload, + payload: byts, }) } + +// WritePacketRTCP writes a RTCP packet to the session. +func (ss *ServerSession) WritePacketRTCP(trackID int, pkt rtcp.Packet) { + byts, err := pkt.Marshal() + if err != nil { + return + } + + ss.writePacketRTCP(trackID, byts) +} diff --git a/serverstream.go b/serverstream.go index 4994e6a3..21201c11 100644 --- a/serverstream.go +++ b/serverstream.go @@ -1,11 +1,13 @@ package gortsplib import ( - "encoding/binary" "sync" "sync/atomic" "time" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/aler9/gortsplib/pkg/liberrors" ) @@ -25,11 +27,11 @@ type ServerStream struct { s *Server tracks Tracks - mutex sync.RWMutex - readersUnicast map[*ServerSession]struct{} - readers map[*ServerSession]struct{} - multicastHandlers []*serverMulticastHandler - trackInfos []*trackInfo + mutex sync.RWMutex + readersUnicast map[*ServerSession]struct{} + readers map[*ServerSession]struct{} + serverMulticastHandlers []*serverMulticastHandler + trackInfos []*trackInfo } // NewServerStream allocates a ServerStream. @@ -67,11 +69,11 @@ func (st *ServerStream) Close() error { ss.Close() } - if st.multicastHandlers != nil { - for _, h := range st.multicastHandlers { + if st.serverMulticastHandlers != nil { + for _, h := range st.serverMulticastHandlers { h.close() } - st.multicastHandlers = nil + st.serverMulticastHandlers = nil } st.readers = nil @@ -138,22 +140,22 @@ func (st *ServerStream) readerAdd( case TransportUDPMulticast: // allocate multicast listeners - if st.multicastHandlers == nil { - st.multicastHandlers = make([]*serverMulticastHandler, len(st.tracks)) + if st.serverMulticastHandlers == nil { + st.serverMulticastHandlers = make([]*serverMulticastHandler, len(st.tracks)) for i := range st.tracks { h, err := newServerMulticastHandler(st.s) if err != nil { - for _, h := range st.multicastHandlers { + for _, h := range st.serverMulticastHandlers { if h != nil { h.close() } } - st.multicastHandlers = nil + st.serverMulticastHandlers = nil return err } - st.multicastHandlers[i] = h + st.serverMulticastHandlers[i] = h } } } @@ -169,12 +171,12 @@ func (st *ServerStream) readerRemove(ss *ServerSession) { delete(st.readers, ss) - if len(st.readers) == 0 && st.multicastHandlers != nil { - for _, l := range st.multicastHandlers { + if len(st.readers) == 0 && st.serverMulticastHandlers != nil { + for _, l := range st.serverMulticastHandlers { l.rtpl.close() l.rtcpl.close() } - st.multicastHandlers = nil + st.serverMulticastHandlers = nil } } @@ -188,8 +190,8 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) { default: // UDPMulticast for trackID := range ss.setuppedTracks { - st.multicastHandlers[trackID].rtcpl.addClient( - ss.author.ip(), st.multicastHandlers[trackID].rtcpl.port(), ss, trackID, false) + st.serverMulticastHandlers[trackID].rtcpl.addClient( + ss.author.ip(), st.serverMulticastHandlers[trackID].rtcpl.port(), ss, trackID, false) } } } @@ -203,56 +205,62 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) { delete(st.readersUnicast, ss) default: // UDPMulticast - if st.multicastHandlers != nil { + if st.serverMulticastHandlers != nil { for trackID := range ss.setuppedTracks { - st.multicastHandlers[trackID].rtcpl.removeClient(ss) + st.serverMulticastHandlers[trackID].rtcpl.removeClient(ss) } } } } // WritePacketRTP writes a RTP packet to all the readers of the stream. -func (st *ServerStream) WritePacketRTP(trackID int, payload []byte) { - if len(payload) >= 8 { - track := st.trackInfos[trackID] - - sequenceNumber := binary.BigEndian.Uint16(payload[2:4]) - atomic.StoreUint32(&track.lastSequenceNumber, uint32(sequenceNumber)) - - timestamp := binary.BigEndian.Uint32(payload[4:8]) - atomic.StoreUint32(&track.lastTimeRTP, timestamp) - atomic.StoreInt64(&track.lastTimeNTP, time.Now().Unix()) - - ssrc := binary.BigEndian.Uint32(payload[8:12]) - atomic.StoreUint32(&track.lastSSRC, ssrc) +func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet) { + byts, err := pkt.Marshal() + if err != nil { + return } + track := st.trackInfos[trackID] + + atomic.StoreUint32(&track.lastSequenceNumber, + uint32(pkt.Header.SequenceNumber)) + + atomic.StoreUint32(&track.lastTimeRTP, pkt.Header.Timestamp) + atomic.StoreInt64(&track.lastTimeNTP, time.Now().Unix()) + + atomic.StoreUint32(&track.lastSSRC, pkt.Header.SSRC) + st.mutex.RLock() defer st.mutex.RUnlock() // send unicast for r := range st.readersUnicast { - r.WritePacketRTP(trackID, payload) + r.writePacketRTP(trackID, byts) } // send multicast - if st.multicastHandlers != nil { - st.multicastHandlers[trackID].writeRTP(payload) + if st.serverMulticastHandlers != nil { + st.serverMulticastHandlers[trackID].writePacketRTP(byts) } } // WritePacketRTCP writes a RTCP packet to all the readers of the stream. -func (st *ServerStream) WritePacketRTCP(trackID int, payload []byte) { +func (st *ServerStream) WritePacketRTCP(trackID int, pkt rtcp.Packet) { + byts, err := pkt.Marshal() + if err != nil { + return + } + st.mutex.RLock() defer st.mutex.RUnlock() // send unicast for r := range st.readersUnicast { - r.WritePacketRTCP(trackID, payload) + r.writePacketRTCP(trackID, byts) } // send multicast - if st.multicastHandlers != nil { - st.multicastHandlers[trackID].writeRTCP(payload) + if st.serverMulticastHandlers != nil { + st.serverMulticastHandlers[trackID].writePacketRTCP(byts) } } diff --git a/serverudpl.go b/serverudpl.go index 7515f702..f662b671 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -210,7 +210,7 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ Session: clientData.ss, TrackID: clientData.trackID, - Payload: payload, + Packet: &pkt, }) } } @@ -224,17 +224,20 @@ func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) if clientData.isPublishing { now := time.Now() atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix()) + for _, pkt := range packets { clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessPacketRTCP(now, pkt) } } if h, ok := u.s.Handler.(ServerHandlerOnPacketRTCP); ok { - h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ - Session: clientData.ss, - TrackID: clientData.trackID, - Payload: payload, - }) + for _, pkt := range packets { + h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ + Session: clientData.ss, + TrackID: clientData.trackID, + Packet: pkt, + }) + } } }