diff --git a/connclient.go b/connclient.go index d7a84b5d..45df2701 100644 --- a/connclient.go +++ b/connclient.go @@ -96,6 +96,7 @@ type ConnClient struct { response *base.Response frame *base.InterleavedFrame tcpFrameBuffer *multibuffer.MultiBuffer + writeFrameFunc func(trackId int, streamType StreamType, content []byte) error reportWriterTerminate chan struct{} reportWriterDone chan struct{} @@ -240,17 +241,14 @@ func (c *ConnClient) ReadFrameTCP() (int, StreamType, []byte, error) { return c.frame.TrackId, c.frame.StreamType, c.frame.Content, nil } -// WriteFrameUDP writes an UDP frame. -func (c *ConnClient) WriteFrameUDP(trackId int, streamType StreamType, content []byte) error { +func (c *ConnClient) writeFrameUDP(trackId int, streamType StreamType, content []byte) error { if streamType == StreamTypeRtp { return c.udpRtpListeners[trackId].write(content) } return c.udpRtcpListeners[trackId].write(content) } -// WriteFrameTCP writes an interleaved frame. -// this can't be used when reading. -func (c *ConnClient) WriteFrameTCP(trackId int, streamType StreamType, content []byte) error { +func (c *ConnClient) writeFrameTCP(trackId int, streamType StreamType, content []byte) error { frame := base.InterleavedFrame{ TrackId: trackId, StreamType: streamType, @@ -261,6 +259,12 @@ func (c *ConnClient) WriteFrameTCP(trackId int, streamType StreamType, content [ return frame.Write(c.bw) } +// WriteFrame writes a frame. +// This can be used only after Record(). +func (c *ConnClient) WriteFrame(trackId int, streamType StreamType, content []byte) error { + return c.writeFrameFunc(trackId, streamType, content) +} + // Do writes a Request and reads a Response. // Interleaved frames sent before the response are ignored. func (c *ConnClient) Do(req *base.Request) (*base.Response, error) { @@ -418,7 +422,7 @@ func (c *ConnClient) Describe(u *base.URL) (Tracks, *base.Response, error) { } } -// build an URL by merging baseUrl with the control attribute from track.Media +// build an URL by merging baseUrl with the control attribute from track.Media. func (c *ConnClient) urlForTrack(baseUrl *base.URL, mode headers.TransportMode, track *Track) *base.URL { control := func() string { // if we're reading, get control from track ID @@ -696,13 +700,7 @@ func (c *ConnClient) Play() (*base.Response, error) { case <-reportWriterTicker.C: for trackId := range c.rtcpReceivers { frame := c.rtcpReceivers[trackId].Report() - - if *c.streamProtocol == StreamProtocolUDP { - c.udpRtcpListeners[trackId].write(frame) - - } else { - c.WriteFrameTCP(trackId, StreamTypeRtcp, frame) - } + c.WriteFrame(trackId, StreamTypeRtcp, frame) } } } @@ -766,6 +764,12 @@ func (c *ConnClient) Record() (*base.Response, error) { c.state = connClientStateRecord + if *c.streamProtocol == StreamProtocolUDP { + c.writeFrameFunc = c.writeFrameUDP + } else { + c.writeFrameFunc = c.writeFrameTCP + } + return nil, nil } diff --git a/dialer_test.go b/dialer_test.go index eb82b2e9..83729581 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -249,7 +249,7 @@ func TestDialPublishUDP(t *testing.T) { break } - err = conn.WriteFrameUDP(track.Id, StreamTypeRtp, buf[:n]) + err = conn.WriteFrame(track.Id, StreamTypeRtp, buf[:n]) if err != nil { break } @@ -349,7 +349,7 @@ func TestDialPublishTCP(t *testing.T) { break } - err = conn.WriteFrameTCP(track.Id, StreamTypeRtp, buf[:n]) + err = conn.WriteFrame(track.Id, StreamTypeRtp, buf[:n]) if err != nil { break } diff --git a/examples/client-publish-tcp.go b/examples/client-publish-tcp.go index eda640bd..02245cd1 100644 --- a/examples/client-publish-tcp.go +++ b/examples/client-publish-tcp.go @@ -57,7 +57,7 @@ func main() { } // write frames to the server - err = conn.WriteFrameTCP(track.Id, gortsplib.StreamTypeRtp, buf[:n]) + err = conn.WriteFrame(track.Id, gortsplib.StreamTypeRtp, buf[:n]) if err != nil { fmt.Println("connection is closed (%s)", err) break diff --git a/examples/client-publish-udp.go b/examples/client-publish-udp.go index 335aaebb..2135cee7 100644 --- a/examples/client-publish-udp.go +++ b/examples/client-publish-udp.go @@ -69,7 +69,7 @@ func main() { } // write frames to the server - err = conn.WriteFrameUDP(track.Id, gortsplib.StreamTypeRtp, buf[:n]) + err = conn.WriteFrame(track.Id, gortsplib.StreamTypeRtp, buf[:n]) if err != nil { break }