server: implement sessions

This commit is contained in:
aler9
2021-05-02 19:12:51 +02:00
committed by Alessandro Ros
parent 712432bcef
commit 259043685d
14 changed files with 1333 additions and 1015 deletions

View File

@@ -17,8 +17,8 @@ import (
type serverHandler struct { type serverHandler struct {
mutex sync.Mutex mutex sync.Mutex
publisher *gortsplib.ServerConn publisher *gortsplib.ServerSession
readers map[*gortsplib.ServerConn]struct{} readers map[*gortsplib.ServerSession]struct{}
sdp []byte sdp []byte
} }
@@ -30,15 +30,18 @@ func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) {
// called when a connection is closed. // called when a connection is closed.
func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) {
log.Println("conn closed (%v)", err) log.Println("conn closed (%v)", err)
}
// called when a session is closed.
func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession) {
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
if sc == sh.publisher { if ss == sh.publisher {
sh.publisher = nil sh.publisher = nil
sh.sdp = nil sh.sdp = nil
} else { } else {
delete(sh.readers, sc) delete(sh.readers, ss)
} }
} }
@@ -70,14 +73,11 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
}, fmt.Errorf("someone is already publishing") }, fmt.Errorf("someone is already publishing")
} }
sh.publisher = ctx.Conn sh.publisher = ctx.Session
sh.sdp = ctx.Tracks.Write() sh.sdp = ctx.Tracks.Write()
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil }, nil
} }
@@ -85,9 +85,6 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) { func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil }, nil
} }
@@ -96,13 +93,10 @@ func (sh *serverHandler) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Re
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
sh.readers[ctx.Conn] = struct{}{} sh.readers[ctx.Session] = struct{}{}
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil }, nil
} }
@@ -111,7 +105,7 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
if ctx.Conn != sh.publisher { if ctx.Session != sh.publisher {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing") }, fmt.Errorf("someone is already publishing")
@@ -119,9 +113,6 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil }, nil
} }
@@ -131,7 +122,7 @@ func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) {
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
// if we are the publisher, route frames to readers // if we are the publisher, route frames to readers
if ctx.Conn == sh.publisher { if ctx.Session == sh.publisher {
for r := range sh.readers { for r := range sh.readers {
r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload)
} }
@@ -149,7 +140,7 @@ func main() {
// configure server // configure server
s := &gortsplib.Server{ s := &gortsplib.Server{
Handler: &serverHandler{}, Handler: &serverHandler{},
TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}}, TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}},
} }

View File

