diff --git a/connclient.go b/connclient.go index ebf4e1e2..b90457c0 100644 --- a/connclient.go +++ b/connclient.go @@ -45,8 +45,13 @@ const ( // ConnClientConf allows to configure a ConnClient. type ConnClientConf struct { // target address in format hostname:port + // either Host or Conn must be non-null Host string + // pre-existing TCP connection to wrap + // either Host or Conn must be non-null + Conn net.Conn + // (optional) timeout of read operations. // It defaults to 10 seconds ReadTimeout time.Duration @@ -73,7 +78,6 @@ type ConnClientConf struct { // ConnClient is a client-side RTSP connection. type ConnClient struct { conf ConnClientConf - nconn net.Conn br *bufio.Reader bw *bufio.Writer session string @@ -110,16 +114,22 @@ func NewConnClient(conf ConnClientConf) (*ConnClient, error) { conf.ListenPacket = net.ListenPacket } - nconn, err := conf.DialTimeout("tcp", conf.Host, conf.ReadTimeout) - if err != nil { - return nil, err + if conf.Host != "" && conf.Conn != nil { + return nil, fmt.Errorf("Host and Conn can't be used together") + } + + if conf.Conn == nil { + var err error + conf.Conn, err = conf.DialTimeout("tcp", conf.Host, conf.ReadTimeout) + if err != nil { + return nil, err + } } return &ConnClient{ conf: conf, - nconn: nconn, - br: bufio.NewReaderSize(nconn, clientReadBufferSize), - bw: bufio.NewWriterSize(nconn, clientWriteBufferSize), + br: bufio.NewReaderSize(conf.Conn, clientReadBufferSize), + bw: bufio.NewWriterSize(conf.Conn, clientWriteBufferSize), rtcpReceivers: make(map[int]*rtcpreceiver.RtcpReceiver), udpLastFrameTimes: make(map[int]*int64), udpRtpListeners: make(map[int]*connClientUDPListener), @@ -138,7 +148,7 @@ func (c *ConnClient) Close() error { }) } - err := c.nconn.Close() + err := c.conf.Conn.Close() if c.receiverReportTerminate != nil { close(c.receiverReportTerminate) @@ -169,13 +179,13 @@ func (c *ConnClient) CloseUDPListeners() { // NetConn returns the underlying net.Conn. func (c *ConnClient) NetConn() net.Conn { - return c.nconn + return c.conf.Conn } func (c *ConnClient) readFrameTCPOrResponse() (interface{}, error) { frame := c.tcpFrames.next() - c.nconn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) + c.conf.Conn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) return base.ReadInterleavedFrameOrResponse(frame, c.br) } @@ -184,7 +194,7 @@ func (c *ConnClient) readFrameTCPOrResponse() (interface{}, error) { func (c *ConnClient) ReadFrameTCP() (int, StreamType, []byte, error) { frame := c.tcpFrames.next() - c.nconn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) + c.conf.Conn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) err := frame.Read(c.br) if err != nil { return 0, 0, nil, err @@ -227,7 +237,7 @@ func (c *ConnClient) WriteFrameTCP(trackId int, streamType StreamType, content [ Content: content, } - c.nconn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout)) + c.conf.Conn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout)) return frame.Write(c.bw) } @@ -267,7 +277,7 @@ func (c *ConnClient) Do(req *base.Request) (*base.Response, error) { c.cseq += 1 req.Header["CSeq"] = base.HeaderValue{strconv.FormatInt(int64(c.cseq), 10)} - c.nconn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout)) + c.conf.Conn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout)) err := req.Write(c.bw) if err != nil { return nil, err @@ -574,13 +584,13 @@ func (c *ConnClient) SetupUDP(u *url.URL, mode TransportMode, track *Track, rtpP c.udpLastFrameTimes[track.Id] = &v } - rtpListener.remoteIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP - rtpListener.remoteZone = c.nconn.RemoteAddr().(*net.TCPAddr).Zone + rtpListener.remoteIp = c.conf.Conn.RemoteAddr().(*net.TCPAddr).IP + rtpListener.remoteZone = c.conf.Conn.RemoteAddr().(*net.TCPAddr).Zone rtpListener.remotePort = (*th.ServerPorts)[0] c.udpRtpListeners[track.Id] = rtpListener - rtcpListener.remoteIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP - rtcpListener.remoteZone = c.nconn.RemoteAddr().(*net.TCPAddr).Zone + rtcpListener.remoteIp = c.conf.Conn.RemoteAddr().(*net.TCPAddr).IP + rtcpListener.remoteZone = c.conf.Conn.RemoteAddr().(*net.TCPAddr).Zone rtcpListener.remotePort = (*th.ServerPorts)[1] c.udpRtcpListeners[track.Id] = rtcpListener @@ -724,7 +734,7 @@ func (c *ConnClient) LoopUDP() error { readDone := make(chan error) go func() { for { - c.nconn.SetReadDeadline(time.Now().Add(clientUDPKeepalivePeriod + c.conf.ReadTimeout)) + c.conf.Conn.SetReadDeadline(time.Now().Add(clientUDPKeepalivePeriod + c.conf.ReadTimeout)) _, err := base.ReadResponse(c.br) if err != nil { readDone <- err @@ -742,7 +752,7 @@ func (c *ConnClient) LoopUDP() error { for { select { case err := <-readDone: - c.nconn.Close() + c.conf.Conn.Close() return err case <-keepaliveTicker.C: @@ -757,7 +767,7 @@ func (c *ConnClient) LoopUDP() error { SkipResponse: true, }) if err != nil { - c.nconn.Close() + c.conf.Conn.Close() <-readDone return err } @@ -769,7 +779,7 @@ func (c *ConnClient) LoopUDP() error { last := time.Unix(atomic.LoadInt64(lastUnix), 0) if now.Sub(last) >= c.conf.ReadTimeout { - c.nconn.Close() + c.conf.Conn.Close() <-readDone return fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") }