server: add Session to ServerHandlerOnSetParameterCtx

This commit is contained in:
aler9
2022-08-07 13:43:59 +02:00
parent 702cac94a6
commit 4940e8faeb
5 changed files with 184 additions and 218 deletions

View File

@@ -128,7 +128,7 @@ func (e ErrServerTrackAlreadySetup) Error() string {
// ErrServerTransportHeaderInvalidMode is an error that can be returned by a server. // ErrServerTransportHeaderInvalidMode is an error that can be returned by a server.
type ErrServerTransportHeaderInvalidMode struct { type ErrServerTransportHeaderInvalidMode struct {
Mode *headers.TransportMode Mode headers.TransportMode
} }
// Error implements the error interface. // Error implements the error interface.

View File

@@ -557,17 +557,39 @@ func TestServerErrorTCPOneConnTwoSessions(t *testing.T) {
} }
func TestServerGetSetParameter(t *testing.T) { func TestServerGetSetParameter(t *testing.T) {
for _, ca := range []string{"inside session", "outside session"} {
t.Run(ca, func(t *testing.T) {
track := &TrackH264{
PayloadType: 96,
SPS: []byte{0x01, 0x02, 0x03, 0x04},
PPS: []byte{0x01, 0x02, 0x03, 0x04},
}
stream := NewServerStream(Tracks{track})
defer stream.Close()
var params []byte var params []byte
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetParameter: func(ctx *ServerHandlerOnSetParameterCtx) (*base.Response, error) { onSetParameter: func(ctx *ServerHandlerOnSetParameterCtx) (*base.Response, error) {
if ca == "inside session" {
require.NotNil(t, ctx.Session)
}
params = ctx.Request.Body params = ctx.Request.Body
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onGetParameter: func(ctx *ServerHandlerOnGetParameterCtx) (*base.Response, error) { onGetParameter: func(ctx *ServerHandlerOnGetParameterCtx) (*base.Response, error) {
if ca == "inside session" {
require.NotNil(t, ctx.Session)
}
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Body: params, Body: params,
@@ -586,38 +608,69 @@ func TestServerGetSetParameter(t *testing.T) {
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var sx headers.Session
if ca == "inside session" {
res, err := writeReqReadRes(conn, br, base.Request{ res, err := writeReqReadRes(conn, br, base.Request{
Method: base.Options, Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream"), URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"1"}, "CSeq": base.HeaderValue{"1"},
"Transport": headers.Transport{
Protocol: headers.TransportProtocolTCP,
Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
InterleavedIDs: &[2]int{0, 1},
}.Marshal(),
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
res, err = writeReqReadRes(conn, br, base.Request{ err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
}
headers := base.Header{
"CSeq": base.HeaderValue{"2"},
}
if ca == "inside session" {
headers["Session"] = base.HeaderValue{sx.Session}
}
res, err := writeReqReadRes(conn, br, base.Request{
Method: base.SetParameter, Method: base.SetParameter,
URL: mustParseURL("rtsp://localhost:8554/teststream"), URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: headers,
"CSeq": base.HeaderValue{"12"},
},
Body: []byte("param1: 123456\r\n"), Body: []byte("param1: 123456\r\n"),
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
headers = base.Header{
"CSeq": base.HeaderValue{"3"},
}
if ca == "inside session" {
headers["Session"] = base.HeaderValue{sx.Session}
}
res, err = writeReqReadRes(conn, br, base.Request{ res, err = writeReqReadRes(conn, br, base.Request{
Method: base.GetParameter, Method: base.GetParameter,
URL: mustParseURL("rtsp://localhost:8554/teststream"), URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: headers,
"CSeq": base.HeaderValue{"3"},
},
Body: []byte("param1\r\n"), Body: []byte("param1\r\n"),
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
require.Equal(t, []byte("param1: 123456\r\n"), res.Body) require.Equal(t, []byte("param1: 123456\r\n"), res.Body)
})
}
} }
func TestServerErrorInvalidSession(t *testing.T) { func TestServerErrorInvalidSession(t *testing.T) {
@@ -799,16 +852,8 @@ func TestServerSessionAutoClose(t *testing.T) {
} }
func TestServerErrorInvalidPath(t *testing.T) { func TestServerErrorInvalidPath(t *testing.T) {
for _, method := range []base.Method{ for _, ca := range []string{"inside session", "outside session"} {
base.Describe, t.Run(ca, func(t *testing.T) {
base.Announce,
base.Play,
base.Record,
base.Pause,
// base.GetParameter,
// base.SetParameter,
} {
t.Run(string(method), func(t *testing.T) {
connClosed := make(chan struct{}) connClosed := make(chan struct{})
track := &TrackH264{ track := &TrackH264{
@@ -826,21 +871,11 @@ func TestServerErrorInvalidPath(t *testing.T) {
require.EqualError(t, ctx.Error, "invalid path") require.EqualError(t, ctx.Error, "invalid path")
close(connClosed) close(connClosed)
}, },
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, stream, nil }, stream, nil
}, },
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
}, },
RTSPAddress: "localhost:8554", RTSPAddress: "localhost:8554",
} }
@@ -854,38 +889,12 @@ func TestServerErrorInvalidPath(t *testing.T) {
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
sxID := "" if ca == "inside session" {
if method == base.Record {
track := &TrackH264{
PayloadType: 96,
SPS: []byte{0x01, 0x02, 0x03, 0x04},
PPS: []byte{0x01, 0x02, 0x03, 0x04},
}
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Body: tracks.Marshal(false),
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}
if method == base.Play || method == base.Record || method == base.Pause {
res, err := writeReqReadRes(conn, br, base.Request{ res, err := writeReqReadRes(conn, br, base.Request{
Method: base.Setup, Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"1"},
"Session": base.HeaderValue{sxID},
"Transport": headers.Transport{ "Transport": headers.Transport{
Protocol: headers.TransportProtocolTCP, Protocol: headers.TransportProtocolTCP,
Delivery: func() *headers.TransportDelivery { Delivery: func() *headers.TransportDelivery {
@@ -893,12 +902,8 @@ func TestServerErrorInvalidPath(t *testing.T) {
return &v return &v
}(), }(),
Mode: func() *headers.TransportMode { Mode: func() *headers.TransportMode {
if method == base.Play || method == base.Pause {
v := headers.TransportModePlay v := headers.TransportModePlay
return &v return &v
}
v := headers.TransportModeRecord
return &v
}(), }(),
InterleavedIDs: &[2]int{0, 1}, InterleavedIDs: &[2]int{0, 1},
}.Marshal(), }.Marshal(),
@@ -911,32 +916,27 @@ func TestServerErrorInvalidPath(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"]) err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err) require.NoError(t, err)
sxID = sx.Session res, err = writeReqReadRes(conn, br, base.Request{
} Method: base.SetParameter,
if method == base.Pause {
res, err := writeReqReadRes(conn, br, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream/"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sxID},
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}
res, err := writeReqReadRes(conn, br, base.Request{
Method: method,
URL: mustParseURL("rtsp://localhost:8554"), URL: mustParseURL("rtsp://localhost:8554"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"3"}, "CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sxID}, "Session": base.HeaderValue{sx.Session},
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode) require.Equal(t, base.StatusBadRequest, res.StatusCode)
} else {
res, err := writeReqReadRes(conn, br, base.Request{
Method: base.SetParameter,
URL: mustParseURL("rtsp://localhost:8554"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
},
})
require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}
<-connClosed <-connClosed
}) })

View File

@@ -350,13 +350,26 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
sxID := getSessionID(req.Header) sxID := getSessionID(req.Header)
var path string
var query string
switch req.Method {
case base.Describe, base.GetParameter, base.SetParameter:
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
path, query = url.PathSplitQuery(pathAndQuery)
}
switch req.Method { switch req.Method {
case base.Options: case base.Options:
if sxID != "" { if sxID != "" {
return sc.handleRequestInSession(sxID, req, false) return sc.handleRequestInSession(sxID, req, false)
} }
// handle request here
var methods []string var methods []string
if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok { if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
methods = append(methods, string(base.Describe)) methods = append(methods, string(base.Describe))
@@ -391,15 +404,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Describe: case base.Describe:
if h, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok { if h, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
path, query := url.PathSplitQuery(pathAndQuery)
res, stream, err := h.OnDescribe(&ServerHandlerOnDescribeCtx{ res, stream, err := h.OnDescribe(&ServerHandlerOnDescribeCtx{
Conn: sc, Conn: sc,
Request: req, Request: req,
@@ -476,17 +480,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
return sc.handleRequestInSession(sxID, req, false) return sc.handleRequestInSession(sxID, req, false)
} }
// handle request here
if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok { if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
path, query := url.PathSplitQuery(pathAndQuery)
return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{ return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{
Conn: sc, Conn: sc,
Request: req, Request: req,
@@ -496,16 +490,11 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
} }
case base.SetParameter: case base.SetParameter:
if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok { if sxID != "" {
pathAndQuery, ok := req.URL.RTSPPathAndQuery() return sc.handleRequestInSession(sxID, req, false)
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
} }
path, query := url.PathSplitQuery(pathAndQuery) if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
return h.OnSetParameter(&ServerHandlerOnSetParameterCtx{ return h.OnSetParameter(&ServerHandlerOnSetParameterCtx{
Conn: sc, Conn: sc,
Request: req, Request: req,

View File

@@ -173,6 +173,7 @@ type ServerHandlerOnGetParameter interface {
// ServerHandlerOnSetParameterCtx is the context of a SET_PARAMETER request. // ServerHandlerOnSetParameterCtx is the context of a SET_PARAMETER request.
type ServerHandlerOnSetParameterCtx struct { type ServerHandlerOnSetParameterCtx struct {
Session *ServerSession
Conn *ServerConn Conn *ServerConn
Request *base.Request Request *base.Request
Path string Path string

View File

@@ -437,6 +437,25 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, liberrors.ErrServerSessionLinkedToOtherConn{} }, liberrors.ErrServerSessionLinkedToOtherConn{}
} }
var path string
var query string
switch req.Method {
case base.Announce, base.Play, base.Record, base.Pause, base.GetParameter, base.SetParameter:
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
if req.Method != base.Announce {
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
}
path, query = url.PathSplitQuery(pathAndQuery)
}
switch req.Method { switch req.Method {
case base.Options: case base.Options:
var methods []string var methods []string
@@ -481,15 +500,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, err }, err
} }
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
path, query := url.PathSplitQuery(pathAndQuery)
ct, ok := req.Header["Content-Type"] ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 { if !ok || len(ct) != 1 {
return &base.Response{ return &base.Response{
@@ -652,7 +662,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay { if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode} }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: *inTH.Mode}
} }
default: // record default: // record
@@ -665,7 +675,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord { if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode} }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: *inTH.Mode}
} }
} }
@@ -802,18 +812,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, err }, err
} }
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := url.PathSplitQuery(pathAndQuery)
if ss.State() == ServerSessionStatePrePlay && if ss.State() == ServerSessionStatePrePlay &&
path != *ss.setuppedPath { path != *ss.setuppedPath {
return &base.Response{ return &base.Response{
@@ -934,18 +932,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, liberrors.ErrServerNotAllAnnouncedTracksSetup{} }, liberrors.ErrServerNotAllAnnouncedTracksSetup{}
} }
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := url.PathSplitQuery(pathAndQuery)
if path != *ss.setuppedPath { if path != *ss.setuppedPath {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
@@ -1033,18 +1019,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, err }, err
} }
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := url.PathSplitQuery(pathAndQuery)
res, err := ss.s.Handler.(ServerHandlerOnPause).OnPause(&ServerHandlerOnPauseCtx{ res, err := ss.s.Handler.(ServerHandlerOnPause).OnPause(&ServerHandlerOnPauseCtx{
Session: ss, Session: ss,
Conn: sc, Conn: sc,
@@ -1127,15 +1101,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
case base.GetParameter: case base.GetParameter:
if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok { if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
path, query := url.PathSplitQuery(pathAndQuery)
return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{ return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{
Session: ss, Session: ss,
Conn: sc, Conn: sc,
@@ -1154,6 +1119,17 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, },
Body: []byte{}, Body: []byte{},
}, nil }, nil
case base.SetParameter:
if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
return h.OnSetParameter(&ServerHandlerOnSetParameterCtx{
Session: ss,
Conn: sc,
Request: req,
Path: path,
Query: query,
})
}
} }
return &base.Response{ return &base.Response{