@@ -16,8 +16,8 @@ import (
type serverHandler struct { type serverHandler struct {
mutex sync.Mutex mutex sync.Mutex
publisher *gortsplib.ServerConn publisher *gortsplib.ServerSession
readers map[*gortsplib.ServerConn]struct{} readers map[*gortsplib.ServerSession]struct{}
sdp []byte sdp []byte
} }
@@ -29,15 +29,18 @@ func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) {
// called when a connection is closed. // called when a connection is closed.
func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) {
log.Println("conn closed (%v)", err) log.Println("conn closed (%v)", err)
}
// called when a session is closed.
func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession) {
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
if sc == sh.publisher { if ss == sh.publisher {
sh.publisher = nil sh.publisher = nil
sh.sdp = nil sh.sdp = nil
} else { } else {
delete(sh.readers, sc) delete(sh.readers, ss)
} }
} }
@@ -69,14 +72,11 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
}, fmt.Errorf("someone is already publishing") }, fmt.Errorf("someone is already publishing")
} }
sh.publisher = ctx.Conn sh.publisher = ctx.Session
sh.sdp = ctx.Tracks.Write() sh.sdp = ctx.Tracks.Write()
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil }, nil
} }
@@ -84,9 +84,6 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) { func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil }, nil
} }
@@ -95,13 +92,10 @@ func (sh *serverHandler) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Re
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
sh.readers[ctx.Conn] = struct{}{} sh.readers[ctx.Session] = struct{}{}
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil }, nil
} }
@@ -110,7 +104,7 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
if ctx.Conn != sh.publisher { if ctx.Session != sh.publisher {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing") }, fmt.Errorf("someone is already publishing")
@@ -118,9 +112,6 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil }, nil
} }
@@ -130,7 +121,7 @@ func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) {
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
// if we are the publisher, route frames to readers // if we are the publisher, route frames to readers
if ctx.Conn == sh.publisher { if ctx.Session == sh.publisher {
for r := range sh.readers { for r := range sh.readers {
r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload)
} }
@@ -140,7 +131,7 @@ func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) {
func main() { func main() {
// configure server // configure server
s := &gortsplib.Server{ s := &gortsplib.Server{
Handler: &serverHandler{}, Handler: &serverHandler{},
UDPRTPAddress: ":8000", UDPRTPAddress: ":8000",
UDPRTCPAddress: ":8001", UDPRTCPAddress: ":8001",
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
) )
// ErrClientWrongState is returned in case of a wrong client state. // ErrClientWrongState is an error that can be returned by a client.
type ErrClientWrongState struct { type ErrClientWrongState struct {
AllowedList []fmt.Stringer AllowedList []fmt.Stringer
State fmt.Stringer State fmt.Stringer
@@ -18,7 +18,7 @@ func (e ErrClientWrongState) Error() string {
e.AllowedList, e.State) e.AllowedList, e.State)
} }
// ErrClientSessionHeaderInvalid is returned in case of an invalid session header. // ErrClientSessionHeaderInvalid is an error that can be returned by a client.
type ErrClientSessionHeaderInvalid struct { type ErrClientSessionHeaderInvalid struct {
Err error Err error
} }
@@ -28,7 +28,7 @@ func (e ErrClientSessionHeaderInvalid) Error() string {
return fmt.Sprintf("invalid session header: %v", e.Err) return fmt.Sprintf("invalid session header: %v", e.Err)
} }
// ErrClientWrongStatusCode is returned in case of a wrong status code. // ErrClientWrongStatusCode is an error that can be returned by a client.
type ErrClientWrongStatusCode struct { type ErrClientWrongStatusCode struct {
Code base.StatusCode Code base.StatusCode
Message string Message string
@@ -39,7 +39,7 @@ func (e ErrClientWrongStatusCode) Error() string {
return fmt.Sprintf("wrong status code: %d (%s)", e.Code, e.Message) return fmt.Sprintf("wrong status code: %d (%s)", e.Code, e.Message)
} }
// ErrClientContentTypeMissing is returned in case the Content-Type header is missing. // ErrClientContentTypeMissing is an error that can be returned by a client.
type ErrClientContentTypeMissing struct{} type ErrClientContentTypeMissing struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -47,7 +47,7 @@ func (e ErrClientContentTypeMissing) Error() string {
return "Content-Type header is missing" return "Content-Type header is missing"
} }
// ErrClientContentTypeUnsupported is returned in case the Content-Type header is unsupported. // ErrClientContentTypeUnsupported is an error that can be returned by a client.
type ErrClientContentTypeUnsupported struct { type ErrClientContentTypeUnsupported struct {
CT base.HeaderValue CT base.HeaderValue
} }
@@ -57,7 +57,7 @@ func (e ErrClientContentTypeUnsupported) Error() string {
return fmt.Sprintf("unsupported Content-Type header '%v'", e.CT) return fmt.Sprintf("unsupported Content-Type header '%v'", e.CT)
} }
// ErrClientCannotReadPublishAtSameTime is returned when the client is trying to read and publish at the same time. // ErrClientCannotReadPublishAtSameTime is an error that can be returned by a client.
type ErrClientCannotReadPublishAtSameTime struct{} type ErrClientCannotReadPublishAtSameTime struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -65,7 +65,7 @@ func (e ErrClientCannotReadPublishAtSameTime) Error() string {
return "cannot read and publish at the same time" return "cannot read and publish at the same time"
} }
// ErrClientCannotSetupTracksDifferentURLs is returned when the client is trying to setup tracks with different base URLs. // ErrClientCannotSetupTracksDifferentURLs is an error that can be returned by a client.
type ErrClientCannotSetupTracksDifferentURLs struct{} type ErrClientCannotSetupTracksDifferentURLs struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -73,7 +73,7 @@ func (e ErrClientCannotSetupTracksDifferentURLs) Error() string {
return "cannot setup tracks with different base URLs" return "cannot setup tracks with different base URLs"
} }
// ErrClientUDPPortsZero is returned when one of the UDP ports is zero. // ErrClientUDPPortsZero is an error that can be returned by a client.
type ErrClientUDPPortsZero struct{} type ErrClientUDPPortsZero struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -81,7 +81,7 @@ func (e ErrClientUDPPortsZero) Error() string {
return "rtpPort and rtcpPort must be both zero or non-zero" return "rtpPort and rtcpPort must be both zero or non-zero"
} }
// ErrClientUDPPortsNotConsecutive is returned when the two UDP ports are not consecutive. // ErrClientUDPPortsNotConsecutive is an error that can be returned by a client.
type ErrClientUDPPortsNotConsecutive struct{} type ErrClientUDPPortsNotConsecutive struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -89,7 +89,7 @@ func (e ErrClientUDPPortsNotConsecutive) Error() string {
return "rtcpPort must be rtpPort + 1" return "rtcpPort must be rtpPort + 1"
} }
// ErrClientServerPortsZero is returned when one of the server ports is zero. // ErrClientServerPortsZero is an error that can be returned by a client.
type ErrClientServerPortsZero struct{} type ErrClientServerPortsZero struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -97,7 +97,7 @@ func (e ErrClientServerPortsZero) Error() string {
return "server ports must be both zero or both not zero" return "server ports must be both zero or both not zero"
} }
// ErrClientServerPortsNotProvided is returned in case the server ports have not been provided. // ErrClientServerPortsNotProvided is an error that can be returned by a client.
type ErrClientServerPortsNotProvided struct{} type ErrClientServerPortsNotProvided struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -105,7 +105,7 @@ func (e ErrClientServerPortsNotProvided) Error() string {
return "server ports have not been provided. Use AnyPortEnable to communicate with this server" return "server ports have not been provided. Use AnyPortEnable to communicate with this server"
} }
// ErrClientTransportHeaderInvalid is returned in case the transport header is invalid. // ErrClientTransportHeaderInvalid is an error that can be returned by a client.
type ErrClientTransportHeaderInvalid struct { type ErrClientTransportHeaderInvalid struct {
Err error Err error
} }
@@ -115,7 +115,7 @@ func (e ErrClientTransportHeaderInvalid) Error() string {
return fmt.Sprintf("invalid transport header: %v", e.Err) return fmt.Sprintf("invalid transport header: %v", e.Err)
} }
// ErrClientTransportHeaderNoInterleavedIDs is returned in case the transport header doesn't contain interleaved IDs. // ErrClientTransportHeaderNoInterleavedIDs is an error that can be returned by a client.
type ErrClientTransportHeaderNoInterleavedIDs struct{} type ErrClientTransportHeaderNoInterleavedIDs struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -123,7 +123,7 @@ func (e ErrClientTransportHeaderNoInterleavedIDs) Error() string {
return "transport header does not contain interleaved IDs" return "transport header does not contain interleaved IDs"
} }
// ErrClientTransportHeaderWrongInterleavedIDs is returned in case the transport header contains wrong interleaved IDs. // ErrClientTransportHeaderWrongInterleavedIDs is an error that can be returned by a client.
type ErrClientTransportHeaderWrongInterleavedIDs struct { type ErrClientTransportHeaderWrongInterleavedIDs struct {
Expected [2]int Expected [2]int
Value [2]int Value [2]int
@@ -134,7 +134,7 @@ func (e ErrClientTransportHeaderWrongInterleavedIDs) Error() string {
return fmt.Sprintf("wrong interleaved IDs, expected %v, got %v", e.Expected, e.Value) return fmt.Sprintf("wrong interleaved IDs, expected %v, got %v", e.Expected, e.Value)
} }
// ErrClientNoUDPPacketsRecently is returned when no UDP packets have been received recently. // ErrClientNoUDPPacketsRecently is an error that can be returned by a client.
type ErrClientNoUDPPacketsRecently struct{} type ErrClientNoUDPPacketsRecently struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -142,8 +142,7 @@ func (e ErrClientNoUDPPacketsRecently) Error() string {
return "no UDP packets received (maybe there's a firewall/NAT in between)" return "no UDP packets received (maybe there's a firewall/NAT in between)"
} }
// ErrClientUDPTimeout is returned when timeout has exceeded but UDP packets have been received previously // ErrClientUDPTimeout is an error that can be returned by a client.
// but now nothing is being received.
type ErrClientUDPTimeout struct{} type ErrClientUDPTimeout struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -151,7 +150,7 @@ func (e ErrClientUDPTimeout) Error() string {
return "UDP timeout" return "UDP timeout"
} }
// ErrClientTCPTimeout is returned when timeout has exceeded. // ErrClientTCPTimeout is an error that can be returned by a client.
type ErrClientTCPTimeout struct{} type ErrClientTCPTimeout struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -159,7 +158,7 @@ func (e ErrClientTCPTimeout) Error() string {
return "TCP timeout" return "TCP timeout"
} }
// ErrClientRTPInfoInvalid is returned in case of an invalid RTP-Info. // ErrClientRTPInfoInvalid is an error that can be returned by a client.
type ErrClientRTPInfoInvalid struct { type ErrClientRTPInfoInvalid struct {
Err error Err error
} }

View File

