server: rewrite timeout system

This commit is contained in:
aler9
2021-05-04 15:38:46 +02:00
committed by Alessandro Ros
parent 5527f4d1f7
commit 6f749e6ba8
7 changed files with 115 additions and 104 deletions

View File

@@ -31,12 +31,12 @@ func (e ErrServerCSeqMissing) Error() string {
return "CSeq is missing" return "CSeq is missing"
} }
// ErrServerInvalidMethod is an error that can be returned by a server. // ErrServerUnhandledRequest is an error that can be returned by a server.
type ErrServerInvalidMethod struct{} type ErrServerUnhandledRequest struct{}
// Error implements the error interface. // Error implements the error interface.
func (e ErrServerInvalidMethod) Error() string { func (e ErrServerUnhandledRequest) Error() string {
return "invalid method" return "unhandled request"
} }
// ErrServerWrongState is an error that can be returned by a server. // ErrServerWrongState is an error that can be returned by a server.

View File

@@ -993,17 +993,12 @@ func TestServerPublishErrorTimeout(t *testing.T) {
"tls", "tls",
} { } {
t.Run(proto, func(t *testing.T) { t.Run(proto, func(t *testing.T) {
errDone := make(chan struct{}) sessionClosed := make(chan struct{})
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onSessionClose: func(ss *ServerSession) { onSessionClose: func(ss *ServerSession) {
/*if proto == "udp" { close(sessionClosed)
require.Equal(t, "no UDP packets received (maybe there's a firewall/NAT in between)", err.Error())
} else {
require.True(t, strings.HasSuffix(err.Error(), "i/o timeout"))
}*/
close(errDone)
}, },
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
@@ -1130,7 +1125,7 @@ func TestServerPublishErrorTimeout(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
<-errDone <-sessionClosed
}) })
} }
} }

View File

