diff --git a/connclient.go b/connclient.go index e976226d..930d4da4 100644 --- a/connclient.go +++ b/connclient.go @@ -161,6 +161,14 @@ func (c *ConnClient) ReadFrame() (*InterleavedFrame, error) { return frame, nil } +// ReadFrameUDP reads an UDP frame. +func (c *ConnClient) ReadFrameUDP(track *Track, streamType StreamType) ([]byte, error) { + if streamType == StreamTypeRtp { + return c.udpRtpListeners[track.Id].read() + } + return c.udpRtcpListeners[track.Id].read() +} + func (c *ConnClient) readFrameOrResponse() (interface{}, error) { c.nconn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) b, err := c.br.ReadByte() @@ -404,9 +412,9 @@ func (c *ConnClient) setup(u *url.URL, track *Track, ht *HeaderTransport) (*Resp // a given track with the UDP transport. It then reads a Response. // If rtpPort and rtcpPort are zero, they will be chosen automatically. func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, - rtcpPort int) (UDPReadFunc, UDPReadFunc, *Response, error) { + rtcpPort int) (*Response, error) { if c.playing { - return nil, nil, nil, fmt.Errorf("can't be called when playing") + return nil, fmt.Errorf("can't be called when playing") } if c.streamUrl != nil && *u != *c.streamUrl { @@ -414,20 +422,20 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, } if c.streamProtocol != nil && *c.streamProtocol != StreamProtocolUDP { - return nil, nil, nil, fmt.Errorf("cannot setup tracks with different protocols") + return nil, fmt.Errorf("cannot setup tracks with different protocols") } if _, ok := c.rtcpReceivers[track.Id]; ok { - return nil, nil, nil, fmt.Errorf("track has already been setup") + return nil, fmt.Errorf("track has already been setup") } if (rtpPort == 0 && rtcpPort != 0) || (rtpPort != 0 && rtcpPort == 0) { - return nil, nil, nil, fmt.Errorf("rtpPort and rtcpPort must be both zero or non-zero") + return nil, fmt.Errorf("rtpPort and rtcpPort must be both zero or non-zero") } if rtpPort != 0 && rtcpPort != (rtpPort+1) { - return nil, nil, nil, fmt.Errorf("rtcpPort must be rtpPort + 1") + return nil, fmt.Errorf("rtcpPort must be rtpPort + 1") } rtpListener, rtcpListener, err := func() (*connClientUDPListener, *connClientUDPListener, error) { @@ -468,7 +476,7 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, } }() if err != nil { - return nil, nil, nil, err + return nil, err } res, err := c.setup(u, track, &HeaderTransport{ @@ -482,20 +490,20 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, if err != nil { rtpListener.close() rtcpListener.close() - return nil, nil, nil, err + return nil, err } th, err := ReadHeaderTransport(res.Header["Transport"]) if err != nil { rtpListener.close() rtcpListener.close() - return nil, nil, nil, fmt.Errorf("SETUP: transport header: %s", err) + return nil, fmt.Errorf("SETUP: transport header: %s", err) } if th.ServerPorts == nil { rtpListener.close() rtcpListener.close() - return nil, nil, nil, fmt.Errorf("SETUP: server ports not provided") + return nil, fmt.Errorf("SETUP: server ports not provided") } c.streamUrl = u @@ -514,7 +522,7 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, rtcpListener.publisherPort = (*th.ServerPorts)[1] c.udpRtcpListeners[track.Id] = rtcpListener - return rtpListener.read, rtcpListener.read, res, nil + return res, nil } // SetupTCP writes a SETUP request, that means that we want to read diff --git a/connclient_test.go b/connclient_test.go index 62e3c443..6db5b053 100644 --- a/connclient_test.go +++ b/connclient_test.go @@ -136,15 +136,9 @@ func TestConnClientReadUDP(t *testing.T) { tracks, _, err := conn.Describe(u) require.NoError(t, err) - var rtpReads []UDPReadFunc - var rtcpReads []UDPReadFunc - for _, track := range tracks { - rtpRead, rtcpRead, _, err := conn.SetupUDP(u, track, 0, 0) + _, err := conn.SetupUDP(u, track, 0, 0) require.NoError(t, err) - - rtpReads = append(rtpReads, rtpRead) - rtcpReads = append(rtcpReads, rtcpRead) } _, err = conn.Play(u) @@ -152,6 +146,6 @@ func TestConnClientReadUDP(t *testing.T) { go conn.LoopUDP(u) - _, err = rtpReads[0]() + _, err = conn.ReadFrameUDP(tracks[0], StreamTypeRtp) require.NoError(t, err) } diff --git a/connclientudpl.go b/connclientudpl.go index 39e325af..6832f08a 100644 --- a/connclientudpl.go +++ b/connclientudpl.go @@ -7,9 +7,6 @@ import ( "time" ) -// UDPReadFunc is a function used to read UDP packets. -type UDPReadFunc func() ([]byte, error) - type connClientUDPListener struct { c *ConnClient pc net.PacketConn diff --git a/examples/read-udp.go b/examples/read-udp.go index cf66b753..506854a8 100644 --- a/examples/read-udp.go +++ b/examples/read-udp.go @@ -32,17 +32,11 @@ func main() { panic(err) } - var rtpReads []gortsplib.UDPReadFunc - var rtcpReads []gortsplib.UDPReadFunc - for _, track := range tracks { - rtpRead, rtcpRead, _, err := conn.SetupUDP(u, track, 0, 0) + _, err := conn.SetupUDP(u, track, 0, 0) if err != nil { panic(err) } - - rtpReads = append(rtpReads, rtpRead) - rtcpReads = append(rtcpReads, rtcpRead) } _, err = conn.Play(u) @@ -53,39 +47,39 @@ func main() { var wg sync.WaitGroup // read RTP frames - for trackId, rtpRead := range rtpReads { + for _, track := range tracks { wg.Add(1) - go func(trackId int, rtpRead gortsplib.UDPReadFunc) { + go func(track *gortsplib.Track) { defer wg.Done() for { - buf, err := rtpRead() + buf, err := conn.ReadFrameUDP(track, gortsplib.StreamTypeRtp) if err != nil { break } - fmt.Printf("frame from track %d, type RTP: %v\n", trackId, buf) + fmt.Printf("frame from track %d, type RTP: %v\n", track.Id, buf) } - }(trackId, rtpRead) + }(track) } // read RTCP frames - for trackId, rtcpRead := range rtcpReads { + for _, track := range tracks { wg.Add(1) - go func(trackId int, rtcpRead gortsplib.UDPReadFunc) { + go func(track *gortsplib.Track) { defer wg.Done() for { - buf, err := rtcpRead() + buf, err := conn.ReadFrameUDP(track, gortsplib.StreamTypeRtcp) if err != nil { break } - fmt.Printf("frame from track %d, type RTCP: %v\n", trackId, buf) + fmt.Printf("frame from track %d, type RTCP: %v\n", track.Id, buf) } - }(trackId, rtcpRead) + }(track) } err = conn.LoopUDP(u) diff --git a/utils.go b/utils.go index 58a2d5b0..840b0487 100644 --- a/utils.go +++ b/utils.go @@ -24,10 +24,14 @@ const ( // String implements fmt.Stringer func (sp StreamProtocol) String() string { - if sp == StreamProtocolUDP { + switch sp { + case StreamProtocolUDP: return "udp" + + case StreamProtocolTCP: + return "tcp" } - return "tcp" + return "unknown" } // StreamCast is the cast of a stream. @@ -43,10 +47,14 @@ const ( // String implements fmt.Stringer func (sc StreamCast) String() string { - if sc == StreamUnicast { + switch sc { + case StreamUnicast: return "unicast" + + case StreamMulticast: + return "multicast" } - return "multicast" + return "unknown" } // StreamType is the type of a stream. @@ -69,7 +77,7 @@ func (st StreamType) String() string { case StreamTypeRtcp: return "RTCP" } - return "UNKNOWN" + return "unknown" } func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) {