mirror of
https://github.com/aler9/gortsplib
synced 2025-10-04 06:46:42 +08:00
add ServerConn.Close(), ServerSession.Close()
This commit is contained in:
@@ -74,7 +74,7 @@ type Client struct {
|
||||
// callback called before every request.
|
||||
OnRequest func(req *base.Request)
|
||||
|
||||
// callback called after very response.
|
||||
// callback called after every response.
|
||||
OnResponse func(res *base.Response)
|
||||
|
||||
// function used to initialize the TCP client.
|
||||
|
@@ -22,22 +22,22 @@ type serverHandler struct {
|
||||
sdp []byte
|
||||
}
|
||||
|
||||
// called when a connection is opened.
|
||||
// called after a connection is opened.
|
||||
func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) {
|
||||
log.Printf("conn opened")
|
||||
}
|
||||
|
||||
// called when a connection is closed.
|
||||
// called after a connection is closed.
|
||||
func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) {
|
||||
log.Printf("conn closed (%v)", err)
|
||||
}
|
||||
|
||||
// called when a session is opened.
|
||||
// called after a session is opened.
|
||||
func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) {
|
||||
log.Printf("session opened")
|
||||
}
|
||||
|
||||
// called when a session is closed.
|
||||
// called after a session is closed.
|
||||
func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) {
|
||||
log.Printf("session closed")
|
||||
|
||||
|
@@ -21,22 +21,22 @@ type serverHandler struct {
|
||||
sdp []byte
|
||||
}
|
||||
|
||||
// called when a connection is opened.
|
||||
// called after a connection is opened.
|
||||
func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) {
|
||||
log.Printf("conn opened")
|
||||
}
|
||||
|
||||
// called when a connection is closed.
|
||||
// called after a connection is closed.
|
||||
func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) {
|
||||
log.Printf("conn closed (%v)", err)
|
||||
}
|
||||
|
||||
// called when a session is opened.
|
||||
// called after a session is opened.
|
||||
func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) {
|
||||
log.Printf("session opened")
|
||||
}
|
||||
|
||||
// called when a session is closed.
|
||||
// called after a session is closed.
|
||||
func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) {
|
||||
log.Printf("session closed")
|
||||
|
||||
|
@@ -15,6 +15,14 @@ func (e ErrServerTerminated) Error() string {
|
||||
return "terminated"
|
||||
}
|
||||
|
||||
// ErrServerSessionNotFound is an error that can be returned by a server.
|
||||
type ErrServerSessionNotFound struct{}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ErrServerSessionNotFound) Error() string {
|
||||
return "session not found"
|
||||
}
|
||||
|
||||
// ErrServerSessionTimedOut is an error that can be returned by a server.
|
||||
type ErrServerSessionTimedOut struct{}
|
||||
|
||||
@@ -48,11 +56,13 @@ func (e ErrServerCSeqMissing) Error() string {
|
||||
}
|
||||
|
||||
// ErrServerUnhandledRequest is an error that can be returned by a server.
|
||||
type ErrServerUnhandledRequest struct{}
|
||||
type ErrServerUnhandledRequest struct {
|
||||
Req *base.Request
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ErrServerUnhandledRequest) Error() string {
|
||||
return "unhandled request"
|
||||
return fmt.Sprintf("unhandled request (%v %v)", e.Req.Method, e.Req.URL)
|
||||
}
|
||||
|
||||
// ErrServerWrongState is an error that can be returned by a server.
|
||||
|
50
server.go
50
server.go
@@ -9,6 +9,9 @@ import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aler9/gortsplib/pkg/base"
|
||||
"github.com/aler9/gortsplib/pkg/liberrors"
|
||||
)
|
||||
|
||||
func extractPort(address string) (int, error) {
|
||||
@@ -41,10 +44,18 @@ func newSessionID(sessions map[string]*ServerSession) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
type sessionGetReq struct {
|
||||
type sessionReqRes struct {
|
||||
res *base.Response
|
||||
err error
|
||||
ss *ServerSession
|
||||
}
|
||||
|
||||
type sessionReq struct {
|
||||
sc *ServerConn
|
||||
req *base.Request
|
||||
id string
|
||||
create bool
|
||||
res chan *ServerSession
|
||||
res chan sessionReqRes
|
||||
}
|
||||
|
||||
// Server is a RTSP server.
|
||||
@@ -100,7 +111,7 @@ type Server struct {
|
||||
|
||||
// in
|
||||
connClose chan *ServerConn
|
||||
sessionGet chan sessionGetReq
|
||||
sessionReq chan sessionReq
|
||||
sessionClose chan *ServerSession
|
||||
terminate chan struct{}
|
||||
|
||||
@@ -194,7 +205,7 @@ func (s *Server) run() {
|
||||
s.sessions = make(map[string]*ServerSession)
|
||||
s.conns = make(map[*ServerConn]struct{})
|
||||
s.connClose = make(chan *ServerConn)
|
||||
s.sessionGet = make(chan sessionGetReq)
|
||||
s.sessionReq = make(chan sessionReq)
|
||||
s.sessionClose = make(chan *ServerSession)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@@ -233,25 +244,35 @@ outer:
|
||||
}
|
||||
s.doConnClose(sc)
|
||||
|
||||
case req := <-s.sessionGet:
|
||||
case req := <-s.sessionReq:
|
||||
if ss, ok := s.sessions[req.id]; ok {
|
||||
req.res <- ss
|
||||
ss.request <- req
|
||||
|
||||
} else {
|
||||
if !req.create {
|
||||
req.res <- nil
|
||||
req.res <- sessionReqRes{
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
err: liberrors.ErrServerSessionNotFound{},
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
id, err := newSessionID(s.sessions)
|
||||
if err != nil {
|
||||
req.res <- nil
|
||||
req.res <- sessionReqRes{
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
err: fmt.Errorf("internal error"),
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
ss := newServerSession(s, id, &wg)
|
||||
s.sessions[id] = ss
|
||||
req.res <- ss
|
||||
ss.request <- req
|
||||
}
|
||||
|
||||
case ss := <-s.sessionClose:
|
||||
@@ -284,11 +305,16 @@ outer:
|
||||
return
|
||||
}
|
||||
|
||||
case req, ok := <-s.sessionGet:
|
||||
case req, ok := <-s.sessionReq:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
req.res <- nil
|
||||
req.res <- sessionReqRes{
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
err: liberrors.ErrServerTerminated{},
|
||||
}
|
||||
|
||||
case _, ok := <-s.sessionClose:
|
||||
if !ok {
|
||||
@@ -321,7 +347,7 @@ outer:
|
||||
close(acceptErr)
|
||||
close(connNew)
|
||||
close(s.connClose)
|
||||
close(s.sessionGet)
|
||||
close(s.sessionReq)
|
||||
close(s.sessionClose)
|
||||
close(s.done)
|
||||
}
|
||||
|
@@ -310,6 +310,11 @@ func TestServerRead(t *testing.T) {
|
||||
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload)
|
||||
close(framesReceived)
|
||||
},
|
||||
onGetParameter: func(ctx *ServerHandlerOnGetParameterCtx) (*base.Response, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -453,11 +458,43 @@ func TestServerRead(t *testing.T) {
|
||||
|
||||
<-framesReceived
|
||||
|
||||
if proto == "udp" {
|
||||
// ping with OPTIONS
|
||||
err = base.Request{
|
||||
Method: base.Options,
|
||||
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"4"},
|
||||
"Session": res.Header["Session"],
|
||||
},
|
||||
}.Write(bconn.Writer)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = res.Read(bconn.Reader)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, base.StatusOK, res.StatusCode)
|
||||
|
||||
// ping with GET_PARAMETER
|
||||
err = base.Request{
|
||||
Method: base.GetParameter,
|
||||
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"5"},
|
||||
"Session": res.Header["Session"],
|
||||
},
|
||||
}.Write(bconn.Writer)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = res.Read(bconn.Reader)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, base.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
err = base.Request{
|
||||
Method: base.Teardown,
|
||||
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"3"},
|
||||
"CSeq": base.HeaderValue{"6"},
|
||||
"Session": res.Header["Session"],
|
||||
},
|
||||
}.Write(bconn.Writer)
|
||||
|
@@ -427,6 +427,31 @@ func TestServerErrorWrongUDPPorts(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerConnClose(t *testing.T) {
|
||||
connClosed := make(chan struct{})
|
||||
|
||||
s := &Server{
|
||||
Handler: &testServerHandler{
|
||||
onConnOpen: func(sc *ServerConn) {
|
||||
sc.Close()
|
||||
},
|
||||
onConnClose: func(sc *ServerConn, err error) {
|
||||
close(connClosed)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := s.Start("127.0.0.1:8554")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
conn, err := net.Dial("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
<-connClosed
|
||||
}
|
||||
|
||||
func TestServerCSeq(t *testing.T) {
|
||||
s := &Server{}
|
||||
err := s.Start("127.0.0.1:8554")
|
||||
@@ -493,7 +518,7 @@ func TestServerErrorCSeqMissing(t *testing.T) {
|
||||
func TestServerErrorInvalidMethod(t *testing.T) {
|
||||
h := &testServerHandler{
|
||||
onConnClose: func(sc *ServerConn, err error) {
|
||||
require.Equal(t, "unhandled request", err.Error())
|
||||
require.Equal(t, "unhandled request (INVALID rtsp://localhost:8554/)", err.Error())
|
||||
},
|
||||
}
|
||||
|
||||
@@ -846,3 +871,55 @@ func TestServerErrorInvalidSession(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerSessionClose(t *testing.T) {
|
||||
sessionClosed := make(chan struct{})
|
||||
|
||||
s := &Server{
|
||||
Handler: &testServerHandler{
|
||||
onSessionOpen: func(ss *ServerSession) {
|
||||
ss.Close()
|
||||
},
|
||||
onSessionClose: func(ss *ServerSession, err error) {
|
||||
close(sessionClosed)
|
||||
},
|
||||
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := s.Start("127.0.0.1:8554")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
conn, err := net.Dial("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||
|
||||
err = base.Request{
|
||||
Method: base.Setup,
|
||||
URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"1"},
|
||||
"Transport": headers.Transport{
|
||||
Protocol: StreamProtocolTCP,
|
||||
Delivery: func() *base.StreamDelivery {
|
||||
v := base.StreamDeliveryUnicast
|
||||
return &v
|
||||
}(),
|
||||
Mode: func() *headers.TransportMode {
|
||||
v := headers.TransportModePlay
|
||||
return &v
|
||||
}(),
|
||||
InterleavedIDs: &[2]int{0, 1},
|
||||
}.Write(),
|
||||
},
|
||||
}.Write(bconn.Writer)
|
||||
require.NoError(t, err)
|
||||
|
||||
<-sessionClosed
|
||||
}
|
||||
|
199
serverconn.go
199
serverconn.go
@@ -3,7 +3,6 @@ package gortsplib
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -55,7 +54,8 @@ type ServerConn struct {
|
||||
tcpFrameBackgroundWriteDone chan struct{}
|
||||
|
||||
// in
|
||||
terminate chan struct{}
|
||||
innerTerminate chan struct{}
|
||||
terminate chan struct{}
|
||||
}
|
||||
|
||||
func newServerConn(
|
||||
@@ -64,10 +64,11 @@ func newServerConn(
|
||||
nconn net.Conn) *ServerConn {
|
||||
|
||||
sc := &ServerConn{
|
||||
s: s,
|
||||
wg: wg,
|
||||
nconn: nconn,
|
||||
terminate: make(chan struct{}),
|
||||
s: s,
|
||||
wg: wg,
|
||||
nconn: nconn,
|
||||
innerTerminate: make(chan struct{}, 1),
|
||||
terminate: make(chan struct{}),
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
@@ -76,6 +77,15 @@ func newServerConn(
|
||||
return sc
|
||||
}
|
||||
|
||||
// Close closes the ServerConn.
|
||||
func (sc *ServerConn) Close() error {
|
||||
select {
|
||||
case sc.innerTerminate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NetConn returns the underlying net.Conn.
|
||||
func (sc *ServerConn) NetConn() net.Conn {
|
||||
return sc.nconn
|
||||
@@ -177,12 +187,26 @@ func (sc *ServerConn) run() {
|
||||
}
|
||||
|
||||
sc.nconn.Close()
|
||||
sc.s.connClose <- sc
|
||||
|
||||
sc.s.connClose <- sc
|
||||
<-sc.terminate
|
||||
|
||||
return err
|
||||
|
||||
case <-sc.innerTerminate:
|
||||
sc.nconn.Close()
|
||||
<-readDone
|
||||
|
||||
if sc.tcpFrameEnabled {
|
||||
sc.tcpFrameWriteBuffer.Close()
|
||||
<-sc.tcpFrameBackgroundWriteDone
|
||||
}
|
||||
|
||||
sc.s.connClose <- sc
|
||||
<-sc.terminate
|
||||
|
||||
return liberrors.ErrServerTerminated{}
|
||||
|
||||
case <-sc.terminate:
|
||||
sc.nconn.Close()
|
||||
<-readDone
|
||||
@@ -226,6 +250,21 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
switch req.Method {
|
||||
case base.Options:
|
||||
// handle request in session
|
||||
if sxID != "" {
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
}
|
||||
|
||||
// handle request here
|
||||
var methods []string
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
|
||||
methods = append(methods, string(base.Describe))
|
||||
@@ -291,58 +330,46 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
case base.Announce:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
|
||||
sres := make(chan *ServerSession)
|
||||
sc.s.sessionGet <- sessionGetReq{id: sxID, create: true, res: sres}
|
||||
ss := <-sres
|
||||
|
||||
if ss == nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("terminated")
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: true,
|
||||
res: cres,
|
||||
}
|
||||
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
}
|
||||
|
||||
case base.Setup:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
|
||||
sres := make(chan *ServerSession)
|
||||
sc.s.sessionGet <- sessionGetReq{id: sxID, create: true, res: sres}
|
||||
ss := <-sres
|
||||
|
||||
if ss == nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("terminated")
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: true,
|
||||
res: cres,
|
||||
}
|
||||
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
}
|
||||
|
||||
case base.Play:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
|
||||
sres := make(chan *ServerSession)
|
||||
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
|
||||
ss := <-sres
|
||||
|
||||
if ss == nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerInvalidSession{}
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
res := <-cres
|
||||
|
||||
if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok {
|
||||
sc.tcpFrameLinkedSession = ss
|
||||
sc.tcpFrameLinkedSession = res.ss
|
||||
sc.tcpFrameIsRecording = false
|
||||
sc.tcpFrameSetEnabled = true
|
||||
return res.res, nil
|
||||
@@ -353,22 +380,18 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
case base.Record:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
|
||||
sres := make(chan *ServerSession)
|
||||
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
|
||||
ss := <-sres
|
||||
|
||||
if ss == nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerInvalidSession{}
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
res := <-cres
|
||||
|
||||
if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok {
|
||||
sc.tcpFrameLinkedSession = ss
|
||||
sc.tcpFrameLinkedSession = res.ss
|
||||
sc.tcpFrameIsRecording = true
|
||||
sc.tcpFrameSetEnabled = true
|
||||
return res.res, nil
|
||||
@@ -379,19 +402,15 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
case base.Pause:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
|
||||
sres := make(chan *ServerSession)
|
||||
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
|
||||
ss := <-sres
|
||||
|
||||
if ss == nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerInvalidSession{}
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
res := <-cres
|
||||
|
||||
if _, ok := res.err.(liberrors.ErrServerTCPFramesDisable); ok {
|
||||
sc.tcpFrameSetEnabled = false
|
||||
@@ -402,31 +421,29 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}
|
||||
|
||||
case base.Teardown:
|
||||
sres := make(chan *ServerSession)
|
||||
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
|
||||
ss := <-sres
|
||||
|
||||
if ss == nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerInvalidSession{}
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
|
||||
case base.GetParameter:
|
||||
sres := make(chan *ServerSession)
|
||||
sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
|
||||
ss := <-sres
|
||||
|
||||
// send request to session
|
||||
if ss != nil {
|
||||
rres := make(chan requestRes)
|
||||
ss.request <- requestReq{sc: sc, req: req, res: rres}
|
||||
res := <-rres
|
||||
// handle request in session
|
||||
if sxID != "" {
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
}
|
||||
|
||||
@@ -471,7 +488,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerUnhandledRequest{}
|
||||
}, liberrors.ErrServerUnhandledRequest{Req: req}
|
||||
}
|
||||
|
||||
func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
|
||||
|
234
serversession.go
234
serversession.go
@@ -112,17 +112,6 @@ type ServerSessionAnnouncedTrack struct {
|
||||
rtcpReceiver *rtcpreceiver.RTCPReceiver
|
||||
}
|
||||
|
||||
type requestRes struct {
|
||||
res *base.Response
|
||||
err error
|
||||
}
|
||||
|
||||
type requestReq struct {
|
||||
sc *ServerConn
|
||||
req *base.Request
|
||||
res chan requestRes
|
||||
}
|
||||
|
||||
// ServerSession is a server-side RTSP session.
|
||||
type ServerSession struct {
|
||||
s *Server
|
||||
@@ -142,8 +131,9 @@ type ServerSession struct {
|
||||
udpLastFrameTime *int64 // publish, udp
|
||||
|
||||
// in
|
||||
request chan requestReq
|
||||
terminate chan struct{}
|
||||
request chan sessionReq
|
||||
innerTerminate chan struct{}
|
||||
terminate chan struct{}
|
||||
}
|
||||
|
||||
func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
|
||||
@@ -152,7 +142,8 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
|
||||
id: id,
|
||||
wg: wg,
|
||||
lastRequestTime: time.Now(),
|
||||
request: make(chan requestReq),
|
||||
request: make(chan sessionReq),
|
||||
innerTerminate: make(chan struct{}, 1),
|
||||
terminate: make(chan struct{}),
|
||||
}
|
||||
|
||||
@@ -162,6 +153,15 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
|
||||
return ss
|
||||
}
|
||||
|
||||
// Close closes the ServerSession.
|
||||
func (ss *ServerSession) Close() error {
|
||||
select {
|
||||
case ss.innerTerminate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// State returns the state of the session.
|
||||
func (ss *ServerSession) State() ServerSessionState {
|
||||
return ss.state
|
||||
@@ -203,78 +203,106 @@ func (ss *ServerSession) run() {
|
||||
h.OnSessionOpen(ss)
|
||||
}
|
||||
|
||||
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkTimeoutTicker.Stop()
|
||||
|
||||
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
|
||||
defer receiverReportTicker.Stop()
|
||||
|
||||
err := func() error {
|
||||
for {
|
||||
select {
|
||||
case req := <-ss.request:
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
|
||||
ss.lastRequestTime = time.Now()
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
}
|
||||
|
||||
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
|
||||
req.res <- requestRes{res, nil}
|
||||
return liberrors.ErrServerSessionTeardown{}
|
||||
}
|
||||
|
||||
req.res <- requestRes{res, err}
|
||||
|
||||
case <-checkTimeoutTicker.C:
|
||||
switch {
|
||||
// in case of record and UDP, timeout happens when no frames are being received
|
||||
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
lft := atomic.LoadInt64(ss.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// in case there's a linked TCP connection, timeout is handled in the connection
|
||||
case ss.linkedConn != nil:
|
||||
|
||||
// otherwise, timeout happens when no requests arrives
|
||||
default:
|
||||
now := time.Now()
|
||||
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
}
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
if ss.state != ServerSessionStateRecord {
|
||||
continue
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for trackID, track := range ss.announcedTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP, r)
|
||||
}
|
||||
|
||||
case <-ss.terminate:
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
readDone := make(chan error)
|
||||
go func() {
|
||||
for req := range ss.request {
|
||||
req.res <- requestRes{nil, fmt.Errorf("terminated")}
|
||||
}
|
||||
readDone <- func() error {
|
||||
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkTimeoutTicker.Stop()
|
||||
|
||||
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
|
||||
defer receiverReportTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-ss.request:
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
|
||||
ss.lastRequestTime = time.Now()
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
}
|
||||
|
||||
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
|
||||
req.res <- sessionReqRes{res: res, err: nil}
|
||||
return liberrors.ErrServerSessionTeardown{}
|
||||
}
|
||||
|
||||
req.res <- sessionReqRes{
|
||||
res: res,
|
||||
err: err,
|
||||
ss: ss,
|
||||
}
|
||||
|
||||
case <-checkTimeoutTicker.C:
|
||||
switch {
|
||||
// in case of record and UDP, timeout happens when no frames are being received
|
||||
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
lft := atomic.LoadInt64(ss.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// in case there's a linked TCP connection, timeout is handled in the connection
|
||||
case ss.linkedConn != nil:
|
||||
|
||||
// otherwise, timeout happens when no requests arrives
|
||||
default:
|
||||
now := time.Now()
|
||||
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
}
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
if ss.state != ServerSessionStateRecord {
|
||||
continue
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for trackID, track := range ss.announcedTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP, r)
|
||||
}
|
||||
|
||||
case <-ss.innerTerminate:
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-readDone:
|
||||
go func() {
|
||||
for req := range ss.request {
|
||||
req.res <- sessionReqRes{
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
err: liberrors.ErrServerTerminated{},
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ss.s.sessionClose <- ss
|
||||
<-ss.terminate
|
||||
|
||||
case <-ss.terminate:
|
||||
select {
|
||||
case ss.innerTerminate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-readDone
|
||||
|
||||
err = liberrors.ErrServerTerminated{}
|
||||
}
|
||||
|
||||
switch ss.state {
|
||||
case ServerSessionStatePlay:
|
||||
if *ss.setupProtocol == StreamProtocolUDP {
|
||||
@@ -292,9 +320,6 @@ func (ss *ServerSession) run() {
|
||||
ss.s.connClose <- ss.linkedConn
|
||||
}
|
||||
|
||||
ss.s.sessionClose <- ss
|
||||
<-ss.terminate
|
||||
|
||||
close(ss.request)
|
||||
|
||||
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
|
||||
@@ -310,6 +335,39 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
case base.Options:
|
||||
var methods []string
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
|
||||
methods = append(methods, string(base.Describe))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
|
||||
methods = append(methods, string(base.Announce))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
|
||||
methods = append(methods, string(base.Setup))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
|
||||
methods = append(methods, string(base.Play))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
|
||||
methods = append(methods, string(base.Record))
|
||||
}
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
|
||||
methods = append(methods, string(base.Pause))
|
||||
}
|
||||
methods = append(methods, string(base.GetParameter))
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
|
||||
methods = append(methods, string(base.SetParameter))
|
||||
}
|
||||
methods = append(methods, string(base.Teardown))
|
||||
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
Header: base.Header{
|
||||
"Public": base.HeaderValue{strings.Join(methods, ", ")},
|
||||
},
|
||||
}, nil
|
||||
|
||||
case base.Announce:
|
||||
err := ss.checkState(map[ServerSessionState]struct{}{
|
||||
ServerSessionStateInitial: {},
|
||||
@@ -808,7 +866,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerUnhandledRequest{Req: req}
|
||||
}
|
||||
|
||||
// WriteFrame writes a frame.
|
||||
|
Reference in New Issue
Block a user