ConnClient: simplify UDP reads

This commit is contained in:
aler9
2020-09-23 21:37:26 +02:00
parent 927c088278
commit 46ea598b35
5 changed files with 45 additions and 44 deletions

View File

@@ -161,6 +161,14 @@ func (c *ConnClient) ReadFrame() (*InterleavedFrame, error) {
return frame, nil return frame, nil
} }
// ReadFrameUDP reads an UDP frame.
func (c *ConnClient) ReadFrameUDP(track *Track, streamType StreamType) ([]byte, error) {
if streamType == StreamTypeRtp {
return c.udpRtpListeners[track.Id].read()
}
return c.udpRtcpListeners[track.Id].read()
}
func (c *ConnClient) readFrameOrResponse() (interface{}, error) { func (c *ConnClient) readFrameOrResponse() (interface{}, error) {
c.nconn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) c.nconn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout))
b, err := c.br.ReadByte() b, err := c.br.ReadByte()
@@ -404,9 +412,9 @@ func (c *ConnClient) setup(u *url.URL, track *Track, ht *HeaderTransport) (*Resp
// a given track with the UDP transport. It then reads a Response. // a given track with the UDP transport. It then reads a Response.
// If rtpPort and rtcpPort are zero, they will be chosen automatically. // If rtpPort and rtcpPort are zero, they will be chosen automatically.
func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
rtcpPort int) (UDPReadFunc, UDPReadFunc, *Response, error) { rtcpPort int) (*Response, error) {
if c.playing { if c.playing {
return nil, nil, nil, fmt.Errorf("can't be called when playing") return nil, fmt.Errorf("can't be called when playing")
} }
if c.streamUrl != nil && *u != *c.streamUrl { if c.streamUrl != nil && *u != *c.streamUrl {
@@ -414,20 +422,20 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
} }
if c.streamProtocol != nil && *c.streamProtocol != StreamProtocolUDP { if c.streamProtocol != nil && *c.streamProtocol != StreamProtocolUDP {
return nil, nil, nil, fmt.Errorf("cannot setup tracks with different protocols") return nil, fmt.Errorf("cannot setup tracks with different protocols")
} }
if _, ok := c.rtcpReceivers[track.Id]; ok { if _, ok := c.rtcpReceivers[track.Id]; ok {
return nil, nil, nil, fmt.Errorf("track has already been setup") return nil, fmt.Errorf("track has already been setup")
} }
if (rtpPort == 0 && rtcpPort != 0) || if (rtpPort == 0 && rtcpPort != 0) ||
(rtpPort != 0 && rtcpPort == 0) { (rtpPort != 0 && rtcpPort == 0) {
return nil, nil, nil, fmt.Errorf("rtpPort and rtcpPort must be both zero or non-zero") return nil, fmt.Errorf("rtpPort and rtcpPort must be both zero or non-zero")
} }
if rtpPort != 0 && rtcpPort != (rtpPort+1) { if rtpPort != 0 && rtcpPort != (rtpPort+1) {
return nil, nil, nil, fmt.Errorf("rtcpPort must be rtpPort + 1") return nil, fmt.Errorf("rtcpPort must be rtpPort + 1")
} }
rtpListener, rtcpListener, err := func() (*connClientUDPListener, *connClientUDPListener, error) { rtpListener, rtcpListener, err := func() (*connClientUDPListener, *connClientUDPListener, error) {
@@ -468,7 +476,7 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
} }
}() }()
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
res, err := c.setup(u, track, &HeaderTransport{ res, err := c.setup(u, track, &HeaderTransport{
@@ -482,20 +490,20 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
if err != nil { if err != nil {
rtpListener.close() rtpListener.close()
rtcpListener.close() rtcpListener.close()
return nil, nil, nil, err return nil, err
} }
th, err := ReadHeaderTransport(res.Header["Transport"]) th, err := ReadHeaderTransport(res.Header["Transport"])
if err != nil { if err != nil {
rtpListener.close() rtpListener.close()
rtcpListener.close() rtcpListener.close()
return nil, nil, nil, fmt.Errorf("SETUP: transport header: %s", err) return nil, fmt.Errorf("SETUP: transport header: %s", err)
} }
if th.ServerPorts == nil { if th.ServerPorts == nil {
rtpListener.close() rtpListener.close()
rtcpListener.close() rtcpListener.close()
return nil, nil, nil, fmt.Errorf("SETUP: server ports not provided") return nil, fmt.Errorf("SETUP: server ports not provided")
} }
c.streamUrl = u c.streamUrl = u
@@ -514,7 +522,7 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
rtcpListener.publisherPort = (*th.ServerPorts)[1] rtcpListener.publisherPort = (*th.ServerPorts)[1]
c.udpRtcpListeners[track.Id] = rtcpListener c.udpRtcpListeners[track.Id] = rtcpListener
return rtpListener.read, rtcpListener.read, res, nil return res, nil
} }
// SetupTCP writes a SETUP request, that means that we want to read // SetupTCP writes a SETUP request, that means that we want to read

View File

@@ -136,15 +136,9 @@ func TestConnClientReadUDP(t *testing.T) {
tracks, _, err := conn.Describe(u) tracks, _, err := conn.Describe(u)
require.NoError(t, err) require.NoError(t, err)
var rtpReads []UDPReadFunc
var rtcpReads []UDPReadFunc
for _, track := range tracks { for _, track := range tracks {
rtpRead, rtcpRead, _, err := conn.SetupUDP(u, track, 0, 0) _, err := conn.SetupUDP(u, track, 0, 0)
require.NoError(t, err) require.NoError(t, err)
rtpReads = append(rtpReads, rtpRead)
rtcpReads = append(rtcpReads, rtcpRead)
} }
_, err = conn.Play(u) _, err = conn.Play(u)
@@ -152,6 +146,6 @@ func TestConnClientReadUDP(t *testing.T) {
go conn.LoopUDP(u) go conn.LoopUDP(u)
_, err = rtpReads[0]() _, err = conn.ReadFrameUDP(tracks[0], StreamTypeRtp)
require.NoError(t, err) require.NoError(t, err)
} }

View File

@@ -7,9 +7,6 @@ import (
"time" "time"
) )
// UDPReadFunc is a function used to read UDP packets.
type UDPReadFunc func() ([]byte, error)
type connClientUDPListener struct { type connClientUDPListener struct {
c *ConnClient c *ConnClient
pc net.PacketConn pc net.PacketConn

View File

@@ -32,17 +32,11 @@ func main() {
panic(err) panic(err)
} }
var rtpReads []gortsplib.UDPReadFunc
var rtcpReads []gortsplib.UDPReadFunc
for _, track := range tracks { for _, track := range tracks {
rtpRead, rtcpRead, _, err := conn.SetupUDP(u, track, 0, 0) _, err := conn.SetupUDP(u, track, 0, 0)
if err != nil { if err != nil {
panic(err) panic(err)
} }
rtpReads = append(rtpReads, rtpRead)
rtcpReads = append(rtcpReads, rtcpRead)
} }
_, err = conn.Play(u) _, err = conn.Play(u)
@@ -53,39 +47,39 @@ func main() {
var wg sync.WaitGroup var wg sync.WaitGroup
// read RTP frames // read RTP frames
for trackId, rtpRead := range rtpReads { for _, track := range tracks {
wg.Add(1) wg.Add(1)
go func(trackId int, rtpRead gortsplib.UDPReadFunc) { go func(track *gortsplib.Track) {
defer wg.Done() defer wg.Done()
for { for {
buf, err := rtpRead() buf, err := conn.ReadFrameUDP(track, gortsplib.StreamTypeRtp)
if err != nil { if err != nil {
break break
} }
fmt.Printf("frame from track %d, type RTP: %v\n", trackId, buf) fmt.Printf("frame from track %d, type RTP: %v\n", track.Id, buf)
} }
}(trackId, rtpRead) }(track)
} }
// read RTCP frames // read RTCP frames
for trackId, rtcpRead := range rtcpReads { for _, track := range tracks {
wg.Add(1) wg.Add(1)
go func(trackId int, rtcpRead gortsplib.UDPReadFunc) { go func(track *gortsplib.Track) {
defer wg.Done() defer wg.Done()
for { for {
buf, err := rtcpRead() buf, err := conn.ReadFrameUDP(track, gortsplib.StreamTypeRtcp)
if err != nil { if err != nil {
break break
} }
fmt.Printf("frame from track %d, type RTCP: %v\n", trackId, buf) fmt.Printf("frame from track %d, type RTCP: %v\n", track.Id, buf)
} }
}(trackId, rtcpRead) }(track)
} }
err = conn.LoopUDP(u) err = conn.LoopUDP(u)

View File

@@ -24,10 +24,14 @@ const (
// String implements fmt.Stringer // String implements fmt.Stringer
func (sp StreamProtocol) String() string { func (sp StreamProtocol) String() string {
if sp == StreamProtocolUDP { switch sp {
case StreamProtocolUDP:
return "udp" return "udp"
case StreamProtocolTCP:
return "tcp"
} }
return "tcp" return "unknown"
} }
// StreamCast is the cast of a stream. // StreamCast is the cast of a stream.
@@ -43,10 +47,14 @@ const (
// String implements fmt.Stringer // String implements fmt.Stringer
func (sc StreamCast) String() string { func (sc StreamCast) String() string {
if sc == StreamUnicast { switch sc {
case StreamUnicast:
return "unicast" return "unicast"
case StreamMulticast:
return "multicast"
} }
return "multicast" return "unknown"
} }
// StreamType is the type of a stream. // StreamType is the type of a stream.
@@ -69,7 +77,7 @@ func (st StreamType) String() string {
case StreamTypeRtcp: case StreamTypeRtcp:
return "RTCP" return "RTCP"
} }
return "UNKNOWN" return "unknown"
} }
func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) { func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) {