add ServerConn.Close(), ServerSession.Close()

This commit is contained in:
aler9
2021-05-07 11:42:01 +02:00
parent 1c9fe6c394
commit e52fda806d
9 changed files with 430 additions and 203 deletions

View File

@@ -74,7 +74,7 @@ type Client struct {
// callback called before every request. // callback called before every request.
OnRequest func(req *base.Request) OnRequest func(req *base.Request)
// callback called after very response. // callback called after every response.
OnResponse func(res *base.Response) OnResponse func(res *base.Response)
// function used to initialize the TCP client. // function used to initialize the TCP client.

View File

@@ -22,22 +22,22 @@ type serverHandler struct {
sdp []byte sdp []byte
} }
// called when a connection is opened. // called after a connection is opened.
func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) { func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) {
log.Printf("conn opened") log.Printf("conn opened")
} }
// called when a connection is closed. // called after a connection is closed.
func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) {
log.Printf("conn closed (%v)", err) log.Printf("conn closed (%v)", err)
} }
// called when a session is opened. // called after a session is opened.
func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) { func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) {
log.Printf("session opened") log.Printf("session opened")
} }
// called when a session is closed. // called after a session is closed.
func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) { func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) {
log.Printf("session closed") log.Printf("session closed")

View File

@@ -21,22 +21,22 @@ type serverHandler struct {
sdp []byte sdp []byte
} }
// called when a connection is opened. // called after a connection is opened.
func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) { func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) {
log.Printf("conn opened") log.Printf("conn opened")
} }
// called when a connection is closed. // called after a connection is closed.
func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) {
log.Printf("conn closed (%v)", err) log.Printf("conn closed (%v)", err)
} }
// called when a session is opened. // called after a session is opened.
func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) { func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) {
log.Printf("session opened") log.Printf("session opened")
} }
// called when a session is closed. // called after a session is closed.
func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) { func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) {
log.Printf("session closed") log.Printf("session closed")

View File

@@ -15,6 +15,14 @@ func (e ErrServerTerminated) Error() string {
return "terminated" return "terminated"
} }
// ErrServerSessionNotFound is an error that can be returned by a server.
type ErrServerSessionNotFound struct{}
// Error implements the error interface.
func (e ErrServerSessionNotFound) Error() string {
return "session not found"
}
// ErrServerSessionTimedOut is an error that can be returned by a server. // ErrServerSessionTimedOut is an error that can be returned by a server.
type ErrServerSessionTimedOut struct{} type ErrServerSessionTimedOut struct{}
@@ -48,11 +56,13 @@ func (e ErrServerCSeqMissing) Error() string {
} }
// ErrServerUnhandledRequest is an error that can be returned by a server. // ErrServerUnhandledRequest is an error that can be returned by a server.
type ErrServerUnhandledRequest struct{} type ErrServerUnhandledRequest struct {
Req *base.Request
}
// Error implements the error interface. // Error implements the error interface.
func (e ErrServerUnhandledRequest) Error() string { func (e ErrServerUnhandledRequest) Error() string {
return "unhandled request" return fmt.Sprintf("unhandled request (%v %v)", e.Req.Method, e.Req.URL)
} }
// ErrServerWrongState is an error that can be returned by a server. // ErrServerWrongState is an error that can be returned by a server.

View File

