From 07aefbcd5d1146b706993c2ace0a7ef508ca2f95 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Tue, 8 Dec 2020 11:54:38 +0100 Subject: [PATCH] add ClientConf.OnRequest, ClientConf.OnResponse --- clientconf.go | 8 +++++++- clientconn.go | 24 ++++++++++++++++-------- clientconnpublish.go | 4 ++-- clientconnread.go | 6 +++--- clientconnudpl.go | 6 +++--- server.go | 18 +++++++++--------- serverconf.go | 2 +- serverconn.go | 8 ++++---- 8 files changed, 45 insertions(+), 31 deletions(-) diff --git a/clientconf.go b/clientconf.go index a2f01c7a..844e7959 100644 --- a/clientconf.go +++ b/clientconf.go @@ -58,6 +58,12 @@ type ClientConf struct { // It defaults to 1. ReadBufferCount int + // callback called before every request. + OnRequest func(req *base.Request) + + // callback called after very response. + OnResponse func(res *base.Response) + // function used to initialize the TCP client. // It defaults to net.DialTimeout. DialTimeout func(network, address string, timeout time.Duration) (net.Conn, error) @@ -95,7 +101,7 @@ func (c ClientConf) Dial(host string) (*ClientConn, error) { } return &ClientConn{ - c: c, + conf: c, nconn: nconn, br: bufio.NewReaderSize(nconn, clientReadBufferSize), bw: bufio.NewWriterSize(nconn, clientWriteBufferSize), diff --git a/clientconn.go b/clientconn.go index 8815e979..13a11c73 100644 --- a/clientconn.go +++ b/clientconn.go @@ -63,7 +63,7 @@ func (s clientConnState) String() string { // ClientConn is a client-side RTSP connection. type ClientConn struct { - c ClientConf + conf ClientConf nconn net.Conn br *bufio.Reader bw *bufio.Writer @@ -146,7 +146,7 @@ func (c *ClientConn) Tracks() Tracks { } func (c *ClientConn) readFrameTCPOrResponse() (interface{}, error) { - c.nconn.SetReadDeadline(time.Now().Add(c.c.ReadTimeout)) + c.nconn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) f := base.InterleavedFrame{ Content: c.tcpFrameBuffer.Next(), } @@ -175,7 +175,11 @@ func (c *ClientConn) Do(req *base.Request) (*base.Response, error) { c.cseq++ req.Header["CSeq"] = base.HeaderValue{strconv.FormatInt(int64(c.cseq), 10)} - c.nconn.SetWriteDeadline(time.Now().Add(c.c.WriteTimeout)) + if c.conf.OnRequest != nil { + c.conf.OnRequest(req) + } + + c.nconn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout)) err := req.Write(c.bw) if err != nil { return nil, err @@ -205,6 +209,10 @@ func (c *ClientConn) Do(req *base.Request) (*base.Response, error) { return nil, err } + if c.conf.OnResponse != nil { + c.conf.OnResponse(res) + } + // get session from response if v, ok := res.Header["Session"]; ok { sx, err := headers.ReadSession(v) @@ -298,7 +306,7 @@ func (c *ClientConn) Describe(u *base.URL) (Tracks, *base.Response, error) { if res.StatusCode != base.StatusOK { // redirect - if !c.c.RedirectDisable && + if !c.conf.RedirectDisable && res.StatusCode >= base.StatusMovedPermanently && res.StatusCode <= base.StatusUseProxy && len(res.Header["Location"]) == 1 { @@ -310,7 +318,7 @@ func (c *ClientConn) Describe(u *base.URL) (Tracks, *base.Response, error) { return nil, nil, err } - nc, err := c.c.Dial(u.Host) + nc, err := c.conf.Dial(u.Host) if err != nil { return nil, nil, err } @@ -385,8 +393,8 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track, } // protocol set by conf - if c.c.StreamProtocol != nil { - return *c.c.StreamProtocol + if c.conf.StreamProtocol != nil { + return *c.conf.StreamProtocol } // try udp @@ -493,7 +501,7 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track, // switch protocol automatically if res.StatusCode == base.StatusUnsupportedTransport && c.streamProtocol == nil && - c.c.StreamProtocol == nil { + c.conf.StreamProtocol == nil { v := StreamProtocolTCP c.streamProtocol = &v diff --git a/clientconnpublish.go b/clientconnpublish.go index c74cbfc7..a6e0f3e6 100644 --- a/clientconnpublish.go +++ b/clientconnpublish.go @@ -164,7 +164,7 @@ func (c *ClientConn) backgroundRecordTCP() { for trackID := range c.rtcpSenders { r := c.rtcpSenders[trackID].Report(now) if r != nil { - c.nconn.SetWriteDeadline(time.Now().Add(c.c.WriteTimeout)) + c.nconn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout)) frame := base.InterleavedFrame{ TrackID: trackID, StreamType: StreamTypeRtcp, @@ -199,7 +199,7 @@ func (c *ClientConn) WriteFrame(trackID int, streamType StreamType, content []by return c.udpRtcpListeners[trackID].write(content) } - c.nconn.SetWriteDeadline(now.Add(c.c.WriteTimeout)) + c.nconn.SetWriteDeadline(now.Add(c.conf.WriteTimeout)) frame := base.InterleavedFrame{ TrackID: trackID, StreamType: streamType, diff --git a/clientconnread.go b/clientconnread.go index 3eb44b7e..9ce369e7 100644 --- a/clientconnread.go +++ b/clientconnread.go @@ -126,7 +126,7 @@ func (c *ClientConn) backgroundPlayUDP(onFrameDone chan error) { for _, lastUnix := range c.udpLastFrameTimes { last := time.Unix(atomic.LoadInt64(lastUnix), 0) - if now.Sub(last) >= c.c.ReadTimeout { + if now.Sub(last) >= c.conf.ReadTimeout { c.nconn.SetReadDeadline(time.Now()) <-readerDone returnError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") @@ -180,7 +180,7 @@ func (c *ClientConn) backgroundPlayTCP(onFrameDone chan error) { for { select { case <-deadlineTicker.C: - c.nconn.SetReadDeadline(time.Now().Add(c.c.ReadTimeout)) + c.nconn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) case <-c.backgroundTerminate: c.nconn.SetReadDeadline(time.Now()) @@ -192,7 +192,7 @@ func (c *ClientConn) backgroundPlayTCP(onFrameDone chan error) { now := time.Now() for trackID := range c.rtcpReceivers { r := c.rtcpReceivers[trackID].Report(now) - c.nconn.SetWriteDeadline(time.Now().Add(c.c.WriteTimeout)) + c.nconn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout)) frame := base.InterleavedFrame{ TrackID: trackID, StreamType: StreamTypeRtcp, diff --git a/clientconnudpl.go b/clientconnudpl.go index bbea1459..559a4118 100644 --- a/clientconnudpl.go +++ b/clientconnudpl.go @@ -31,7 +31,7 @@ type clientConnUDPListener struct { } func newClientConnUDPListener(c *ClientConn, port int) (*clientConnUDPListener, error) { - pc, err := c.c.ListenPacket("udp", ":"+strconv.FormatInt(int64(port), 10)) + pc, err := c.conf.ListenPacket("udp", ":"+strconv.FormatInt(int64(port), 10)) if err != nil { return nil, err } @@ -44,7 +44,7 @@ func newClientConnUDPListener(c *ClientConn, port int) (*clientConnUDPListener, return &clientConnUDPListener{ c: c, pc: pc, - udpFrameBuffer: multibuffer.New(c.c.ReadBufferCount, clientConnUDPReadBufferSize), + udpFrameBuffer: multibuffer.New(c.conf.ReadBufferCount, clientConnUDPReadBufferSize), }, nil } @@ -92,7 +92,7 @@ func (l *clientConnUDPListener) run() { } func (l *clientConnUDPListener) write(buf []byte) error { - l.pc.SetWriteDeadline(time.Now().Add(l.c.c.WriteTimeout)) + l.pc.SetWriteDeadline(time.Now().Add(l.c.conf.WriteTimeout)) _, err := l.pc.WriteTo(buf, &net.UDPAddr{ IP: l.remoteIP, Zone: l.remoteZone, diff --git a/server.go b/server.go index 6270b010..8d1658b4 100644 --- a/server.go +++ b/server.go @@ -13,7 +13,7 @@ type ServerHandler interface { } type Server struct { - c ServerConf + conf ServerConf listener *net.TCPListener } @@ -27,24 +27,24 @@ func (s *Server) Accept() (*ServerConn, error) { return nil, err } - if s.c.ReadTimeout == 0 { - s.c.ReadTimeout = 10 * time.Second + if s.conf.ReadTimeout == 0 { + s.conf.ReadTimeout = 10 * time.Second } - if s.c.WriteTimeout == 0 { - s.c.WriteTimeout = 10 * time.Second + if s.conf.WriteTimeout == 0 { + s.conf.WriteTimeout = 10 * time.Second } - if s.c.ReadBufferCount == 0 { - s.c.ReadBufferCount = 1 + if s.conf.ReadBufferCount == 0 { + s.conf.ReadBufferCount = 1 } sc := &ServerConn{ - c: s.c, + conf: s.conf, nconn: nconn, br: bufio.NewReaderSize(nconn, serverReadBufferSize), bw: bufio.NewWriterSize(nconn, serverWriteBufferSize), request: &base.Request{}, frame: &base.InterleavedFrame{}, - tcpFrameBuffer: multibuffer.New(s.c.ReadBufferCount, clientTCPFrameReadBufferSize), + tcpFrameBuffer: multibuffer.New(s.conf.ReadBufferCount, clientTCPFrameReadBufferSize), } return sc, nil diff --git a/serverconf.go b/serverconf.go index 3645552a..6f69d226 100644 --- a/serverconf.go +++ b/serverconf.go @@ -61,7 +61,7 @@ func (c ServerConf) Serve(address string, handler ServerHandler) (*Server, error } s := &Server{ - c: c, + conf: c, listener: listener, } diff --git a/serverconn.go b/serverconn.go index 8717dfcc..61b11e66 100644 --- a/serverconn.go +++ b/serverconn.go @@ -16,7 +16,7 @@ const ( // ServerConn is a server-side RTSP connection. type ServerConn struct { - c ServerConf + conf ServerConf nconn net.Conn br *bufio.Reader bw *bufio.Writer @@ -51,7 +51,7 @@ func (s *ServerConn) ReadFrameTCPOrRequest(timeout bool) (interface{}, error) { s.frame.Content = s.tcpFrameBuffer.Next() if timeout { - s.nconn.SetReadDeadline(time.Now().Add(s.c.ReadTimeout)) + s.nconn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout)) } return base.ReadInterleavedFrameOrRequest(s.frame, s.request, s.br) @@ -59,7 +59,7 @@ func (s *ServerConn) ReadFrameTCPOrRequest(timeout bool) (interface{}, error) { // WriteResponse writes a Response. func (s *ServerConn) WriteResponse(res *base.Response) error { - s.nconn.SetWriteDeadline(time.Now().Add(s.c.WriteTimeout)) + s.nconn.SetWriteDeadline(time.Now().Add(s.conf.WriteTimeout)) return res.Write(s.bw) } @@ -71,6 +71,6 @@ func (s *ServerConn) WriteFrameTCP(trackID int, streamType StreamType, content [ Content: content, } - s.nconn.SetWriteDeadline(time.Now().Add(s.c.WriteTimeout)) + s.nconn.SetWriteDeadline(time.Now().Add(s.conf.WriteTimeout)) return frame.Write(s.bw) }