diff --git a/client.go b/client.go index 8d8a6eda..79b3fd35 100644 --- a/client.go +++ b/client.go @@ -21,6 +21,9 @@ import ( "sync/atomic" "time" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/headers" @@ -132,9 +135,9 @@ type Client struct { // called after every response. OnResponse func(*base.Response) // called when a RTP packet arrives. - OnPacketRTP func(int, []byte) + OnPacketRTP func(int, *rtp.Packet) // called when a RTCP packet arrives. - OnPacketRTCP func(int, []byte) + OnPacketRTCP func(int, rtcp.Packet) // // RTSP parameters @@ -249,11 +252,11 @@ type Client struct { func (c *Client) Start(scheme string, host string) error { // callbacks if c.OnPacketRTP == nil { - c.OnPacketRTP = func(trackID int, payload []byte) { + c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) { } } if c.OnPacketRTCP == nil { - c.OnPacketRTCP = func(trackID int, payload []byte) { + c.OnPacketRTCP = func(trackID int, pkt rtcp.Packet) { } } @@ -757,17 +760,37 @@ func (c *Client) runReader() { atomic.StoreInt64(c.tcpLastFrameTime, now.Unix()) if isRTP { - c.tracks[trackID].rtcpReceiver.ProcessPacketRTP(now, payload) - c.OnPacketRTP(trackID, payload) + var pkt rtp.Packet + err := pkt.Unmarshal(payload) + if err != nil { + return + } + + c.tracks[trackID].rtcpReceiver.ProcessPacketRTP(now, &pkt) + c.OnPacketRTP(trackID, &pkt) } else { - c.tracks[trackID].rtcpReceiver.ProcessPacketRTCP(now, payload) - c.OnPacketRTCP(trackID, payload) + packets, err := rtcp.Unmarshal(payload) + if err != nil { + return + } + + for _, pkt := range packets { + c.tracks[trackID].rtcpReceiver.ProcessPacketRTCP(now, pkt) + c.OnPacketRTCP(trackID, pkt) + } } } } else { processFunc = func(trackID int, isRTP bool, payload []byte) { if !isRTP { - c.OnPacketRTCP(trackID, payload) + packets, err := rtcp.Unmarshal(payload) + if err != nil { + return + } + + for _, pkt := range packets { + c.OnPacketRTCP(trackID, pkt) + } } } } @@ -1585,11 +1608,13 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp // open the firewall by sending packets to the counterpart. for _, cct := range c.tracks { - cct.udpRTPListener.write( - []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + byts, _ := (&rtp.Packet{ + Header: rtp.Header{Version: 2}, + }).Marshal() + cct.udpRTPListener.write(byts) - cct.udpRTCPListener.write( - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) + byts, _ = (&rtcp.ReceiverReport{}).Marshal() + cct.udpRTCPListener.write(byts) } } @@ -1806,10 +1831,6 @@ func (c *Client) runWriter() { case TransportUDP, TransportUDPMulticast: writeFunc = func(trackID int, isRTP bool, payload []byte) { if isRTP { - if c.tracks[trackID].rtcpSender != nil { - c.tracks[trackID].rtcpSender.ProcessPacketRTP(time.Now(), payload) - } - c.tracks[trackID].udpRTPListener.write(payload) } else { c.tracks[trackID].udpRTCPListener.write(payload) @@ -1821,10 +1842,6 @@ func (c *Client) runWriter() { writeFunc = func(trackID int, isRTP bool, payload []byte) { if isRTP { - if c.tracks[trackID].rtcpSender != nil { - c.tracks[trackID].rtcpSender.ProcessPacketRTP(time.Now(), payload) - } - f := c.tracks[trackID].tcpRTPFrame f.Payload = payload f.Write(&buf) @@ -1854,7 +1871,7 @@ func (c *Client) runWriter() { } // WritePacketRTP writes a RTP packet. -func (c *Client) WritePacketRTP(trackID int, payload []byte) error { +func (c *Client) WritePacketRTP(trackID int, pkt *rtp.Packet) error { c.writeMutex.RLock() defer c.writeMutex.RUnlock() @@ -1867,16 +1884,25 @@ func (c *Client) WritePacketRTP(trackID int, payload []byte) error { } } + byts, err := pkt.Marshal() + if err != nil { + return err + } + + if c.tracks[trackID].rtcpSender != nil { + c.tracks[trackID].rtcpSender.ProcessPacketRTP(time.Now(), pkt) + } + c.writeBuffer.Push(trackTypePayload{ trackID: trackID, isRTP: true, - payload: payload, + payload: byts, }) return nil } // WritePacketRTCP writes a RTCP packet. -func (c *Client) WritePacketRTCP(trackID int, payload []byte) error { +func (c *Client) WritePacketRTCP(trackID int, pkt rtcp.Packet) error { c.writeMutex.RLock() defer c.writeMutex.RUnlock() @@ -1889,10 +1915,15 @@ func (c *Client) WritePacketRTCP(trackID int, payload []byte) error { } } + byts, err := pkt.Marshal() + if err != nil { + return err + } + c.writeBuffer.Push(trackTypePayload{ trackID: trackID, isRTP: false, - payload: payload, + payload: byts, }) return nil } diff --git a/client_publish_test.go b/client_publish_test.go index ef0df2f1..469a4ec4 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -18,6 +18,39 @@ import ( "github.com/aler9/gortsplib/pkg/rtcpreceiver" ) +var testRTPPacket = rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 97, + CSRC: []uint32{}, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, +} + +var testRTPPacketMarshaled = func() []byte { + byts, _ := testRTPPacket.Marshal() + return byts +}() + +var testRTCPPacket = rtcp.SourceDescription{ + Chunks: []rtcp.SourceDescriptionChunk{ + { + Source: 1234, + Items: []rtcp.SourceDescriptionItem{ + { + Type: rtcp.SDESCNAME, + Text: "myname", + }, + }, + }, + }, +} + +var testRTCPPacketMarshaled = func() []byte { + byts, _ := testRTCPPacket.Marshal() + return byts +}() + func TestClientPublishSerial(t *testing.T) { for _, transport := range []string{ "udp", @@ -138,31 +171,37 @@ func TestClientPublishSerial(t *testing.T) { _, err = conn.Write(bb.Bytes()) require.NoError(t, err) - // client -> server + // client -> server (RTP) if transport == "udp" { buf := make([]byte, 2048) n, _, err := l1.ReadFrom(buf) require.NoError(t, err) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf[:n]) + var pkt rtp.Packet + err = pkt.Unmarshal(buf[:n]) + require.NoError(t, err) + require.Equal(t, testRTPPacket, pkt) } else { var f base.InterleavedFrame f.Payload = make([]byte, 2048) err = f.Read(br) require.NoError(t, err) require.Equal(t, 0, f.Channel) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, f.Payload) + var pkt rtp.Packet + err = pkt.Unmarshal(f.Payload) + require.NoError(t, err) + require.Equal(t, testRTPPacket, pkt) } // server -> client (RTCP) if transport == "udp" { - l2.WriteTo([]byte{0x05, 0x06, 0x07, 0x08}, &net.UDPAddr{ + l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[1], }) } else { base.InterleavedFrame{ Channel: 1, - Payload: []byte{0x05, 0x06, 0x07, 0x08}, + Payload: testRTCPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -194,9 +233,9 @@ func TestClientPublishSerial(t *testing.T) { v := TransportTCP return &v }(), - OnPacketRTCP: func(trackID int, payload []byte) { + OnPacketRTCP: func(trackID int, pkt rtcp.Packet) { require.Equal(t, 0, trackID) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, payload) + require.Equal(t, &testRTCPPacket, pkt) close(recvDone) }, } @@ -214,16 +253,14 @@ func TestClientPublishSerial(t *testing.T) { c.Wait() }() - err = c.WritePacketRTP(0, - []byte{0x01, 0x02, 0x03, 0x04}) + err = c.WritePacketRTP(0, &testRTPPacket) require.NoError(t, err) <-recvDone c.Close() <-done - err = c.WritePacketRTP(0, - []byte{0x01, 0x02, 0x03, 0x04}) + err = c.WritePacketRTP(0, &testRTPPacket) require.Error(t, err) }) } @@ -376,8 +413,7 @@ func TestClientPublishParallel(t *testing.T) { defer t.Stop() for range t.C { - err := c.WritePacketRTP(0, - []byte{0x01, 0x02, 0x03, 0x04}) + err := c.WritePacketRTP(0, &testRTPPacket) if err != nil { return } @@ -531,8 +567,7 @@ func TestClientPublishPauseSerial(t *testing.T) { require.NoError(t, err) defer c.Close() - err = c.WritePacketRTP(0, - []byte{0x01, 0x02, 0x03, 0x04}) + err = c.WritePacketRTP(0, &testRTPPacket) require.NoError(t, err) _, err = c.Pause() @@ -541,8 +576,7 @@ func TestClientPublishPauseSerial(t *testing.T) { _, err = c.Record() require.NoError(t, err) - err = c.WritePacketRTP(0, - []byte{0x01, 0x02, 0x03, 0x04}) + err = c.WritePacketRTP(0, &testRTPPacket) require.NoError(t, err) }) } @@ -677,8 +711,7 @@ func TestClientPublishPauseParallel(t *testing.T) { defer t.Stop() for range t.C { - err := c.WritePacketRTP(0, - []byte{0x01, 0x02, 0x03, 0x04}) + err := c.WritePacketRTP(0, &testRTPPacket) if err != nil { return } @@ -794,7 +827,10 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { err = f.Read(br) require.NoError(t, err) require.Equal(t, 0, f.Channel) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, f.Payload) + var pkt rtp.Packet + err = pkt.Unmarshal(f.Payload) + require.NoError(t, err) + require.Equal(t, testRTPPacket, pkt) req, err = readRequest(br) require.NoError(t, err) @@ -817,8 +853,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { require.NoError(t, err) defer c.Close() - err = c.WritePacketRTP(0, - []byte{0x01, 0x02, 0x03, 0x04}) + err = c.WritePacketRTP(0, &testRTPPacket) require.NoError(t, err) } @@ -915,14 +950,17 @@ func TestClientPublishRTCPReport(t *testing.T) { buf := make([]byte, 2048) n, _, err := l1.ReadFrom(buf) require.NoError(t, err) - rr.ProcessPacketRTP(time.Now(), buf[:n]) + var pkt rtp.Packet + err = pkt.Unmarshal(buf[:n]) + require.NoError(t, err) + rr.ProcessPacketRTP(time.Now(), &pkt) buf = make([]byte, 2048) n, _, err = l2.ReadFrom(buf) require.NoError(t, err) - pkt, err := rtcp.Unmarshal(buf[:n]) + packets, err := rtcp.Unmarshal(buf[:n]) require.NoError(t, err) - sr, ok := pkt[0].(*rtcp.SenderReport) + sr, ok := packets[0].(*rtcp.SenderReport) require.True(t, ok) require.Equal(t, &rtcp.SenderReport{ SSRC: 753621, @@ -931,7 +969,7 @@ func TestClientPublishRTCPReport(t *testing.T) { PacketCount: 1, OctetCount: 4, }, sr) - rr.ProcessPacketRTCP(time.Now(), buf[:n]) + rr.ProcessPacketRTCP(time.Now(), packets[0]) close(reportReceived) @@ -958,7 +996,7 @@ func TestClientPublishRTCPReport(t *testing.T) { require.NoError(t, err) defer c.Close() - byts, _ := (&rtp.Packet{ + err = c.WritePacketRTP(0, &rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: true, @@ -968,8 +1006,7 @@ func TestClientPublishRTCPReport(t *testing.T) { SSRC: 753621, }, Payload: []byte{0x01, 0x02, 0x03, 0x04}, - }).Marshal() - err = c.WritePacketRTP(0, byts) + }) require.NoError(t, err) <-reportReceived @@ -1056,14 +1093,14 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { base.InterleavedFrame{ Channel: 0, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, + Payload: testRTPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) base.InterleavedFrame{ Channel: 1, - Payload: []byte{0x05, 0x06, 0x07, 0x08}, + Payload: testRTCPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -1086,10 +1123,10 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { v := TransportTCP return &v }(), - OnPacketRTP: func(trackID int, payload []byte) { + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { t.Errorf("should not happen") }, - OnPacketRTCP: func(trackID int, payload []byte) { + OnPacketRTCP: func(trackID int, pkt rtcp.Packet) { close(rtcpReceived) }, } diff --git a/client_read_test.go b/client_read_test.go index f227b453..047b772f 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -317,18 +317,18 @@ func TestClientRead(t *testing.T) { _, err = conn.Write(bb.Bytes()) require.NoError(t, err) - // server -> client + // server -> client (RTP) switch transport { case "udp": time.Sleep(1 * time.Second) - l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ + l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[0], }) case "multicast": time.Sleep(1 * time.Second) - l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ + l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("224.1.0.1"), Port: 25000, }) @@ -336,7 +336,7 @@ func TestClientRead(t *testing.T) { case "tcp", "tls": base.InterleavedFrame{ Channel: 0, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, + Payload: testRTPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -353,7 +353,9 @@ func TestClientRead(t *testing.T) { buf = make([]byte, 2048) n, _, err := l2.ReadFrom(buf) require.NoError(t, err) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n]) + packets, err := rtcp.Unmarshal(buf[:n]) + require.NoError(t, err) + require.Equal(t, &testRTCPPacket, packets[0]) close(packetRecv) case "tcp", "tls": @@ -362,7 +364,9 @@ func TestClientRead(t *testing.T) { err := f.Read(br) require.NoError(t, err) require.Equal(t, 1, f.Channel) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, f.Payload) + packets, err := rtcp.Unmarshal(f.Payload) + require.NoError(t, err) + require.Equal(t, &testRTCPPacket, packets[0]) close(packetRecv) } @@ -401,7 +405,7 @@ func TestClientRead(t *testing.T) { }(), } - c.OnPacketRTP = func(trackID int, payload []byte) { + c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) { // ignore multicast loopback if transport == "multicast" { counter++ @@ -411,9 +415,9 @@ func TestClientRead(t *testing.T) { } require.Equal(t, 0, trackID) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) + require.Equal(t, &testRTPPacket, pkt) - err := c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08}) + err := c.WritePacketRTCP(0, &testRTCPPacket) require.NoError(t, err) } @@ -427,7 +431,14 @@ func TestClientRead(t *testing.T) { } func TestClientReadNonStandardFrameSize(t *testing.T) { - refPayload := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5) + refRTPPacket := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + CSRC: []uint32{}, + }, + Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5), + } l, err := net.Listen("tcp", "localhost:8554") require.NoError(t, err) @@ -519,9 +530,10 @@ func TestClientReadNonStandardFrameSize(t *testing.T) { _, err = conn.Write(bb.Bytes()) require.NoError(t, err) + byts, _ := refRTPPacket.Marshal() base.InterleavedFrame{ Channel: 0, - Payload: refPayload, + Payload: byts, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -530,14 +542,14 @@ func TestClientReadNonStandardFrameSize(t *testing.T) { packetRecv := make(chan struct{}) c := &Client{ - ReadBufferSize: 4500, + ReadBufferSize: 4500 + 4, Transport: func() *Transport { v := TransportTCP return &v }(), - OnPacketRTP: func(trackID int, payload []byte) { + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { require.Equal(t, 0, trackID) - require.Equal(t, refPayload, payload) + require.Equal(t, &refRTPPacket, pkt) close(packetRecv) }, } @@ -632,7 +644,7 @@ func TestClientReadPartial(t *testing.T) { base.InterleavedFrame{ Channel: 0, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, + Payload: testRTPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -656,9 +668,9 @@ func TestClientReadPartial(t *testing.T) { v := TransportTCP return &v }(), - OnPacketRTP: func(trackID int, payload []byte) { + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { require.Equal(t, 0, trackID) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) + require.Equal(t, &testRTPPacket, pkt) close(packetRecv) }, } @@ -916,11 +928,12 @@ func TestClientReadAnyPort(t *testing.T) { time.Sleep(500 * time.Millisecond) - l1a.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ + l1a.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[0], }) + // read RTCP if ca == "random" { // skip firewall opening buf := make([]byte, 2048) @@ -930,7 +943,9 @@ func TestClientReadAnyPort(t *testing.T) { buf = make([]byte, 2048) n, _, err := l1b.ReadFrom(buf) require.NoError(t, err) - require.Equal(t, buf[:n], []byte{0x05, 0x06, 0x07, 0x08}) + packets, err := rtcp.Unmarshal(buf[:n]) + require.NoError(t, err) + require.Equal(t, &testRTCPPacket, packets[0]) close(serverRecv) } }() @@ -939,8 +954,8 @@ func TestClientReadAnyPort(t *testing.T) { c := &Client{ AnyPortEnable: true, - OnPacketRTP: func(trackID int, payload []byte) { - require.Equal(t, payload, []byte{0x01, 0x02, 0x03, 0x04}) + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { + require.Equal(t, &testRTPPacket, pkt) close(packetRecv) }, } @@ -952,7 +967,7 @@ func TestClientReadAnyPort(t *testing.T) { <-packetRecv if ca == "random" { - c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08}) + c.WritePacketRTCP(0, &testRTCPPacket) <-serverRecv } }) @@ -1061,7 +1076,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { base.InterleavedFrame{ Channel: 0, - Payload: []byte("\x00\x00\x00\x00"), + Payload: testRTPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -1070,7 +1085,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { packetRecv := make(chan struct{}) c := Client{ - OnPacketRTP: func(trackID int, payload []byte) { + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { close(packetRecv) }, } @@ -1279,7 +1294,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { base.InterleavedFrame{ Channel: 0, - Payload: []byte("\x00\x00\x00\x00"), + Payload: testRTPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -1301,7 +1316,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { c := &Client{ ReadTimeout: 1 * time.Second, - OnPacketRTP: func(trackID int, payload []byte) { + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { close(packetRecv) }, } @@ -1409,7 +1424,7 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { base.InterleavedFrame{ Channel: 2, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, + Payload: testRTPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -1433,7 +1448,7 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { v := TransportTCP return &v }(), - OnPacketRTP: func(trackID int, payload []byte) { + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { require.Equal(t, 0, trackID) close(packetRecv) }, @@ -1577,7 +1592,7 @@ func TestClientReadRedirect(t *testing.T) { require.NoError(t, err) defer l1.Close() - l1.WriteTo([]byte("\x00\x00\x00\x00"), &net.UDPAddr{ + l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[0], }) @@ -1586,7 +1601,7 @@ func TestClientReadRedirect(t *testing.T) { packetRecv := make(chan struct{}) c := Client{ - OnPacketRTP: func(trackID int, payload []byte) { + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { close(packetRecv) }, } @@ -1622,14 +1637,14 @@ func TestClientReadPause(t *testing.T) { select { case <-t.C: if inTH.Protocol == headers.TransportProtocolUDP { - l1.WriteTo([]byte("\x00\x00\x00\x00"), &net.UDPAddr{ + l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: inTH.ClientPorts[0], }) } else { base.InterleavedFrame{ Channel: 0, - Payload: []byte("\x00\x00\x00\x00"), + Payload: testRTPPacketMarshaled, }.Write(&bb) conn.Write(bb.Bytes()) } @@ -1797,7 +1812,7 @@ func TestClientReadPause(t *testing.T) { v := TransportTCP return &v }(), - OnPacketRTP: func(trackID int, payload []byte) { + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { if atomic.SwapInt32(&firstFrame, 1) == 0 { close(packetRecv) } @@ -1930,7 +1945,7 @@ func TestClientReadRTCPReport(t *testing.T) { rs := rtcpsender.New(90000) - byts, _ := (&rtp.Packet{ + pkt := rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: true, @@ -1940,15 +1955,18 @@ func TestClientReadRTCPReport(t *testing.T) { SSRC: 753621, }, Payload: []byte{0x01, 0x02, 0x03, 0x04}, - }).Marshal() + } + byts, _ := pkt.Marshal() _, err = l1.WriteTo(byts, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: inTH.ClientPorts[0], }) require.NoError(t, err) - rs.ProcessPacketRTP(time.Now(), byts) + rs.ProcessPacketRTP(time.Now(), &pkt) - _, err = l2.WriteTo(rs.Report(time.Now()), &net.UDPAddr{ + sr := rs.Report(time.Now()) + byts, _ = sr.Marshal() + _, err = l2.WriteTo(byts, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: inTH.ClientPorts[1], }) @@ -1957,9 +1975,9 @@ func TestClientReadRTCPReport(t *testing.T) { buf = make([]byte, 2048) n, _, err := l2.ReadFrom(buf) require.NoError(t, err) - pkt, err := rtcp.Unmarshal(buf[:n]) + packets, err := rtcp.Unmarshal(buf[:n]) require.NoError(t, err) - rr, ok := pkt[0].(*rtcp.ReceiverReport) + rr, ok := packets[0].(*rtcp.ReceiverReport) require.True(t, ok) require.Equal(t, &rtcp.ReceiverReport{ SSRC: rr.SSRC, @@ -2109,7 +2127,7 @@ func TestClientReadErrorTimeout(t *testing.T) { if transport == "udp" || transport == "auto" { // write a packet to skip the protocol autodetection feature - l1.WriteTo([]byte("\x01\x02\x03\x04"), &net.UDPAddr{ + l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[0], }) @@ -2251,14 +2269,14 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { base.InterleavedFrame{ Channel: 6, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, + Payload: testRTPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) base.InterleavedFrame{ Channel: 0, - Payload: []byte{0x05, 0x06, 0x07, 0x08}, + Payload: testRTPPacketMarshaled, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) @@ -2281,7 +2299,7 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { v := TransportTCP return &v }(), - OnPacketRTP: func(trackID int, payload []byte) { + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { close(recv) }, } diff --git a/clientudpl.go b/clientudpl.go index 11e7ff9c..448e9f57 100644 --- a/clientudpl.go +++ b/clientudpl.go @@ -7,6 +7,8 @@ import ( "sync/atomic" "time" + "github.com/pion/rtcp" + "github.com/pion/rtp" "golang.org/x/net/ipv4" "github.com/aler9/gortsplib/pkg/multibuffer" @@ -35,6 +37,7 @@ type clientUDPListener struct { readBuffer *multibuffer.MultiBuffer lastPacketTime *int64 processFunc func(time.Time, []byte) + rtpPkt rtp.Packet readerDone chan struct{} } @@ -136,7 +139,7 @@ func (l *clientUDPListener) start(forPlay bool) { l.processFunc = l.processPlayRTCP } } else { - l.processFunc = l.processRecord + l.processFunc = l.processRecordRTCP } l.running = true @@ -174,24 +177,43 @@ func (l *clientUDPListener) runReader() { } func (l *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { - l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTP(now, payload) - l.c.OnPacketRTP(l.trackID, payload) + err := l.rtpPkt.Unmarshal(payload) + if err != nil { + return + } + + l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTP(now, &l.rtpPkt) + l.c.OnPacketRTP(l.trackID, &l.rtpPkt) } func (l *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { - l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTCP(now, payload) - l.c.OnPacketRTCP(l.trackID, payload) + packets, err := rtcp.Unmarshal(payload) + if err != nil { + return + } + + for _, pkt := range packets { + l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTCP(now, pkt) + l.c.OnPacketRTCP(l.trackID, pkt) + } } -func (l *clientUDPListener) processRecord(now time.Time, payload []byte) { - l.c.OnPacketRTCP(l.trackID, payload) +func (l *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) { + packets, err := rtcp.Unmarshal(payload) + if err != nil { + return + } + + for _, pkt := range packets { + l.c.OnPacketRTCP(l.trackID, pkt) + } } -func (l *clientUDPListener) write(buf []byte) error { +func (l *clientUDPListener) write(payload []byte) error { // no mutex is needed here since Write() has an internal lock. // https://github.com/golang/go/issues/27203#issuecomment-534386117 l.pc.SetWriteDeadline(time.Now().Add(l.c.WriteTimeout)) - _, err := l.pc.WriteTo(buf, l.remoteWriteAddr) + _, err := l.pc.WriteTo(payload, l.remoteWriteAddr) return err } diff --git a/examples/client-publish-aac/main.go b/examples/client-publish-aac/main.go index 60713d02..8b3459fa 100644 --- a/examples/client-publish-aac/main.go +++ b/examples/client-publish-aac/main.go @@ -5,6 +5,7 @@ import ( "net" "github.com/aler9/gortsplib" + "github.com/pion/rtp" ) // This example shows how to @@ -50,6 +51,7 @@ func main() { defer c.Close() buf = make([]byte, 2048) + var pkt rtp.Packet for { // read packets from the source n, _, err := pc.ReadFrom(buf) @@ -57,8 +59,14 @@ func main() { panic(err) } + // marshal RTP packets + err = pkt.Unmarshal(buf[:n]) + if err != nil { + panic(err) + } + // route RTP packets to the server - err = c.WritePacketRTP(0, buf[:n]) + err = c.WritePacketRTP(0, &pkt) if err != nil { panic(err) } diff --git a/examples/client-publish-h264/main.go b/examples/client-publish-h264/main.go index 14d4fda6..4d4fb03a 100644 --- a/examples/client-publish-h264/main.go +++ b/examples/client-publish-h264/main.go @@ -6,6 +6,7 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/rtph264" + "github.com/pion/rtp" ) // This example shows how to @@ -51,6 +52,7 @@ func main() { defer c.Close() buf := make([]byte, 2048) + var pkt rtp.Packet for { // read packets from the source n, _, err := pc.ReadFrom(buf) @@ -58,8 +60,14 @@ func main() { panic(err) } + // marshal RTP packets + err = pkt.Unmarshal(buf[:n]) + if err != nil { + panic(err) + } + // route RTP packets to the server - err = c.WritePacketRTP(0, buf[:n]) + err = c.WritePacketRTP(0, &pkt) if err != nil { panic(err) } diff --git a/examples/client-publish-options/main.go b/examples/client-publish-options/main.go index 51ec66bf..7dcaaf15 100644 --- a/examples/client-publish-options/main.go +++ b/examples/client-publish-options/main.go @@ -7,6 +7,7 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/rtph264" + "github.com/pion/rtp" ) // This example shows how to @@ -60,6 +61,7 @@ func main() { defer c.Close() buf := make([]byte, 2048) + var pkt rtp.Packet for { // read packets from the source n, _, err := pc.ReadFrom(buf) @@ -67,8 +69,14 @@ func main() { panic(err) } + // marshal RTP packets + err = pkt.Unmarshal(buf[:n]) + if err != nil { + panic(err) + } + // route RTP packets to the server - err = c.WritePacketRTP(0, buf[:n]) + err = c.WritePacketRTP(0, &pkt) if err != nil { panic(err) } diff --git a/examples/client-publish-opus/main.go b/examples/client-publish-opus/main.go index bf7000a6..e3c81829 100644 --- a/examples/client-publish-opus/main.go +++ b/examples/client-publish-opus/main.go @@ -5,6 +5,7 @@ import ( "net" "github.com/aler9/gortsplib" + "github.com/pion/rtp" ) // This example shows how to @@ -50,6 +51,7 @@ func main() { defer c.Close() buf = make([]byte, 2048) + var pkt rtp.Packet for { // read packets from the source n, _, err := pc.ReadFrom(buf) @@ -57,8 +59,14 @@ func main() { panic(err) } + // marshal RTP packets + err = pkt.Unmarshal(buf[:n]) + if err != nil { + panic(err) + } + // route RTP packets to the server - err = c.WritePacketRTP(0, buf[:n]) + err = c.WritePacketRTP(0, &pkt) if err != nil { panic(err) } diff --git a/examples/client-publish-pause/main.go b/examples/client-publish-pause/main.go index dab273fd..33962c97 100644 --- a/examples/client-publish-pause/main.go +++ b/examples/client-publish-pause/main.go @@ -7,6 +7,7 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/rtph264" + "github.com/pion/rtp" ) // This example shows how to @@ -55,6 +56,7 @@ func main() { for { go func() { buf := make([]byte, 2048) + var pkt rtp.Packet for { // read packets from the source n, _, err := pc.ReadFrom(buf) @@ -62,8 +64,14 @@ func main() { break } + // marshal RTP packets + err = pkt.Unmarshal(buf[:n]) + if err != nil { + panic(err) + } + // route RTP packets to the server - err = c.WritePacketRTP(0, buf[:n]) + err = c.WritePacketRTP(0, &pkt) if err != nil { break } diff --git a/examples/client-read-aac/main.go b/examples/client-read-aac/main.go index 99d79079..73c114df 100644 --- a/examples/client-read-aac/main.go +++ b/examples/client-read-aac/main.go @@ -61,20 +61,13 @@ func main() { dec := rtpaac.NewDecoder(clockRate) // called when a RTP packet arrives - c.OnPacketRTP = func(trackID int, payload []byte) { + c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) { if trackID != aacTrack { return } - // parse RTP packet - var pkt rtp.Packet - err := pkt.Unmarshal(payload) - if err != nil { - return - } - // decode AAC AUs from the RTP packet - aus, _, err := dec.Decode(&pkt) + aus, _, err := dec.Decode(pkt) if err != nil { return } diff --git a/examples/client-read-h264-convert-to-jpeg/main.go b/examples/client-read-h264-convert-to-jpeg/main.go index 3036c747..e8d6468d 100644 --- a/examples/client-read-h264-convert-to-jpeg/main.go +++ b/examples/client-read-h264-convert-to-jpeg/main.go @@ -99,20 +99,13 @@ func main() { // called when a RTP packet arrives saveCount := 0 - c.OnPacketRTP = func(trackID int, payload []byte) { + c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) { if trackID != h264trID { return } - // parse RTP packet - var pkt rtp.Packet - err := pkt.Unmarshal(payload) - if err != nil { - return - } - // decode H264 NALUs from the RTP packet - nalus, _, err := rtpDec.Decode(&pkt) + nalus, _, err := rtpDec.Decode(pkt) if err != nil { return } diff --git a/examples/client-read-h264-decode/main.go b/examples/client-read-h264-decode/main.go index 5288d5a3..de64a681 100644 --- a/examples/client-read-h264-decode/main.go +++ b/examples/client-read-h264-decode/main.go @@ -75,20 +75,13 @@ func main() { h264dec.decode(h264tr.PPS()) // called when a RTP packet arrives - c.OnPacketRTP = func(trackID int, payload []byte) { + c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) { if trackID != h264trID { return } - // parse RTP packet - var pkt rtp.Packet - err := pkt.Unmarshal(payload) - if err != nil { - return - } - // decode H264 NALUs from the RTP packet - nalus, _, err := rtpDec.Decode(&pkt) + nalus, _, err := rtpDec.Decode(pkt) if err != nil { return } diff --git a/examples/client-read-h264-save-to-disk/main.go b/examples/client-read-h264-save-to-disk/main.go index e4e39c64..bb9443a4 100644 --- a/examples/client-read-h264-save-to-disk/main.go +++ b/examples/client-read-h264-save-to-disk/main.go @@ -67,20 +67,13 @@ func main() { } // called when a RTP packet arrives - c.OnPacketRTP = func(trackID int, payload []byte) { + c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) { if trackID != h264Track { return } - // parse RTP packet - var pkt rtp.Packet - err := pkt.Unmarshal(payload) - if err != nil { - return - } - // decode H264 NALUs from the RTP packet - nalus, pts, err := dec.DecodeUntilMarker(&pkt) + nalus, pts, err := dec.DecodeUntilMarker(pkt) if err != nil { return } diff --git a/examples/client-read-options/main.go b/examples/client-read-options/main.go index f8829183..dfb59dde 100644 --- a/examples/client-read-options/main.go +++ b/examples/client-read-options/main.go @@ -5,6 +5,8 @@ import ( "time" "github.com/aler9/gortsplib" + "github.com/pion/rtcp" + "github.com/pion/rtp" ) // This example shows how to @@ -21,12 +23,12 @@ func main() { // timeout of write operations WriteTimeout: 10 * time.Second, // called when a RTP packet arrives - OnPacketRTP: func(trackID int, payload []byte) { - log.Printf("RTP packet from track %d, size %d\n", trackID, len(payload)) + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { + log.Printf("RTP packet from track %d, payload type %d\n", trackID, pkt.Header.PayloadType) }, // called when a RTCP packet arrives - OnPacketRTCP: func(trackID int, payload []byte) { - log.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload)) + OnPacketRTCP: func(trackID int, pkt rtcp.Packet) { + log.Printf("RTCP packet from track %d, type %T\n", trackID, pkt) }, } diff --git a/examples/client-read-partial/main.go b/examples/client-read-partial/main.go index a67a39d7..f8d49a2d 100644 --- a/examples/client-read-partial/main.go +++ b/examples/client-read-partial/main.go @@ -5,6 +5,8 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/base" + "github.com/pion/rtcp" + "github.com/pion/rtp" ) // This example shows how to @@ -15,12 +17,12 @@ import ( func main() { c := gortsplib.Client{ // called when a RTP packet arrives - OnPacketRTP: func(trackID int, payload []byte) { - log.Printf("RTP packet from track %d, size %d\n", trackID, len(payload)) + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { + log.Printf("RTP packet from track %d, payload type %d\n", trackID, pkt.Header.PayloadType) }, // called when a RTCP packet arrives - OnPacketRTCP: func(trackID int, payload []byte) { - log.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload)) + OnPacketRTCP: func(trackID int, pkt rtcp.Packet) { + log.Printf("RTCP packet from track %d, type %T\n", trackID, pkt) }, } diff --git a/examples/client-read-pause/main.go b/examples/client-read-pause/main.go index ea3031fe..29df792c 100644 --- a/examples/client-read-pause/main.go +++ b/examples/client-read-pause/main.go @@ -5,6 +5,8 @@ import ( "time" "github.com/aler9/gortsplib" + "github.com/pion/rtcp" + "github.com/pion/rtp" ) // This example shows how to @@ -16,12 +18,12 @@ import ( func main() { c := gortsplib.Client{ // called when a RTP packet arrives - OnPacketRTP: func(trackID int, payload []byte) { - log.Printf("RTP packet from track %d, size %d\n", trackID, len(payload)) + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { + log.Printf("RTP packet from track %d, payload type %d\n", trackID, pkt.Header.PayloadType) }, // called when a RTCP packet arrives - OnPacketRTCP: func(trackID int, payload []byte) { - log.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload)) + OnPacketRTCP: func(trackID int, pkt rtcp.Packet) { + log.Printf("RTCP packet from track %d, type %T\n", trackID, pkt) }, } diff --git a/examples/client-read/main.go b/examples/client-read/main.go index bf35aedc..22f5fda7 100644 --- a/examples/client-read/main.go +++ b/examples/client-read/main.go @@ -4,6 +4,8 @@ import ( "log" "github.com/aler9/gortsplib" + "github.com/pion/rtcp" + "github.com/pion/rtp" ) // This example shows how to @@ -12,12 +14,12 @@ import ( func main() { c := gortsplib.Client{ // called when a RTP packet arrives - OnPacketRTP: func(trackID int, payload []byte) { - log.Printf("RTP packet from track %d, size %d\n", trackID, len(payload)) + OnPacketRTP: func(trackID int, pkt *rtp.Packet) { + log.Printf("RTP packet from track %d, payload type %d\n", trackID, pkt.Header.PayloadType) }, // called when a RTCP packet arrives - OnPacketRTCP: func(trackID int, payload []byte) { - log.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload)) + OnPacketRTCP: func(trackID int, pkt rtcp.Packet) { + log.Printf("RTCP packet from track %d, type %T\n", trackID, pkt) }, } diff --git a/go.mod b/go.mod index 01759481..543853cc 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,9 @@ go 1.15 require ( github.com/asticode/go-astits v1.10.0 github.com/icza/bitio v1.0.0 - github.com/pion/rtcp v1.2.4 - github.com/pion/rtp v1.6.1 + github.com/pion/rtcp v1.2.9 + github.com/pion/rtp v1.7.4 github.com/pion/sdp/v3 v3.0.2 - github.com/stretchr/testify v1.6.1 + github.com/stretchr/testify v1.7.0 golang.org/x/net v0.0.0-20210610132358-84b48f89b13b ) diff --git a/go.sum b/go.sum index cdfc0b4a..cba2811c 100644 --- a/go.sum +++ b/go.sum @@ -10,10 +10,10 @@ github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lTo github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= -github.com/pion/rtcp v1.2.4 h1:NT3H5LkUGgaEapvp0HGik+a+CpflRF7KTD7H+o7OWIM= -github.com/pion/rtcp v1.2.4/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE0= -github.com/pion/rtp v1.6.1 h1:2Y2elcVBrahYnHKN2X7rMHX/r1R4TEBMP1LaVu/wNhk= -github.com/pion/rtp v1.6.1/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= +github.com/pion/rtcp v1.2.9 h1:1ujStwg++IOLIEoOiIQ2s+qBuJ1VN81KW+9pMPsif+U= +github.com/pion/rtcp v1.2.9/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo= +github.com/pion/rtp v1.7.4 h1:4dMbjb1SuynU5OpA3kz1zHK+u+eOCQjW3MAeVHf1ODA= +github.com/pion/rtp v1.7.4/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= github.com/pion/sdp/v3 v3.0.2 h1:UNnSPVaMM+Pdu/mR9UvAyyo6zkdYbKeuOooCwZvTl/g= github.com/pion/sdp/v3 v3.0.2/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk= github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE= @@ -21,8 +21,9 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/net v0.0.0-20210610132358-84b48f89b13b h1:k+E048sYJHyVnsr1GDrRZWQ32D2C7lWs9JRc0bel53A= golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/pkg/rtcpreceiver/rtcpreceiver.go b/pkg/rtcpreceiver/rtcpreceiver.go index ae5af036..d43c666f 100644 --- a/pkg/rtcpreceiver/rtcpreceiver.go +++ b/pkg/rtcpreceiver/rtcpreceiver.go @@ -7,6 +7,7 @@ import ( "time" "github.com/pion/rtcp" + "github.com/pion/rtp" ) func randUint32() uint32 { @@ -53,7 +54,7 @@ func New(receiverSSRC *uint32, clockRate int) *RTCPReceiver { // Report generates a RTCP receiver report. // It returns nil if no RTCP sender reports have been passed to ProcessPacketRTCP yet. -func (rr *RTCPReceiver) Report(ts time.Time) []byte { +func (rr *RTCPReceiver) Report(ts time.Time) rtcp.Packet { rr.mutex.Lock() defer rr.mutex.Unlock() @@ -85,91 +86,73 @@ func (rr *RTCPReceiver) Report(ts time.Time) []byte { rr.totalLostSinceReport = 0 rr.totalSinceReport = 0 - byts, err := report.Marshal() - if err != nil { - panic(err) - } - - return byts + return report } // ProcessPacketRTP extracts the needed data from RTP packets. -func (rr *RTCPReceiver) ProcessPacketRTP(ts time.Time, payload []byte) { +func (rr *RTCPReceiver) ProcessPacketRTP(ts time.Time, pkt *rtp.Packet) { rr.mutex.Lock() defer rr.mutex.Unlock() - // do not parse the entire packet, extract only the fields we need - if len(payload) >= 8 { - sequenceNumber := uint16(payload[2])<<8 | uint16(payload[3]) - rtpTime := uint32(payload[4])<<24 | uint32(payload[5])<<16 | uint32(payload[6])<<8 | uint32(payload[7]) + // first packet + if !rr.firstRTPReceived { + rr.firstRTPReceived = true + rr.totalSinceReport = 1 + rr.lastSequenceNumber = pkt.Header.SequenceNumber + rr.lastRTPTimeRTP = pkt.Header.Timestamp + rr.lastRTPTimeTime = ts - // first packet - if !rr.firstRTPReceived { - rr.firstRTPReceived = true - rr.totalSinceReport = 1 - rr.lastSequenceNumber = sequenceNumber - rr.lastRTPTimeRTP = rtpTime - rr.lastRTPTimeTime = ts + // subsequent packets + } else { + diff := int32(pkt.Header.SequenceNumber) - int32(rr.lastSequenceNumber) - // subsequent packets - } else { - diff := int32(sequenceNumber) - int32(rr.lastSequenceNumber) - - // following packet or following packet after an overflow - if diff > 0 || diff < -0x0FFF { - // overflow - if diff < -0x0FFF { - rr.sequenceNumberCycles++ - } - - // detect lost packets - if sequenceNumber != (rr.lastSequenceNumber + 1) { - rr.totalLost += uint32(uint16(diff) - 1) - rr.totalLostSinceReport += uint32(uint16(diff) - 1) - - // allow up to 24 bits - if rr.totalLost > 0xFFFFFF { - rr.totalLost = 0xFFFFFF - } - if rr.totalLostSinceReport > 0xFFFFFF { - rr.totalLostSinceReport = 0xFFFFFF - } - } - - // compute jitter - // https://tools.ietf.org/html/rfc3550#page-39 - D := ts.Sub(rr.lastRTPTimeTime).Seconds()*rr.clockRate - - (float64(rtpTime) - float64(rr.lastRTPTimeRTP)) - if D < 0 { - D = -D - } - rr.jitter += (D - rr.jitter) / 16 - - rr.totalSinceReport += uint32(uint16(diff)) - rr.lastSequenceNumber = sequenceNumber - rr.lastRTPTimeRTP = rtpTime - rr.lastRTPTimeTime = ts + // following packet or following packet after an overflow + if diff > 0 || diff < -0x0FFF { + // overflow + if diff < -0x0FFF { + rr.sequenceNumberCycles++ } - // ignore invalid packets (diff = 0) or reordered packets (diff < 0) + + // detect lost packets + if pkt.Header.SequenceNumber != (rr.lastSequenceNumber + 1) { + rr.totalLost += uint32(uint16(diff) - 1) + rr.totalLostSinceReport += uint32(uint16(diff) - 1) + + // allow up to 24 bits + if rr.totalLost > 0xFFFFFF { + rr.totalLost = 0xFFFFFF + } + if rr.totalLostSinceReport > 0xFFFFFF { + rr.totalLostSinceReport = 0xFFFFFF + } + } + + // compute jitter + // https://tools.ietf.org/html/rfc3550#page-39 + D := ts.Sub(rr.lastRTPTimeTime).Seconds()*rr.clockRate - + (float64(pkt.Header.Timestamp) - float64(rr.lastRTPTimeRTP)) + if D < 0 { + D = -D + } + rr.jitter += (D - rr.jitter) / 16 + + rr.totalSinceReport += uint32(uint16(diff)) + rr.lastSequenceNumber = pkt.Header.SequenceNumber + rr.lastRTPTimeRTP = pkt.Header.Timestamp + rr.lastRTPTimeTime = ts } + // ignore invalid packets (diff = 0) or reordered packets (diff < 0) } } // ProcessPacketRTCP extracts the needed data from RTCP packets. -func (rr *RTCPReceiver) ProcessPacketRTCP(ts time.Time, payload []byte) { - rr.mutex.Lock() - defer rr.mutex.Unlock() +func (rr *RTCPReceiver) ProcessPacketRTCP(ts time.Time, pkt rtcp.Packet) { + if sr, ok := (pkt).(*rtcp.SenderReport); ok { + rr.mutex.Lock() + defer rr.mutex.Unlock() - // we can afford to unmarshal all RTCP packets - // since they are sent with a frequency much lower than the one of RTP packets - packets, err := rtcp.Unmarshal(payload) - if err == nil { - for _, packet := range packets { - if sr, ok := (packet).(*rtcp.SenderReport); ok { - rr.senderSSRC = sr.SSRC - rr.lastSenderReport = uint32(sr.NTPTime >> 16) - rr.lastSenderReportTime = ts - } - } + rr.senderSSRC = sr.SSRC + rr.lastSenderReport = uint32(sr.NTPTime >> 16) + rr.lastSenderReportTime = ts } } diff --git a/pkg/rtcpreceiver/rtcpreceiver_test.go b/pkg/rtcpreceiver/rtcpreceiver_test.go index 36d23b9b..0efc54ed 100644 --- a/pkg/rtcpreceiver/rtcpreceiver_test.go +++ b/pkg/rtcpreceiver/rtcpreceiver_test.go @@ -13,7 +13,7 @@ func TestRTCPReceiverBase(t *testing.T) { v := uint32(0x65f83afb) rr := New(&v, 90000) - require.Equal(t, []byte(nil), rr.Report(time.Now())) + require.Equal(t, nil, rr.Report(time.Now())) srPkt := rtcp.SenderReport{ SSRC: 0xba9da416, @@ -22,9 +22,8 @@ func TestRTCPReceiverBase(t *testing.T) { PacketCount: 714, OctetCount: 859127, } - byts, _ := srPkt.Marshal() ts := time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTCP(ts, byts) + rr.ProcessPacketRTCP(ts, &srPkt) rtpPkt := rtp.Packet{ Header: rtp.Header{ @@ -37,9 +36,8 @@ func TestRTCPReceiverBase(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -52,9 +50,8 @@ func TestRTCPReceiverBase(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) expectedPkt := rtcp.ReceiverReport{ SSRC: 0x65f83afb, @@ -67,9 +64,8 @@ func TestRTCPReceiverBase(t *testing.T) { }, }, } - expected, _ := expectedPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC) - require.Equal(t, expected, rr.Report(ts)) + require.Equal(t, &expectedPkt, rr.Report(ts)) } func TestRTCPReceiverOverflow(t *testing.T) { @@ -83,9 +79,8 @@ func TestRTCPReceiverOverflow(t *testing.T) { PacketCount: 714, OctetCount: 859127, } - byts, _ := srPkt.Marshal() ts := time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTCP(ts, byts) + rr.ProcessPacketRTCP(ts, &srPkt) rtpPkt := rtp.Packet{ Header: rtp.Header{ @@ -98,9 +93,8 @@ func TestRTCPReceiverOverflow(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -113,9 +107,8 @@ func TestRTCPReceiverOverflow(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) expectedPkt := rtcp.ReceiverReport{ SSRC: 0x65f83afb, @@ -128,9 +121,8 @@ func TestRTCPReceiverOverflow(t *testing.T) { }, }, } - expected, _ := expectedPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - require.Equal(t, expected, rr.Report(ts)) + require.Equal(t, &expectedPkt, rr.Report(ts)) } func TestRTCPReceiverPacketLost(t *testing.T) { @@ -144,9 +136,8 @@ func TestRTCPReceiverPacketLost(t *testing.T) { PacketCount: 714, OctetCount: 859127, } - byts, _ := srPkt.Marshal() ts := time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTCP(ts, byts) + rr.ProcessPacketRTCP(ts, &srPkt) rtpPkt := rtp.Packet{ Header: rtp.Header{ @@ -159,9 +150,8 @@ func TestRTCPReceiverPacketLost(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -174,9 +164,8 @@ func TestRTCPReceiverPacketLost(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) expectedPkt := rtcp.ReceiverReport{ SSRC: 0x65f83afb, @@ -194,9 +183,8 @@ func TestRTCPReceiverPacketLost(t *testing.T) { }, }, } - expected, _ := expectedPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - require.Equal(t, expected, rr.Report(ts)) + require.Equal(t, &expectedPkt, rr.Report(ts)) } func TestRTCPReceiverOverflowPacketLost(t *testing.T) { @@ -210,9 +198,8 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) { PacketCount: 714, OctetCount: 859127, } - byts, _ := srPkt.Marshal() ts := time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTCP(ts, byts) + rr.ProcessPacketRTCP(ts, &srPkt) rtpPkt := rtp.Packet{ Header: rtp.Header{ @@ -225,9 +212,8 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -240,9 +226,8 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) expectedPkt := rtcp.ReceiverReport{ SSRC: 0x65f83afb, @@ -260,9 +245,8 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) { }, }, } - expected, _ := expectedPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - require.Equal(t, expected, rr.Report(ts)) + require.Equal(t, &expectedPkt, rr.Report(ts)) } func TestRTCPReceiverReorderedPackets(t *testing.T) { @@ -276,9 +260,8 @@ func TestRTCPReceiverReorderedPackets(t *testing.T) { PacketCount: 714, OctetCount: 859127, } - byts, _ := srPkt.Marshal() ts := time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTCP(ts, byts) + rr.ProcessPacketRTCP(ts, &srPkt) rtpPkt := rtp.Packet{ Header: rtp.Header{ @@ -291,9 +274,8 @@ func TestRTCPReceiverReorderedPackets(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -306,9 +288,8 @@ func TestRTCPReceiverReorderedPackets(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) expectedPkt := rtcp.ReceiverReport{ SSRC: 0x65f83afb, @@ -321,9 +302,8 @@ func TestRTCPReceiverReorderedPackets(t *testing.T) { }, }, } - expected, _ := expectedPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - require.Equal(t, expected, rr.Report(ts)) + require.Equal(t, &expectedPkt, rr.Report(ts)) } func TestRTCPReceiverJitter(t *testing.T) { @@ -337,9 +317,8 @@ func TestRTCPReceiverJitter(t *testing.T) { PacketCount: 714, OctetCount: 859127, } - byts, _ := srPkt.Marshal() ts := time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTCP(ts, byts) + rr.ProcessPacketRTCP(ts, &srPkt) rtpPkt := rtp.Packet{ Header: rtp.Header{ @@ -352,9 +331,8 @@ func TestRTCPReceiverJitter(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -367,9 +345,8 @@ func TestRTCPReceiverJitter(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - rr.ProcessPacketRTP(ts, byts) + rr.ProcessPacketRTP(ts, &rtpPkt) expectedPkt := rtcp.ReceiverReport{ SSRC: 0x65f83afb, @@ -383,7 +360,6 @@ func TestRTCPReceiverJitter(t *testing.T) { }, }, } - expected, _ := expectedPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC) - require.Equal(t, expected, rr.Report(ts)) + require.Equal(t, &expectedPkt, rr.Report(ts)) } diff --git a/pkg/rtcpsender/rtcpsender.go b/pkg/rtcpsender/rtcpsender.go index 8d70dc6a..d1dc97b5 100644 --- a/pkg/rtcpsender/rtcpsender.go +++ b/pkg/rtcpsender/rtcpsender.go @@ -32,7 +32,7 @@ func New(clockRate int) *RTCPSender { // Report generates a RTCP sender report. // It returns nil if no packets has been passed to ProcessPacketRTP yet. -func (rs *RTCPSender) Report(ts time.Time) []byte { +func (rs *RTCPSender) Report(ts time.Time) rtcp.Packet { rs.mutex.Lock() defer rs.mutex.Unlock() @@ -40,7 +40,7 @@ func (rs *RTCPSender) Report(ts time.Time) []byte { return nil } - report := &rtcp.SenderReport{ + return &rtcp.SenderReport{ SSRC: rs.senderSSRC, NTPTime: func() uint64 { // seconds since 1st January 1900 @@ -55,33 +55,22 @@ func (rs *RTCPSender) Report(ts time.Time) []byte { PacketCount: rs.packetCount, OctetCount: rs.octetCount, } - - byts, err := report.Marshal() - if err != nil { - panic(err) - } - - return byts } // ProcessPacketRTP extracts the needed data from RTP packets. -func (rs *RTCPSender) ProcessPacketRTP(ts time.Time, payload []byte) { +func (rs *RTCPSender) ProcessPacketRTP(ts time.Time, pkt *rtp.Packet) { rs.mutex.Lock() defer rs.mutex.Unlock() - pkt := rtp.Packet{} - err := pkt.Unmarshal(payload) - if err == nil { - if !rs.firstRTPReceived { - rs.firstRTPReceived = true - rs.senderSSRC = pkt.SSRC - } - - // always update time to minimize errors - rs.lastRTPTimeRTP = pkt.Timestamp - rs.lastRTPTimeTime = ts - - rs.packetCount++ - rs.octetCount += uint32(len(pkt.Payload)) + if !rs.firstRTPReceived { + rs.firstRTPReceived = true + rs.senderSSRC = pkt.SSRC } + + // always update time to minimize errors + rs.lastRTPTimeRTP = pkt.Timestamp + rs.lastRTPTimeTime = ts + + rs.packetCount++ + rs.octetCount += uint32(len(pkt.Payload)) } diff --git a/pkg/rtcpsender/rtcpsender_test.go b/pkg/rtcpsender/rtcpsender_test.go index 9ac225c2..2394b5e0 100644 --- a/pkg/rtcpsender/rtcpsender_test.go +++ b/pkg/rtcpsender/rtcpsender_test.go @@ -12,7 +12,7 @@ import ( func TestRTCPSender(t *testing.T) { rs := New(90000) - require.Equal(t, []byte(nil), rs.Report(time.Now())) + require.Equal(t, nil, rs.Report(time.Now())) rtpPkt := rtp.Packet{ Header: rtp.Header{ @@ -25,9 +25,8 @@ func TestRTCPSender(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ := rtpPkt.Marshal() ts := time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rs.ProcessPacketRTP(ts, byts) + rs.ProcessPacketRTP(ts, &rtpPkt) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -40,9 +39,8 @@ func TestRTCPSender(t *testing.T) { }, Payload: []byte("\x00\x00"), } - byts, _ = rtpPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 15, 20, 500000000, time.UTC) - rs.ProcessPacketRTP(ts, byts) + rs.ProcessPacketRTP(ts, &rtpPkt) expectedPkt := rtcp.SenderReport{ SSRC: 0xba9da416, @@ -51,7 +49,6 @@ func TestRTCPSender(t *testing.T) { PacketCount: 2, OctetCount: 4, } - expected, _ := expectedPkt.Marshal() ts = time.Date(2008, 0o5, 20, 22, 16, 20, 600000000, time.UTC) - require.Equal(t, expected, rs.Report(ts)) + require.Equal(t, &expectedPkt, rs.Report(ts)) } diff --git a/server_read_test.go b/server_read_test.go index a5fb8283..597bc7e8 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -1648,7 +1648,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { buf, err := (&rtp.Packet{ Header: rtp.Header{ - Version: 0x80, + Version: 2, PayloadType: 96, SequenceNumber: 556, Timestamp: 984512368, @@ -1684,7 +1684,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { buf, err = (&rtp.Packet{ Header: rtp.Header{ - Version: 0x80, + Version: 2, PayloadType: 96, SequenceNumber: 87, Timestamp: 756436454, diff --git a/serversession.go b/serversession.go index efd99a50..e27aba78 100644 --- a/serversession.go +++ b/serversession.go @@ -355,7 +355,8 @@ func (ss *ServerSession) run() { for trackID, track := range ss.announcedTracks { rr := track.rtcpReceiver.Report(now) if rr != nil { - ss.WritePacketRTCP(trackID, rr) + byts, _ := rr.Marshal() + ss.WritePacketRTCP(trackID, byts) } } diff --git a/serverudpl.go b/serverudpl.go index 53a12d6b..7515f702 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -8,6 +8,8 @@ import ( "sync/atomic" "time" + "github.com/pion/rtcp" + "github.com/pion/rtp" "golang.org/x/net/ipv4" "github.com/aler9/gortsplib/pkg/multibuffer" @@ -194,9 +196,15 @@ func (u *serverUDPListener) runReader() { } func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { + var pkt rtp.Packet + err := pkt.Unmarshal(payload) + if err != nil { + return + } + now := time.Now() atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix()) - clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessPacketRTP(now, payload) + clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessPacketRTP(now, &pkt) if h, ok := u.s.Handler.(ServerHandlerOnPacketRTP); ok { h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ @@ -208,10 +216,17 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { } func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) { + packets, err := rtcp.Unmarshal(payload) + if err != nil { + return + } + if clientData.isPublishing { now := time.Now() atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix()) - clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessPacketRTCP(now, payload) + for _, pkt := range packets { + clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessPacketRTCP(now, pkt) + } } if h, ok := u.s.Handler.(ServerHandlerOnPacketRTCP); ok {