@@ -9,6 +9,9 @@ import (
"strconv" "strconv"
"sync" "sync"
"time" "time"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/liberrors"
) )
func extractPort(address string) (int, error) { func extractPort(address string) (int, error) {
@@ -41,10 +44,18 @@ func newSessionID(sessions map[string]*ServerSession) (string, error) {
} }
} }
type sessionGetReq struct { type sessionReqRes struct {
res *base.Response
err error
ss *ServerSession
}
type sessionReq struct {
sc *ServerConn
req *base.Request
id string id string
create bool create bool
res chan *ServerSession res chan sessionReqRes
} }
// Server is a RTSP server. // Server is a RTSP server.
@@ -100,7 +111,7 @@ type Server struct {
// in // in
connClose chan *ServerConn connClose chan *ServerConn
sessionGet chan sessionGetReq sessionReq chan sessionReq
sessionClose chan *ServerSession sessionClose chan *ServerSession
terminate chan struct{} terminate chan struct{}
@@ -194,7 +205,7 @@ func (s *Server) run() {
s.sessions = make(map[string]*ServerSession) s.sessions = make(map[string]*ServerSession)
s.conns = make(map[*ServerConn]struct{}) s.conns = make(map[*ServerConn]struct{})
s.connClose = make(chan *ServerConn) s.connClose = make(chan *ServerConn)
s.sessionGet = make(chan sessionGetReq) s.sessionReq = make(chan sessionReq)
s.sessionClose = make(chan *ServerSession) s.sessionClose = make(chan *ServerSession)
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -233,25 +244,35 @@ outer:
} }
s.doConnClose(sc) s.doConnClose(sc)
case req := <-s.sessionGet: case req := <-s.sessionReq:
if ss, ok := s.sessions[req.id]; ok { if ss, ok := s.sessions[req.id]; ok {
req.res <- ss ss.request <- req
} else { } else {
if !req.create { if !req.create {
req.res <- nil req.res <- sessionReqRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerSessionNotFound{},
}
continue continue
} }
id, err := newSessionID(s.sessions) id, err := newSessionID(s.sessions)
if err != nil { if err != nil {
req.res <- nil req.res <- sessionReqRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: fmt.Errorf("internal error"),
}
continue continue
} }
ss := newServerSession(s, id, &wg) ss := newServerSession(s, id, &wg)
s.sessions[id] = ss s.sessions[id] = ss
req.res <- ss ss.request <- req
} }
case ss := <-s.sessionClose: case ss := <-s.sessionClose:
@@ -284,11 +305,16 @@ outer:
return return
} }
case req, ok := <-s.sessionGet: case req, ok := <-s.sessionReq:
if !ok { if !ok {
return return
} }
req.res <- nil req.res <- sessionReqRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerTerminated{},
}
case _, ok := <-s.sessionClose: case _, ok := <-s.sessionClose:
if !ok { if !ok {
@@ -321,7 +347,7 @@ outer:
close(acceptErr) close(acceptErr)
close(connNew) close(connNew)
close(s.connClose) close(s.connClose)
close(s.sessionGet) close(s.sessionReq)
close(s.sessionClose) close(s.sessionClose)
close(s.done) close(s.done)
} }

View File

