server: implement sessions

This commit is contained in:
aler9
2021-05-02 19:12:51 +02:00
committed by Alessandro Ros
parent 712432bcef
commit 259043685d
14 changed files with 1333 additions and 1015 deletions

View File

@@ -17,8 +17,8 @@ import (
type serverHandler struct {
mutex sync.Mutex
publisher *gortsplib.ServerConn
readers map[*gortsplib.ServerConn]struct{}
publisher *gortsplib.ServerSession
readers map[*gortsplib.ServerSession]struct{}
sdp []byte
}
@@ -30,15 +30,18 @@ func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) {
// called when a connection is closed.
func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) {
log.Println("conn closed (%v)", err)
}
// called when a session is closed.
func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession) {
sh.mutex.Lock()
defer sh.mutex.Unlock()
if sc == sh.publisher {
if ss == sh.publisher {
sh.publisher = nil
sh.sdp = nil
} else {
delete(sh.readers, sc)
delete(sh.readers, ss)
}
}
@@ -70,14 +73,11 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
}, fmt.Errorf("someone is already publishing")
}
sh.publisher = ctx.Conn
sh.publisher = ctx.Session
sh.sdp = ctx.Tracks.Write()
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
@@ -85,9 +85,6 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
@@ -96,13 +93,10 @@ func (sh *serverHandler) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Re
sh.mutex.Lock()
defer sh.mutex.Unlock()
sh.readers[ctx.Conn] = struct{}{}
sh.readers[ctx.Session] = struct{}{}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
@@ -111,7 +105,7 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
sh.mutex.Lock()
defer sh.mutex.Unlock()
if ctx.Conn != sh.publisher {
if ctx.Session != sh.publisher {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
@@ -119,9 +113,6 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
@@ -131,7 +122,7 @@ func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) {
defer sh.mutex.Unlock()
// if we are the publisher, route frames to readers
if ctx.Conn == sh.publisher {
if ctx.Session == sh.publisher {
for r := range sh.readers {
r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload)
}

View File

@@ -16,8 +16,8 @@ import (
type serverHandler struct {
mutex sync.Mutex
publisher *gortsplib.ServerConn
readers map[*gortsplib.ServerConn]struct{}
publisher *gortsplib.ServerSession
readers map[*gortsplib.ServerSession]struct{}
sdp []byte
}
@@ -29,15 +29,18 @@ func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) {
// called when a connection is closed.
func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) {
log.Println("conn closed (%v)", err)
}
// called when a session is closed.
func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession) {
sh.mutex.Lock()
defer sh.mutex.Unlock()
if sc == sh.publisher {
if ss == sh.publisher {
sh.publisher = nil
sh.sdp = nil
} else {
delete(sh.readers, sc)
delete(sh.readers, ss)
}
}
@@ -69,14 +72,11 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
}, fmt.Errorf("someone is already publishing")
}
sh.publisher = ctx.Conn
sh.publisher = ctx.Session
sh.sdp = ctx.Tracks.Write()
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
@@ -84,9 +84,6 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
@@ -95,13 +92,10 @@ func (sh *serverHandler) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Re
sh.mutex.Lock()
defer sh.mutex.Unlock()
sh.readers[ctx.Conn] = struct{}{}
sh.readers[ctx.Session] = struct{}{}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
@@ -110,7 +104,7 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
sh.mutex.Lock()
defer sh.mutex.Unlock()
if ctx.Conn != sh.publisher {
if ctx.Session != sh.publisher {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
@@ -118,9 +112,6 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
@@ -130,7 +121,7 @@ func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) {
defer sh.mutex.Unlock()
// if we are the publisher, route frames to readers
if ctx.Conn == sh.publisher {
if ctx.Session == sh.publisher {
for r := range sh.readers {
r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload)
}

View File

@@ -6,7 +6,7 @@ import (
"github.com/aler9/gortsplib/pkg/base"
)
// ErrClientWrongState is returned in case of a wrong client state.
// ErrClientWrongState is an error that can be returned by a client.
type ErrClientWrongState struct {
AllowedList []fmt.Stringer
State fmt.Stringer
@@ -18,7 +18,7 @@ func (e ErrClientWrongState) Error() string {
e.AllowedList, e.State)
}
// ErrClientSessionHeaderInvalid is returned in case of an invalid session header.
// ErrClientSessionHeaderInvalid is an error that can be returned by a client.
type ErrClientSessionHeaderInvalid struct {
Err error
}
@@ -28,7 +28,7 @@ func (e ErrClientSessionHeaderInvalid) Error() string {
return fmt.Sprintf("invalid session header: %v", e.Err)
}
// ErrClientWrongStatusCode is returned in case of a wrong status code.
// ErrClientWrongStatusCode is an error that can be returned by a client.
type ErrClientWrongStatusCode struct {
Code base.StatusCode
Message string
@@ -39,7 +39,7 @@ func (e ErrClientWrongStatusCode) Error() string {
return fmt.Sprintf("wrong status code: %d (%s)", e.Code, e.Message)
}
// ErrClientContentTypeMissing is returned in case the Content-Type header is missing.
// ErrClientContentTypeMissing is an error that can be returned by a client.
type ErrClientContentTypeMissing struct{}
// Error implements the error interface.
@@ -47,7 +47,7 @@ func (e ErrClientContentTypeMissing) Error() string {
return "Content-Type header is missing"
}
// ErrClientContentTypeUnsupported is returned in case the Content-Type header is unsupported.
// ErrClientContentTypeUnsupported is an error that can be returned by a client.
type ErrClientContentTypeUnsupported struct {
CT base.HeaderValue
}
@@ -57,7 +57,7 @@ func (e ErrClientContentTypeUnsupported) Error() string {
return fmt.Sprintf("unsupported Content-Type header '%v'", e.CT)
}
// ErrClientCannotReadPublishAtSameTime is returned when the client is trying to read and publish at the same time.
// ErrClientCannotReadPublishAtSameTime is an error that can be returned by a client.
type ErrClientCannotReadPublishAtSameTime struct{}
// Error implements the error interface.
@@ -65,7 +65,7 @@ func (e ErrClientCannotReadPublishAtSameTime) Error() string {
return "cannot read and publish at the same time"
}
// ErrClientCannotSetupTracksDifferentURLs is returned when the client is trying to setup tracks with different base URLs.
// ErrClientCannotSetupTracksDifferentURLs is an error that can be returned by a client.
type ErrClientCannotSetupTracksDifferentURLs struct{}
// Error implements the error interface.
@@ -73,7 +73,7 @@ func (e ErrClientCannotSetupTracksDifferentURLs) Error() string {
return "cannot setup tracks with different base URLs"
}
// ErrClientUDPPortsZero is returned when one of the UDP ports is zero.
// ErrClientUDPPortsZero is an error that can be returned by a client.
type ErrClientUDPPortsZero struct{}
// Error implements the error interface.
@@ -81,7 +81,7 @@ func (e ErrClientUDPPortsZero) Error() string {
return "rtpPort and rtcpPort must be both zero or non-zero"
}
// ErrClientUDPPortsNotConsecutive is returned when the two UDP ports are not consecutive.
// ErrClientUDPPortsNotConsecutive is an error that can be returned by a client.
type ErrClientUDPPortsNotConsecutive struct{}
// Error implements the error interface.
@@ -89,7 +89,7 @@ func (e ErrClientUDPPortsNotConsecutive) Error() string {
return "rtcpPort must be rtpPort + 1"
}
// ErrClientServerPortsZero is returned when one of the server ports is zero.
// ErrClientServerPortsZero is an error that can be returned by a client.
type ErrClientServerPortsZero struct{}
// Error implements the error interface.
@@ -97,7 +97,7 @@ func (e ErrClientServerPortsZero) Error() string {
return "server ports must be both zero or both not zero"
}
// ErrClientServerPortsNotProvided is returned in case the server ports have not been provided.
// ErrClientServerPortsNotProvided is an error that can be returned by a client.
type ErrClientServerPortsNotProvided struct{}
// Error implements the error interface.
@@ -105,7 +105,7 @@ func (e ErrClientServerPortsNotProvided) Error() string {
return "server ports have not been provided. Use AnyPortEnable to communicate with this server"
}
// ErrClientTransportHeaderInvalid is returned in case the transport header is invalid.
// ErrClientTransportHeaderInvalid is an error that can be returned by a client.
type ErrClientTransportHeaderInvalid struct {
Err error
}
@@ -115,7 +115,7 @@ func (e ErrClientTransportHeaderInvalid) Error() string {
return fmt.Sprintf("invalid transport header: %v", e.Err)
}
// ErrClientTransportHeaderNoInterleavedIDs is returned in case the transport header doesn't contain interleaved IDs.
// ErrClientTransportHeaderNoInterleavedIDs is an error that can be returned by a client.
type ErrClientTransportHeaderNoInterleavedIDs struct{}
// Error implements the error interface.
@@ -123,7 +123,7 @@ func (e ErrClientTransportHeaderNoInterleavedIDs) Error() string {
return "transport header does not contain interleaved IDs"
}
// ErrClientTransportHeaderWrongInterleavedIDs is returned in case the transport header contains wrong interleaved IDs.
// ErrClientTransportHeaderWrongInterleavedIDs is an error that can be returned by a client.
type ErrClientTransportHeaderWrongInterleavedIDs struct {
Expected [2]int
Value [2]int
@@ -134,7 +134,7 @@ func (e ErrClientTransportHeaderWrongInterleavedIDs) Error() string {
return fmt.Sprintf("wrong interleaved IDs, expected %v, got %v", e.Expected, e.Value)
}
// ErrClientNoUDPPacketsRecently is returned when no UDP packets have been received recently.
// ErrClientNoUDPPacketsRecently is an error that can be returned by a client.
type ErrClientNoUDPPacketsRecently struct{}
// Error implements the error interface.
@@ -142,8 +142,7 @@ func (e ErrClientNoUDPPacketsRecently) Error() string {
return "no UDP packets received (maybe there's a firewall/NAT in between)"
}
// ErrClientUDPTimeout is returned when timeout has exceeded but UDP packets have been received previously
// but now nothing is being received.
// ErrClientUDPTimeout is an error that can be returned by a client.
type ErrClientUDPTimeout struct{}
// Error implements the error interface.
@@ -151,7 +150,7 @@ func (e ErrClientUDPTimeout) Error() string {
return "UDP timeout"
}
// ErrClientTCPTimeout is returned when timeout has exceeded.
// ErrClientTCPTimeout is an error that can be returned by a client.
type ErrClientTCPTimeout struct{}
// Error implements the error interface.
@@ -159,7 +158,7 @@ func (e ErrClientTCPTimeout) Error() string {
return "TCP timeout"
}
// ErrClientRTPInfoInvalid is returned in case of an invalid RTP-Info.
// ErrClientRTPInfoInvalid is an error that can be returned by a client.
type ErrClientRTPInfoInvalid struct {
Err error
}

View File

@@ -7,15 +7,23 @@ import (
"github.com/aler9/gortsplib/pkg/headers"
)
// ErrServerTeardown is returned in case of a teardown request.
type ErrServerTeardown struct{}
// ErrServerTCPFramesEnable is an error that can be returned by a server.
type ErrServerTCPFramesEnable struct{}
// Error implements the error interface.
func (e ErrServerTeardown) Error() string {
return "teardown"
func (e ErrServerTCPFramesEnable) Error() string {
return ""
}
// ErrServerCSeqMissing is returned in case the CSeq is missing.
// ErrServerTCPFramesDisable is an error that can be returned by a server.
type ErrServerTCPFramesDisable struct{}
// Error implements the error interface.
func (e ErrServerTCPFramesDisable) Error() string {
return ""
}
// ErrServerCSeqMissing is an error that can be returned by a server.
type ErrServerCSeqMissing struct{}
// Error implements the error interface.
@@ -23,7 +31,7 @@ func (e ErrServerCSeqMissing) Error() string {
return "CSeq is missing"
}
// ErrServerWrongState is returned in case of a wrong client state.
// ErrServerWrongState is an error that can be returned by a server.
type ErrServerWrongState struct {
AllowedList []fmt.Stringer
State fmt.Stringer
@@ -35,7 +43,7 @@ func (e ErrServerWrongState) Error() string {
e.AllowedList, e.State)
}
// ErrServerNoPath is returned in case the path can't be retrieved.
// ErrServerNoPath is an error that can be returned by a server.
type ErrServerNoPath struct{}
// Error implements the error interface.
@@ -43,7 +51,7 @@ func (e ErrServerNoPath) Error() string {
return "RTSP path can't be retrieved"
}
// ErrServerContentTypeMissing is returned in case the Content-Type header is missing.
// ErrServerContentTypeMissing is an error that can be returned by a server.
type ErrServerContentTypeMissing struct{}
// Error implements the error interface.
@@ -51,7 +59,7 @@ func (e ErrServerContentTypeMissing) Error() string {
return "Content-Type header is missing"
}
// ErrServerContentTypeUnsupported is returned in case the Content-Type header is unsupported.
// ErrServerContentTypeUnsupported is an error that can be returned by a server.
type ErrServerContentTypeUnsupported struct {
CT base.HeaderValue
}
@@ -61,7 +69,7 @@ func (e ErrServerContentTypeUnsupported) Error() string {
return fmt.Sprintf("unsupported Content-Type header '%v'", e.CT)
}
// ErrServerSDPInvalid is returned in case the SDP is invalid.
// ErrServerSDPInvalid is an error that can be returned by a server.
type ErrServerSDPInvalid struct {
Err error
}
@@ -71,7 +79,7 @@ func (e ErrServerSDPInvalid) Error() string {
return fmt.Sprintf("invalid SDP: %v", e.Err)
}
// ErrServerSDPNoTracksDefined is returned in case the SDP has no tracks defined.
// ErrServerSDPNoTracksDefined is an error that can be returned by a server.
type ErrServerSDPNoTracksDefined struct{}
// Error implements the error interface.
@@ -79,7 +87,7 @@ func (e ErrServerSDPNoTracksDefined) Error() string {
return "no tracks defined in the SDP"
}
// ErrServerTransportHeaderInvalid is returned in case the transport header is invalid.
// ErrServerTransportHeaderInvalid is an error that can be returned by a server.
type ErrServerTransportHeaderInvalid struct {
Err error
}
@@ -89,7 +97,7 @@ func (e ErrServerTransportHeaderInvalid) Error() string {
return fmt.Sprintf("invalid transport header: %v", e.Err)
}
// ErrServerTrackAlreadySetup is returned in case a track has already been setup.
// ErrServerTrackAlreadySetup is an error that can be returned by a server.
type ErrServerTrackAlreadySetup struct {
TrackID int
}
@@ -99,7 +107,7 @@ func (e ErrServerTrackAlreadySetup) Error() string {
return fmt.Sprintf("track %d has already been setup", e.TrackID)
}
// ErrServerTransportHeaderWrongMode is returned in case the transport header contains a wrong mode.
// ErrServerTransportHeaderWrongMode is an error that can be returned by a server.
type ErrServerTransportHeaderWrongMode struct {
Mode *headers.TransportMode
}
@@ -109,7 +117,7 @@ func (e ErrServerTransportHeaderWrongMode) Error() string {
return fmt.Sprintf("transport header contains a wrong mode (%v)", e.Mode)
}
// ErrServerTransportHeaderNoClientPorts is returned in case the transport header doesn't contain client ports.
// ErrServerTransportHeaderNoClientPorts is an error that can be returned by a server.
type ErrServerTransportHeaderNoClientPorts struct{}
// Error implements the error interface.
@@ -117,7 +125,7 @@ func (e ErrServerTransportHeaderNoClientPorts) Error() string {
return "transport header does not contain client ports"
}
// ErrServerTransportHeaderNoInterleavedIDs is returned in case the transport header doesn't contain interleaved IDs.
// ErrServerTransportHeaderNoInterleavedIDs is an error that can be returned by a server.
type ErrServerTransportHeaderNoInterleavedIDs struct{}
// Error implements the error interface.
@@ -125,7 +133,7 @@ func (e ErrServerTransportHeaderNoInterleavedIDs) Error() string {
return "transport header does not contain interleaved IDs"
}
// ErrServerTransportHeaderWrongInterleavedIDs is returned in case the transport header contains wrong interleaved IDs.
// ErrServerTransportHeaderWrongInterleavedIDs is an error that can be returned by a server.
type ErrServerTransportHeaderWrongInterleavedIDs struct {
Expected [2]int
Value [2]int
@@ -136,7 +144,7 @@ func (e ErrServerTransportHeaderWrongInterleavedIDs) Error() string {
return fmt.Sprintf("wrong interleaved IDs, expected %v, got %v", e.Expected, e.Value)
}
// ErrServerTracksDifferentProtocols is returned in case the client is trying to setup tracks with different protocols.
// ErrServerTracksDifferentProtocols is an error that can be returned by a server.
type ErrServerTracksDifferentProtocols struct{}
// Error implements the error interface.
@@ -144,7 +152,7 @@ func (e ErrServerTracksDifferentProtocols) Error() string {
return "can't setup tracks with different protocols"
}
// ErrServerNoTracksSetup is returned in case no tracks have been setup.
// ErrServerNoTracksSetup is an error that can be returned by a server.
type ErrServerNoTracksSetup struct{}
// Error implements the error interface.
@@ -152,7 +160,7 @@ func (e ErrServerNoTracksSetup) Error() string {
return "no tracks have been setup"
}
// ErrServerNotAllAnnouncedTracksSetup is returned in case not all announced tracks have been setup.
// ErrServerNotAllAnnouncedTracksSetup is an error that can be returned by a server.
type ErrServerNotAllAnnouncedTracksSetup struct{}
// Error implements the error interface.
@@ -160,10 +168,18 @@ func (e ErrServerNotAllAnnouncedTracksSetup) Error() string {
return "not all announced tracks have been setup"
}
// ErrServerNoUDPPacketsRecently is returned when no UDP packets have been received recently.
// ErrServerNoUDPPacketsRecently is an error that can be returned by a server.
type ErrServerNoUDPPacketsRecently struct{}
// Error implements the error interface.
func (e ErrServerNoUDPPacketsRecently) Error() string {
return "no UDP packets received (maybe there's a firewall/NAT in between)"
}
// ErrServerLinkedToOtherSession is an error that can be returned by a server.
type ErrServerLinkedToOtherSession struct{}
// Error implements the error interface.
func (e ErrServerLinkedToOtherSession) Error() string {
return "connection is linked to another session"
}

View File

@@ -1,7 +1,9 @@
package gortsplib
import (
"crypto/rand"
"crypto/tls"
"encoding/binary"
"fmt"
"net"
"strconv"
@@ -23,6 +25,27 @@ func extractPort(address string) (int, error) {
return int(tmp2), nil
}
func newSessionID(sessions map[string]*ServerSession) (string, error) {
for {
b := make([]byte, 4)
_, err := rand.Read(b)
if err != nil {
return "", err
}
id := strconv.FormatUint(uint64(binary.LittleEndian.Uint32(b)), 10)
if _, ok := sessions[id]; !ok {
return id, nil
}
}
}
type sessionGetReq struct {
id string
res chan *ServerSession
}
// Server is a RTSP server.
type Server struct {
// an handler to handle requests.
@@ -69,11 +92,14 @@ type Server struct {
tcpListener net.Listener
udpRTPListener *serverUDPListener
udpRTCPListener *serverUDPListener
sessions map[string]*ServerSession
conns map[*ServerConn]struct{}
exitError error
// in
connClose chan *ServerConn
sessionGet chan sessionGetReq
sessionClose chan *ServerSession
terminate chan struct{}
// out
@@ -160,8 +186,11 @@ func (s *Server) Start(address string) error {
}
func (s *Server) run() {
s.sessions = make(map[string]*ServerSession)
s.conns = make(map[*ServerConn]struct{})
s.connClose = make(chan *ServerConn)
s.sessionGet = make(chan sessionGetReq)
s.sessionClose = make(chan *ServerSession)
var wg sync.WaitGroup
@@ -199,6 +228,28 @@ outer:
}
s.doConnClose(sc)
case req := <-s.sessionGet:
if ss, ok := s.sessions[req.id]; ok {
req.res <- ss
} else {
id, err := newSessionID(s.sessions)
if err != nil {
req.res <- nil
continue
}
ss := newServerSession(s, id, &wg)
s.sessions[id] = ss
req.res <- ss
}
case ss := <-s.sessionClose:
if _, ok := s.sessions[ss.id]; !ok {
continue
}
s.doSessionClose(ss)
case <-s.terminate:
break outer
}
@@ -222,6 +273,17 @@ outer:
if !ok {
return
}
case req, ok := <-s.sessionGet:
if !ok {
return
}
req.res <- nil
case _, ok := <-s.sessionClose:
if !ok {
return
}
}
}
}()
@@ -240,11 +302,17 @@ outer:
s.doConnClose(sc)
}
for _, ss := range s.sessions {
s.doSessionClose(ss)
}
wg.Wait()
close(acceptErr)
close(connNew)
close(s.connClose)
close(s.sessionGet)
close(s.sessionClose)
close(s.done)
}
@@ -275,3 +343,8 @@ func (s *Server) doConnClose(sc *ServerConn) {
delete(s.conns, sc)
close(sc.terminate)
}
func (s *Server) doSessionClose(ss *ServerSession) {
delete(s.sessions, ss.id)
close(ss.terminate)
}

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"net"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@@ -159,6 +158,7 @@ func TestServerPublishSetupPath(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": th.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -248,6 +248,7 @@ func TestServerPublishSetupErrorDifferentPaths(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": th.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -336,6 +337,7 @@ func TestServerPublishSetupErrorTrackTwice(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": th.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -350,6 +352,7 @@ func TestServerPublishSetupErrorTrackTwice(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Transport": th.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -446,6 +449,7 @@ func TestServerPublishRecordErrorPartialTracks(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": th.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -459,6 +463,7 @@ func TestServerPublishRecordErrorPartialTracks(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -506,7 +511,7 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, StreamTypeRTCP, ctx.StreamType)
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload)
ctx.Conn.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C})
ctx.Session.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C})
}
},
},
@@ -578,6 +583,7 @@ func TestServerPublish(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": inTH.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -607,6 +613,7 @@ func TestServerPublish(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -752,6 +759,7 @@ func TestServerPublishErrorWrongProtocol(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": inTH.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -769,6 +777,7 @@ func TestServerPublishErrorWrongProtocol(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -862,6 +871,7 @@ func TestServerPublishRTCPReport(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": inTH.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -879,6 +889,7 @@ func TestServerPublishRTCPReport(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -945,12 +956,12 @@ func TestServerPublishErrorTimeout(t *testing.T) {
s := &Server{
Handler: &testServerHandler{
onConnClose: func(sc *ServerConn, err error) {
if proto == "udp" {
onSessionClose: func(ss *ServerSession) {
/*if proto == "udp" {
require.Equal(t, "no UDP packets received (maybe there's a firewall/NAT in between)", err.Error())
} else {
require.True(t, strings.HasSuffix(err.Error(), "i/o timeout"))
}
}*/
close(errDone)
},
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
@@ -1038,6 +1049,7 @@ func TestServerPublishErrorTimeout(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": inTH.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -1055,6 +1067,7 @@ func TestServerPublishErrorTimeout(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)

View File

@@ -176,6 +176,7 @@ func TestServerReadSetupErrorDifferentPaths(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": th.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -249,6 +250,7 @@ func TestServerReadSetupErrorTrackTwice(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": th.Write(),
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -277,8 +279,8 @@ func TestServerRead(t *testing.T) {
}, nil
},
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04})
ctx.Conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08})
ctx.Session.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04})
ctx.Session.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08})
return &base.Response{
StatusCode: base.StatusOK,
@@ -359,6 +361,7 @@ func TestServerRead(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -435,7 +438,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
go func() {
defer close(writerDone)
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
ctx.Session.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
@@ -443,7 +446,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
for {
select {
case <-t.C:
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
ctx.Session.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
case <-writerTerminate:
return
}
@@ -499,6 +502,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -514,44 +518,21 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
}
func TestServerReadPlayPlay(t *testing.T) {
writerTerminate := make(chan struct{})
writerDone := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(sc *ServerConn, err error) {
close(writerTerminate)
<-writerDone
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
if ctx.Conn.State() != ServerConnStatePlay {
go func() {
defer close(writerDone)
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for {
select {
case <-t.C:
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
case <-writerTerminate:
return
}
}
}()
}
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
UDPRTPAddress: "127.0.0.1:8000",
UDPRTCPAddress: "127.0.0.1:8001",
}
err := s.Start("127.0.0.1:8554")
@@ -569,7 +550,7 @@ func TestServerReadPlayPlay(t *testing.T) {
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": headers.Transport{
Protocol: StreamProtocolTCP,
Protocol: StreamProtocolUDP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
@@ -578,7 +559,7 @@ func TestServerReadPlayPlay(t *testing.T) {
v := headers.TransportModePlay
return &v
}(),
InterleavedIDs: &[2]int{0, 1},
ClientPorts: &[2]int{30450, 30451},
}.Write(),
},
}.Write(bconn.Writer)
@@ -594,6 +575,7 @@ func TestServerReadPlayPlay(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -607,12 +589,12 @@ func TestServerReadPlayPlay(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
buf := make([]byte, 2048)
err = res.ReadIgnoreFrames(bconn.Reader, buf)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}
@@ -645,7 +627,7 @@ func TestServerReadPlayPausePlay(t *testing.T) {
for {
select {
case <-t.C:
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
ctx.Session.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
case <-writerTerminate:
return
}
@@ -705,6 +687,7 @@ func TestServerReadPlayPausePlay(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -718,20 +701,7 @@ func TestServerReadPlayPausePlay(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
},
}.Write(bconn.Writer)
require.NoError(t, err)
buf := make([]byte, 2048)
err = res.ReadIgnoreFrames(bconn.Reader, buf)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
err = base.Request{
Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -740,10 +710,19 @@ func TestServerReadPlayPausePlay(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var fr base.InterleavedFrame
fr.Payload = make([]byte, 2048)
err = fr.Read(bconn.Reader)
err = base.Request{
Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}
func TestServerReadPlayPausePause(t *testing.T) {
@@ -771,7 +750,7 @@ func TestServerReadPlayPausePause(t *testing.T) {
for {
select {
case <-t.C:
ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
ctx.Session.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
case <-writerTerminate:
return
}
@@ -830,6 +809,7 @@ func TestServerReadPlayPausePause(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -843,6 +823,7 @@ func TestServerReadPlayPausePause(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)
@@ -857,6 +838,7 @@ func TestServerReadPlayPausePause(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
},
}.Write(bconn.Writer)
require.NoError(t, err)

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"crypto/tls"
"fmt"
"io"
"net"
"sync"
"testing"
@@ -17,6 +16,7 @@ import (
type testServerHandler struct {
onConnClose func(*ServerConn, error)
onSessionClose func(*ServerSession)
onDescribe func(*ServerHandlerOnDescribeCtx) (*base.Response, []byte, error)
onAnnounce func(*ServerHandlerOnAnnounceCtx) (*base.Response, error)
onSetup func(*ServerHandlerOnSetupCtx) (*base.Response, error)
@@ -32,6 +32,12 @@ func (sh *testServerHandler) OnConnClose(sc *ServerConn, err error) {
}
}
func (sh *testServerHandler) OnSessionClose(ss *ServerSession) {
if sh.onSessionClose != nil {
sh.onSessionClose(ss)
}
}
func (sh *testServerHandler) OnDescribe(ctx *ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) {
if sh.onDescribe != nil {
return sh.onDescribe(ctx)
@@ -167,23 +173,22 @@ func TestServerHighLevelPublishRead(t *testing.T) {
t.Run(encryptedStr+"_"+ca.publisherSoft+"_"+ca.publisherProto+"_"+
ca.readerSoft+"_"+ca.readerProto, func(t *testing.T) {
var mutex sync.Mutex
var publisher *ServerConn
var publisher *ServerSession
var sdp []byte
readers := make(map[*ServerConn]struct{})
readers := make(map[*ServerSession]struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(sc *ServerConn, err error) {
onSessionClose: func(ss *ServerSession) {
mutex.Lock()
defer mutex.Unlock()
if sc == publisher {
if ss == publisher {
publisher = nil
sdp = nil
} else {
delete(readers, sc)
delete(readers, ss)
}
},
onDescribe: func(ctx *ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) {
@@ -222,7 +227,7 @@ func TestServerHighLevelPublishRead(t *testing.T) {
}, fmt.Errorf("someone is already publishing")
}
publisher = ctx.Conn
publisher = ctx.Session
sdp = ctx.Tracks.Write()
return &base.Response{
@@ -256,7 +261,7 @@ func TestServerHighLevelPublishRead(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
readers[ctx.Conn] = struct{}{}
readers[ctx.Session] = struct{}{}
return &base.Response{
StatusCode: base.StatusOK,
@@ -275,7 +280,7 @@ func TestServerHighLevelPublishRead(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
if ctx.Conn != publisher {
if ctx.Session != publisher {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
@@ -292,7 +297,7 @@ func TestServerHighLevelPublishRead(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
if ctx.Conn == publisher {
if ctx.Session == publisher {
for r := range readers {
r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload)
}
@@ -448,33 +453,3 @@ func TestServerErrorCSeqMissing(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}
func TestServerTeardownResponse(t *testing.T) {
s := &Server{}
err := s.Start("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
err = base.Request{
Method: base.Teardown,
URL: base.MustParseURL("rtsp://localhost:8554/"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
},
}.Write(bconn.Writer)
require.NoError(t, err)
var res base.Response
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
buf := make([]byte, 2048)
_, err = bconn.Read(buf)
require.Equal(t, io.EOF, err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -19,6 +19,16 @@ type ServerHandlerOnConnClose interface {
OnConnClose(*ServerConn, error)
}
// ServerHandlerOnSessionOpen can be implemented by a ServerHandler.
type ServerHandlerOnSessionOpen interface {
OnSessionOpen(*ServerSession)
}
// ServerHandlerOnSessionClose can be implemented by a ServerHandler.
type ServerHandlerOnSessionClose interface {
OnSessionClose(*ServerSession)
}
// ServerHandlerOnRequest can be implemented by a ServerHandler.
type ServerHandlerOnRequest interface {
OnRequest(*base.Request)
@@ -57,8 +67,8 @@ type ServerHandlerOnDescribe interface {
// ServerHandlerOnAnnounceCtx is the context of an ANNOUNCE request.
type ServerHandlerOnAnnounceCtx struct {
Session *ServerSession
Conn *ServerConn
// Session *ServerSession
Req *base.Request
Path string
Query string
@@ -72,8 +82,8 @@ type ServerHandlerOnAnnounce interface {
// ServerHandlerOnSetupCtx is the context of a OPTIONS request.
type ServerHandlerOnSetupCtx struct {
Conn *ServerConn
Session *ServerSession
Conn *ServerConn
Req *base.Request
Path string
Query string
@@ -88,8 +98,8 @@ type ServerHandlerOnSetup interface {
// ServerHandlerOnPlayCtx is the context of a PLAY request.
type ServerHandlerOnPlayCtx struct {
Session *ServerSession
Conn *ServerConn
// Session *ServerSession
Req *base.Request
Path string
Query string
@@ -102,8 +112,8 @@ type ServerHandlerOnPlay interface {
// ServerHandlerOnRecordCtx is the context of a RECORD request.
type ServerHandlerOnRecordCtx struct {
Session *ServerSession
Conn *ServerConn
// Session *ServerSession
Req *base.Request
Path string
Query string
@@ -116,8 +126,8 @@ type ServerHandlerOnRecord interface {
// ServerHandlerOnPauseCtx is the context of a PAUSE request.
type ServerHandlerOnPauseCtx struct {
Session *ServerSession
Conn *ServerConn
// Session *ServerSession
Req *base.Request
Path string
Query string
@@ -156,8 +166,8 @@ type ServerHandlerOnSetParameter interface {
// ServerHandlerOnTeardownCtx is the context of a TEARDOWN request.
type ServerHandlerOnTeardownCtx struct {
Session *ServerSession
Conn *ServerConn
// Session *ServerSession
Req *base.Request
Path string
Query string
@@ -170,8 +180,7 @@ type ServerHandlerOnTeardown interface {
// ServerHandlerOnFrameCtx is the context of a frame request.
type ServerHandlerOnFrameCtx struct {
Conn *ServerConn
// Session *ServerSession
Session *ServerSession
TrackID int
StreamType StreamType
Payload []byte

View File

@@ -1,5 +1,822 @@
package gortsplib
import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/rtcpreceiver"
)
const (
serverSessionCheckStreamPeriod = 1 * time.Second
)
func setupGetTrackIDPathQuery(url *base.URL,
thMode *headers.TransportMode,
announcedTracks []ServerSessionAnnouncedTrack,
setupPath *string, setupQuery *string) (int, string, string, error) {
pathAndQuery, ok := url.RTSPPathAndQuery()
if !ok {
return 0, "", "", liberrors.ErrServerNoPath{}
}
if thMode == nil || *thMode == headers.TransportModePlay {
i := stringsReverseIndex(pathAndQuery, "/trackID=")
// URL doesn't contain trackID - it's track zero
if i < 0 {
if !strings.HasSuffix(pathAndQuery, "/") {
return 0, "", "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery)
}
pathAndQuery = pathAndQuery[:len(pathAndQuery)-1]
path, query := base.PathSplitQuery(pathAndQuery)
// we assume it's track 0
return 0, path, query, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, "", "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery)
}
trackID := int(tmp)
pathAndQuery = pathAndQuery[:i]
path, query := base.PathSplitQuery(pathAndQuery)
if setupPath != nil && (path != *setupPath || query != *setupQuery) {
return 0, "", "", fmt.Errorf("can't setup tracks with different paths")
}
return trackID, path, query, nil
}
for trackID, track := range announcedTracks {
u, _ := track.track.URL()
if u.String() == url.String() {
return trackID, *setupPath, *setupQuery, nil
}
}
return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
}
// ServerSessionState is a state of a ServerSession.
type ServerSessionState int
// standard states.
const (
ServerSessionStateInitial ServerSessionState = iota
ServerSessionStatePrePlay
ServerSessionStatePlay
ServerSessionStatePreRecord
ServerSessionStateRecord
)
// String implements fmt.Stringer.
func (s ServerSessionState) String() string {
switch s {
case ServerSessionStateInitial:
return "initial"
case ServerSessionStatePrePlay:
return "prePlay"
case ServerSessionStatePlay:
return "play"
case ServerSessionStatePreRecord:
return "preRecord"
case ServerSessionStateRecord:
return "record"
}
return "unknown"
}
// ServerSessionSetuppedTrack is a setupped track of a ServerSession.
type ServerSessionSetuppedTrack struct {
udpRTPPort int
udpRTCPPort int
}
// ServerSessionAnnouncedTrack is an announced track of a ServerSession.
type ServerSessionAnnouncedTrack struct {
track *Track
rtcpReceiver *rtcpreceiver.RTCPReceiver
udpLastFrameTime *int64
}
type requestRes struct {
res *base.Response
err error
}
type requestReq struct {
sc *ServerConn
req *base.Request
res chan requestRes
}
// ServerSession is a server-side RTSP session.
type ServerSession struct {
s *Server
id string
wg *sync.WaitGroup
state ServerSessionState
setuppedTracks map[int]ServerSessionSetuppedTrack
setupProtocol *StreamProtocol
setupPath *string
setupQuery *string
// TCP stream protocol
linkedConn *ServerConn
// UDP stream protocol
udpIP net.IP
udpZone string
// publish
announcedTracks []ServerSessionAnnouncedTrack
// in
request chan requestReq
terminate chan struct{}
}
func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
ss := &ServerSession{
s: s,
id: id,
wg: wg,
request: make(chan requestReq),
terminate: make(chan struct{}),
}
wg.Add(1)
go ss.run()
return ss
}
// State returns the state of the session.
func (ss *ServerSession) State() ServerSessionState {
return ss.state
}
// StreamProtocol returns the stream protocol of the setupped tracks.
func (ss *ServerSession) StreamProtocol() *StreamProtocol {
return ss.setupProtocol
}
// SetuppedTracks returns the setupped tracks.
func (ss *ServerSession) SetuppedTracks() map[int]ServerSessionSetuppedTrack {
return ss.setuppedTracks
}
// AnnouncedTracks returns the announced tracks.
func (ss *ServerSession) AnnouncedTracks() []ServerSessionAnnouncedTrack {
return ss.announcedTracks
}
func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error {
if _, ok := allowed[ss.state]; ok {
return nil
}
allowedList := make([]fmt.Stringer, len(allowed))
i := 0
for a := range allowed {
allowedList[i] = a
i++
}
return liberrors.ErrServerWrongState{AllowedList: allowedList, State: ss.state}
}
func (ss *ServerSession) run() {
defer ss.wg.Done()
if h, ok := ss.s.Handler.(ServerHandlerOnSessionOpen); ok {
h.OnSessionOpen(ss)
}
checkStreamTicker := time.NewTicker(serverSessionCheckStreamPeriod)
defer checkStreamTicker.Stop()
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
defer receiverReportTicker.Stop()
outer:
for {
select {
case req := <-ss.request:
res, err := ss.handleRequest(req.sc, req.req)
req.res <- requestRes{res, err}
case <-checkStreamTicker.C:
if ss.state != ServerSessionStateRecord || *ss.setupProtocol != StreamProtocolUDP {
continue
}
inTimeout := func() bool {
now := time.Now()
for _, track := range ss.announcedTracks {
lft := atomic.LoadInt64(track.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) < ss.s.ReadTimeout {
return false
}
}
return true
}()
if inTimeout {
break outer
}
case <-receiverReportTicker.C:
if ss.state != ServerSessionStateRecord {
continue
}
now := time.Now()
for trackID, track := range ss.announcedTracks {
r := track.rtcpReceiver.Report(now)
ss.WriteFrame(trackID, StreamTypeRTCP, r)
}
case <-ss.terminate:
break outer
}
}
go func() {
for req := range ss.request {
req.res <- requestRes{nil, fmt.Errorf("terminated")}
}
}()
switch ss.state {
case ServerSessionStatePlay:
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTCPListener.removeClient(ss)
}
case ServerSessionStateRecord:
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTPListener.removeClient(ss)
ss.s.udpRTCPListener.removeClient(ss)
}
}
if ss.linkedConn != nil {
ss.s.connClose <- ss.linkedConn
}
ss.s.sessionClose <- ss
<-ss.terminate
close(ss.request)
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
h.OnSessionClose(ss)
}
}
func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) {
switch req.Method {
case base.Announce:
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStateInitial: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerContentTypeMissing{}
}
if ct[0] != "application/sdp" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerContentTypeUnsupported{CT: ct}
}
tracks, err := ReadTracks(req.Body, req.URL)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerSDPInvalid{Err: err}
}
if len(tracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerSDPNoTracksDefined{}
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
for _, track := range tracks {
trackURL, err := track.URL()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unable to generate track URL")
}
trackPath, ok := trackURL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track URL (%v)", trackURL)
}
if !strings.HasPrefix(trackPath, path) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track path: must begin with '%s', but is '%s'",
path, trackPath)
}
}
res, err := ss.s.Handler.(ServerHandlerOnAnnounce).OnAnnounce(&ServerHandlerOnAnnounceCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
Tracks: tracks,
})
if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStatePreRecord
ss.setupPath = &path
ss.setupQuery = &query
ss.announcedTracks = make([]ServerSessionAnnouncedTrack, len(tracks))
for trackID, track := range tracks {
clockRate, _ := track.ClockRate()
v := time.Now().Unix()
ss.announcedTracks[trackID] = ServerSessionAnnouncedTrack{
track: track,
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
udpLastFrameTime: &v,
}
}
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
}
return res, err
case base.Setup:
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStateInitial: {},
ServerSessionStatePrePlay: {},
ServerSessionStatePreRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
var th headers.Transport
err = th.Read(req.Header["Transport"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalid{Err: err}
}
if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
trackID, path, query, err := setupGetTrackIDPathQuery(req.URL, th.Mode,
ss.announcedTracks, ss.setupPath, ss.setupQuery)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if _, ok := ss.setuppedTracks[trackID]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTrackAlreadySetup{TrackID: trackID}
}
switch ss.state {
case ServerSessionStateInitial, ServerSessionStatePrePlay: // play
if th.Mode != nil && *th.Mode != headers.TransportModePlay {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderWrongMode{Mode: th.Mode}
}
default: // record
if th.Mode == nil || *th.Mode != headers.TransportModeRecord {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderWrongMode{Mode: th.Mode}
}
}
if th.Protocol == StreamProtocolUDP {
if ss.s.udpRTPListener == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if th.ClientPorts == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoClientPorts{}
}
} else {
if th.InterleavedIDs == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoInterleavedIDs{}
}
if th.InterleavedIDs[0] != (trackID*2) ||
th.InterleavedIDs[1] != (1+trackID*2) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderWrongInterleavedIDs{
Expected: [2]int{(trackID * 2), (1 + trackID*2)}, Value: *th.InterleavedIDs}
}
}
if ss.setupProtocol != nil && *ss.setupProtocol != th.Protocol {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTracksDifferentProtocols{}
}
res, err := ss.s.Handler.(ServerHandlerOnSetup).OnSetup(&ServerHandlerOnSetupCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
TrackID: trackID,
Transport: &th,
})
if res.StatusCode == base.StatusOK {
ss.setupProtocol = &th.Protocol
if ss.setuppedTracks == nil {
ss.setuppedTracks = make(map[int]ServerSessionSetuppedTrack)
}
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
if th.Protocol == StreamProtocolUDP {
ss.setuppedTracks[trackID] = ServerSessionSetuppedTrack{
udpRTPPort: th.ClientPorts[0],
udpRTCPPort: th.ClientPorts[1],
}
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolUDP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
ClientPorts: th.ClientPorts,
ServerPorts: &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()},
}.Write()
} else {
ss.setuppedTracks[trackID] = ServerSessionSetuppedTrack{}
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolTCP,
InterleavedIDs: th.InterleavedIDs,
}.Write()
}
}
if ss.state == ServerSessionStateInitial {
ss.state = ServerSessionStatePrePlay
ss.setupPath = &path
ss.setupQuery = &query
}
// workaround to prevent a bug in rtspclientsink
// that makes impossible for the client to receive the response
// and send frames.
// this was causing problems during unit tests.
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
strings.HasPrefix(ua[0], "GStreamer") {
select {
case <-time.After(1 * time.Second):
case <-sc.terminate:
}
}
return res, err
case base.Play:
// play can be sent twice, allow calling it even if we're already playing
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStatePrePlay: {},
ServerSessionStatePlay: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if len(ss.setuppedTracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoTracksSetup{}
}
// with TCP, PLAY can't be called twice
// with UDP, it can
if ss.state == ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolTCP {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
if ss.state != ServerSessionStatePlay {
ss.linkedConn = sc
}
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
if ss.state != ServerSessionStatePlay {
ss.state = ServerSessionStatePlay
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
if *ss.setupProtocol == StreamProtocolUDP {
ss.udpIP = sc.ip()
ss.udpZone = sc.zone()
// readers can send RTCP frames, they cannot sent RTP frames
for trackID, track := range ss.setuppedTracks {
sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false)
}
return res, err
}
return res, liberrors.ErrServerTCPFramesEnable{}
}
} else {
ss.linkedConn = nil
}
return res, err
case base.Record:
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStatePreRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if len(ss.setuppedTracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoTracksSetup{}
}
if len(ss.setuppedTracks) != len(ss.announcedTracks) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNotAllAnnouncedTracksSetup{}
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStateRecord
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
if *ss.setupProtocol == StreamProtocolUDP {
ss.udpIP = sc.ip()
ss.udpZone = sc.zone()
for trackID, track := range ss.setuppedTracks {
ss.s.udpRTPListener.addClient(ss.udpIP, track.udpRTPPort, ss, trackID, true)
ss.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, true)
// open the firewall by sending packets to the counterpart
ss.WriteFrame(trackID, StreamTypeRTP,
[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
ss.WriteFrame(trackID, StreamTypeRTCP,
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
}
return res, err
}
ss.linkedConn = sc
return res, liberrors.ErrServerTCPFramesEnable{}
}
return res, err
case base.Pause:
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStatePrePlay: {},
ServerSessionStatePlay: {},
ServerSessionStatePreRecord: {},
ServerSessionStateRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := ss.s.Handler.(ServerHandlerOnPause).OnPause(&ServerHandlerOnPauseCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
switch ss.state {
case ServerSessionStatePlay:
ss.state = ServerSessionStatePrePlay
ss.linkedConn = nil
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTCPListener.removeClient(ss)
} else {
return res, liberrors.ErrServerTCPFramesDisable{}
}
case ServerSessionStateRecord:
ss.state = ServerSessionStatePreRecord
ss.linkedConn = nil
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTPListener.removeClient(ss)
} else {
return res, liberrors.ErrServerTCPFramesDisable{}
}
}
}
return res, err
case base.Teardown:
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
return ss.s.Handler.(ServerHandlerOnTeardown).OnTeardown(&ServerHandlerOnTeardownCtx{
Session: ss,
Conn: sc,
Req: req,
Path: path,
Query: query,
})
}
return nil, fmt.Errorf("unimplemented")
}
// WriteFrame writes a frame.
func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload []byte) {
if *ss.setupProtocol == StreamProtocolUDP {
track := ss.setuppedTracks[trackID]
if streamType == StreamTypeRTP {
ss.s.udpRTPListener.write(payload, &net.UDPAddr{
IP: ss.udpIP,
Zone: ss.udpZone,
Port: track.udpRTPPort,
})
} else {
ss.s.udpRTCPListener.write(payload, &net.UDPAddr{
IP: ss.udpIP,
Zone: ss.udpZone,
Port: track.udpRTCPPort,
})
}
} else {
ss.linkedConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,
Payload: payload,
})
}
}

View File

@@ -20,7 +20,7 @@ type bufAddrPair struct {
}
type clientData struct {
sc *ServerConn
ss *ServerSession
trackID int
isPublishing bool
}
@@ -123,13 +123,13 @@ func (u *serverUDPListener) run() {
if clientData.isPublishing {
now := time.Now()
atomic.StoreInt64(clientData.sc.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix())
clientData.sc.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n])
atomic.StoreInt64(clientData.ss.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix())
clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n])
}
if h, ok := u.s.Handler.(ServerHandlerOnFrame); ok {
h.OnFrame(&ServerHandlerOnFrameCtx{
Conn: clientData.sc,
Session: clientData.ss,
TrackID: clientData.trackID,
StreamType: u.streamType,
Payload: buf[:n],
@@ -166,7 +166,7 @@ func (u *serverUDPListener) write(buf []byte, addr *net.UDPAddr) {
u.ringBuffer.Push(bufAddrPair{buf, addr})
}
func (u *serverUDPListener) addClient(ip net.IP, port int, sc *ServerConn, trackID int, isPublishing bool) {
func (u *serverUDPListener) addClient(ip net.IP, port int, ss *ServerSession, trackID int, isPublishing bool) {
u.clientsMutex.Lock()
defer u.clientsMutex.Unlock()
@@ -174,18 +174,19 @@ func (u *serverUDPListener) addClient(ip net.IP, port int, sc *ServerConn, track
addr.fill(ip, port)
u.clients[addr] = &clientData{
sc: sc,
ss: ss,
trackID: trackID,
isPublishing: isPublishing,
}
}
func (u *serverUDPListener) removeClient(ip net.IP, port int) {
func (u *serverUDPListener) removeClient(ss *ServerSession) {
u.clientsMutex.Lock()
defer u.clientsMutex.Unlock()
var addr clientAddr
addr.fill(ip, port)
for addr, data := range u.clients {
if data.ss == ss {
delete(u.clients, addr)
}
}
}