server: shut down session after a TEARDOWN request

This commit is contained in:
aler9
2021-05-02 21:10:30 +02:00
committed by Alessandro Ros
parent 259043685d
commit ab7ede2c00
8 changed files with 158 additions and 62 deletions

View File

@@ -12,7 +12,7 @@ type ErrServerTCPFramesEnable struct{}
// Error implements the error interface. // Error implements the error interface.
func (e ErrServerTCPFramesEnable) Error() string { func (e ErrServerTCPFramesEnable) Error() string {
return "" return "tcp frame enable"
} }
// ErrServerTCPFramesDisable is an error that can be returned by a server. // ErrServerTCPFramesDisable is an error that can be returned by a server.
@@ -20,7 +20,7 @@ type ErrServerTCPFramesDisable struct{}
// Error implements the error interface. // Error implements the error interface.
func (e ErrServerTCPFramesDisable) Error() string { func (e ErrServerTCPFramesDisable) Error() string {
return "" return "tcp frame disable"
} }
// ErrServerCSeqMissing is an error that can be returned by a server. // ErrServerCSeqMissing is an error that can be returned by a server.
@@ -183,3 +183,19 @@ type ErrServerLinkedToOtherSession struct{}
func (e ErrServerLinkedToOtherSession) Error() string { func (e ErrServerLinkedToOtherSession) Error() string {
return "connection is linked to another session" return "connection is linked to another session"
} }
// ErrServerTeardown is an error that can be returned by a server.
type ErrServerTeardown struct{}
// Error implements the error interface.
func (e ErrServerTeardown) Error() string {
return "teardown"
}
// ErrServerSessionLinkedToOtherConn is an error that can be returned by a server.
type ErrServerSessionLinkedToOtherConn struct{}
// Error implements the error interface.
func (e ErrServerSessionLinkedToOtherConn) Error() string {
return "session is linked to another connection"
}

View File