@@ -488,7 +488,7 @@ func TestServerErrorCSeqMissing(t *testing.T) {
func TestServerErrorInvalidMethod(t *testing.T) { func TestServerErrorInvalidMethod(t *testing.T) {
h := &testServerHandler{ h := &testServerHandler{
onConnClose: func(sc *ServerConn, err error) { onConnClose: func(sc *ServerConn, err error) {
require.Equal(t, "invalid method", err.Error()) require.Equal(t, "unhandled request", err.Error())
}, },
} }

View File

@@ -296,7 +296,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
rres := make(chan requestRes) rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres} ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres res := <-rres
return res.res, res.err return res.res, res.err
} }
@@ -315,7 +314,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
rres := make(chan requestRes) rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres} ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres res := <-rres
return res.res, res.err return res.res, res.err
} }
@@ -409,10 +407,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
rres := make(chan requestRes) rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres} ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres res := <-rres
return res.res, res.err return res.res, res.err
case base.GetParameter: case base.GetParameter:
sres := make(chan *ServerSession)
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
ss := <-sres
// send request to session
if ss != nil {
rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres
return res.res, res.err
}
// handle request here
if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok { if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {
pathAndQuery, ok := req.URL.RTSPPath() pathAndQuery, ok := req.URL.RTSPPath()
if !ok { if !ok {
@@ -431,15 +441,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}) })
} }
// GET_PARAMETER is used like a ping
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"text/parameters"},
},
Body: []byte("\n"),
}, nil
case base.SetParameter: case base.SetParameter:
if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok { if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
pathAndQuery, ok := req.URL.RTSPPath() pathAndQuery, ok := req.URL.RTSPPath()
@@ -462,7 +463,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidMethod{} }, liberrors.ErrServerUnhandledRequest{}
} }
func (sc *ServerConn) handleRequestOuter(req *base.Request) error { func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
@@ -473,7 +474,7 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
res, err := sc.handleRequest(req) res, err := sc.handleRequest(req)
if res.Header == nil { if res.Header == nil {
res.Header = base.Header{} res.Header = make(base.Header)
} }
// add cseq // add cseq

View File

@@ -127,6 +127,7 @@ type ServerHandlerOnPause interface {
// ServerHandlerOnGetParameterCtx is the context of a GET_PARAMETER request. // ServerHandlerOnGetParameterCtx is the context of a GET_PARAMETER request.
type ServerHandlerOnGetParameterCtx struct { type ServerHandlerOnGetParameterCtx struct {
Session *ServerSession
Conn *ServerConn Conn *ServerConn
Req *base.Request Req *base.Request
Path string Path string

View File

@@ -17,6 +17,7 @@ import (
const ( const (
serverSessionCheckStreamPeriod = 1 * time.Second serverSessionCheckStreamPeriod = 1 * time.Second
serverSessionCloseAfterNoRequestsFor = 1 * 60 * time.Second
) )
func setupGetTrackIDPathQuery(url *base.URL, func setupGetTrackIDPathQuery(url *base.URL,
@@ -110,7 +111,6 @@ type ServerSessionSetuppedTrack struct {
type ServerSessionAnnouncedTrack struct { type ServerSessionAnnouncedTrack struct {
track *Track track *Track
rtcpReceiver *rtcpreceiver.RTCPReceiver rtcpReceiver *rtcpreceiver.RTCPReceiver
udpLastFrameTime *int64
} }
type requestRes struct { type requestRes struct {
@@ -135,16 +135,12 @@ type ServerSession struct {
setupProtocol *StreamProtocol setupProtocol *StreamProtocol
setupPath *string setupPath *string
setupQuery *string setupQuery *string
lastRequestTime time.Time
// TCP stream protocol linkedConn *ServerConn // tcp
linkedConn *ServerConn udpIP net.IP // udp
udpZone string // udp
// UDP stream protocol announcedTracks []ServerSessionAnnouncedTrack // publish
udpIP net.IP udpLastFrameTime *int64 // publish, udp
udpZone string
// publish
announcedTracks []ServerSessionAnnouncedTrack
// in // in
request chan requestReq request chan requestReq
@@ -156,6 +152,7 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
s: s, s: s,
id: id, id: id,
wg: wg, wg: wg,
lastRequestTime: time.Now(),
request: make(chan requestReq), request: make(chan requestReq),
terminate: make(chan struct{}), terminate: make(chan struct{}),
} }
@@ -207,8 +204,8 @@ func (ss *ServerSession) run() {
h.OnSessionOpen(ss) h.OnSessionOpen(ss)
} }
checkStreamTicker := time.NewTicker(serverSessionCheckStreamPeriod) checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
defer checkStreamTicker.Stop() defer checkTimeoutTicker.Stop()
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod) receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
defer receiverReportTicker.Stop() defer receiverReportTicker.Stop()
@@ -219,6 +216,15 @@ outer:
case req := <-ss.request: case req := <-ss.request:
res, err := ss.handleRequest(req.sc, req.req) res, err := ss.handleRequest(req.sc, req.req)
ss.lastRequestTime = time.Now()
if res.StatusCode == base.StatusOK {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
}
if _, ok := err.(liberrors.ErrServerTeardown); ok { if _, ok := err.(liberrors.ErrServerTeardown); ok {
req.res <- requestRes{res, nil} req.res <- requestRes{res, nil}
break outer break outer
@@ -226,24 +232,26 @@ outer:
req.res <- requestRes{res, err} req.res <- requestRes{res, err}
case <-checkStreamTicker.C: case <-checkTimeoutTicker.C:
if ss.state != ServerSessionStateRecord || *ss.setupProtocol != StreamProtocolUDP { switch {
continue // in case of record and UDP, timeout happens when no frames are being received
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
now := time.Now()
lft := atomic.LoadInt64(ss.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
break outer
} }
inTimeout := func() bool { // in case there's a linked TCP connection, timeout is handled in the connection
case ss.linkedConn != nil:
// otherwise, timeout happens when no requests arrives
default:
now := time.Now() now := time.Now()
for _, track := range ss.announcedTracks { if now.Sub(ss.lastRequestTime) >= serverSessionCloseAfterNoRequestsFor {
lft := atomic.LoadInt64(track.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) < ss.s.ReadTimeout {
return false
}
}
return true
}()
if inTimeout {
break outer break outer
} }
}
case <-receiverReportTicker.C: case <-receiverReportTicker.C:
if ss.state != ServerSessionStateRecord { if ss.state != ServerSessionStateRecord {
@@ -387,20 +395,14 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.announcedTracks = make([]ServerSessionAnnouncedTrack, len(tracks)) ss.announcedTracks = make([]ServerSessionAnnouncedTrack, len(tracks))
for trackID, track := range tracks { for trackID, track := range tracks {
clockRate, _ := track.ClockRate() clockRate, _ := track.ClockRate()
v := time.Now().Unix()
ss.announcedTracks[trackID] = ServerSessionAnnouncedTrack{ ss.announcedTracks[trackID] = ServerSessionAnnouncedTrack{
track: track, track: track,
rtcpReceiver: rtcpreceiver.New(nil, clockRate), rtcpReceiver: rtcpreceiver.New(nil, clockRate),
udpLastFrameTime: &v,
} }
} }
if res.Header == nil { v := time.Now().Unix()
res.Header = make(base.Header) ss.udpLastFrameTime = &v
}
res.Header["Session"] = base.HeaderValue{ss.id}
} }
return res, err return res, err
@@ -517,8 +519,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
res.Header = make(base.Header) res.Header = make(base.Header)
} }
res.Header["Session"] = base.HeaderValue{ss.id}
if th.Protocol == StreamProtocolUDP { if th.Protocol == StreamProtocolUDP {
ss.setuppedTracks[trackID] = ServerSessionSetuppedTrack{ ss.setuppedTracks[trackID] = ServerSessionSetuppedTrack{
udpRTPPort: th.ClientPorts[0], udpRTPPort: th.ClientPorts[0],
@@ -595,7 +595,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
path, query := base.PathSplitQuery(pathAndQuery) path, query := base.PathSplitQuery(pathAndQuery)
if ss.state != ServerSessionStatePlay { if ss.state != ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolTCP {
ss.linkedConn = sc ss.linkedConn = sc
} }
@@ -611,12 +611,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if ss.state != ServerSessionStatePlay { if ss.state != ServerSessionStatePlay {
ss.state = ServerSessionStatePlay ss.state = ServerSessionStatePlay
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
if *ss.setupProtocol == StreamProtocolUDP { if *ss.setupProtocol == StreamProtocolUDP {
ss.udpIP = sc.ip() ss.udpIP = sc.ip()
ss.udpZone = sc.zone() ss.udpZone = sc.zone()
@@ -625,6 +619,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
for trackID, track := range ss.setuppedTracks { for trackID, track := range ss.setuppedTracks {
sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false) sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false)
} }
return res, err return res, err
} }
@@ -675,12 +670,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStateRecord ss.state = ServerSessionStateRecord
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
if *ss.setupProtocol == StreamProtocolUDP { if *ss.setupProtocol == StreamProtocolUDP {
ss.udpIP = sc.ip() ss.udpIP = sc.ip()
ss.udpZone = sc.zone() ss.udpZone = sc.zone()
@@ -695,6 +684,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.WriteFrame(trackID, StreamTypeRTCP, ss.WriteFrame(trackID, StreamTypeRTCP,
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
} }
return res, err return res, err
} }
@@ -738,12 +728,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}) })
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
switch ss.state { switch ss.state {
case ServerSessionStatePlay: case ServerSessionStatePlay:
ss.state = ServerSessionStatePrePlay ss.state = ServerSessionStatePrePlay
@@ -775,6 +759,36 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, liberrors.ErrServerTeardown{} }, liberrors.ErrServerTeardown{}
case base.GetParameter:
if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
}
// GET_PARAMETER is used like a ping when reading, and sometimes
// also when publishing; reply with 200
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"text/parameters"},
},
Body: []byte("\n"),
}, nil
} }
return nil, fmt.Errorf("unimplemented") return nil, fmt.Errorf("unimplemented")

View File

@@ -123,7 +123,7 @@ func (u *serverUDPListener) run() {
if clientData.isPublishing { if clientData.isPublishing {
now := time.Now() now := time.Now()
atomic.StoreInt64(clientData.ss.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix()) atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix())
clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n]) clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n])
} }