diff --git a/clientconn.go b/clientconn.go index 914e695c..04a5dd06 100644 --- a/clientconn.go +++ b/clientconn.go @@ -194,7 +194,7 @@ func (c *ClientConn) checkState(allowed map[clientConnState]struct{}) error { for a := range allowed { allowedList = append(allowedList, a) } - return fmt.Errorf("client must be in state %v, while is in state %v", + return fmt.Errorf("must be in state %v, while is in state %v", allowedList, c.state) } diff --git a/server.go b/server.go index cff87338..43ff0726 100644 --- a/server.go +++ b/server.go @@ -22,5 +22,5 @@ func (s *Server) Accept() (*ServerConn, error) { return nil, err } - return newServerConn(s, nconn), nil + return newServerConn(s.conf, nconn), nil } diff --git a/serverconn.go b/serverconn.go index b969c388..c6f191db 100644 --- a/serverconn.go +++ b/serverconn.go @@ -26,14 +26,35 @@ var ( ErrServerTeardown = errors.New("teardown") ) -type serverConnState int +// ServerConnState is the state of the connection. +type ServerConnState int +// standard states. const ( - serverConnStateInitial serverConnState = iota - serverConnStatePlay - serverConnStateRecord + ServerConnStateInitial ServerConnState = iota + ServerConnStatePrePlay + ServerConnStatePlay + ServerConnStatePreRecord + ServerConnStateRecord ) +// String implements fmt.Stringer. +func (s ServerConnState) String() string { + switch s { + case ServerConnStateInitial: + return "initial" + case ServerConnStatePrePlay: + return "prePlay" + case ServerConnStatePlay: + return "play" + case ServerConnStatePreRecord: + return "preRecord" + case ServerConnStateRecord: + return "record" + } + return "uknown" +} + type serverConnTrack struct { proto StreamProtocol rtpPort int @@ -106,13 +127,13 @@ type ServerConnReadHandlers struct { // ServerConn is a server-side RTSP connection. type ServerConn struct { - s *Server + conf ServerConf nconn net.Conn br *bufio.Reader bw *bufio.Writer - state serverConnState + state ServerConnState tracks map[int]serverConnTrack - tracksProto *StreamProtocol + tracksProtocol *StreamProtocol writeMutex sync.Mutex readHandlers ServerConnReadHandlers nextFramesEnabled bool @@ -123,16 +144,16 @@ type ServerConn struct { terminate chan struct{} } -func newServerConn(s *Server, nconn net.Conn) *ServerConn { +func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn { conn := func() net.Conn { - if s.conf.TLSConfig != nil { - return tls.Server(nconn, s.conf.TLSConfig) + if conf.TLSConfig != nil { + return tls.Server(nconn, conf.TLSConfig) } return nconn }() return &ServerConn{ - s: s, + conf: conf, nconn: nconn, br: bufio.NewReaderSize(conn, serverReadBufferSize), bw: bufio.NewWriterSize(conn, serverWriteBufferSize), @@ -148,6 +169,29 @@ func (sc *ServerConn) Close() error { return err } +// State returns the state. +func (sc *ServerConn) State() ServerConnState { + return sc.state +} + +// TracksProtocol returns the tracks protocol. +func (sc *ServerConn) TracksProtocol() *StreamProtocol { + return sc.tracksProtocol +} + +func (sc *ServerConn) checkState(allowed map[ServerConnState]struct{}) error { + if _, ok := allowed[sc.state]; ok { + return nil + } + + var allowedList []ServerConnState + for a := range allowed { + allowedList = append(allowedList, a) + } + return fmt.Errorf("must be in state %v, while is in state %v", + allowedList, sc.state) +} + // NetConn returns the underlying net.Conn. func (sc *ServerConn) NetConn() net.Conn { return sc.nconn @@ -163,20 +207,20 @@ func (sc *ServerConn) zone() string { func (sc *ServerConn) frameModeEnable() { switch sc.state { - case serverConnStatePlay: - if *sc.tracksProto == StreamProtocolTCP { + case ServerConnStatePlay: + if *sc.tracksProtocol == StreamProtocolTCP { sc.nextFramesEnabled = true } - case serverConnStateRecord: - if *sc.tracksProto == StreamProtocolTCP { + case ServerConnStateRecord: + if *sc.tracksProtocol == StreamProtocolTCP { sc.nextFramesEnabled = true sc.readTimeoutEnabled = true } else { for trackID, track := range sc.tracks { - sc.s.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc) - sc.s.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc) + sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc) + sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc) } } } @@ -184,17 +228,17 @@ func (sc *ServerConn) frameModeEnable() { func (sc *ServerConn) frameModeDisable() { switch sc.state { - case serverConnStatePlay: + case ServerConnStatePlay: sc.nextFramesEnabled = false - case serverConnStateRecord: + case ServerConnStateRecord: sc.nextFramesEnabled = false sc.readTimeoutEnabled = false for _, track := range sc.tracks { if track.proto == StreamProtocolUDP { - sc.s.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort) - sc.s.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort) + sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort) + sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort) } } } @@ -245,11 +289,29 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.Describe: if sc.readHandlers.OnDescribe != nil { + err := sc.checkState(map[ServerConnState]struct{}{ + ServerConnStateInitial: {}, + }) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } + return sc.readHandlers.OnDescribe(req) } case base.Announce: if sc.readHandlers.OnAnnounce != nil { + err := sc.checkState(map[ServerConnState]struct{}{ + ServerConnStateInitial: {}, + }) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } + ct, ok := req.Header["Content-Type"] if !ok || len(ct) != 1 { return &base.Response{ @@ -277,11 +339,27 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } res, err := sc.readHandlers.OnAnnounce(req, tracks) + + if res.StatusCode == 200 { + sc.state = ServerConnStatePreRecord + } + return res, err } case base.Setup: if sc.readHandlers.OnSetup != nil { + err := sc.checkState(map[ServerConnState]struct{}{ + ServerConnStateInitial: {}, + ServerConnStatePrePlay: {}, + ServerConnStatePreRecord: {}, + }) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } + _, controlPath, ok := req.URL.BasePathControlAttr() if !ok { return &base.Response{ @@ -309,14 +387,14 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, fmt.Errorf("track %d has already been setup", trackID) } - if sc.tracksProto != nil && *sc.tracksProto != th.Protocol { + if sc.tracksProtocol != nil && *sc.tracksProtocol != th.Protocol { return &base.Response{ StatusCode: base.StatusBadRequest, - }, fmt.Errorf("can't receive tracks with different protocols") + }, fmt.Errorf("can't setup tracks with different protocols") } if th.Protocol == StreamProtocolUDP { - if sc.s.conf.UDPRTPListener == nil { + if sc.conf.UDPRTPListener == nil { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, }, nil @@ -347,9 +425,15 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { res, err := sc.readHandlers.OnSetup(req, th) if res.StatusCode == 200 { - sc.tracksProto = &th.Protocol + sc.tracksProtocol = &th.Protocol if th.Protocol == StreamProtocolUDP { + sc.tracks[trackID] = serverConnTrack{ + proto: StreamProtocolUDP, + rtpPort: th.ClientPorts[0], + rtcpPort: th.ClientPorts[1], + } + res.Header["Transport"] = headers.Transport{ Protocol: StreamProtocolUDP, Delivery: func() *base.StreamDelivery { @@ -357,27 +441,26 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { return &v }(), ClientPorts: th.ClientPorts, - ServerPorts: &[2]int{sc.s.conf.UDPRTPListener.port(), sc.s.conf.UDPRTCPListener.port()}, + ServerPorts: &[2]int{sc.conf.UDPRTPListener.port(), sc.conf.UDPRTCPListener.port()}, }.Write() + } else { sc.tracks[trackID] = serverConnTrack{ - proto: StreamProtocolUDP, - rtpPort: th.ClientPorts[0], - rtcpPort: th.ClientPorts[1], + proto: StreamProtocolTCP, } - } else { res.Header["Transport"] = headers.Transport{ Protocol: StreamProtocolTCP, InterleavedIds: th.InterleavedIds, }.Write() - - sc.tracks[trackID] = serverConnTrack{ - proto: StreamProtocolTCP, - } } } + switch sc.state { + case ServerConnStateInitial: + sc.state = ServerConnStatePrePlay + } + // workaround to prevent a bug in rtspclientsink // that makes impossible for the client to receive the response // and send frames. @@ -397,10 +480,21 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.Play: if sc.readHandlers.OnPlay != nil { + // play can be sent twice, allow calling it even if we're already playing + err := sc.checkState(map[ServerConnState]struct{}{ + ServerConnStatePrePlay: {}, + ServerConnStatePlay: {}, + }) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } + res, err := sc.readHandlers.OnPlay(req) - if res.StatusCode == 200 { - sc.state = serverConnStatePlay + if res.StatusCode == 200 && sc.state != ServerConnStatePlay { + sc.state = ServerConnStatePlay sc.frameModeEnable() } @@ -409,10 +503,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.Record: if sc.readHandlers.OnRecord != nil { + err := sc.checkState(map[ServerConnState]struct{}{ + ServerConnStatePreRecord: {}, + }) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } + res, err := sc.readHandlers.OnRecord(req) if res.StatusCode == 200 { - sc.state = serverConnStateRecord + sc.state = ServerConnStateRecord sc.frameModeEnable() } @@ -421,11 +524,30 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.Pause: if sc.readHandlers.OnPause != nil { + err := sc.checkState(map[ServerConnState]struct{}{ + ServerConnStatePrePlay: {}, + ServerConnStatePlay: {}, + ServerConnStatePreRecord: {}, + ServerConnStateRecord: {}, + }) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } + res, err := sc.readHandlers.OnPause(req) if res.StatusCode == 200 { - sc.frameModeDisable() - sc.state = serverConnStateInitial + switch sc.state { + case ServerConnStatePlay: + sc.frameModeDisable() + sc.state = ServerConnStatePrePlay + + case ServerConnStateRecord: + sc.frameModeDisable() + sc.state = ServerConnStatePreRecord + } } return res, err @@ -471,7 +593,7 @@ func (sc *ServerConn) backgroundRead() error { cseq, ok := req.Header["CSeq"] if !ok || len(cseq) != 1 { sc.writeMutex.Lock() - sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) + sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) base.Response{ StatusCode: base.StatusBadRequest, Header: base.Header{}, @@ -498,7 +620,7 @@ func (sc *ServerConn) backgroundRead() error { sc.writeMutex.Lock() - sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) + sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) res.Write(sc.bw) // set framesEnabled after sending the response @@ -514,13 +636,13 @@ func (sc *ServerConn) backgroundRead() error { var req base.Request var frame base.InterleavedFrame - tcpFrameBuffer := multibuffer.New(sc.s.conf.ReadBufferCount, clientTCPFrameReadBufferSize) + tcpFrameBuffer := multibuffer.New(sc.conf.ReadBufferCount, clientTCPFrameReadBufferSize) var errRet error outer: for { if sc.readTimeoutEnabled { - sc.nconn.SetReadDeadline(time.Now().Add(sc.s.conf.ReadTimeout)) + sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout)) } else { sc.nconn.SetReadDeadline(time.Time{}) } @@ -589,14 +711,14 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b if track.proto == StreamProtocolUDP { if streamType == StreamTypeRTP { - return sc.s.conf.UDPRTPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{ + return sc.conf.UDPRTPListener.write(sc.conf.WriteTimeout, payload, &net.UDPAddr{ IP: sc.ip(), Zone: sc.zone(), Port: track.rtpPort, }) } - return sc.s.conf.UDPRTCPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{ + return sc.conf.UDPRTCPListener.write(sc.conf.WriteTimeout, payload, &net.UDPAddr{ IP: sc.ip(), Zone: sc.zone(), Port: track.rtcpPort, @@ -609,7 +731,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b return errors.New("frames are disabled") } - sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) + sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) frame := base.InterleavedFrame{ TrackID: trackID, StreamType: streamType,