@@ -42,8 +42,9 @@ func newSessionID(sessions map[string]*ServerSession) (string, error) {
} }
type sessionGetReq struct { type sessionGetReq struct {
id string id string
res chan *ServerSession create bool
res chan *ServerSession
} }
// Server is a RTSP server. // Server is a RTSP server.
@@ -233,6 +234,11 @@ outer:
req.res <- ss req.res <- ss
} else { } else {
if !req.create {
req.res <- nil
continue
}
id, err := newSessionID(s.sessions) id, err := newSessionID(s.sessions)
if err != nil { if err != nil {
req.res <- nil req.res <- nil

View File

@@ -482,10 +482,26 @@ func TestServerPublish(t *testing.T) {
"tcp", "tcp",
} { } {
t.Run(proto, func(t *testing.T) { t.Run(proto, func(t *testing.T) {
connOpened := make(chan struct{})
connClosed := make(chan struct{})
sessionOpened := make(chan struct{})
sessionClosed := make(chan struct{})
rtpReceived := uint64(0) rtpReceived := uint64(0)
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onConnOpen: func(sc *ServerConn) {
close(connOpened)
},
onConnClose: func(sc *ServerConn, err error) {
close(connClosed)
},
onSessionOpen: func(ss *ServerSession) {
close(sessionOpened)
},
onSessionClose: func(ss *ServerSession) {
close(sessionClosed)
},
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -531,6 +547,8 @@ func TestServerPublish(t *testing.T) {
defer conn.Close() defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
<-connOpened
track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
require.NoError(t, err) require.NoError(t, err)
@@ -558,6 +576,8 @@ func TestServerPublish(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
<-sessionOpened
inTH := &headers.Transport{ inTH := &headers.Transport{
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
@@ -674,6 +694,25 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, StreamTypeRTCP, f.StreamType) require.Equal(t, StreamTypeRTCP, f.StreamType)
require.Equal(t, []byte{0x09, 0x0A, 0x0B, 0x0C}, f.Payload) require.Equal(t, []byte{0x09, 0x0A, 0x0B, 0x0C}, f.Payload)
} }
err = base.Request{
Method: base.Teardown,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
<-sessionClosed
conn.Close()
<-connClosed
}) })
} }
} }

View File

@@ -116,12 +116,13 @@ func TestServerReadSetupPath(t *testing.T) {
} }
func TestServerReadSetupErrorDifferentPaths(t *testing.T) { func TestServerReadSetupErrorDifferentPaths(t *testing.T) {
serverErr := make(chan error) connClosed := make(chan struct{})
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onConnClose: func(sc *ServerConn, err error) { onConnClose: func(sc *ServerConn, err error) {
serverErr <- err require.Equal(t, "can't setup tracks with different paths", err.Error())
close(connClosed)
}, },
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
@@ -185,17 +186,17 @@ func TestServerReadSetupErrorDifferentPaths(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode) require.Equal(t, base.StatusBadRequest, res.StatusCode)
err = <-serverErr <-connClosed
require.Equal(t, "can't setup tracks with different paths", err.Error())
} }
func TestServerReadSetupErrorTrackTwice(t *testing.T) { func TestServerReadSetupErrorTrackTwice(t *testing.T) {
serverErr := make(chan error) connClosed := make(chan struct{})
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onConnClose: func(sc *ServerConn, err error) { onConnClose: func(sc *ServerConn, err error) {
serverErr <- err require.Equal(t, "track 0 has already been setup", err.Error())
close(connClosed)
}, },
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
@@ -259,8 +260,7 @@ func TestServerReadSetupErrorTrackTwice(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode) require.Equal(t, base.StatusBadRequest, res.StatusCode)
err = <-serverErr <-connClosed
require.Equal(t, "track 0 has already been setup", err.Error())
} }
func TestServerRead(t *testing.T) { func TestServerRead(t *testing.T) {
@@ -269,10 +269,26 @@ func TestServerRead(t *testing.T) {
"tcp", "tcp",
} { } {
t.Run(proto, func(t *testing.T) { t.Run(proto, func(t *testing.T) {
connOpened := make(chan struct{})
connClosed := make(chan struct{})
sessionOpened := make(chan struct{})
sessionClosed := make(chan struct{})
framesReceived := make(chan struct{}) framesReceived := make(chan struct{})
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onConnOpen: func(sc *ServerConn) {
close(connOpened)
},
onConnClose: func(sc *ServerConn, err error) {
close(connClosed)
},
onSessionOpen: func(ss *ServerSession) {
close(sessionOpened)
},
onSessionClose: func(ss *ServerSession) {
close(sessionClosed)
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -303,9 +319,10 @@ func TestServerRead(t *testing.T) {
conn, err := net.Dial("tcp", "localhost:8554") conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err) require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
<-connOpened
inTH := &headers.Transport{ inTH := &headers.Transport{
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
@@ -344,6 +361,8 @@ func TestServerRead(t *testing.T) {
err = th.Read(res.Header["Transport"]) err = th.Read(res.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
<-sessionOpened
var l1 net.PacketConn var l1 net.PacketConn
var l2 net.PacketConn var l2 net.PacketConn
if proto == "udp" { if proto == "udp" {
@@ -415,6 +434,25 @@ func TestServerRead(t *testing.T) {
} }
<-framesReceived <-framesReceived
err = base.Request{
Method: base.Teardown,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
<-sessionClosed
conn.Close()
<-connClosed
}) })
} }
} }

View File

@@ -15,7 +15,9 @@ import (
) )
type testServerHandler struct { type testServerHandler struct {
onConnOpen func(*ServerConn)
onConnClose func(*ServerConn, error) onConnClose func(*ServerConn, error)
onSessionOpen func(*ServerSession)
onSessionClose func(*ServerSession) onSessionClose func(*ServerSession)
onDescribe func(*ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) onDescribe func(*ServerHandlerOnDescribeCtx) (*base.Response, []byte, error)
onAnnounce func(*ServerHandlerOnAnnounceCtx) (*base.Response, error) onAnnounce func(*ServerHandlerOnAnnounceCtx) (*base.Response, error)
@@ -26,12 +28,24 @@ type testServerHandler struct {
onFrame func(*ServerHandlerOnFrameCtx) onFrame func(*ServerHandlerOnFrameCtx)
} }
func (sh *testServerHandler) OnConnOpen(sc *ServerConn) {
if sh.onConnOpen != nil {
sh.onConnOpen(sc)
}
}
func (sh *testServerHandler) OnConnClose(sc *ServerConn, err error) { func (sh *testServerHandler) OnConnClose(sc *ServerConn, err error) {
if sh.onConnClose != nil { if sh.onConnClose != nil {
sh.onConnClose(sc, err) sh.onConnClose(sc, err)
} }
} }
func (sh *testServerHandler) OnSessionOpen(ss *ServerSession) {
if sh.onSessionOpen != nil {
sh.onSessionOpen(ss)
}
}
func (sh *testServerHandler) OnSessionClose(ss *ServerSession) { func (sh *testServerHandler) OnSessionClose(ss *ServerSession) {
if sh.onSessionClose != nil { if sh.onSessionClose != nil {
sh.onSessionClose(ss) sh.onSessionClose(ss)

View File

@@ -307,7 +307,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Announce: case base.Announce:
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok { if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
sres := make(chan *ServerSession) sres := make(chan *ServerSession)
sc.s.sessionGet <- sessionGetReq{id: sxID, res: sres} sc.s.sessionGet <- sessionGetReq{id: sxID, create: true, res: sres}
ss := <-sres ss := <-sres
if ss == nil { if ss == nil {
@@ -326,7 +326,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Setup: case base.Setup:
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok { if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
sres := make(chan *ServerSession) sres := make(chan *ServerSession)
sc.s.sessionGet <- sessionGetReq{id: sxID, res: sres} sc.s.sessionGet <- sessionGetReq{id: sxID, create: true, res: sres}
ss := <-sres ss := <-sres
if ss == nil { if ss == nil {
@@ -345,7 +345,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Play: case base.Play:
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok { if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
sres := make(chan *ServerSession) sres := make(chan *ServerSession)
sc.s.sessionGet <- sessionGetReq{id: sxID, res: sres} sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
ss := <-sres ss := <-sres
if ss == nil { if ss == nil {
@@ -371,7 +371,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Record: case base.Record:
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok { if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
sres := make(chan *ServerSession) sres := make(chan *ServerSession)
sc.s.sessionGet <- sessionGetReq{id: sxID, res: sres} sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
ss := <-sres ss := <-sres
if ss == nil { if ss == nil {
@@ -397,7 +397,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Pause: case base.Pause:
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok { if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
sres := make(chan *ServerSession) sres := make(chan *ServerSession)
sc.s.sessionGet <- sessionGetReq{id: sxID, res: sres} sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
ss := <-sres ss := <-sres
if ss == nil { if ss == nil {
@@ -419,24 +419,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
} }
case base.Teardown: case base.Teardown:
if _, ok := sc.s.Handler.(ServerHandlerOnTeardown); ok { sres := make(chan *ServerSession)
sres := make(chan *ServerSession) sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres}
sc.s.sessionGet <- sessionGetReq{id: sxID, res: sres} ss := <-sres
ss := <-sres
if ss == nil { if ss == nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("terminated") }, fmt.Errorf("terminated")
}
rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres
return res.res, res.err
} }
rres := make(chan requestRes)
ss.request <- requestReq{sc: sc, req: req, res: rres}
res := <-rres
return res.res, res.err
case base.GetParameter: case base.GetParameter:
if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok { if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {
pathAndQuery, ok := req.URL.RTSPPath() pathAndQuery, ok := req.URL.RTSPPath()

View File

@@ -164,20 +164,6 @@ type ServerHandlerOnSetParameter interface {
OnSetParameter(*ServerHandlerOnSetParameterCtx) (*base.Response, error) OnSetParameter(*ServerHandlerOnSetParameterCtx) (*base.Response, error)
} }
// ServerHandlerOnTeardownCtx is the context of a TEARDOWN request.
type ServerHandlerOnTeardownCtx struct {
Session *ServerSession
Conn *ServerConn
Req *base.Request
Path string
Query string
}
// ServerHandlerOnTeardown can be implemented by a ServerHandler.
type ServerHandlerOnTeardown interface {
OnTeardown(*ServerHandlerOnTeardownCtx) (*base.Response, error)
}
// ServerHandlerOnFrameCtx is the context of a frame request. // ServerHandlerOnFrameCtx is the context of a frame request.
type ServerHandlerOnFrameCtx struct { type ServerHandlerOnFrameCtx struct {
Session *ServerSession Session *ServerSession

View File

@@ -218,6 +218,12 @@ outer:
select { select {
case req := <-ss.request: case req := <-ss.request:
res, err := ss.handleRequest(req.sc, req.req) res, err := ss.handleRequest(req.sc, req.req)
if _, ok := err.(liberrors.ErrServerTeardown); ok {
req.res <- requestRes{res, nil}
break outer
}
req.res <- requestRes{res, err} req.res <- requestRes{res, err}
case <-checkStreamTicker.C: case <-checkStreamTicker.C:
@@ -289,6 +295,10 @@ outer:
} }
func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) { func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) {
if ss.linkedConn != nil && sc != ss.linkedConn {
return nil, liberrors.ErrServerSessionLinkedToOtherConn{}
}
switch req.Method { switch req.Method {
case base.Announce: case base.Announce:
err := ss.checkState(map[ServerSessionState]struct{}{ err := ss.checkState(map[ServerSessionState]struct{}{
@@ -772,22 +782,11 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
return res, err return res, err
case base.Teardown: case base.Teardown:
pathAndQuery, ok := req.URL.RTSPPath() ss.linkedConn = nil
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery) return &base.Response{
StatusCode: base.StatusOK,
return ss.s.Handler.(ServerHandlerOnTeardown).OnTeardown(&ServerHandlerOnTeardownCtx{ }, liberrors.ErrServerTeardown{}
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
} }
return nil, fmt.Errorf("unimplemented") return nil, fmt.Errorf("unimplemented")