mirror of
https://github.com/aler9/gortsplib
synced 2025-10-06 07:37:07 +08:00
server: rewrite timeout system
This commit is contained in:
@@ -31,12 +31,12 @@ func (e ErrServerCSeqMissing) Error() string {
|
||||
return "CSeq is missing"
|
||||
}
|
||||
|
||||
// ErrServerInvalidMethod is an error that can be returned by a server.
|
||||
type ErrServerInvalidMethod struct{}
|
||||
// ErrServerUnhandledRequest is an error that can be returned by a server.
|
||||
type ErrServerUnhandledRequest struct{}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ErrServerInvalidMethod) Error() string {
|
||||
return "invalid method"
|
||||
func (e ErrServerUnhandledRequest) Error() string {
|
||||
return "unhandled request"
|
||||
}
|
||||
|
||||
// ErrServerWrongState is an error that can be returned by a server.
|
||||
|
@@ -993,17 +993,12 @@ func TestServerPublishErrorTimeout(t *testing.T) {
|
||||
"tls",
|
||||
} {
|
||||
t.Run(proto, func(t *testing.T) {
|
||||
errDone := make(chan struct{})
|
||||
sessionClosed := make(chan struct{})
|
||||
|
||||
s := &Server{
|
||||
Handler: &testServerHandler{
|
||||
onSessionClose: func(ss *ServerSession) {
|
||||
/*if proto == "udp" {
|
||||
require.Equal(t, "no UDP packets received (maybe there's a firewall/NAT in between)", err.Error())
|
||||
} else {
|
||||
require.True(t, strings.HasSuffix(err.Error(), "i/o timeout"))
|
||||
}*/
|
||||
close(errDone)
|
||||
close(sessionClosed)
|
||||
},
|
||||
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
|
||||
return &base.Response{
|
||||
@@ -1130,7 +1125,7 @@ func TestServerPublishErrorTimeout(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, base.StatusOK, res.StatusCode)
|
||||
|
||||
<-errDone
|
||||
<-sessionClosed
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -488,7 +488,7 @@ func TestServerErrorCSeqMissing(t *testing.T) {
|
||||
func TestServerErrorInvalidMethod(t *testing.T) {
|
||||
h := &testServerHandler{
|
||||
onConnClose: func(sc *ServerConn, err error) {
|
||||
require.Equal(t, "invalid method", err.Error())
|
||||
require.Equal(t, "unhandled request", err.Error())
|
||||
},
|
||||
}
|
||||
|
||||
|
@@ -296,7 +296,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
|
||||
return res.res, res.err
|
||||
}
|
||||
|
||||
@@ -315,7 +314,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
|
||||
return res.res, res.err
|
||||
}
|
||||
|
||||
@@ -409,10 +407,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
|
||||
return res.res, res.err
|
||||
|
||||
case base.GetParameter:
|
||||
sres := make(chan *ServerSession)
|
||||
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
|
||||
ss := <-sres
|
||||
|
||||
// send request to session
|
||||
if ss != nil {
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
return res.res, res.err
|
||||
}
|
||||
|
||||
// handle request here
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {
|
||||
pathAndQuery, ok := req.URL.RTSPPath()
|
||||
if !ok {
|
||||
@@ -431,15 +441,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
})
|
||||
}
|
||||
|
||||
// GET_PARAMETER is used like a ping
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
Header: base.Header{
|
||||
"Content-Type": base.HeaderValue{"text/parameters"},
|
||||
},
|
||||
Body: []byte("\n"),
|
||||
}, nil
|
||||
|
||||
case base.SetParameter:
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
|
||||
pathAndQuery, ok := req.URL.RTSPPath()
|
||||
@@ -462,7 +463,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerInvalidMethod{}
|
||||
}, liberrors.ErrServerUnhandledRequest{}
|
||||
}
|
||||
|
||||
func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
|
||||
@@ -473,7 +474,7 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
|
||||
res, err := sc.handleRequest(req)
|
||||
|
||||
if res.Header == nil {
|
||||
res.Header = base.Header{}
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
|
||||
// add cseq
|
||||
|
@@ -127,6 +127,7 @@ type ServerHandlerOnPause interface {
|
||||
|
||||
// ServerHandlerOnGetParameterCtx is the context of a GET_PARAMETER request.
|
||||
type ServerHandlerOnGetParameterCtx struct {
|
||||
Session *ServerSession
|
||||
Conn *ServerConn
|
||||
Req *base.Request
|
||||
Path string
|
||||
|
124
serversession.go
124
serversession.go
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
const (
|
||||
serverSessionCheckStreamPeriod = 1 * time.Second
|
||||
serverSessionCloseAfterNoRequestsFor = 1 * 60 * time.Second
|
||||
)
|
||||
|
||||
func setupGetTrackIDPathQuery(url *base.URL,
|
||||
@@ -110,7 +111,6 @@ type ServerSessionSetuppedTrack struct {
|
||||
type ServerSessionAnnouncedTrack struct {
|
||||
track *Track
|
||||
rtcpReceiver *rtcpreceiver.RTCPReceiver
|
||||
udpLastFrameTime *int64
|
||||
}
|
||||
|
||||
type requestRes struct {
|
||||
@@ -135,16 +135,12 @@ type ServerSession struct {
|
||||
setupProtocol *StreamProtocol
|
||||
setupPath *string
|
||||
setupQuery *string
|
||||
|
||||
// TCP stream protocol
|
||||
linkedConn *ServerConn
|
||||
|
||||
// UDP stream protocol
|
||||
udpIP net.IP
|
||||
udpZone string
|
||||
|
||||
// publish
|
||||
announcedTracks []ServerSessionAnnouncedTrack
|
||||
lastRequestTime time.Time
|
||||
linkedConn *ServerConn // tcp
|
||||
udpIP net.IP // udp
|
||||
udpZone string // udp
|
||||
announcedTracks []ServerSessionAnnouncedTrack // publish
|
||||
udpLastFrameTime *int64 // publish, udp
|
||||
|
||||
// in
|
||||
request chan requestReq
|
||||
@@ -156,6 +152,7 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
|
||||
s: s,
|
||||
id: id,
|
||||
wg: wg,
|
||||
lastRequestTime: time.Now(),
|
||||
request: make(chan requestReq),
|
||||
terminate: make(chan struct{}),
|
||||
}
|
||||
@@ -207,8 +204,8 @@ func (ss *ServerSession) run() {
|
||||
h.OnSessionOpen(ss)
|
||||
}
|
||||
|
||||
checkStreamTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkStreamTicker.Stop()
|
||||
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkTimeoutTicker.Stop()
|
||||
|
||||
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
|
||||
defer receiverReportTicker.Stop()
|
||||
@@ -219,6 +216,15 @@ outer:
|
||||
case req := <-ss.request:
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
|
||||
ss.lastRequestTime = time.Now()
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
}
|
||||
|
||||
if _, ok := err.(liberrors.ErrServerTeardown); ok {
|
||||
req.res <- requestRes{res, nil}
|
||||
break outer
|
||||
@@ -226,24 +232,26 @@ outer:
|
||||
|
||||
req.res <- requestRes{res, err}
|
||||
|
||||
case <-checkStreamTicker.C:
|
||||
if ss.state != ServerSessionStateRecord || *ss.setupProtocol != StreamProtocolUDP {
|
||||
continue
|
||||
case <-checkTimeoutTicker.C:
|
||||
switch {
|
||||
// in case of record and UDP, timeout happens when no frames are being received
|
||||
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
lft := atomic.LoadInt64(ss.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
|
||||
break outer
|
||||
}
|
||||
|
||||
inTimeout := func() bool {
|
||||
// in case there's a linked TCP connection, timeout is handled in the connection
|
||||
case ss.linkedConn != nil:
|
||||
|
||||
// otherwise, timeout happens when no requests arrives
|
||||
default:
|
||||
now := time.Now()
|
||||
for _, track := range ss.announcedTracks {
|
||||
lft := atomic.LoadInt64(track.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) < ss.s.ReadTimeout {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}()
|
||||
if inTimeout {
|
||||
if now.Sub(ss.lastRequestTime) >= serverSessionCloseAfterNoRequestsFor {
|
||||
break outer
|
||||
}
|
||||
}
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
if ss.state != ServerSessionStateRecord {
|
||||
@@ -387,20 +395,14 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.announcedTracks = make([]ServerSessionAnnouncedTrack, len(tracks))
|
||||
for trackID, track := range tracks {
|
||||
clockRate, _ := track.ClockRate()
|
||||
v := time.Now().Unix()
|
||||
|
||||
ss.announcedTracks[trackID] = ServerSessionAnnouncedTrack{
|
||||
track: track,
|
||||
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
|
||||
udpLastFrameTime: &v,
|
||||
}
|
||||
}
|
||||
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
v := time.Now().Unix()
|
||||
ss.udpLastFrameTime = &v
|
||||
}
|
||||
|
||||
return res, err
|
||||
@@ -517,8 +519,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
|
||||
if th.Protocol == StreamProtocolUDP {
|
||||
ss.setuppedTracks[trackID] = ServerSessionSetuppedTrack{
|
||||
udpRTPPort: th.ClientPorts[0],
|
||||
@@ -595,7 +595,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
|
||||
path, query := base.PathSplitQuery(pathAndQuery)
|
||||
|
||||
if ss.state != ServerSessionStatePlay {
|
||||
if ss.state != ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolTCP {
|
||||
ss.linkedConn = sc
|
||||
}
|
||||
|
||||
@@ -611,12 +611,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
if ss.state != ServerSessionStatePlay {
|
||||
ss.state = ServerSessionStatePlay
|
||||
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
|
||||
if *ss.setupProtocol == StreamProtocolUDP {
|
||||
ss.udpIP = sc.ip()
|
||||
ss.udpZone = sc.zone()
|
||||
@@ -625,6 +619,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
for trackID, track := range ss.setuppedTracks {
|
||||
sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
@@ -675,12 +670,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
if res.StatusCode == base.StatusOK {
|
||||
ss.state = ServerSessionStateRecord
|
||||
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
|
||||
if *ss.setupProtocol == StreamProtocolUDP {
|
||||
ss.udpIP = sc.ip()
|
||||
ss.udpZone = sc.zone()
|
||||
@@ -695,6 +684,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP,
|
||||
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
@@ -738,12 +728,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
})
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
|
||||
switch ss.state {
|
||||
case ServerSessionStatePlay:
|
||||
ss.state = ServerSessionStatePrePlay
|
||||
@@ -775,6 +759,36 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, liberrors.ErrServerTeardown{}
|
||||
|
||||
case base.GetParameter:
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {
|
||||
pathAndQuery, ok := req.URL.RTSPPath()
|
||||
if !ok {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerNoPath{}
|
||||
}
|
||||
|
||||
path, query := base.PathSplitQuery(pathAndQuery)
|
||||
|
||||
return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{
|
||||
Session: ss,
|
||||
Conn: sc,
|
||||
Req: req,
|
||||
Path: path,
|
||||
Query: query,
|
||||
})
|
||||
}
|
||||
|
||||
// GET_PARAMETER is used like a ping when reading, and sometimes
|
||||
// also when publishing; reply with 200
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
Header: base.Header{
|
||||
"Content-Type": base.HeaderValue{"text/parameters"},
|
||||
},
|
||||
Body: []byte("\n"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
|
@@ -123,7 +123,7 @@ func (u *serverUDPListener) run() {
|
||||
|
||||
if clientData.isPublishing {
|
||||
now := time.Now()
|
||||
atomic.StoreInt64(clientData.ss.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix())
|
||||
atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix())
|
||||
clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n])
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user