@@ -7,15 +7,23 @@ import (
"github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/headers"
) )
// ErrServerTeardown is returned in case of a teardown request. // ErrServerTCPFramesEnable is an error that can be returned by a server.
type ErrServerTeardown struct{} type ErrServerTCPFramesEnable struct{}
// Error implements the error interface. // Error implements the error interface.
func (e ErrServerTeardown) Error() string { func (e ErrServerTCPFramesEnable) Error() string {
return "teardown" return ""
} }
// ErrServerCSeqMissing is returned in case the CSeq is missing. // ErrServerTCPFramesDisable is an error that can be returned by a server.
type ErrServerTCPFramesDisable struct{}
// Error implements the error interface.
func (e ErrServerTCPFramesDisable) Error() string {
return ""
}
// ErrServerCSeqMissing is an error that can be returned by a server.
type ErrServerCSeqMissing struct{} type ErrServerCSeqMissing struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -23,7 +31,7 @@ func (e ErrServerCSeqMissing) Error() string {
return "CSeq is missing" return "CSeq is missing"
} }
// ErrServerWrongState is returned in case of a wrong client state. // ErrServerWrongState is an error that can be returned by a server.
type ErrServerWrongState struct { type ErrServerWrongState struct {
AllowedList []fmt.Stringer AllowedList []fmt.Stringer
State fmt.Stringer State fmt.Stringer
@@ -35,7 +43,7 @@ func (e ErrServerWrongState) Error() string {
e.AllowedList, e.State) e.AllowedList, e.State)
} }
// ErrServerNoPath is returned in case the path can't be retrieved. // ErrServerNoPath is an error that can be returned by a server.
type ErrServerNoPath struct{} type ErrServerNoPath struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -43,7 +51,7 @@ func (e ErrServerNoPath) Error() string {
return "RTSP path can't be retrieved" return "RTSP path can't be retrieved"
} }
// ErrServerContentTypeMissing is returned in case the Content-Type header is missing. // ErrServerContentTypeMissing is an error that can be returned by a server.
type ErrServerContentTypeMissing struct{} type ErrServerContentTypeMissing struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -51,7 +59,7 @@ func (e ErrServerContentTypeMissing) Error() string {
return "Content-Type header is missing" return "Content-Type header is missing"
} }
// ErrServerContentTypeUnsupported is returned in case the Content-Type header is unsupported. // ErrServerContentTypeUnsupported is an error that can be returned by a server.
type ErrServerContentTypeUnsupported struct { type ErrServerContentTypeUnsupported struct {
CT base.HeaderValue CT base.HeaderValue
} }
@@ -61,7 +69,7 @@ func (e ErrServerContentTypeUnsupported) Error() string {
return fmt.Sprintf("unsupported Content-Type header '%v'", e.CT) return fmt.Sprintf("unsupported Content-Type header '%v'", e.CT)
} }
// ErrServerSDPInvalid is returned in case the SDP is invalid. // ErrServerSDPInvalid is an error that can be returned by a server.
type ErrServerSDPInvalid struct { type ErrServerSDPInvalid struct {
Err error Err error
} }
@@ -71,7 +79,7 @@ func (e ErrServerSDPInvalid) Error() string {
return fmt.Sprintf("invalid SDP: %v", e.Err) return fmt.Sprintf("invalid SDP: %v", e.Err)
} }
// ErrServerSDPNoTracksDefined is returned in case the SDP has no tracks defined. // ErrServerSDPNoTracksDefined is an error that can be returned by a server.
type ErrServerSDPNoTracksDefined struct{} type ErrServerSDPNoTracksDefined struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -79,7 +87,7 @@ func (e ErrServerSDPNoTracksDefined) Error() string {
return "no tracks defined in the SDP" return "no tracks defined in the SDP"
} }
// ErrServerTransportHeaderInvalid is returned in case the transport header is invalid. // ErrServerTransportHeaderInvalid is an error that can be returned by a server.
type ErrServerTransportHeaderInvalid struct { type ErrServerTransportHeaderInvalid struct {
Err error Err error
} }
@@ -89,7 +97,7 @@ func (e ErrServerTransportHeaderInvalid) Error() string {
return fmt.Sprintf("invalid transport header: %v", e.Err) return fmt.Sprintf("invalid transport header: %v", e.Err)
} }
// ErrServerTrackAlreadySetup is returned in case a track has already been setup. // ErrServerTrackAlreadySetup is an error that can be returned by a server.
type ErrServerTrackAlreadySetup struct { type ErrServerTrackAlreadySetup struct {
TrackID int TrackID int
} }
@@ -99,7 +107,7 @@ func (e ErrServerTrackAlreadySetup) Error() string {
return fmt.Sprintf("track %d has already been setup", e.TrackID) return fmt.Sprintf("track %d has already been setup", e.TrackID)
} }
// ErrServerTransportHeaderWrongMode is returned in case the transport header contains a wrong mode. // ErrServerTransportHeaderWrongMode is an error that can be returned by a server.
type ErrServerTransportHeaderWrongMode struct { type ErrServerTransportHeaderWrongMode struct {
Mode *headers.TransportMode Mode *headers.TransportMode
} }
@@ -109,7 +117,7 @@ func (e ErrServerTransportHeaderWrongMode) Error() string {
return fmt.Sprintf("transport header contains a wrong mode (%v)", e.Mode) return fmt.Sprintf("transport header contains a wrong mode (%v)", e.Mode)
} }
// ErrServerTransportHeaderNoClientPorts is returned in case the transport header doesn't contain client ports. // ErrServerTransportHeaderNoClientPorts is an error that can be returned by a server.
type ErrServerTransportHeaderNoClientPorts struct{} type ErrServerTransportHeaderNoClientPorts struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -117,7 +125,7 @@ func (e ErrServerTransportHeaderNoClientPorts) Error() string {
return "transport header does not contain client ports" return "transport header does not contain client ports"
} }
// ErrServerTransportHeaderNoInterleavedIDs is returned in case the transport header doesn't contain interleaved IDs. // ErrServerTransportHeaderNoInterleavedIDs is an error that can be returned by a server.
type ErrServerTransportHeaderNoInterleavedIDs struct{} type ErrServerTransportHeaderNoInterleavedIDs struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -125,7 +133,7 @@ func (e ErrServerTransportHeaderNoInterleavedIDs) Error() string {
return "transport header does not contain interleaved IDs" return "transport header does not contain interleaved IDs"
} }
// ErrServerTransportHeaderWrongInterleavedIDs is returned in case the transport header contains wrong interleaved IDs. // ErrServerTransportHeaderWrongInterleavedIDs is an error that can be returned by a server.
type ErrServerTransportHeaderWrongInterleavedIDs struct { type ErrServerTransportHeaderWrongInterleavedIDs struct {
Expected [2]int Expected [2]int
Value [2]int Value [2]int
@@ -136,7 +144,7 @@ func (e ErrServerTransportHeaderWrongInterleavedIDs) Error() string {
return fmt.Sprintf("wrong interleaved IDs, expected %v, got %v", e.Expected, e.Value) return fmt.Sprintf("wrong interleaved IDs, expected %v, got %v", e.Expected, e.Value)
} }
// ErrServerTracksDifferentProtocols is returned in case the client is trying to setup tracks with different protocols. // ErrServerTracksDifferentProtocols is an error that can be returned by a server.
type ErrServerTracksDifferentProtocols struct{} type ErrServerTracksDifferentProtocols struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -144,7 +152,7 @@ func (e ErrServerTracksDifferentProtocols) Error() string {
return "can't setup tracks with different protocols" return "can't setup tracks with different protocols"
} }
// ErrServerNoTracksSetup is returned in case no tracks have been setup. // ErrServerNoTracksSetup is an error that can be returned by a server.
type ErrServerNoTracksSetup struct{} type ErrServerNoTracksSetup struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -152,7 +160,7 @@ func (e ErrServerNoTracksSetup) Error() string {
return "no tracks have been setup" return "no tracks have been setup"
} }
// ErrServerNotAllAnnouncedTracksSetup is returned in case not all announced tracks have been setup. // ErrServerNotAllAnnouncedTracksSetup is an error that can be returned by a server.
type ErrServerNotAllAnnouncedTracksSetup struct{} type ErrServerNotAllAnnouncedTracksSetup struct{}
// Error implements the error interface. // Error implements the error interface.
@@ -160,10 +168,18 @@ func (e ErrServerNotAllAnnouncedTracksSetup) Error() string {
return "not all announced tracks have been setup" return "not all announced tracks have been setup"
} }
// ErrServerNoUDPPacketsRecently is returned when no UDP packets have been received recently. // ErrServerNoUDPPacketsRecently is an error that can be returned by a server.
type ErrServerNoUDPPacketsRecently struct{} type ErrServerNoUDPPacketsRecently struct{}
// Error implements the error interface. // Error implements the error interface.
func (e ErrServerNoUDPPacketsRecently) Error() string { func (e ErrServerNoUDPPacketsRecently) Error() string {
return "no UDP packets received (maybe there's a firewall/NAT in between)" return "no UDP packets received (maybe there's a firewall/NAT in between)"
} }
// ErrServerLinkedToOtherSession is an error that can be returned by a server.
type ErrServerLinkedToOtherSession struct{}
// Error implements the error interface.
func (e ErrServerLinkedToOtherSession) Error() string {
return "connection is linked to another session"
}

