mirror of
https://github.com/aler9/gortsplib
synced 2025-10-29 01:33:00 +08:00
server: fix panic when recording with wrong transport header (https://github.com/bluenviron/mediamtx/issues/3677) (#604)
This commit is contained in:
@@ -47,6 +47,14 @@ const (
|
|||||||
TransportProtocolTCP
|
TransportProtocolTCP
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// String implements fmt.Stringer.
|
||||||
|
func (p TransportProtocol) String() string {
|
||||||
|
if p == TransportProtocolUDP {
|
||||||
|
return "RTP/AVP"
|
||||||
|
}
|
||||||
|
return "RTP/AVP/TCP"
|
||||||
|
}
|
||||||
|
|
||||||
// TransportDelivery is a delivery method.
|
// TransportDelivery is a delivery method.
|
||||||
type TransportDelivery int
|
type TransportDelivery int
|
||||||
|
|
||||||
@@ -56,6 +64,14 @@ const (
|
|||||||
TransportDeliveryMulticast
|
TransportDeliveryMulticast
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// String implements fmt.Stringer.
|
||||||
|
func (d TransportDelivery) String() string {
|
||||||
|
if d == TransportDeliveryUnicast {
|
||||||
|
return "unicast"
|
||||||
|
}
|
||||||
|
return "multicast"
|
||||||
|
}
|
||||||
|
|
||||||
// TransportMode is a transport mode.
|
// TransportMode is a transport mode.
|
||||||
type TransportMode int
|
type TransportMode int
|
||||||
|
|
||||||
@@ -67,6 +83,33 @@ const (
|
|||||||
TransportModeRecord
|
TransportModeRecord
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (m *TransportMode) unmarshal(v string) error {
|
||||||
|
str := strings.ToLower(v)
|
||||||
|
|
||||||
|
switch str {
|
||||||
|
case "play":
|
||||||
|
*m = TransportModePlay
|
||||||
|
return nil
|
||||||
|
|
||||||
|
// receive is an old alias for record, used by ffmpeg with the
|
||||||
|
// -listen flag, and by Darwin Streaming Server
|
||||||
|
case "record", "receive":
|
||||||
|
*m = TransportModeRecord
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid transport mode: '%s'", str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements fmt.Stringer.
|
||||||
|
func (m TransportMode) String() string {
|
||||||
|
if m == TransportModePlay {
|
||||||
|
return "play"
|
||||||
|
}
|
||||||
|
return "record"
|
||||||
|
}
|
||||||
|
|
||||||
// Transport is a Transport header.
|
// Transport is a Transport header.
|
||||||
type Transport struct {
|
type Transport struct {
|
||||||
// protocol of the stream
|
// protocol of the stream
|
||||||
@@ -218,24 +261,12 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "mode":
|
case "mode":
|
||||||
str := strings.ToLower(v)
|
var m TransportMode
|
||||||
str = strings.TrimPrefix(str, "\"")
|
err = m.unmarshal(v)
|
||||||
str = strings.TrimSuffix(str, "\"")
|
if err != nil {
|
||||||
|
return err
|
||||||
switch str {
|
|
||||||
case "play":
|
|
||||||
v := TransportModePlay
|
|
||||||
h.Mode = &v
|
|
||||||
|
|
||||||
// receive is an old alias for record, used by ffmpeg with the
|
|
||||||
// -listen flag, and by Darwin Streaming Server
|
|
||||||
case "record", "receive":
|
|
||||||
v := TransportModeRecord
|
|
||||||
h.Mode = &v
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("invalid transport mode: '%s'", str)
|
|
||||||
}
|
}
|
||||||
|
h.Mode = &m
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// ignore non-standard keys
|
// ignore non-standard keys
|
||||||
@@ -253,18 +284,10 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error {
|
|||||||
func (h Transport) Marshal() base.HeaderValue {
|
func (h Transport) Marshal() base.HeaderValue {
|
||||||
var rets []string
|
var rets []string
|
||||||
|
|
||||||
if h.Protocol == TransportProtocolUDP {
|
rets = append(rets, h.Protocol.String())
|
||||||
rets = append(rets, "RTP/AVP")
|
|
||||||
} else {
|
|
||||||
rets = append(rets, "RTP/AVP/TCP")
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.Delivery != nil {
|
if h.Delivery != nil {
|
||||||
if *h.Delivery == TransportDeliveryUnicast {
|
rets = append(rets, h.Delivery.String())
|
||||||
rets = append(rets, "unicast")
|
|
||||||
} else {
|
|
||||||
rets = append(rets, "multicast")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.Source != nil {
|
if h.Source != nil {
|
||||||
@@ -309,11 +332,7 @@ func (h Transport) Marshal() base.HeaderValue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if h.Mode != nil {
|
if h.Mode != nil {
|
||||||
if *h.Mode == TransportModePlay {
|
rets = append(rets, "mode="+h.Mode.String())
|
||||||
rets = append(rets, "mode=play")
|
|
||||||
} else {
|
|
||||||
rets = append(rets, "mode=record")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return base.HeaderValue{strings.Join(rets, ";")}
|
return base.HeaderValue{strings.Join(rets, ";")}
|
||||||
|
|||||||
@@ -84,12 +84,16 @@ func (e ErrServerMediaNotFound) Error() string {
|
|||||||
|
|
||||||
// ErrServerTransportHeaderInvalidMode is an error that can be returned by a server.
|
// ErrServerTransportHeaderInvalidMode is an error that can be returned by a server.
|
||||||
type ErrServerTransportHeaderInvalidMode struct {
|
type ErrServerTransportHeaderInvalidMode struct {
|
||||||
Mode headers.TransportMode
|
Mode *headers.TransportMode
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error implements the error interface.
|
// Error implements the error interface.
|
||||||
func (e ErrServerTransportHeaderInvalidMode) Error() string {
|
func (e ErrServerTransportHeaderInvalidMode) Error() string {
|
||||||
return fmt.Sprintf("transport header contains a invalid mode (%v)", e.Mode)
|
m := "null"
|
||||||
|
if e.Mode != nil {
|
||||||
|
m = e.Mode.String()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("transport header contains a invalid mode (%v)", m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrServerTransportHeaderNoClientPorts is an error that can be returned by a server.
|
// ErrServerTransportHeaderNoClientPorts is an error that can be returned by a server.
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ func TestServerRecordErrorAnnounce(t *testing.T) {
|
|||||||
"unsupported Content-Type header '[aa]'",
|
"unsupported Content-Type header '[aa]'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"invalid medias",
|
"invalid sdp",
|
||||||
base.Request{
|
base.Request{
|
||||||
Method: base.Announce,
|
Method: base.Announce,
|
||||||
URL: mustParseURL("rtsp://localhost:8554/teststream"),
|
URL: mustParseURL("rtsp://localhost:8554/teststream"),
|
||||||
@@ -122,6 +122,29 @@ func TestServerRecordErrorAnnounce(t *testing.T) {
|
|||||||
},
|
},
|
||||||
"invalid SDP: invalid line: (\x01\x02\x03\x04)",
|
"invalid SDP: invalid line: (\x01\x02\x03\x04)",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"invalid session",
|
||||||
|
base.Request{
|
||||||
|
Method: base.Announce,
|
||||||
|
URL: mustParseURL("rtsp://localhost:8554/teststream"),
|
||||||
|
Header: base.Header{
|
||||||
|
"CSeq": base.HeaderValue{"1"},
|
||||||
|
"Content-Type": base.HeaderValue{"application/sdp"},
|
||||||
|
},
|
||||||
|
Body: []byte("v=0\r\n" +
|
||||||
|
"o=- 0 0 IN IP4 127.0.0.1\r\n" +
|
||||||
|
"s=-\r\n" +
|
||||||
|
"c=IN IP4 0.0.0.0\r\n" +
|
||||||
|
"t=0 0\r\n" +
|
||||||
|
"m=video 0 RTP/AVP 96\r\n" +
|
||||||
|
"a=control\r\n" +
|
||||||
|
"a=rtpmap:97 H264/90000\r\n" +
|
||||||
|
"a=fmtp:aa packetization-mode=1; profile-level-id=4D002A; " +
|
||||||
|
"sprop-parameter-sets=Z00AKp2oHgCJ+WbgICAgQA==,aO48gA==\r\n",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"invalid SDP: media 1 is invalid: clock rate not found",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"invalid URL 1",
|
"invalid URL 1",
|
||||||
invalidURLAnnounceReq(t, "rtsp:// aaaaa"),
|
invalidURLAnnounceReq(t, "rtsp:// aaaaa"),
|
||||||
@@ -168,6 +191,87 @@ func TestServerRecordErrorAnnounce(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerRecordErrorSetup(t *testing.T) {
|
||||||
|
for _, ca := range []struct {
|
||||||
|
name string
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"invalid transport",
|
||||||
|
"transport header contains a invalid mode (null)",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(ca.name, func(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Handler: &testServerHandler{
|
||||||
|
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
|
||||||
|
require.EqualError(t, ctx.Error, ca.err)
|
||||||
|
},
|
||||||
|
onAnnounce: func(_ *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
|
||||||
|
return &base.Response{
|
||||||
|
StatusCode: base.StatusOK,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
|
||||||
|
return &base.Response{
|
||||||
|
StatusCode: base.StatusOK,
|
||||||
|
}, nil, nil
|
||||||
|
},
|
||||||
|
onRecord: func(_ *ServerHandlerOnRecordCtx) (*base.Response, error) {
|
||||||
|
return &base.Response{
|
||||||
|
StatusCode: base.StatusOK,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) {
|
||||||
|
return &base.Response{
|
||||||
|
StatusCode: base.StatusOK,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
RTSPAddress: "localhost:8554",
|
||||||
|
UDPRTPAddress: "127.0.0.1:8000",
|
||||||
|
UDPRTCPAddress: "127.0.0.1:8001",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
nconn, err := net.Dial("tcp", "localhost:8554")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer nconn.Close()
|
||||||
|
conn := conn.NewConn(nconn)
|
||||||
|
|
||||||
|
medias := []*description.Media{testH264Media}
|
||||||
|
|
||||||
|
doAnnounce(t, conn, "rtsp://localhost:8554/teststream", medias)
|
||||||
|
|
||||||
|
var inTH *headers.Transport
|
||||||
|
|
||||||
|
switch ca.name {
|
||||||
|
case "invalid transport":
|
||||||
|
inTH = &headers.Transport{
|
||||||
|
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
|
||||||
|
Mode: nil,
|
||||||
|
Protocol: headers.TransportProtocolUDP,
|
||||||
|
ClientPorts: &[2]int{35466, 35467},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := writeReqReadRes(conn, base.Request{
|
||||||
|
Method: base.Setup,
|
||||||
|
URL: mustParseURL("rtsp://localhost:8554/teststream/" + medias[0].Control),
|
||||||
|
Header: base.Header{
|
||||||
|
"CSeq": base.HeaderValue{"1"},
|
||||||
|
"Transport": inTH.Marshal(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEqual(t, base.StatusOK, res.StatusCode)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServerRecordPath(t *testing.T) {
|
func TestServerRecordPath(t *testing.T) {
|
||||||
for _, ca := range []struct {
|
for _, ca := range []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -728,7 +728,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
|||||||
if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay {
|
if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay {
|
||||||
return &base.Response{
|
return &base.Response{
|
||||||
StatusCode: base.StatusBadRequest,
|
StatusCode: base.StatusBadRequest,
|
||||||
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: *inTH.Mode}
|
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
|
||||||
}
|
}
|
||||||
|
|
||||||
default: // record
|
default: // record
|
||||||
@@ -741,7 +741,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
|||||||
if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord {
|
if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord {
|
||||||
return &base.Response{
|
return &base.Response{
|
||||||
StatusCode: base.StatusBadRequest,
|
StatusCode: base.StatusBadRequest,
|
||||||
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: *inTH.Mode}
|
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user