server: test invalid paths

This commit is contained in:
aler9
2021-05-08 16:01:06 +02:00
parent 028ed2b973
commit 11a5fb68ad
4 changed files with 174 additions and 24 deletions

View File

@@ -77,12 +77,12 @@ func (e ErrServerWrongState) Error() string {
e.AllowedList, e.State)
}
// ErrServerNoPath is an error that can be returned by a server.
type ErrServerNoPath struct{}
// ErrServerInvalidPath is an error that can be returned by a server.
type ErrServerInvalidPath struct{}
// Error implements the error interface.
func (e ErrServerNoPath) Error() string {
return "RTSP path can't be retrieved"
func (e ErrServerInvalidPath) Error() string {
return "invalid path"
}
// ErrServerContentTypeMissing is an error that can be returned by a server.

View File

@@ -5,10 +5,12 @@ import (
"crypto/tls"
"fmt"
"net"
"strconv"
"sync"
"testing"
"time"
psdp "github.com/pion/sdp/v3"
"github.com/stretchr/testify/require"
"github.com/aler9/gortsplib/pkg/base"
@@ -927,6 +929,10 @@ func TestServerSessionClose(t *testing.T) {
}.Write(bconn.Writer)
require.NoError(t, err)
res, err := readResponse(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
<-sessionClosed
}
@@ -975,7 +981,151 @@ func TestServerSessionAutoClose(t *testing.T) {
}.Write(bconn.Writer)
require.NoError(t, err)
res, err := readResponse(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
conn.Close()
<-sessionClosed
}
func TestServerErrorInvalidPath(t *testing.T) {
for _, method := range []base.Method{
base.Describe,
base.Announce,
base.Play,
base.Record,
base.Pause,
//base.GetParameter,
//base.SetParameter,
} {
t.Run(string(method), func(t *testing.T) {
s := &Server{
Handler: &testServerHandler{
onConnClose: func(sc *ServerConn, err error) {
require.Equal(t, "invalid path", err.Error())
},
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
}
err := s.Start("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
sxID := ""
if method == base.Record {
track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
require.NoError(t, err)
tracks := Tracks{track}
for i, t := range tracks {
t.Media.Attributes = append(t.Media.Attributes, psdp.Attribute{
Key: "control",
Value: "trackID=" + strconv.FormatInt(int64(i), 10),
})
}
err = base.Request{
Method: base.Announce,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Body: tracks.Write(),
}.Write(bconn.Writer)
require.NoError(t, err)
res, err := readResponse(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
sxID = res.Header["Session"][0]
}
if method == base.Play || method == base.Record || method == base.Pause {
err = base.Request{
Method: base.Setup,
URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sxID},
"Transport": headers.Transport{
Protocol: StreamProtocolTCP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
if method == base.Play || method == base.Pause {
v := headers.TransportModePlay
return &v
}
v := headers.TransportModeRecord
return &v
}(),
InterleavedIDs: &[2]int{0, 1},
}.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
res, err := readResponse(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
sxID = res.Header["Session"][0]
}
if method == base.Pause {
err = base.Request{
Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream/"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sxID},
},
}.Write(bconn.Writer)
require.NoError(t, err)
res, err := readResponse(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}
err = base.Request{
Method: method,
URL: base.MustParseURL("rtsp://localhost:8554"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": base.HeaderValue{sxID},
},
}.Write(bconn.Writer)
require.NoError(t, err)
res, err := readResponse(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode)
})
}
}

View File

@@ -245,14 +245,14 @@ func (sc *ServerConn) run() {
}
}()
sc.nconn.Close()
<-readDone
if sc.tcpFrameEnabled {
sc.tcpFrameWriteBuffer.Close()
<-sc.tcpFrameBackgroundWriteDone
}
sc.nconn.Close()
<-readDone
for _, ss := range sc.sessions {
ss.connRemove <- sc
}
@@ -336,7 +336,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}, liberrors.ErrServerInvalidPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
@@ -430,7 +430,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}, liberrors.ErrServerInvalidPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
@@ -449,7 +449,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}, liberrors.ErrServerInvalidPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)

View File

@@ -26,7 +26,7 @@ func setupGetTrackIDPathQuery(url *base.URL,
pathAndQuery, ok := url.RTSPPathAndQuery()
if !ok {
return 0, "", "", liberrors.ErrServerNoPath{}
return 0, "", "", liberrors.ErrServerInvalidPath{}
}
if thMode == nil || *thMode == headers.TransportModePlay {
@@ -417,6 +417,15 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, err
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
@@ -443,15 +452,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, liberrors.ErrServerSDPNoTracksDefined{}
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
for _, track := range tracks {
trackURL, err := track.URL()
if err != nil {
@@ -681,7 +681,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}, liberrors.ErrServerInvalidPath{}
}
// path can end with a slash due to Content-Base, remove it
@@ -750,7 +750,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}, liberrors.ErrServerInvalidPath{}
}
// path can end with a slash due to Content-Base, remove it
@@ -818,7 +818,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}, liberrors.ErrServerInvalidPath{}
}
// path can end with a slash due to Content-Base, remove it
@@ -875,7 +875,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}, liberrors.ErrServerInvalidPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)