View File

@@ -1,7 +1,9 @@
package gortsplib package gortsplib
import ( import (
"crypto/rand"
"crypto/tls" "crypto/tls"
"encoding/binary"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@@ -23,6 +25,27 @@ func extractPort(address string) (int, error) {
return int(tmp2), nil return int(tmp2), nil
} }
func newSessionID(sessions map[string]*ServerSession) (string, error) {
for {
b := make([]byte, 4)
_, err := rand.Read(b)
if err != nil {
return "", err
}
id := strconv.FormatUint(uint64(binary.LittleEndian.Uint32(b)), 10)
if _, ok := sessions[id]; !ok {
return id, nil
}
}
}
type sessionGetReq struct {
id string
res chan *ServerSession
}
// Server is a RTSP server. // Server is a RTSP server.
type Server struct { type Server struct {
// an handler to handle requests. // an handler to handle requests.
@@ -69,12 +92,15 @@ type Server struct {
tcpListener net.Listener tcpListener net.Listener
udpRTPListener *serverUDPListener udpRTPListener *serverUDPListener
udpRTCPListener *serverUDPListener udpRTCPListener *serverUDPListener
sessions map[string]*ServerSession
conns map[*ServerConn]struct{} conns map[*ServerConn]struct{}
exitError error exitError error
// in // in
connClose chan *ServerConn connClose chan *ServerConn
terminate chan struct{} sessionGet chan sessionGetReq
sessionClose chan *ServerSession
terminate chan struct{}
// out // out
done chan struct{} done chan struct{}
@@ -160,8 +186,11 @@ func (s *Server) Start(address string) error {
} }
func (s *Server) run() { func (s *Server) run() {
s.sessions = make(map[string]*ServerSession)
s.conns = make(map[*ServerConn]struct{}) s.conns = make(map[*ServerConn]struct{})
s.connClose = make(chan *ServerConn) s.connClose = make(chan *ServerConn)
s.sessionGet = make(chan sessionGetReq)
s.sessionClose = make(chan *ServerSession)
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -199,6 +228,28 @@ outer:
} }
s.doConnClose(sc) s.doConnClose(sc)
case req := <-s.sessionGet:
if ss, ok := s.sessions[req.id]; ok {
req.res <- ss
} else {
id, err := newSessionID(s.sessions)
if err != nil {
req.res <- nil
continue
}
ss := newServerSession(s, id, &wg)
s.sessions[id] = ss
req.res <- ss
}
case ss := <-s.sessionClose:
if _, ok := s.sessions[ss.id]; !ok {
continue
}
s.doSessionClose(ss)
case <-s.terminate: case <-s.terminate:
break outer break outer
} }
@@ -222,6 +273,17 @@ outer:
if !ok { if !ok {
return return
} }
case req, ok := <-s.sessionGet:
if !ok {
return
}
req.res <- nil
case _, ok := <-s.sessionClose:
if !ok {
return
}
} }
} }
}() }()
@@ -240,11 +302,17 @@ outer:
s.doConnClose(sc) s.doConnClose(sc)
} }
for _, ss := range s.sessions {
s.doSessionClose(ss)
}
wg.Wait() wg.Wait()
close(acceptErr) close(acceptErr)
close(connNew) close(connNew)
close(s.connClose) close(s.connClose)
close(s.sessionGet)
close(s.sessionClose)
close(s.done) close(s.done)
} }
@@ -275,3 +343,8 @@ func (s *Server) doConnClose(sc *ServerConn) {
delete(s.conns, sc) delete(s.conns, sc)
close(sc.terminate) close(sc.terminate)
} }
func (s *Server) doSessionClose(ss *ServerSession) {
delete(s.sessions, ss.id)
close(ss.terminate)
}

View File

