diff --git a/go.mod b/go.mod index a5f93931..1a3c7291 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.15 require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect - github.com/aler9/gortsplib v0.0.0-20210106112607-8e70ac4d59c4 + github.com/aler9/gortsplib v0.0.0-20210106201702-d17ef3fcc3ff github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.4.9 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 diff --git a/go.sum b/go.sum index a0344d79..4815f029 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/aler9/gortsplib v0.0.0-20210106112607-8e70ac4d59c4 h1:gfqoSKl2KyzzZjHfs//ImOtX6u1vgbcFcp0bUCd6Q2Q= -github.com/aler9/gortsplib v0.0.0-20210106112607-8e70ac4d59c4/go.mod h1:8P09VjpiPJFyfkVosyF5/TY82jNwkMN165NS/7sc32I= +github.com/aler9/gortsplib v0.0.0-20210106201702-d17ef3fcc3ff h1:rXe4QSWV7QwDaOW1NCqEOa7T4p5N86Q13urvo82TuPg= +github.com/aler9/gortsplib v0.0.0-20210106201702-d17ef3fcc3ff/go.mod h1:8P09VjpiPJFyfkVosyF5/TY82jNwkMN165NS/7sc32I= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/internal/client/client.go b/internal/client/client.go index 49979661..680232d9 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -30,43 +30,12 @@ const ( pauseAfterAuthError = 2 * time.Second ) -type streamTrack struct { - rtpPort int - rtcpPort int -} - type describeData struct { sdp []byte redirect string err error } -type state int - -const ( - stateInitial state = iota - statePrePlay - statePlay - statePreRecord - stateRecord -) - -func (s state) String() string { - switch s { - case stateInitial: - return "initial" - case statePrePlay: - return "prePlay" - case statePlay: - return "play" - case statePreRecord: - return "preRecord" - case stateRecord: - return "record" - } - return "invalid" -} - // Path is implemented by path.Path. type Path interface { Name() string @@ -100,14 +69,11 @@ type Client struct { conn *gortsplib.ServerConn parent Parent - state state path Path authUser string authPass string authValidator *auth.Validator authFailures int - streamProtocol gortsplib.StreamProtocol - streamTracks map[int]*streamTrack rtcpReceivers map[int]*rtcpreceiver.RTCPReceiver udpLastFrameTimes []*int64 onReadCmd *externalcmd.Cmd @@ -144,8 +110,6 @@ func New( stats: stats, conn: conn, parent: parent, - state: stateInitial, - streamTracks: make(map[int]*streamTrack), rtcpReceivers: make(map[int]*rtcpreceiver.RTCPReceiver), terminate: make(chan struct{}), } @@ -203,15 +167,6 @@ func (c *Client) run() { } onDescribe := func(req *base.Request) (*base.Response, error) { - err := c.checkState(map[state]struct{}{ - stateInitial: {}, - }) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, err - } - basePath, ok := req.URL.BasePath() if !ok { return &base.Response{ @@ -274,7 +229,7 @@ func (c *Client) run() { "Content-Base": base.HeaderValue{req.URL.String() + "/"}, "Content-Type": base.HeaderValue{"application/sdp"}, }, - Content: res.sdp, + Body: res.sdp, }, nil case <-c.terminate: @@ -296,15 +251,6 @@ func (c *Client) run() { } onAnnounce := func(req *base.Request, tracks gortsplib.Tracks) (*base.Response, error) { - err := c.checkState(map[state]struct{}{ - stateInitial: {}, - }) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, err - } - basePath, ok := req.URL.BasePath() if !ok { return &base.Response{ @@ -342,62 +288,28 @@ func (c *Client) run() { } c.path = path - c.state = statePreRecord return &base.Response{ StatusCode: base.StatusOK, }, nil } - onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) { - if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("multicast is not supported") - } - - basePath, controlPath, ok := req.URL.BasePathControlAttr() + onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + basePath, _, ok := req.URL.BasePathControlAttr() if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("unable to find control attribute (%s)", req.URL) } - switch c.state { - // play - case stateInitial, statePrePlay: - if th.Mode != nil && *th.Mode != headers.TransportModePlay { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("transport header must contain mode=play or not contain a mode") - } - + switch c.conn.State() { + case gortsplib.ServerConnStateInitial, gortsplib.ServerConnStatePrePlay: // play if c.path != nil && basePath != c.path.Name() { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath) } - if !strings.HasPrefix(controlPath, "trackID=") { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid control attribute (%s)", controlPath) - } - - tmp, err := strconv.ParseInt(controlPath[len("trackID="):], 10, 64) - if err != nil || tmp < 0 { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid track id (%s)", controlPath) - } - trackID := int(tmp) - - if _, ok := c.streamTracks[trackID]; ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("track %d has already been setup", trackID) - } - // play with UDP if th.Protocol == gortsplib.StreamProtocolUDP { if _, ok := c.protocols[gortsplib.StreamProtocolUDP]; !ok { @@ -406,18 +318,6 @@ func (c *Client) run() { }, nil } - if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolUDP { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("can't receive tracks with different protocols") - } - - if th.ClientPorts == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("transport header does not have valid client ports (%v)", req.Header["Transport"]) - } - path, err := c.parent.OnClientSetupPlay(c, basePath, trackID, req) if err != nil { switch terr := err.(type) { @@ -443,13 +343,6 @@ func (c *Client) run() { } c.path = path - c.state = statePrePlay - - c.streamProtocol = gortsplib.StreamProtocolUDP - c.streamTracks[trackID] = &streamTrack{ - rtpPort: (*th.ClientPorts)[0], - rtcpPort: (*th.ClientPorts)[1], - } return &base.Response{ StatusCode: base.StatusOK, @@ -467,12 +360,6 @@ func (c *Client) run() { }, nil } - if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolTCP { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("can't receive tracks with different protocols") - } - path, err := c.parent.OnClientSetupPlay(c, basePath, trackID, req) if err != nil { switch terr := err.(type) { @@ -498,13 +385,6 @@ func (c *Client) run() { } c.path = path - c.state = statePrePlay - - c.streamProtocol = gortsplib.StreamProtocolTCP - c.streamTracks[trackID] = &streamTrack{ - rtpPort: 0, - rtcpPort: 0, - } return &base.Response{ StatusCode: base.StatusOK, @@ -513,14 +393,7 @@ func (c *Client) run() { }, }, nil - // record - case statePreRecord: - if th.Mode == nil || *th.Mode != headers.TransportModeRecord { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("transport header does not contain mode=record") - } - + default: // record // after ANNOUNCE, c.path is already set if basePath != c.path.Name() { return &base.Response{ @@ -536,31 +409,18 @@ func (c *Client) run() { }, nil } - if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolUDP { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("can't publish tracks with different protocols") - } - if th.ClientPorts == nil { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("transport header does not have valid client ports (%s)", req.Header["Transport"]) } - if len(c.streamTracks) >= c.path.SourceTrackCount() { + if c.conn.TracksLen() >= c.path.SourceTrackCount() { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("all the tracks have already been setup") } - c.streamProtocol = gortsplib.StreamProtocolUDP - trackID := len(c.streamTracks) - c.streamTracks[trackID] = &streamTrack{ - rtpPort: (*th.ClientPorts)[0], - rtcpPort: (*th.ClientPorts)[1], - } - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ @@ -570,72 +430,30 @@ func (c *Client) run() { } // record with TCP + if _, ok := c.protocols[gortsplib.StreamProtocolTCP]; !ok { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, }, nil } - if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolTCP { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("can't publish tracks with different protocols") - } - - interleavedIds := [2]int{len(c.streamTracks) * 2, 1 + len(c.streamTracks)*2} - - if th.InterleavedIds == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("transport header does not contain the interleaved field") - } - - if (*th.InterleavedIds)[0] != interleavedIds[0] || (*th.InterleavedIds)[1] != interleavedIds[1] { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("wrong interleaved ids, expected %v, got %v", interleavedIds, *th.InterleavedIds) - } - - if len(c.streamTracks) >= c.path.SourceTrackCount() { + if c.conn.TracksLen() >= c.path.SourceTrackCount() { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("all the tracks have already been setup") } - c.streamProtocol = gortsplib.StreamProtocolTCP - trackID := len(c.streamTracks) - c.streamTracks[trackID] = &streamTrack{ - rtpPort: 0, - rtcpPort: 0, - } - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Session": base.HeaderValue{sessionID}, }, }, nil - - default: - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("client is in state '%s'", c.state) } } onPlay := func(req *base.Request) (*base.Response, error) { - // play can be sent twice, allow calling it even if we're already playing - err := c.checkState(map[state]struct{}{ - statePrePlay: {}, - statePlay: {}, - }) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, err - } - - if c.state == statePrePlay { + if c.conn.State() == gortsplib.ServerConnStatePrePlay { basePath, ok := req.URL.BasePath() if !ok { return &base.Response{ @@ -652,15 +470,9 @@ func (c *Client) run() { }, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath) } - if len(c.streamTracks) == 0 { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("no tracks have been setup") - } + c.startPlay() } - c.startPlay() - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ @@ -670,15 +482,6 @@ func (c *Client) run() { } onRecord := func(req *base.Request) (*base.Response, error) { - err := c.checkState(map[state]struct{}{ - statePreRecord: {}, - }) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, err - } - basePath, ok := req.URL.BasePath() if !ok { return &base.Response{ @@ -695,7 +498,7 @@ func (c *Client) run() { }, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath) } - if len(c.streamTracks) != c.path.SourceTrackCount() { + if c.conn.TracksLen() != c.path.SourceTrackCount() { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("not all tracks have been setup") @@ -712,27 +515,13 @@ func (c *Client) run() { } onPause := func(req *base.Request) (*base.Response, error) { - err := c.checkState(map[state]struct{}{ - statePrePlay: {}, - statePlay: {}, - statePreRecord: {}, - stateRecord: {}, - }) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, err - } - - switch c.state { - case statePlay: + switch c.conn.State() { + case gortsplib.ServerConnStatePlay: c.stopPlay() - c.state = statePrePlay c.path.OnClientPause(c) - case stateRecord: + case gortsplib.ServerConnStateRecord: c.stopRecord() - c.state = statePreRecord c.path.OnClientPause(c) } @@ -745,21 +534,17 @@ func (c *Client) run() { } onFrame := func(trackID int, streamType gortsplib.StreamType, payload []byte) { - if c.state != stateRecord { + if c.conn.State() != gortsplib.ServerConnStateRecord { return } - if c.streamProtocol == gortsplib.StreamProtocolUDP { + if *c.conn.TracksProtocol() == gortsplib.StreamProtocolUDP { now := time.Now() atomic.StoreInt64(c.udpLastFrameTimes[trackID], now.Unix()) c.rtcpReceivers[trackID].ProcessFrame(now, streamType, payload) c.path.OnFrame(trackID, streamType, payload) } else { - if trackID >= len(c.streamTracks) { - return - } - c.rtcpReceivers[trackID].ProcessFrame(time.Now(), streamType, payload) c.path.OnFrame(trackID, streamType, payload) } @@ -784,11 +569,11 @@ func (c *Client) run() { c.log(logger.Info, "ERR: %s", err) } - switch c.state { - case statePlay: + switch c.conn.State() { + case gortsplib.ServerConnStatePlay: c.stopPlay() - case stateRecord: + case gortsplib.ServerConnStateRecord: c.stopRecord() } @@ -804,11 +589,11 @@ func (c *Client) run() { c.conn.Close() <-readDone - switch c.state { - case statePlay: + switch c.conn.State() { + case gortsplib.ServerConnStatePlay: c.stopPlay() - case stateRecord: + case gortsplib.ServerConnStateRecord: c.stopRecord() } @@ -899,30 +684,15 @@ func (c *Client) Authenticate(authMethods []headers.AuthMethod, ips []interface{ return nil } -func (c *Client) checkState(allowed map[state]struct{}) error { - if _, ok := allowed[c.state]; ok { - return nil - } - - var allowedList []state - for s := range allowed { - allowedList = append(allowedList, s) - } - - return fmt.Errorf("client must be in state %v, while is in state %v", - allowedList, c.state) -} - func (c *Client) startPlay() { - c.state = statePlay c.path.OnClientPlay(c) - c.log(logger.Info, "is reading from path '%s', %d %s with %s", c.path.Name(), len(c.streamTracks), func() string { - if len(c.streamTracks) == 1 { + c.log(logger.Info, "is reading from path '%s', %d %s with %s", c.path.Name(), c.conn.TracksLen(), func() string { + if c.conn.TracksLen() == 1 { return "track" } return "tracks" - }(), c.streamProtocol) + }(), *c.conn.TracksProtocol()) if c.path.Conf().RunOnRead != "" { c.onReadCmd = externalcmd.New(c.path.Conf().RunOnRead, c.path.Conf().RunOnReadRestart, externalcmd.Environment{ @@ -939,31 +709,21 @@ func (c *Client) stopPlay() { } func (c *Client) startRecord() { - c.state = stateRecord c.path.OnClientRecord(c) - c.log(logger.Info, "is publishing to path '%s', %d %s with %s", c.path.Name(), len(c.streamTracks), func() string { - if len(c.streamTracks) == 1 { + c.log(logger.Info, "is publishing to path '%s', %d %s with %s", c.path.Name(), c.conn.TracksLen(), func() string { + if c.conn.TracksLen() == 1 { return "track" } return "tracks" - }(), c.streamProtocol) + }(), *c.conn.TracksProtocol()) - if c.streamProtocol == gortsplib.StreamProtocolUDP { - c.udpLastFrameTimes = make([]*int64, len(c.streamTracks)) - for trackID := range c.streamTracks { + if *c.conn.TracksProtocol() == gortsplib.StreamProtocolUDP { + c.udpLastFrameTimes = make([]*int64, c.conn.TracksLen()) + for trackID := range c.conn.Tracks() { v := time.Now().Unix() c.udpLastFrameTimes[trackID] = &v } - - // open the firewall by sending packets to the counterpart - for trackID := range c.streamTracks { - c.conn.WriteFrame(trackID, gortsplib.StreamTypeRTP, - []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) - - c.conn.WriteFrame(trackID, gortsplib.StreamTypeRTCP, - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) - } } if c.path.Conf().RunOnPublish != "" { @@ -976,7 +736,7 @@ func (c *Client) startRecord() { c.backgroundRecordTerminate = make(chan struct{}) c.backgroundRecordDone = make(chan struct{}) - if c.streamProtocol == gortsplib.StreamProtocolUDP { + if *c.conn.TracksProtocol() == gortsplib.StreamProtocolUDP { go c.backgroundRecordUDP() } else { go c.backgroundRecordTCP() @@ -1005,7 +765,6 @@ func (c *Client) backgroundRecordUDP() { select { case <-checkStreamTicker.C: now := time.Now() - for _, lastUnix := range c.udpLastFrameTimes { last := time.Unix(atomic.LoadInt64(lastUnix), 0) @@ -1018,7 +777,7 @@ func (c *Client) backgroundRecordUDP() { case <-receiverReportTicker.C: now := time.Now() - for trackID := range c.streamTracks { + for trackID := range c.conn.Tracks() { r := c.rtcpReceivers[trackID].Report(now) c.conn.WriteFrame(trackID, gortsplib.StreamTypeRTP, r) } @@ -1039,7 +798,7 @@ func (c *Client) backgroundRecordTCP() { select { case <-receiverReportTicker.C: now := time.Now() - for trackID := range c.streamTracks { + for trackID := range c.conn.Tracks() { r := c.rtcpReceivers[trackID].Report(now) c.conn.WriteFrame(trackID, gortsplib.StreamTypeRTCP, r) } @@ -1052,8 +811,7 @@ func (c *Client) backgroundRecordTCP() { // OnReaderFrame implements path.Reader. func (c *Client) OnReaderFrame(trackID int, streamType base.StreamType, buf []byte) { - _, ok := c.streamTracks[trackID] - if !ok { + if !c.conn.HasTrack(trackID) { return } diff --git a/internal/serverplain/server.go b/internal/serverplain/server.go index 39227d55..986a1e4a 100644 --- a/internal/serverplain/server.go +++ b/internal/serverplain/server.go @@ -53,7 +53,7 @@ func New(port int, done: make(chan struct{}), } - parent.Log(logger.Info, "[TCP/RTSP server] opened on :%d", port) + parent.Log(logger.Info, "[TCP/RTSP listener] opened on :%d", port) go s.run() return s, nil diff --git a/internal/servertls/server.go b/internal/servertls/server.go index 2a688101..65fa5d25 100644 --- a/internal/servertls/server.go +++ b/internal/servertls/server.go @@ -58,7 +58,7 @@ func New(port int, done: make(chan struct{}), } - parent.Log(logger.Info, "[TCP/TLS/RTSPS server] opened on :%d", port) + parent.Log(logger.Info, "[TCP/TLS/RTSPS listener] opened on :%d", port) go s.run() return s, nil diff --git a/main.go b/main.go index 0424f538..87562b38 100644 --- a/main.go +++ b/main.go @@ -296,7 +296,7 @@ func (p *program) closeResources(newConf *conf.Conf) { closeServerUDPRTCP = true } - closeServerTCP := false + closeServerPlain := false if newConf == nil || newConf.EncryptionParsed != p.conf.EncryptionParsed || newConf.RtspPort != p.conf.RtspPort || @@ -304,7 +304,7 @@ func (p *program) closeResources(newConf *conf.Conf) { newConf.WriteTimeout != p.conf.WriteTimeout || closeServerUDPRTP || closeServerUDPRTCP { - closeServerTCP = true + closeServerPlain = true } closeServerTLS := false @@ -329,7 +329,7 @@ func (p *program) closeResources(newConf *conf.Conf) { closeClientMan := false if newConf == nil || - closeServerTCP || + closeServerPlain || closeServerTLS || closePathMan || newConf.RtspPort != p.conf.RtspPort || @@ -360,7 +360,7 @@ func (p *program) closeResources(newConf *conf.Conf) { p.serverTLS = nil } - if closeServerTCP && p.serverTCP != nil { + if closeServerPlain && p.serverTCP != nil { p.serverTCP.Close() p.serverTCP = nil }