mirror of
https://github.com/aler9/gortsplib
synced 2025-10-16 20:20:40 +08:00
add ServerConn.Close(), ServerSession.Close()
This commit is contained in:
234
serversession.go
234
serversession.go
@@ -112,17 +112,6 @@ type ServerSessionAnnouncedTrack struct {
|
||||
rtcpReceiver *rtcpreceiver.RTCPReceiver
|
||||
}
|
||||
|
||||
type requestRes struct {
|
||||
res *base.Response
|
||||
err error
|
||||
}
|
||||
|
||||
type requestReq struct {
|
||||
sc *ServerConn
|
||||
req *base.Request
|
||||
res chan requestRes
|
||||
}
|
||||
|
||||
// ServerSession is a server-side RTSP session.
|
||||
type ServerSession struct {
|
||||
s *Server
|
||||
@@ -142,8 +131,9 @@ type ServerSession struct {
|
||||
udpLastFrameTime *int64 // publish, udp
|
||||
|
||||
// in
|
||||
request chan requestReq
|
||||
terminate chan struct{}
|
||||
request chan sessionReq
|
||||
innerTerminate chan struct{}
|
||||
terminate chan struct{}
|
||||
}
|
||||
|
||||
func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
|
||||
@@ -152,7 +142,8 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
|
||||
id: id,
|
||||
wg: wg,
|
||||
lastRequestTime: time.Now(),
|
||||
request: make(chan requestReq),
|
||||
request: make(chan sessionReq),
|
||||
innerTerminate: make(chan struct{}, 1),
|
||||
terminate: make(chan struct{}),
|
||||
}
|
||||
|
||||
@@ -162,6 +153,15 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
|
||||
return ss
|
||||
}
|
||||
|
||||
// Close closes the ServerSession.
|
||||
func (ss *ServerSession) Close() error {
|
||||
select {
|
||||
case ss.innerTerminate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// State returns the state of the session.
|
||||
func (ss *ServerSession) State() ServerSessionState {
|
||||
return ss.state
|
||||
@@ -203,78 +203,106 @@ func (ss *ServerSession) run() {
|
||||
h.OnSessionOpen(ss)
|
||||
}
|
||||
|
||||
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkTimeoutTicker.Stop()
|
||||
|
||||
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
|
||||
defer receiverReportTicker.Stop()
|
||||
|
||||
err := func() error {
|
||||
for {
|
||||
select {
|
||||
case req := <-ss.request:
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
|
||||
ss.lastRequestTime = time.Now()
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
}
|
||||
|
||||
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
|
||||
req.res <- requestRes{res, nil}
|
||||
return liberrors.ErrServerSessionTeardown{}
|
||||
}
|
||||
|
||||
req.res <- requestRes{res, err}
|
||||
|
||||
case <-checkTimeoutTicker.C:
|
||||
switch {
|
||||
// in case of record and UDP, timeout happens when no frames are being received
|
||||
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
lft := atomic.LoadInt64(ss.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// in case there's a linked TCP connection, timeout is handled in the connection
|
||||
case ss.linkedConn != nil:
|
||||
|
||||
// otherwise, timeout happens when no requests arrives
|
||||
default:
|
||||
now := time.Now()
|
||||
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
}
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
if ss.state != ServerSessionStateRecord {
|
||||
continue
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for trackID, track := range ss.announcedTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP, r)
|
||||
}
|
||||
|
||||
case <-ss.terminate:
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
readDone := make(chan error)
|
||||
go func() {
|
||||
for req := range ss.request {
|
||||
req.res <- requestRes{nil, fmt.Errorf("terminated")}
|
||||
}
|
||||
readDone <- func() error {
|
||||
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkTimeoutTicker.Stop()
|
||||
|
||||
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
|
||||
defer receiverReportTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-ss.request:
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
|
||||
ss.lastRequestTime = time.Now()
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
}
|
||||
|
||||
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
|
||||
req.res <- sessionReqRes{res: res, err: nil}
|
||||
return liberrors.ErrServerSessionTeardown{}
|
||||
}
|
||||
|
||||
req.res <- sessionReqRes{
|
||||
res: res,
|
||||
err: err,
|
||||
ss: ss,
|
||||
}
|
||||
|
||||
case <-checkTimeoutTicker.C:
|
||||
switch {
|
||||
// in case of record and UDP, timeout happens when no frames are being received
|
||||
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
lft := atomic.LoadInt64(ss.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// in case there's a linked TCP connection, timeout is handled in the connection
|
||||
case ss.linkedConn != nil:
|
||||
|
||||
// otherwise, timeout happens when no requests arrives
|
||||
default:
|
||||
now := time.Now()
|
||||
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
}
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
if ss.state != ServerSessionStateRecord {
|
||||
continue
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for trackID, track := range ss.announcedTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP, r)
|
||||
}
|
||||
|
||||
case <-ss.innerTerminate:
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-readDone:
|
||||
go func() {
|
||||
for req := range ss.request {
|
||||
req.res <- sessionReqRes{
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
err: liberrors.ErrServerTerminated{},
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ss.s.sessionClose <- ss
|
||||
<-ss.terminate
|
||||
|
||||
case <-ss.terminate:
|
||||
select {
|
||||
case ss.innerTerminate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-readDone
|
||||
|
||||
err = liberrors.ErrServerTerminated{}
|
||||
}
|
||||
|
||||
switch ss.state {
|
||||
case ServerSessionStatePlay:
|
||||
if *ss.setupProtocol == StreamProtocolUDP {
|
||||
@@ -292,9 +320,6 @@ func (ss *ServerSession) run() {
|
||||
ss.s.connClose <- ss.linkedConn
|
||||
}
|
||||
|
||||
ss.s.sessionClose <- ss
|
||||
<-ss.terminate
|
||||
|
||||
close(ss.request)
|
||||
|
||||
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
|
||||
@@ -310,6 +335,39 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
case base.Options:
|
||||
var methods []string
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
|
||||
methods = append(methods, string(base.Describe))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
|
||||
methods = append(methods, string(base.Announce))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
|
||||
methods = append(methods, string(base.Setup))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
|
||||
methods = append(methods, string(base.Play))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
|
||||
methods = append(methods, string(base.Record))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
|
||||
methods = append(methods, string(base.Pause))
|
||||
}
|
||||
methods = append(methods, string(base.GetParameter))
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
|
||||
methods = append(methods, string(base.SetParameter))
|
||||
}
|
||||
methods = append(methods, string(base.Teardown))
|
||||
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
Header: base.Header{
|
||||
"Public": base.HeaderValue{strings.Join(methods, ", ")},
|
||||
},
|
||||
}, nil
|
||||
|
||||
case base.Announce:
|
||||
err := ss.checkState(map[ServerSessionState]struct{}{
|
||||
ServerSessionStateInitial: {},
|
||||
@@ -808,7 +866,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerUnhandledRequest{Req: req}
|
||||
}
|
||||
|
||||
// WriteFrame writes a frame.
|
||||
|
Reference in New Issue
Block a user