@@ -310,6 +310,11 @@ func TestServerRead(t *testing.T) {
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload)
close(framesReceived) close(framesReceived)
}, },
onGetParameter: func(ctx *ServerHandlerOnGetParameterCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
}, },
} }
@@ -453,11 +458,43 @@ func TestServerRead(t *testing.T) {
<-framesReceived <-framesReceived
if proto == "udp" {
// ping with OPTIONS
err = base.Request{
Method: base.Options,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"4"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
// ping with GET_PARAMETER
err = base.Request{
Method: base.GetParameter,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"5"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}
err = base.Request{ err = base.Request{
Method: base.Teardown, Method: base.Teardown,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"3"}, "CSeq": base.HeaderValue{"6"},
"Session": res.Header["Session"], "Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)

View File

@@ -427,6 +427,31 @@ func TestServerErrorWrongUDPPorts(t *testing.T) {
}) })
} }
func TestServerConnClose(t *testing.T) {
connClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnOpen: func(sc *ServerConn) {
sc.Close()
},
onConnClose: func(sc *ServerConn, err error) {
close(connClosed)
},
},
}
err := s.Start("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
<-connClosed
}
func TestServerCSeq(t *testing.T) { func TestServerCSeq(t *testing.T) {
s := &Server{} s := &Server{}
err := s.Start("127.0.0.1:8554") err := s.Start("127.0.0.1:8554")
@@ -493,7 +518,7 @@ func TestServerErrorCSeqMissing(t *testing.T) {
func TestServerErrorInvalidMethod(t *testing.T) { func TestServerErrorInvalidMethod(t *testing.T) {
h := &testServerHandler{ h := &testServerHandler{
onConnClose: func(sc *ServerConn, err error) { onConnClose: func(sc *ServerConn, err error) {
require.Equal(t, "unhandled request", err.Error()) require.Equal(t, "unhandled request (INVALID rtsp://localhost:8554/)", err.Error())
}, },
} }
@@ -846,3 +871,55 @@ func TestServerErrorInvalidSession(t *testing.T) {
}) })
} }
} }
func TestServerSessionClose(t *testing.T) {
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onSessionOpen: func(ss *ServerSession) {
ss.Close()
},
onSessionClose: func(ss *ServerSession, err error) {
close(sessionClosed)
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
}
err := s.Start("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
err = base.Request{
Method: base.Setup,
URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": headers.Transport{
Protocol: StreamProtocolTCP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
InterleavedIDs: &[2]int{0, 1},
}.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
<-sessionClosed
}

View File

@@ -3,7 +3,6 @@ package gortsplib
import ( import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
@@ -55,6 +54,7 @@ type ServerConn struct {
tcpFrameBackgroundWriteDone chan struct{} tcpFrameBackgroundWriteDone chan struct{}
// in // in
innerTerminate chan struct{}
terminate chan struct{} terminate chan struct{}
} }
@@ -67,6 +67,7 @@ func newServerConn(
s: s, s: s,
wg: wg, wg: wg,
nconn: nconn, nconn: nconn,
innerTerminate: make(chan struct{}, 1),
terminate: make(chan struct{}), terminate: make(chan struct{}),
} }
@@ -76,6 +77,15 @@ func newServerConn(
return sc return sc
} }
// Close closes the ServerConn.
func (sc *ServerConn) Close() error {
select {
case sc.innerTerminate <- struct{}{}:
default:
}
return nil
}
// NetConn returns the underlying net.Conn. // NetConn returns the underlying net.Conn.
func (sc *ServerConn) NetConn() net.Conn { func (sc *ServerConn) NetConn() net.Conn {
return sc.nconn return sc.nconn
@@ -177,12 +187,26 @@ func (sc *ServerConn) run() {
} }
sc.nconn.Close() sc.nconn.Close()
sc.s.connClose <- sc
sc.s.connClose <- sc
<-sc.terminate <-sc.terminate
return err return err
case <-sc.innerTerminate:
sc.nconn.Close()
<-readDone
if sc.tcpFrameEnabled {
sc.tcpFrameWriteBuffer.Close()
<-sc.tcpFrameBackgroundWriteDone
}
sc.s.connClose <- sc
<-sc.terminate
return liberrors.ErrServerTerminated{}
case <-sc.terminate: case <-sc.terminate:
sc.nconn.Close() sc.nconn.Close()
<-readDone <-readDone
@@ -226,6 +250,21 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
switch req.Method { switch req.Method {
case base.Options: case base.Options:
// handle request in session
if sxID != "" {
cres := make(chan sessionReqRes)
sc.s.sessionReq <- sessionReq{
sc: sc,
req: req,
id: sxID,
create: false,
res: cres,
}
res := <-cres
return res.res, res.err
}
// handle request here
var methods []string var methods []string
if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok { if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
methods = append(methods, string(base.Describe)) methods = append(methods, string(base.Describe))
@@ -291,58 +330,46 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Announce: case base.Announce:
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok { if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
sres := make(chan *ServerSession) cres := make(chan sessionReqRes)
sc.s.sessionGet <- sessionGetReq{id: sxID, create: true, res: sres} sc.s.sessionReq <- sessionReq{
ss := <-sres sc: sc,
req: req,
if ss == nil { id: sxID,
return &base.Response{ create: true,
StatusCode: base.StatusBadRequest, res: cres,
}, fmt.Errorf("terminated")
} }
res := <-cres
rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres
return res.res, res.err return res.res, res.err
} }
case base.Setup: case base.Setup:
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok { if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
sres := make(chan *ServerSession) cres := make(chan sessionReqRes)
sc.s.sessionGet <- sessionGetReq{id: sxID, create: true, res: sres} sc.s.sessionReq <- sessionReq{
ss := <-sres sc: sc,
req: req,
if ss == nil { id: sxID,
return &base.Response{ create: true,
StatusCode: base.StatusBadRequest, res: cres,
}, fmt.Errorf("terminated")
} }
res := <-cres
rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres
return res.res, res.err return res.res, res.err
} }
case base.Play: case base.Play:
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok { if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
sres := make(chan *ServerSession) cres := make(chan sessionReqRes)
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} sc.s.sessionReq <- sessionReq{
ss := <-sres sc: sc,
req: req,
if ss == nil { id: sxID,
return &base.Response{ create: false,
StatusCode: base.StatusBadRequest, res: cres,
}, liberrors.ErrServerInvalidSession{}
} }
res := <-cres
rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres
if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok { if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok {
sc.tcpFrameLinkedSession = ss sc.tcpFrameLinkedSession = res.ss
sc.tcpFrameIsRecording = false sc.tcpFrameIsRecording = false
sc.tcpFrameSetEnabled = true sc.tcpFrameSetEnabled = true
return res.res, nil return res.res, nil
@@ -353,22 +380,18 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Record: case base.Record:
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok { if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
sres := make(chan *ServerSession) cres := make(chan sessionReqRes)
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} sc.s.sessionReq <- sessionReq{
ss := <-sres sc: sc,
req: req,
if ss == nil { id: sxID,
return &base.Response{ create: false,
StatusCode: base.StatusBadRequest, res: cres,
}, liberrors.ErrServerInvalidSession{}
} }
res := <-cres
rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres
if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok { if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok {
sc.tcpFrameLinkedSession = ss sc.tcpFrameLinkedSession = res.ss
sc.tcpFrameIsRecording = true sc.tcpFrameIsRecording = true
sc.tcpFrameSetEnabled = true sc.tcpFrameSetEnabled = true
return res.res, nil return res.res, nil
@@ -379,19 +402,15 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Pause: case base.Pause:
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok { if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
sres := make(chan *ServerSession) cres := make(chan sessionReqRes)
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} sc.s.sessionReq <- sessionReq{
ss := <-sres sc: sc,
req: req,
if ss == nil { id: sxID,
return &base.Response{ create: false,
StatusCode: base.StatusBadRequest, res: cres,
}, liberrors.ErrServerInvalidSession{}
} }
res := <-cres
rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres
if _, ok := res.err.(liberrors.ErrServerTCPFramesDisable); ok { if _, ok := res.err.(liberrors.ErrServerTCPFramesDisable); ok {
sc.tcpFrameSetEnabled = false sc.tcpFrameSetEnabled = false
@@ -402,31 +421,29 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
} }
case base.Teardown: case base.Teardown:
sres := make(chan *ServerSession) cres := make(chan sessionReqRes)
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} sc.s.sessionReq <- sessionReq{
ss := <-sres sc: sc,
req: req,
if ss == nil { id: sxID,
return &base.Response{ create: false,
StatusCode: base.StatusBadRequest, res: cres,
}, liberrors.ErrServerInvalidSession{}
} }
res := <-cres
rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres
return res.res, res.err return res.res, res.err
case base.GetParameter: case base.GetParameter:
sres := make(chan *ServerSession) // handle request in session
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} if sxID != "" {
ss := <-sres cres := make(chan sessionReqRes)
sc.s.sessionReq <- sessionReq{
// send request to session sc: sc,
if ss != nil { req: req,
rres := make(chan requestRes) id: sxID,
ss.request <- requestReq{sc: sc, req: req, res: rres} create: false,
res := <-rres res: cres,
}
res := <-cres
return res.res, res.err return res.res, res.err
} }
@@ -471,7 +488,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerUnhandledRequest{} }, liberrors.ErrServerUnhandledRequest{Req: req}
} }
func (sc *ServerConn) handleRequestOuter(req *base.Request) error { func (sc *ServerConn) handleRequestOuter(req *base.Request) error {

View File

@@ -112,17 +112,6 @@ type ServerSessionAnnouncedTrack struct {
rtcpReceiver *rtcpreceiver.RTCPReceiver rtcpReceiver *rtcpreceiver.RTCPReceiver
} }
type requestRes struct {
res *base.Response
err error
}
type requestReq struct {
sc *ServerConn
req *base.Request
res chan requestRes
}
// ServerSession is a server-side RTSP session. // ServerSession is a server-side RTSP session.
type ServerSession struct { type ServerSession struct {
s *Server s *Server
@@ -142,7 +131,8 @@ type ServerSession struct {
udpLastFrameTime *int64 // publish, udp udpLastFrameTime *int64 // publish, udp
// in // in
request chan requestReq request chan sessionReq
innerTerminate chan struct{}
terminate chan struct{} terminate chan struct{}
} }
@@ -152,7 +142,8 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
id: id, id: id,
wg: wg, wg: wg,
lastRequestTime: time.Now(), lastRequestTime: time.Now(),
request: make(chan requestReq), request: make(chan sessionReq),
innerTerminate: make(chan struct{}, 1),
terminate: make(chan struct{}), terminate: make(chan struct{}),
} }
@@ -162,6 +153,15 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
return ss return ss
} }
// Close closes the ServerSession.
func (ss *ServerSession) Close() error {
select {
case ss.innerTerminate <- struct{}{}:
default:
}
return nil
}
// State returns the state of the session. // State returns the state of the session.
func (ss *ServerSession) State() ServerSessionState { func (ss *ServerSession) State() ServerSessionState {
return ss.state return ss.state
@@ -203,13 +203,15 @@ func (ss *ServerSession) run() {
h.OnSessionOpen(ss) h.OnSessionOpen(ss)
} }
readDone := make(chan error)
go func() {
readDone <- func() error {
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod) checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
defer checkTimeoutTicker.Stop() defer checkTimeoutTicker.Stop()
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod) receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
defer receiverReportTicker.Stop() defer receiverReportTicker.Stop()
err := func() error {
for { for {
select { select {
case req := <-ss.request: case req := <-ss.request:
@@ -225,11 +227,15 @@ func (ss *ServerSession) run() {
} }
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok { if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
req.res <- requestRes{res, nil} req.res <- sessionReqRes{res: res, err: nil}
return liberrors.ErrServerSessionTeardown{} return liberrors.ErrServerSessionTeardown{}
} }
req.res <- requestRes{res, err} req.res <- sessionReqRes{
res: res,
err: err,
ss: ss,
}
case <-checkTimeoutTicker.C: case <-checkTimeoutTicker.C:
switch { switch {
@@ -263,18 +269,40 @@ func (ss *ServerSession) run() {
ss.WriteFrame(trackID, StreamTypeRTCP, r) ss.WriteFrame(trackID, StreamTypeRTCP, r)
} }
case <-ss.terminate: case <-ss.innerTerminate:
return liberrors.ErrServerTerminated{} return liberrors.ErrServerTerminated{}
} }
} }
}() }()
}()
var err error
select {
case err = <-readDone:
go func() { go func() {
for req := range ss.request { for req := range ss.request {
req.res <- requestRes{nil, fmt.Errorf("terminated")} req.res <- sessionReqRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerTerminated{},
}
} }
}() }()
ss.s.sessionClose <- ss
<-ss.terminate
case <-ss.terminate:
select {
case ss.innerTerminate <- struct{}{}:
default:
}
<-readDone
err = liberrors.ErrServerTerminated{}
}
switch ss.state { switch ss.state {
case ServerSessionStatePlay: case ServerSessionStatePlay:
if *ss.setupProtocol == StreamProtocolUDP { if *ss.setupProtocol == StreamProtocolUDP {
@@ -292,9 +320,6 @@ func (ss *ServerSession) run() {
ss.s.connClose <- ss.linkedConn ss.s.connClose <- ss.linkedConn
} }
ss.s.sessionClose <- ss
<-ss.terminate
close(ss.request) close(ss.request)
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
@@ -310,6 +335,39 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
} }
switch req.Method { switch req.Method {
case base.Options:
var methods []string
if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
methods = append(methods, string(base.Describe))
}
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
methods = append(methods, string(base.Announce))
}
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
methods = append(methods, string(base.Setup))
}
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
methods = append(methods, string(base.Play))
}
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
methods = append(methods, string(base.Record))
}
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
methods = append(methods, string(base.Pause))
}
methods = append(methods, string(base.GetParameter))
if _, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
methods = append(methods, string(base.SetParameter))
}
methods = append(methods, string(base.Teardown))
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join(methods, ", ")},
},
}, nil
case base.Announce: case base.Announce:
err := ss.checkState(map[ServerSessionState]struct{}{ err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStateInitial: {}, ServerSessionStateInitial: {},
@@ -808,7 +866,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, nil }, nil
} }
return nil, fmt.Errorf("unimplemented") return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerUnhandledRequest{Req: req}
} }
// WriteFrame writes a frame. // WriteFrame writes a frame.