@@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"net" "net"
"strconv" "strconv"
"strings"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@@ -159,6 +158,7 @@ func TestServerPublishSetupPath(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": th.Write(), "Transport": th.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -248,6 +248,7 @@ func TestServerPublishSetupErrorDifferentPaths(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": th.Write(), "Transport": th.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -336,6 +337,7 @@ func TestServerPublishSetupErrorTrackTwice(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": th.Write(), "Transport": th.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -350,6 +352,7 @@ func TestServerPublishSetupErrorTrackTwice(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"3"}, "CSeq": base.HeaderValue{"3"},
"Transport": th.Write(), "Transport": th.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -446,6 +449,7 @@ func TestServerPublishRecordErrorPartialTracks(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": th.Write(), "Transport": th.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -458,7 +462,8 @@ func TestServerPublishRecordErrorPartialTracks(t *testing.T) {
Method: base.Record, Method: base.Record,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"3"}, "CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -506,7 +511,7 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, StreamTypeRTCP, ctx.StreamType) require.Equal(t, StreamTypeRTCP, ctx.StreamType)
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload) require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload)
ctx.Conn.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C}) ctx.Session.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C})
} }
}, },
}, },
@@ -578,6 +583,7 @@ func TestServerPublish(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": inTH.Write(), "Transport": inTH.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -606,7 +612,8 @@ func TestServerPublish(t *testing.T) {
Method: base.Record, Method: base.Record,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"3"}, "CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -752,6 +759,7 @@ func TestServerPublishErrorWrongProtocol(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": inTH.Write(), "Transport": inTH.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -768,7 +776,8 @@ func TestServerPublishErrorWrongProtocol(t *testing.T) {
Method: base.Record, Method: base.Record,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"3"}, "CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -862,6 +871,7 @@ func TestServerPublishRTCPReport(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": inTH.Write(), "Transport": inTH.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -878,7 +888,8 @@ func TestServerPublishRTCPReport(t *testing.T) {
Method: base.Record, Method: base.Record,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"3"}, "CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -945,12 +956,12 @@ func TestServerPublishErrorTimeout(t *testing.T) {
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onConnClose: func(sc *ServerConn, err error) { onSessionClose: func(ss *ServerSession) {
if proto == "udp" { /*if proto == "udp" {
require.Equal(t, "no UDP packets received (maybe there's a firewall/NAT in between)", err.Error()) require.Equal(t, "no UDP packets received (maybe there's a firewall/NAT in between)", err.Error())
} else { } else {
require.True(t, strings.HasSuffix(err.Error(), "i/o timeout")) require.True(t, strings.HasSuffix(err.Error(), "i/o timeout"))
} }*/
close(errDone) close(errDone)
}, },
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
@@ -1038,6 +1049,7 @@ func TestServerPublishErrorTimeout(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": inTH.Write(), "Transport": inTH.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -1054,7 +1066,8 @@ func TestServerPublishErrorTimeout(t *testing.T) {
Method: base.Record, Method: base.Record,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"3"}, "CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -176,6 +176,7 @@ func TestServerReadSetupErrorDifferentPaths(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": th.Write(), "Transport": th.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -249,6 +250,7 @@ func TestServerReadSetupErrorTrackTwice(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": th.Write(), "Transport": th.Write(),
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -277,8 +279,8 @@ func TestServerRead(t *testing.T) {
}, nil }, nil
}, },
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) ctx.Session.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04})
ctx.Conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08}) ctx.Session.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08})
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -358,7 +360,8 @@ func TestServerRead(t *testing.T) {
Method: base.Play, Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -435,7 +438,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
go func() { go func() {
defer close(writerDone) defer close(writerDone)
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) ctx.Session.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
t := time.NewTicker(50 * time.Millisecond) t := time.NewTicker(50 * time.Millisecond)
defer t.Stop() defer t.Stop()
@@ -443,7 +446,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
for { for {
select { select {
case <-t.C: case <-t.C:
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) ctx.Session.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
case <-writerTerminate: case <-writerTerminate:
return return
} }
@@ -498,7 +501,8 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
Method: base.Play, Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -514,44 +518,21 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
} }
func TestServerReadPlayPlay(t *testing.T) { func TestServerReadPlayPlay(t *testing.T) {
writerTerminate := make(chan struct{})
writerDone := make(chan struct{})
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onConnClose: func(sc *ServerConn, err error) {
close(writerTerminate)
<-writerDone
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
if ctx.Conn.State() != ServerConnStatePlay {
go func() {
defer close(writerDone)
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for {
select {
case <-t.C:
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
case <-writerTerminate:
return
}
}
}()
}
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
}, },
UDPRTPAddress: "127.0.0.1:8000",
UDPRTCPAddress: "127.0.0.1:8001",
} }
err := s.Start("127.0.0.1:8554") err := s.Start("127.0.0.1:8554")
@@ -569,7 +550,7 @@ func TestServerReadPlayPlay(t *testing.T) {
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"1"}, "CSeq": base.HeaderValue{"1"},
"Transport": headers.Transport{ "Transport": headers.Transport{
Protocol: StreamProtocolTCP, Protocol: StreamProtocolUDP,
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
return &v return &v
@@ -578,7 +559,7 @@ func TestServerReadPlayPlay(t *testing.T) {
v := headers.TransportModePlay v := headers.TransportModePlay
return &v return &v
}(), }(),
InterleavedIDs: &[2]int{0, 1}, ClientPorts: &[2]int{30450, 30451},
}.Write(), }.Write(),
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
@@ -593,7 +574,8 @@ func TestServerReadPlayPlay(t *testing.T) {
Method: base.Play, Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -606,13 +588,13 @@ func TestServerReadPlayPlay(t *testing.T) {
Method: base.Play, Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"3"}, "CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
buf := make([]byte, 2048) err = res.Read(bconn.Reader)
err = res.ReadIgnoreFrames(bconn.Reader, buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
} }
@@ -645,7 +627,7 @@ func TestServerReadPlayPausePlay(t *testing.T) {
for { for {
select { select {
case <-t.C: case <-t.C:
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) ctx.Session.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
case <-writerTerminate: case <-writerTerminate:
return return
} }
@@ -704,7 +686,8 @@ func TestServerReadPlayPausePlay(t *testing.T) {
Method: base.Play, Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -717,21 +700,8 @@ func TestServerReadPlayPausePlay(t *testing.T) {
Method: base.Pause, Method: base.Pause,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
}, "Session": res.Header["Session"],
}.Write(bconn.Writer)
require.NoError(t, err)
buf := make([]byte, 2048)
err = res.ReadIgnoreFrames(bconn.Reader, buf)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
err = base.Request{
Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -740,10 +710,19 @@ func TestServerReadPlayPausePlay(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
var fr base.InterleavedFrame err = base.Request{
fr.Payload = make([]byte, 2048) Method: base.Play,
err = fr.Read(bconn.Reader) URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
} }
func TestServerReadPlayPausePause(t *testing.T) { func TestServerReadPlayPausePause(t *testing.T) {
@@ -771,7 +750,7 @@ func TestServerReadPlayPausePause(t *testing.T) {
for { for {
select { select {
case <-t.C: case <-t.C:
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) ctx.Session.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
case <-writerTerminate: case <-writerTerminate:
return return
} }
@@ -829,7 +808,8 @@ func TestServerReadPlayPausePause(t *testing.T) {
Method: base.Play, Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -842,7 +822,8 @@ func TestServerReadPlayPausePause(t *testing.T) {
Method: base.Pause, Method: base.Pause,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -856,7 +837,8 @@ func TestServerReadPlayPausePause(t *testing.T) {
Method: base.Pause, Method: base.Pause,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"), URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"testing" "testing"
@@ -16,14 +15,15 @@ import (
) )
type testServerHandler struct { type testServerHandler struct {
onConnClose func(*ServerConn, error) onConnClose func(*ServerConn, error)
onDescribe func(*ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) onSessionClose func(*ServerSession)
onAnnounce func(*ServerHandlerOnAnnounceCtx) (*base.Response, error) onDescribe func(*ServerHandlerOnDescribeCtx) (*base.Response, []byte, error)
onSetup func(*ServerHandlerOnSetupCtx) (*base.Response, error) onAnnounce func(*ServerHandlerOnAnnounceCtx) (*base.Response, error)
onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error) onSetup func(*ServerHandlerOnSetupCtx) (*base.Response, error)
onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error) onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error)
onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error) onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error)
onFrame func(*ServerHandlerOnFrameCtx) onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error)
onFrame func(*ServerHandlerOnFrameCtx)
} }
func (sh *testServerHandler) OnConnClose(sc *ServerConn, err error) { func (sh *testServerHandler) OnConnClose(sc *ServerConn, err error) {
@@ -32,6 +32,12 @@ func (sh *testServerHandler) OnConnClose(sc *ServerConn, err error) {
} }
} }
func (sh *testServerHandler) OnSessionClose(ss *ServerSession) {
if sh.onSessionClose != nil {
sh.onSessionClose(ss)
}
}
func (sh *testServerHandler) OnDescribe(ctx *ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) { func (sh *testServerHandler) OnDescribe(ctx *ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) {
if sh.onDescribe != nil { if sh.onDescribe != nil {
return sh.onDescribe(ctx) return sh.onDescribe(ctx)
@@ -167,23 +173,22 @@ func TestServerHighLevelPublishRead(t *testing.T) {
t.Run(encryptedStr+"_"+ca.publisherSoft+"_"+ca.publisherProto+"_"+ t.Run(encryptedStr+"_"+ca.publisherSoft+"_"+ca.publisherProto+"_"+
ca.readerSoft+"_"+ca.readerProto, func(t *testing.T) { ca.readerSoft+"_"+ca.readerProto, func(t *testing.T) {
var mutex sync.Mutex var mutex sync.Mutex
var publisher *ServerConn var publisher *ServerSession
var sdp []byte var sdp []byte
readers := make(map[*ServerConn]struct{}) readers := make(map[*ServerSession]struct{})
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onConnClose: func(sc *ServerConn, err error) { onSessionClose: func(ss *ServerSession) {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
if sc == publisher { if ss == publisher {
publisher = nil publisher = nil
sdp = nil sdp = nil
} else { } else {
delete(readers, sc) delete(readers, ss)
} }
}, },
onDescribe: func(ctx *ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) { onDescribe: func(ctx *ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) {
@@ -222,7 +227,7 @@ func TestServerHighLevelPublishRead(t *testing.T) {
}, fmt.Errorf("someone is already publishing") }, fmt.Errorf("someone is already publishing")
} }
publisher = ctx.Conn publisher = ctx.Session
sdp = ctx.Tracks.Write() sdp = ctx.Tracks.Write()
return &base.Response{ return &base.Response{
@@ -256,7 +261,7 @@ func TestServerHighLevelPublishRead(t *testing.T) {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
readers[ctx.Conn] = struct{}{} readers[ctx.Session] = struct{}{}
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -275,7 +280,7 @@ func TestServerHighLevelPublishRead(t *testing.T) {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
if ctx.Conn != publisher { if ctx.Session != publisher {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing") }, fmt.Errorf("someone is already publishing")
@@ -292,7 +297,7 @@ func TestServerHighLevelPublishRead(t *testing.T) {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
if ctx.Conn == publisher { if ctx.Session == publisher {
for r := range readers { for r := range readers {
r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload)
} }
@@ -448,33 +453,3 @@ func TestServerErrorCSeqMissing(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode) require.Equal(t, base.StatusBadRequest, res.StatusCode)
} }
func TestServerTeardownResponse(t *testing.T) {
s := &Server{}
err := s.Start("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
err = base.Request{
Method: base.Teardown,
URL: base.MustParseURL("rtsp://localhost:8554/"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
},
}.Write(bconn.Writer)
require.NoError(t, err)
var res base.Response
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
buf := make([]byte, 2048)
_, err = bconn.Read(buf)
require.Equal(t, io.EOF, err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -19,6 +19,16 @@ type ServerHandlerOnConnClose interface {
OnConnClose(*ServerConn, error) OnConnClose(*ServerConn, error)
} }
// ServerHandlerOnSessionOpen can be implemented by a ServerHandler.
type ServerHandlerOnSessionOpen interface {
OnSessionOpen(*ServerSession)
}
// ServerHandlerOnSessionClose can be implemented by a ServerHandler.
type ServerHandlerOnSessionClose interface {
OnSessionClose(*ServerSession)
}
// ServerHandlerOnRequest can be implemented by a ServerHandler. // ServerHandlerOnRequest can be implemented by a ServerHandler.
type ServerHandlerOnRequest interface { type ServerHandlerOnRequest interface {
OnRequest(*base.Request) OnRequest(*base.Request)
@@ -57,12 +67,12 @@ type ServerHandlerOnDescribe interface {
// ServerHandlerOnAnnounceCtx is the context of an ANNOUNCE request. // ServerHandlerOnAnnounceCtx is the context of an ANNOUNCE request.
type ServerHandlerOnAnnounceCtx struct { type ServerHandlerOnAnnounceCtx struct {
Conn *ServerConn Session *ServerSession
// Session *ServerSession Conn *ServerConn
Req *base.Request Req *base.Request
Path string Path string
Query string Query string
Tracks Tracks Tracks Tracks
} }
// ServerHandlerOnAnnounce can be implemented by a ServerHandler. // ServerHandlerOnAnnounce can be implemented by a ServerHandler.
@@ -72,8 +82,8 @@ type ServerHandlerOnAnnounce interface {
// ServerHandlerOnSetupCtx is the context of a OPTIONS request. // ServerHandlerOnSetupCtx is the context of a OPTIONS request.
type ServerHandlerOnSetupCtx struct { type ServerHandlerOnSetupCtx struct {
Conn *ServerConn
Session *ServerSession Session *ServerSession
Conn *ServerConn
Req *base.Request Req *base.Request
Path string Path string
Query string Query string
@@ -88,11 +98,11 @@ type ServerHandlerOnSetup interface {
// ServerHandlerOnPlayCtx is the context of a PLAY request. // ServerHandlerOnPlayCtx is the context of a PLAY request.
type ServerHandlerOnPlayCtx struct { type ServerHandlerOnPlayCtx struct {
Conn *ServerConn Session *ServerSession
// Session *ServerSession Conn *ServerConn
Req *base.Request Req *base.Request
Path string Path string
Query string Query string
} }
// ServerHandlerOnPlay can be implemented by a ServerHandler. // ServerHandlerOnPlay can be implemented by a ServerHandler.
@@ -102,11 +112,11 @@ type ServerHandlerOnPlay interface {
// ServerHandlerOnRecordCtx is the context of a RECORD request. // ServerHandlerOnRecordCtx is the context of a RECORD request.
type ServerHandlerOnRecordCtx struct { type ServerHandlerOnRecordCtx struct {
Conn *ServerConn Session *ServerSession
// Session *ServerSession Conn *ServerConn
Req *base.Request Req *base.Request
Path string Path string
Query string Query string
} }
// ServerHandlerOnRecord can be implemented by a ServerHandler. // ServerHandlerOnRecord can be implemented by a ServerHandler.
@@ -116,11 +126,11 @@ type ServerHandlerOnRecord interface {
// ServerHandlerOnPauseCtx is the context of a PAUSE request. // ServerHandlerOnPauseCtx is the context of a PAUSE request.
type ServerHandlerOnPauseCtx struct { type ServerHandlerOnPauseCtx struct {
Conn *ServerConn Session *ServerSession
// Session *ServerSession Conn *ServerConn
Req *base.Request Req *base.Request
Path string Path string
Query string Query string
} }
// ServerHandlerOnPause can be implemented by a ServerHandler. // ServerHandlerOnPause can be implemented by a ServerHandler.
@@ -156,11 +166,11 @@ type ServerHandlerOnSetParameter interface {
// ServerHandlerOnTeardownCtx is the context of a TEARDOWN request. // ServerHandlerOnTeardownCtx is the context of a TEARDOWN request.
type ServerHandlerOnTeardownCtx struct { type ServerHandlerOnTeardownCtx struct {
Conn *ServerConn Session *ServerSession
// Session *ServerSession Conn *ServerConn
Req *base.Request Req *base.Request
Path string Path string
Query string Query string
} }
// ServerHandlerOnTeardown can be implemented by a ServerHandler. // ServerHandlerOnTeardown can be implemented by a ServerHandler.
@@ -170,8 +180,7 @@ type ServerHandlerOnTeardown interface {
// ServerHandlerOnFrameCtx is the context of a frame request. // ServerHandlerOnFrameCtx is the context of a frame request.
type ServerHandlerOnFrameCtx struct { type ServerHandlerOnFrameCtx struct {
Conn *ServerConn Session *ServerSession
// Session *ServerSession
TrackID int TrackID int
StreamType StreamType StreamType StreamType
Payload []byte Payload []byte

View File

@@ -1,5 +1,822 @@
package gortsplib package gortsplib
import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/rtcpreceiver"
)
const (
serverSessionCheckStreamPeriod = 1 * time.Second
)
func setupGetTrackIDPathQuery(url *base.URL,
thMode *headers.TransportMode,
announcedTracks []ServerSessionAnnouncedTrack,
setupPath *string, setupQuery *string) (int, string, string, error) {
pathAndQuery, ok := url.RTSPPathAndQuery()
if !ok {
return 0, "", "", liberrors.ErrServerNoPath{}
}
if thMode == nil || *thMode == headers.TransportModePlay {
i := stringsReverseIndex(pathAndQuery, "/trackID=")
// URL doesn't contain trackID - it's track zero
if i < 0 {
if !strings.HasSuffix(pathAndQuery, "/") {
return 0, "", "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery)
}
pathAndQuery = pathAndQuery[:len(pathAndQuery)-1]
path, query := base.PathSplitQuery(pathAndQuery)
// we assume it's track 0
return 0, path, query, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, "", "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery)
}
trackID := int(tmp)
pathAndQuery = pathAndQuery[:i]
path, query := base.PathSplitQuery(pathAndQuery)
if setupPath != nil && (path != *setupPath || query != *setupQuery) {
return 0, "", "", fmt.Errorf("can't setup tracks with different paths")
}
return trackID, path, query, nil
}
for trackID, track := range announcedTracks {
u, _ := track.track.URL()
if u.String() == url.String() {
return trackID, *setupPath, *setupQuery, nil
}
}
return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
}
// ServerSessionState is a state of a ServerSession.
type ServerSessionState int
// standard states.
const (
ServerSessionStateInitial ServerSessionState = iota
ServerSessionStatePrePlay
ServerSessionStatePlay
ServerSessionStatePreRecord
ServerSessionStateRecord
)
// String implements fmt.Stringer.
func (s ServerSessionState) String() string {
switch s {
case ServerSessionStateInitial:
return "initial"
case ServerSessionStatePrePlay:
return "prePlay"
case ServerSessionStatePlay:
return "play"
case ServerSessionStatePreRecord:
return "preRecord"
case ServerSessionStateRecord:
return "record"
}
return "unknown"
}
// ServerSessionSetuppedTrack is a setupped track of a ServerSession.
type ServerSessionSetuppedTrack struct {
udpRTPPort int
udpRTCPPort int
}
// ServerSessionAnnouncedTrack is an announced track of a ServerSession.
type ServerSessionAnnouncedTrack struct {
track *Track
rtcpReceiver *rtcpreceiver.RTCPReceiver
udpLastFrameTime *int64
}
type requestRes struct {
res *base.Response
err error
}
type requestReq struct {
sc *ServerConn
req *base.Request
res chan requestRes
}
// ServerSession is a server-side RTSP session. // ServerSession is a server-side RTSP session.
type ServerSession struct { type ServerSession struct {
s *Server
id string
wg *sync.WaitGroup
state ServerSessionState
setuppedTracks map[int]ServerSessionSetuppedTrack
setupProtocol *StreamProtocol
setupPath *string
setupQuery *string
// TCP stream protocol
linkedConn *ServerConn
// UDP stream protocol
udpIP net.IP
udpZone string
// publish
announcedTracks []ServerSessionAnnouncedTrack
// in
request chan requestReq
terminate chan struct{}
}
func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
ss := &ServerSession{
s: s,
id: id,
wg: wg,
request: make(chan requestReq),
terminate: make(chan struct{}),
}
wg.Add(1)
go ss.run()
return ss
}
// State returns the state of the session.
func (ss *ServerSession) State() ServerSessionState {
return ss.state
}
// StreamProtocol returns the stream protocol of the setupped tracks.
func (ss *ServerSession) StreamProtocol() *StreamProtocol {
return ss.setupProtocol
}
// SetuppedTracks returns the setupped tracks.
func (ss *ServerSession) SetuppedTracks() map[int]ServerSessionSetuppedTrack {
return ss.setuppedTracks
}
// AnnouncedTracks returns the announced tracks.
func (ss *ServerSession) AnnouncedTracks() []ServerSessionAnnouncedTrack {
return ss.announcedTracks
}
func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error {
if _, ok := allowed[ss.state]; ok {
return nil
}
allowedList := make([]fmt.Stringer, len(allowed))
i := 0
for a := range allowed {
allowedList[i] = a
i++
}
return liberrors.ErrServerWrongState{AllowedList: allowedList, State: ss.state}
}
func (ss *ServerSession) run() {
defer ss.wg.Done()
if h, ok := ss.s.Handler.(ServerHandlerOnSessionOpen); ok {
h.OnSessionOpen(ss)
}
checkStreamTicker := time.NewTicker(serverSessionCheckStreamPeriod)
defer checkStreamTicker.Stop()
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
defer receiverReportTicker.Stop()
outer:
for {
select {
case req := <-ss.request:
res, err := ss.handleRequest(req.sc, req.req)
req.res <- requestRes{res, err}
case <-checkStreamTicker.C:
if ss.state != ServerSessionStateRecord || *ss.setupProtocol != StreamProtocolUDP {
continue
}
inTimeout := func() bool {
now := time.Now()
for _, track := range ss.announcedTracks {
lft := atomic.LoadInt64(track.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) < ss.s.ReadTimeout {
return false
}
}
return true
}()
if inTimeout {
break outer
}
case <-receiverReportTicker.C:
if ss.state != ServerSessionStateRecord {
continue
}
now := time.Now()
for trackID, track := range ss.announcedTracks {
r := track.rtcpReceiver.Report(now)
ss.WriteFrame(trackID, StreamTypeRTCP, r)
}
case <-ss.terminate:
break outer
}
}
go func() {
for req := range ss.request {
req.res <- requestRes{nil, fmt.Errorf("terminated")}
}
}()
switch ss.state {
case ServerSessionStatePlay:
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTCPListener.removeClient(ss)
}
case ServerSessionStateRecord:
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTPListener.removeClient(ss)
ss.s.udpRTCPListener.removeClient(ss)
}
}
if ss.linkedConn != nil {
ss.s.connClose <- ss.linkedConn
}
ss.s.sessionClose <- ss
<-ss.terminate
close(ss.request)
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
h.OnSessionClose(ss)
}
}
func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) {
switch req.Method {
case base.Announce:
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStateInitial: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerContentTypeMissing{}
}
if ct[0] != "application/sdp" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerContentTypeUnsupported{CT: ct}
}
tracks, err := ReadTracks(req.Body, req.URL)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerSDPInvalid{Err: err}
}
if len(tracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerSDPNoTracksDefined{}
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
for _, track := range tracks {
trackURL, err := track.URL()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unable to generate track URL")
}
trackPath, ok := trackURL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track URL (%v)", trackURL)
}
if !strings.HasPrefix(trackPath, path) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track path: must begin with '%s', but is '%s'",
path, trackPath)
}
}
res, err := ss.s.Handler.(ServerHandlerOnAnnounce).OnAnnounce(&ServerHandlerOnAnnounceCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
Tracks: tracks,
})
if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStatePreRecord
ss.setupPath = &path
ss.setupQuery = &query
ss.announcedTracks = make([]ServerSessionAnnouncedTrack, len(tracks))
for trackID, track := range tracks {
clockRate, _ := track.ClockRate()
v := time.Now().Unix()
ss.announcedTracks[trackID] = ServerSessionAnnouncedTrack{
track: track,
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
udpLastFrameTime: &v,
}
}
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
}
return res, err
case base.Setup:
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStateInitial: {},
ServerSessionStatePrePlay: {},
ServerSessionStatePreRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
var th headers.Transport
err = th.Read(req.Header["Transport"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalid{Err: err}
}
if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
trackID, path, query, err := setupGetTrackIDPathQuery(req.URL, th.Mode,
ss.announcedTracks, ss.setupPath, ss.setupQuery)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if _, ok := ss.setuppedTracks[trackID]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTrackAlreadySetup{TrackID: trackID}
}
switch ss.state {
case ServerSessionStateInitial, ServerSessionStatePrePlay: // play
if th.Mode != nil && *th.Mode != headers.TransportModePlay {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderWrongMode{Mode: th.Mode}
}
default: // record
if th.Mode == nil || *th.Mode != headers.TransportModeRecord {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderWrongMode{Mode: th.Mode}
}
}
if th.Protocol == StreamProtocolUDP {
if ss.s.udpRTPListener == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if th.ClientPorts == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoClientPorts{}
}
} else {
if th.InterleavedIDs == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoInterleavedIDs{}
}
if th.InterleavedIDs[0] != (trackID*2) ||
th.InterleavedIDs[1] != (1+trackID*2) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderWrongInterleavedIDs{
Expected: [2]int{(trackID * 2), (1 + trackID*2)}, Value: *th.InterleavedIDs}
}
}
if ss.setupProtocol != nil && *ss.setupProtocol != th.Protocol {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTracksDifferentProtocols{}
}
res, err := ss.s.Handler.(ServerHandlerOnSetup).OnSetup(&ServerHandlerOnSetupCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
TrackID: trackID,
Transport: &th,
})
if res.StatusCode == base.StatusOK {
ss.setupProtocol = &th.Protocol
if ss.setuppedTracks == nil {
ss.setuppedTracks = make(map[int]ServerSessionSetuppedTrack)
}
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
if th.Protocol == StreamProtocolUDP {
ss.setuppedTracks[trackID] = ServerSessionSetuppedTrack{
udpRTPPort: th.ClientPorts[0],
udpRTCPPort: th.ClientPorts[1],
}
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolUDP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
ClientPorts: th.ClientPorts,
ServerPorts: &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()},
}.Write()
} else {
ss.setuppedTracks[trackID] = ServerSessionSetuppedTrack{}
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolTCP,
InterleavedIDs: th.InterleavedIDs,
}.Write()
}
}
if ss.state == ServerSessionStateInitial {
ss.state = ServerSessionStatePrePlay
ss.setupPath = &path
ss.setupQuery = &query
}
// workaround to prevent a bug in rtspclientsink
// that makes impossible for the client to receive the response
// and send frames.
// this was causing problems during unit tests.
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
strings.HasPrefix(ua[0], "GStreamer") {
select {
case <-time.After(1 * time.Second):
case <-sc.terminate:
}
}
return res, err
case base.Play:
// play can be sent twice, allow calling it even if we're already playing
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStatePrePlay: {},
ServerSessionStatePlay: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if len(ss.setuppedTracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoTracksSetup{}
}
// with TCP, PLAY can't be called twice
// with UDP, it can
if ss.state == ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolTCP {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
if ss.state != ServerSessionStatePlay {
ss.linkedConn = sc
}
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
if ss.state != ServerSessionStatePlay {
ss.state = ServerSessionStatePlay
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
if *ss.setupProtocol == StreamProtocolUDP {
ss.udpIP = sc.ip()
ss.udpZone = sc.zone()
// readers can send RTCP frames, they cannot sent RTP frames
for trackID, track := range ss.setuppedTracks {
sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false)
}
return res, err
}
return res, liberrors.ErrServerTCPFramesEnable{}
}
} else {
ss.linkedConn = nil
}
return res, err
case base.Record:
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStatePreRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if len(ss.setuppedTracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoTracksSetup{}
}
if len(ss.setuppedTracks) != len(ss.announcedTracks) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNotAllAnnouncedTracksSetup{}
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStateRecord
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
if *ss.setupProtocol == StreamProtocolUDP {
ss.udpIP = sc.ip()
ss.udpZone = sc.zone()
for trackID, track := range ss.setuppedTracks {
ss.s.udpRTPListener.addClient(ss.udpIP, track.udpRTPPort, ss, trackID, true)
ss.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, true)
// open the firewall by sending packets to the counterpart
ss.WriteFrame(trackID, StreamTypeRTP,
[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
ss.WriteFrame(trackID, StreamTypeRTCP,
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
}
return res, err
}
ss.linkedConn = sc
return res, liberrors.ErrServerTCPFramesEnable{}
}
return res, err
case base.Pause:
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStatePrePlay: {},
ServerSessionStatePlay: {},
ServerSessionStatePreRecord: {},
ServerSessionStateRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := ss.s.Handler.(ServerHandlerOnPause).OnPause(&ServerHandlerOnPauseCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
switch ss.state {
case ServerSessionStatePlay:
ss.state = ServerSessionStatePrePlay
ss.linkedConn = nil
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTCPListener.removeClient(ss)
} else {
return res, liberrors.ErrServerTCPFramesDisable{}
}
case ServerSessionStateRecord:
ss.state = ServerSessionStatePreRecord
ss.linkedConn = nil
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTPListener.removeClient(ss)
} else {
return res, liberrors.ErrServerTCPFramesDisable{}
}
}
}
return res, err
case base.Teardown:
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
return ss.s.Handler.(ServerHandlerOnTeardown).OnTeardown(&ServerHandlerOnTeardownCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
}
return nil, fmt.Errorf("unimplemented")
}
// WriteFrame writes a frame.
func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload []byte) {
if *ss.setupProtocol == StreamProtocolUDP {
track := ss.setuppedTracks[trackID]
if streamType == StreamTypeRTP {
ss.s.udpRTPListener.write(payload, &net.UDPAddr{
IP: ss.udpIP,
Zone: ss.udpZone,
Port: track.udpRTPPort,
})
} else {
ss.s.udpRTCPListener.write(payload, &net.UDPAddr{
IP: ss.udpIP,
Zone: ss.udpZone,
Port: track.udpRTCPPort,
})
}
} else {
ss.linkedConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,
Payload: payload,
})
}
} }

View File

@@ -20,7 +20,7 @@ type bufAddrPair struct {
} }
type clientData struct { type clientData struct {
sc *ServerConn ss *ServerSession
trackID int trackID int
isPublishing bool isPublishing bool
} }
@@ -123,13 +123,13 @@ func (u *serverUDPListener) run() {
if clientData.isPublishing { if clientData.isPublishing {
now := time.Now() now := time.Now()
atomic.StoreInt64(clientData.sc.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix()) atomic.StoreInt64(clientData.ss.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix())
clientData.sc.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n]) clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n])
} }
if h, ok := u.s.Handler.(ServerHandlerOnFrame); ok { if h, ok := u.s.Handler.(ServerHandlerOnFrame); ok {
h.OnFrame(&ServerHandlerOnFrameCtx{ h.OnFrame(&ServerHandlerOnFrameCtx{
Conn: clientData.sc, Session: clientData.ss,
TrackID: clientData.trackID, TrackID: clientData.trackID,
StreamType: u.streamType, StreamType: u.streamType,
Payload: buf[:n], Payload: buf[:n],
@@ -166,7 +166,7 @@ func (u *serverUDPListener) write(buf []byte, addr *net.UDPAddr) {
u.ringBuffer.Push(bufAddrPair{buf, addr}) u.ringBuffer.Push(bufAddrPair{buf, addr})
} }
func (u *serverUDPListener) addClient(ip net.IP, port int, sc *ServerConn, trackID int, isPublishing bool) { func (u *serverUDPListener) addClient(ip net.IP, port int, ss *ServerSession, trackID int, isPublishing bool) {
u.clientsMutex.Lock() u.clientsMutex.Lock()
defer u.clientsMutex.Unlock() defer u.clientsMutex.Unlock()
@@ -174,18 +174,19 @@ func (u *serverUDPListener) addClient(ip net.IP, port int, sc *ServerConn, track
addr.fill(ip, port) addr.fill(ip, port)
u.clients[addr] = &clientData{ u.clients[addr] = &clientData{
sc: sc, ss: ss,
trackID: trackID, trackID: trackID,
isPublishing: isPublishing, isPublishing: isPublishing,
} }
} }
func (u *serverUDPListener) removeClient(ip net.IP, port int) { func (u *serverUDPListener) removeClient(ss *ServerSession) {
u.clientsMutex.Lock() u.clientsMutex.Lock()
defer u.clientsMutex.Unlock() defer u.clientsMutex.Unlock()
var addr clientAddr for addr, data := range u.clients {
addr.fill(ip, port) if data.ss == ss {
delete(u.clients, addr)
delete(u.clients, addr) }
}
} }