add ClientConf.OnRequest, ClientConf.OnResponse

This commit is contained in:
aler9
2020-12-08 11:54:38 +01:00
parent eb7ebc5543
commit 07aefbcd5d
8 changed files with 45 additions and 31 deletions

View File

@@ -58,6 +58,12 @@ type ClientConf struct {
// It defaults to 1. // It defaults to 1.
ReadBufferCount int ReadBufferCount int
// callback called before every request.
OnRequest func(req *base.Request)
// callback called after very response.
OnResponse func(res *base.Response)
// function used to initialize the TCP client. // function used to initialize the TCP client.
// It defaults to net.DialTimeout. // It defaults to net.DialTimeout.
DialTimeout func(network, address string, timeout time.Duration) (net.Conn, error) DialTimeout func(network, address string, timeout time.Duration) (net.Conn, error)
@@ -95,7 +101,7 @@ func (c ClientConf) Dial(host string) (*ClientConn, error) {
} }
return &ClientConn{ return &ClientConn{
c: c, conf: c,
nconn: nconn, nconn: nconn,
br: bufio.NewReaderSize(nconn, clientReadBufferSize), br: bufio.NewReaderSize(nconn, clientReadBufferSize),
bw: bufio.NewWriterSize(nconn, clientWriteBufferSize), bw: bufio.NewWriterSize(nconn, clientWriteBufferSize),

View File

@@ -63,7 +63,7 @@ func (s clientConnState) String() string {
// ClientConn is a client-side RTSP connection. // ClientConn is a client-side RTSP connection.
type ClientConn struct { type ClientConn struct {
c ClientConf conf ClientConf
nconn net.Conn nconn net.Conn
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
@@ -146,7 +146,7 @@ func (c *ClientConn) Tracks() Tracks {
} }
func (c *ClientConn) readFrameTCPOrResponse() (interface{}, error) { func (c *ClientConn) readFrameTCPOrResponse() (interface{}, error) {
c.nconn.SetReadDeadline(time.Now().Add(c.c.ReadTimeout)) c.nconn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout))
f := base.InterleavedFrame{ f := base.InterleavedFrame{
Content: c.tcpFrameBuffer.Next(), Content: c.tcpFrameBuffer.Next(),
} }
@@ -175,7 +175,11 @@ func (c *ClientConn) Do(req *base.Request) (*base.Response, error) {
c.cseq++ c.cseq++
req.Header["CSeq"] = base.HeaderValue{strconv.FormatInt(int64(c.cseq), 10)} req.Header["CSeq"] = base.HeaderValue{strconv.FormatInt(int64(c.cseq), 10)}
c.nconn.SetWriteDeadline(time.Now().Add(c.c.WriteTimeout)) if c.conf.OnRequest != nil {
c.conf.OnRequest(req)
}
c.nconn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout))
err := req.Write(c.bw) err := req.Write(c.bw)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -205,6 +209,10 @@ func (c *ClientConn) Do(req *base.Request) (*base.Response, error) {
return nil, err return nil, err
} }
if c.conf.OnResponse != nil {
c.conf.OnResponse(res)
}
// get session from response // get session from response
if v, ok := res.Header["Session"]; ok { if v, ok := res.Header["Session"]; ok {
sx, err := headers.ReadSession(v) sx, err := headers.ReadSession(v)
@@ -298,7 +306,7 @@ func (c *ClientConn) Describe(u *base.URL) (Tracks, *base.Response, error) {
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
// redirect // redirect
if !c.c.RedirectDisable && if !c.conf.RedirectDisable &&
res.StatusCode >= base.StatusMovedPermanently && res.StatusCode >= base.StatusMovedPermanently &&
res.StatusCode <= base.StatusUseProxy && res.StatusCode <= base.StatusUseProxy &&
len(res.Header["Location"]) == 1 { len(res.Header["Location"]) == 1 {
@@ -310,7 +318,7 @@ func (c *ClientConn) Describe(u *base.URL) (Tracks, *base.Response, error) {
return nil, nil, err return nil, nil, err
} }
nc, err := c.c.Dial(u.Host) nc, err := c.conf.Dial(u.Host)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -385,8 +393,8 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track,
} }
// protocol set by conf // protocol set by conf
if c.c.StreamProtocol != nil { if c.conf.StreamProtocol != nil {
return *c.c.StreamProtocol return *c.conf.StreamProtocol
} }
// try udp // try udp
@@ -493,7 +501,7 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track,
// switch protocol automatically // switch protocol automatically
if res.StatusCode == base.StatusUnsupportedTransport && if res.StatusCode == base.StatusUnsupportedTransport &&
c.streamProtocol == nil && c.streamProtocol == nil &&
c.c.StreamProtocol == nil { c.conf.StreamProtocol == nil {
v := StreamProtocolTCP v := StreamProtocolTCP
c.streamProtocol = &v c.streamProtocol = &v

View File

@@ -164,7 +164,7 @@ func (c *ClientConn) backgroundRecordTCP() {
for trackID := range c.rtcpSenders { for trackID := range c.rtcpSenders {
r := c.rtcpSenders[trackID].Report(now) r := c.rtcpSenders[trackID].Report(now)
if r != nil { if r != nil {
c.nconn.SetWriteDeadline(time.Now().Add(c.c.WriteTimeout)) c.nconn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout))
frame := base.InterleavedFrame{ frame := base.InterleavedFrame{
TrackID: trackID, TrackID: trackID,
StreamType: StreamTypeRtcp, StreamType: StreamTypeRtcp,
@@ -199,7 +199,7 @@ func (c *ClientConn) WriteFrame(trackID int, streamType StreamType, content []by
return c.udpRtcpListeners[trackID].write(content) return c.udpRtcpListeners[trackID].write(content)
} }
c.nconn.SetWriteDeadline(now.Add(c.c.WriteTimeout)) c.nconn.SetWriteDeadline(now.Add(c.conf.WriteTimeout))
frame := base.InterleavedFrame{ frame := base.InterleavedFrame{
TrackID: trackID, TrackID: trackID,
StreamType: streamType, StreamType: streamType,

View File

@@ -126,7 +126,7 @@ func (c *ClientConn) backgroundPlayUDP(onFrameDone chan error) {
for _, lastUnix := range c.udpLastFrameTimes { for _, lastUnix := range c.udpLastFrameTimes {
last := time.Unix(atomic.LoadInt64(lastUnix), 0) last := time.Unix(atomic.LoadInt64(lastUnix), 0)
if now.Sub(last) >= c.c.ReadTimeout { if now.Sub(last) >= c.conf.ReadTimeout {
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readerDone <-readerDone
returnError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") returnError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)")
@@ -180,7 +180,7 @@ func (c *ClientConn) backgroundPlayTCP(onFrameDone chan error) {
for { for {
select { select {
case <-deadlineTicker.C: case <-deadlineTicker.C:
c.nconn.SetReadDeadline(time.Now().Add(c.c.ReadTimeout)) c.nconn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout))
case <-c.backgroundTerminate: case <-c.backgroundTerminate:
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
@@ -192,7 +192,7 @@ func (c *ClientConn) backgroundPlayTCP(onFrameDone chan error) {
now := time.Now() now := time.Now()
for trackID := range c.rtcpReceivers { for trackID := range c.rtcpReceivers {
r := c.rtcpReceivers[trackID].Report(now) r := c.rtcpReceivers[trackID].Report(now)
c.nconn.SetWriteDeadline(time.Now().Add(c.c.WriteTimeout)) c.nconn.SetWriteDeadline(time.Now().Add(c.conf.WriteTimeout))
frame := base.InterleavedFrame{ frame := base.InterleavedFrame{
TrackID: trackID, TrackID: trackID,
StreamType: StreamTypeRtcp, StreamType: StreamTypeRtcp,

View File

@@ -31,7 +31,7 @@ type clientConnUDPListener struct {
} }
func newClientConnUDPListener(c *ClientConn, port int) (*clientConnUDPListener, error) { func newClientConnUDPListener(c *ClientConn, port int) (*clientConnUDPListener, error) {
pc, err := c.c.ListenPacket("udp", ":"+strconv.FormatInt(int64(port), 10)) pc, err := c.conf.ListenPacket("udp", ":"+strconv.FormatInt(int64(port), 10))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -44,7 +44,7 @@ func newClientConnUDPListener(c *ClientConn, port int) (*clientConnUDPListener,
return &clientConnUDPListener{ return &clientConnUDPListener{
c: c, c: c,
pc: pc, pc: pc,
udpFrameBuffer: multibuffer.New(c.c.ReadBufferCount, clientConnUDPReadBufferSize), udpFrameBuffer: multibuffer.New(c.conf.ReadBufferCount, clientConnUDPReadBufferSize),
}, nil }, nil
} }
@@ -92,7 +92,7 @@ func (l *clientConnUDPListener) run() {
} }
func (l *clientConnUDPListener) write(buf []byte) error { func (l *clientConnUDPListener) write(buf []byte) error {
l.pc.SetWriteDeadline(time.Now().Add(l.c.c.WriteTimeout)) l.pc.SetWriteDeadline(time.Now().Add(l.c.conf.WriteTimeout))
_, err := l.pc.WriteTo(buf, &net.UDPAddr{ _, err := l.pc.WriteTo(buf, &net.UDPAddr{
IP: l.remoteIP, IP: l.remoteIP,
Zone: l.remoteZone, Zone: l.remoteZone,

View File

@@ -13,7 +13,7 @@ type ServerHandler interface {
} }
type Server struct { type Server struct {
c ServerConf conf ServerConf
listener *net.TCPListener listener *net.TCPListener
} }
@@ -27,24 +27,24 @@ func (s *Server) Accept() (*ServerConn, error) {
return nil, err return nil, err
} }
if s.c.ReadTimeout == 0 { if s.conf.ReadTimeout == 0 {
s.c.ReadTimeout = 10 * time.Second s.conf.ReadTimeout = 10 * time.Second
} }
if s.c.WriteTimeout == 0 { if s.conf.WriteTimeout == 0 {
s.c.WriteTimeout = 10 * time.Second s.conf.WriteTimeout = 10 * time.Second
} }
if s.c.ReadBufferCount == 0 { if s.conf.ReadBufferCount == 0 {
s.c.ReadBufferCount = 1 s.conf.ReadBufferCount = 1
} }
sc := &ServerConn{ sc := &ServerConn{
c: s.c, conf: s.conf,
nconn: nconn, nconn: nconn,
br: bufio.NewReaderSize(nconn, serverReadBufferSize), br: bufio.NewReaderSize(nconn, serverReadBufferSize),
bw: bufio.NewWriterSize(nconn, serverWriteBufferSize), bw: bufio.NewWriterSize(nconn, serverWriteBufferSize),
request: &base.Request{}, request: &base.Request{},
frame: &base.InterleavedFrame{}, frame: &base.InterleavedFrame{},
tcpFrameBuffer: multibuffer.New(s.c.ReadBufferCount, clientTCPFrameReadBufferSize), tcpFrameBuffer: multibuffer.New(s.conf.ReadBufferCount, clientTCPFrameReadBufferSize),
} }
return sc, nil return sc, nil

View File

@@ -61,7 +61,7 @@ func (c ServerConf) Serve(address string, handler ServerHandler) (*Server, error
} }
s := &Server{ s := &Server{
c: c, conf: c,
listener: listener, listener: listener,
} }

View File

@@ -16,7 +16,7 @@ const (
// ServerConn is a server-side RTSP connection. // ServerConn is a server-side RTSP connection.
type ServerConn struct { type ServerConn struct {
c ServerConf conf ServerConf
nconn net.Conn nconn net.Conn
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
@@ -51,7 +51,7 @@ func (s *ServerConn) ReadFrameTCPOrRequest(timeout bool) (interface{}, error) {
s.frame.Content = s.tcpFrameBuffer.Next() s.frame.Content = s.tcpFrameBuffer.Next()
if timeout { if timeout {
s.nconn.SetReadDeadline(time.Now().Add(s.c.ReadTimeout)) s.nconn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout))
} }
return base.ReadInterleavedFrameOrRequest(s.frame, s.request, s.br) return base.ReadInterleavedFrameOrRequest(s.frame, s.request, s.br)
@@ -59,7 +59,7 @@ func (s *ServerConn) ReadFrameTCPOrRequest(timeout bool) (interface{}, error) {
// WriteResponse writes a Response. // WriteResponse writes a Response.
func (s *ServerConn) WriteResponse(res *base.Response) error { func (s *ServerConn) WriteResponse(res *base.Response) error {
s.nconn.SetWriteDeadline(time.Now().Add(s.c.WriteTimeout)) s.nconn.SetWriteDeadline(time.Now().Add(s.conf.WriteTimeout))
return res.Write(s.bw) return res.Write(s.bw)
} }
@@ -71,6 +71,6 @@ func (s *ServerConn) WriteFrameTCP(trackID int, streamType StreamType, content [
Content: content, Content: content,
} }
s.nconn.SetWriteDeadline(time.Now().Add(s.c.WriteTimeout)) s.nconn.SetWriteDeadline(time.Now().Add(s.conf.WriteTimeout))
return frame.Write(s.bw) return frame.Write(s.bw)
} }