add ServerConn states

This commit is contained in:
aler9
2021-01-06 15:34:54 +01:00
parent 8e70ac4d59
commit 8cd0b902ed
3 changed files with 170 additions and 48 deletions

View File

@@ -194,7 +194,7 @@ func (c *ClientConn) checkState(allowed map[clientConnState]struct{}) error {
for a := range allowed {
allowedList = append(allowedList, a)
}
return fmt.Errorf("client must be in state %v, while is in state %v",
return fmt.Errorf("must be in state %v, while is in state %v",
allowedList, c.state)
}

View File

@@ -22,5 +22,5 @@ func (s *Server) Accept() (*ServerConn, error) {
return nil, err
}
return newServerConn(s, nconn), nil
return newServerConn(s.conf, nconn), nil
}

View File

@@ -26,14 +26,35 @@ var (
ErrServerTeardown = errors.New("teardown")
)
type serverConnState int
// ServerConnState is the state of the connection.
type ServerConnState int
// standard states.
const (
serverConnStateInitial serverConnState = iota
serverConnStatePlay
serverConnStateRecord
ServerConnStateInitial ServerConnState = iota
ServerConnStatePrePlay
ServerConnStatePlay
ServerConnStatePreRecord
ServerConnStateRecord
)
// String implements fmt.Stringer.
func (s ServerConnState) String() string {
switch s {
case ServerConnStateInitial:
return "initial"
case ServerConnStatePrePlay:
return "prePlay"
case ServerConnStatePlay:
return "play"
case ServerConnStatePreRecord:
return "preRecord"
case ServerConnStateRecord:
return "record"
}
return "uknown"
}
type serverConnTrack struct {
proto StreamProtocol
rtpPort int
@@ -106,13 +127,13 @@ type ServerConnReadHandlers struct {
// ServerConn is a server-side RTSP connection.
type ServerConn struct {
s *Server
conf ServerConf
nconn net.Conn
br *bufio.Reader
bw *bufio.Writer
state serverConnState
state ServerConnState
tracks map[int]serverConnTrack
tracksProto *StreamProtocol
tracksProtocol *StreamProtocol
writeMutex sync.Mutex
readHandlers ServerConnReadHandlers
nextFramesEnabled bool
@@ -123,16 +144,16 @@ type ServerConn struct {
terminate chan struct{}
}
func newServerConn(s *Server, nconn net.Conn) *ServerConn {
func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn {
conn := func() net.Conn {
if s.conf.TLSConfig != nil {
return tls.Server(nconn, s.conf.TLSConfig)
if conf.TLSConfig != nil {
return tls.Server(nconn, conf.TLSConfig)
}
return nconn
}()
return &ServerConn{
s: s,
conf: conf,
nconn: nconn,
br: bufio.NewReaderSize(conn, serverReadBufferSize),
bw: bufio.NewWriterSize(conn, serverWriteBufferSize),
@@ -148,6 +169,29 @@ func (sc *ServerConn) Close() error {
return err
}
// State returns the state.
func (sc *ServerConn) State() ServerConnState {
return sc.state
}
// TracksProtocol returns the tracks protocol.
func (sc *ServerConn) TracksProtocol() *StreamProtocol {
return sc.tracksProtocol
}
func (sc *ServerConn) checkState(allowed map[ServerConnState]struct{}) error {
if _, ok := allowed[sc.state]; ok {
return nil
}
var allowedList []ServerConnState
for a := range allowed {
allowedList = append(allowedList, a)
}
return fmt.Errorf("must be in state %v, while is in state %v",
allowedList, sc.state)
}
// NetConn returns the underlying net.Conn.
func (sc *ServerConn) NetConn() net.Conn {
return sc.nconn
@@ -163,20 +207,20 @@ func (sc *ServerConn) zone() string {
func (sc *ServerConn) frameModeEnable() {
switch sc.state {
case serverConnStatePlay:
if *sc.tracksProto == StreamProtocolTCP {
case ServerConnStatePlay:
if *sc.tracksProtocol == StreamProtocolTCP {
sc.nextFramesEnabled = true
}
case serverConnStateRecord:
if *sc.tracksProto == StreamProtocolTCP {
case ServerConnStateRecord:
if *sc.tracksProtocol == StreamProtocolTCP {
sc.nextFramesEnabled = true
sc.readTimeoutEnabled = true
} else {
for trackID, track := range sc.tracks {
sc.s.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
sc.s.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
}
}
}
@@ -184,17 +228,17 @@ func (sc *ServerConn) frameModeEnable() {
func (sc *ServerConn) frameModeDisable() {
switch sc.state {
case serverConnStatePlay:
case ServerConnStatePlay:
sc.nextFramesEnabled = false
case serverConnStateRecord:
case ServerConnStateRecord:
sc.nextFramesEnabled = false
sc.readTimeoutEnabled = false
for _, track := range sc.tracks {
if track.proto == StreamProtocolUDP {
sc.s.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
sc.s.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
}
}
}
@@ -245,11 +289,29 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Describe:
if sc.readHandlers.OnDescribe != nil {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStateInitial: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
return sc.readHandlers.OnDescribe(req)
}
case base.Announce:
if sc.readHandlers.OnAnnounce != nil {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStateInitial: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
@@ -277,11 +339,27 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}
res, err := sc.readHandlers.OnAnnounce(req, tracks)
if res.StatusCode == 200 {
sc.state = ServerConnStatePreRecord
}
return res, err
}
case base.Setup:
if sc.readHandlers.OnSetup != nil {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStateInitial: {},
ServerConnStatePrePlay: {},
ServerConnStatePreRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
_, controlPath, ok := req.URL.BasePathControlAttr()
if !ok {
return &base.Response{
@@ -309,14 +387,14 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("track %d has already been setup", trackID)
}
if sc.tracksProto != nil && *sc.tracksProto != th.Protocol {
if sc.tracksProtocol != nil && *sc.tracksProtocol != th.Protocol {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("can't receive tracks with different protocols")
}, fmt.Errorf("can't setup tracks with different protocols")
}
if th.Protocol == StreamProtocolUDP {
if sc.s.conf.UDPRTPListener == nil {
if sc.conf.UDPRTPListener == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
@@ -347,9 +425,15 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
res, err := sc.readHandlers.OnSetup(req, th)
if res.StatusCode == 200 {
sc.tracksProto = &th.Protocol
sc.tracksProtocol = &th.Protocol
if th.Protocol == StreamProtocolUDP {
sc.tracks[trackID] = serverConnTrack{
proto: StreamProtocolUDP,
rtpPort: th.ClientPorts[0],
rtcpPort: th.ClientPorts[1],
}
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolUDP,
Delivery: func() *base.StreamDelivery {
@@ -357,27 +441,26 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
return &v
}(),
ClientPorts: th.ClientPorts,
ServerPorts: &[2]int{sc.s.conf.UDPRTPListener.port(), sc.s.conf.UDPRTCPListener.port()},
ServerPorts: &[2]int{sc.conf.UDPRTPListener.port(), sc.conf.UDPRTCPListener.port()},
}.Write()
} else {
sc.tracks[trackID] = serverConnTrack{
proto: StreamProtocolUDP,
rtpPort: th.ClientPorts[0],
rtcpPort: th.ClientPorts[1],
proto: StreamProtocolTCP,
}
} else {
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolTCP,
InterleavedIds: th.InterleavedIds,
}.Write()
sc.tracks[trackID] = serverConnTrack{
proto: StreamProtocolTCP,
}
}
}
switch sc.state {
case ServerConnStateInitial:
sc.state = ServerConnStatePrePlay
}
// workaround to prevent a bug in rtspclientsink
// that makes impossible for the client to receive the response
// and send frames.
@@ -397,10 +480,21 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Play:
if sc.readHandlers.OnPlay != nil {
// play can be sent twice, allow calling it even if we're already playing
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStatePrePlay: {},
ServerConnStatePlay: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
res, err := sc.readHandlers.OnPlay(req)
if res.StatusCode == 200 {
sc.state = serverConnStatePlay
if res.StatusCode == 200 && sc.state != ServerConnStatePlay {
sc.state = ServerConnStatePlay
sc.frameModeEnable()
}
@@ -409,10 +503,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Record:
if sc.readHandlers.OnRecord != nil {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStatePreRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
res, err := sc.readHandlers.OnRecord(req)
if res.StatusCode == 200 {
sc.state = serverConnStateRecord
sc.state = ServerConnStateRecord
sc.frameModeEnable()
}
@@ -421,11 +524,30 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Pause:
if sc.readHandlers.OnPause != nil {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStatePrePlay: {},
ServerConnStatePlay: {},
ServerConnStatePreRecord: {},
ServerConnStateRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
res, err := sc.readHandlers.OnPause(req)
if res.StatusCode == 200 {
sc.frameModeDisable()
sc.state = serverConnStateInitial
switch sc.state {
case ServerConnStatePlay:
sc.frameModeDisable()
sc.state = ServerConnStatePrePlay
case ServerConnStateRecord:
sc.frameModeDisable()
sc.state = ServerConnStatePreRecord
}
}
return res, err
@@ -471,7 +593,7 @@ func (sc *ServerConn) backgroundRead() error {
cseq, ok := req.Header["CSeq"]
if !ok || len(cseq) != 1 {
sc.writeMutex.Lock()
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
base.Response{
StatusCode: base.StatusBadRequest,
Header: base.Header{},
@@ -498,7 +620,7 @@ func (sc *ServerConn) backgroundRead() error {
sc.writeMutex.Lock()
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
res.Write(sc.bw)
// set framesEnabled after sending the response
@@ -514,13 +636,13 @@ func (sc *ServerConn) backgroundRead() error {
var req base.Request
var frame base.InterleavedFrame
tcpFrameBuffer := multibuffer.New(sc.s.conf.ReadBufferCount, clientTCPFrameReadBufferSize)
tcpFrameBuffer := multibuffer.New(sc.conf.ReadBufferCount, clientTCPFrameReadBufferSize)
var errRet error
outer:
for {
if sc.readTimeoutEnabled {
sc.nconn.SetReadDeadline(time.Now().Add(sc.s.conf.ReadTimeout))
sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout))
} else {
sc.nconn.SetReadDeadline(time.Time{})
}
@@ -589,14 +711,14 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b
if track.proto == StreamProtocolUDP {
if streamType == StreamTypeRTP {
return sc.s.conf.UDPRTPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{
return sc.conf.UDPRTPListener.write(sc.conf.WriteTimeout, payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.rtpPort,
})
}
return sc.s.conf.UDPRTCPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{
return sc.conf.UDPRTCPListener.write(sc.conf.WriteTimeout, payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.rtcpPort,
@@ -609,7 +731,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b
return errors.New("frames are disabled")
}
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
frame := base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,