mirror of
https://github.com/aler9/gortsplib
synced 2025-10-04 23:02:45 +08:00
add ServerConn states
This commit is contained in:
@@ -194,7 +194,7 @@ func (c *ClientConn) checkState(allowed map[clientConnState]struct{}) error {
|
||||
for a := range allowed {
|
||||
allowedList = append(allowedList, a)
|
||||
}
|
||||
return fmt.Errorf("client must be in state %v, while is in state %v",
|
||||
return fmt.Errorf("must be in state %v, while is in state %v",
|
||||
allowedList, c.state)
|
||||
}
|
||||
|
||||
|
@@ -22,5 +22,5 @@ func (s *Server) Accept() (*ServerConn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newServerConn(s, nconn), nil
|
||||
return newServerConn(s.conf, nconn), nil
|
||||
}
|
||||
|
212
serverconn.go
212
serverconn.go
@@ -26,14 +26,35 @@ var (
|
||||
ErrServerTeardown = errors.New("teardown")
|
||||
)
|
||||
|
||||
type serverConnState int
|
||||
// ServerConnState is the state of the connection.
|
||||
type ServerConnState int
|
||||
|
||||
// standard states.
|
||||
const (
|
||||
serverConnStateInitial serverConnState = iota
|
||||
serverConnStatePlay
|
||||
serverConnStateRecord
|
||||
ServerConnStateInitial ServerConnState = iota
|
||||
ServerConnStatePrePlay
|
||||
ServerConnStatePlay
|
||||
ServerConnStatePreRecord
|
||||
ServerConnStateRecord
|
||||
)
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (s ServerConnState) String() string {
|
||||
switch s {
|
||||
case ServerConnStateInitial:
|
||||
return "initial"
|
||||
case ServerConnStatePrePlay:
|
||||
return "prePlay"
|
||||
case ServerConnStatePlay:
|
||||
return "play"
|
||||
case ServerConnStatePreRecord:
|
||||
return "preRecord"
|
||||
case ServerConnStateRecord:
|
||||
return "record"
|
||||
}
|
||||
return "uknown"
|
||||
}
|
||||
|
||||
type serverConnTrack struct {
|
||||
proto StreamProtocol
|
||||
rtpPort int
|
||||
@@ -106,13 +127,13 @@ type ServerConnReadHandlers struct {
|
||||
|
||||
// ServerConn is a server-side RTSP connection.
|
||||
type ServerConn struct {
|
||||
s *Server
|
||||
conf ServerConf
|
||||
nconn net.Conn
|
||||
br *bufio.Reader
|
||||
bw *bufio.Writer
|
||||
state serverConnState
|
||||
state ServerConnState
|
||||
tracks map[int]serverConnTrack
|
||||
tracksProto *StreamProtocol
|
||||
tracksProtocol *StreamProtocol
|
||||
writeMutex sync.Mutex
|
||||
readHandlers ServerConnReadHandlers
|
||||
nextFramesEnabled bool
|
||||
@@ -123,16 +144,16 @@ type ServerConn struct {
|
||||
terminate chan struct{}
|
||||
}
|
||||
|
||||
func newServerConn(s *Server, nconn net.Conn) *ServerConn {
|
||||
func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn {
|
||||
conn := func() net.Conn {
|
||||
if s.conf.TLSConfig != nil {
|
||||
return tls.Server(nconn, s.conf.TLSConfig)
|
||||
if conf.TLSConfig != nil {
|
||||
return tls.Server(nconn, conf.TLSConfig)
|
||||
}
|
||||
return nconn
|
||||
}()
|
||||
|
||||
return &ServerConn{
|
||||
s: s,
|
||||
conf: conf,
|
||||
nconn: nconn,
|
||||
br: bufio.NewReaderSize(conn, serverReadBufferSize),
|
||||
bw: bufio.NewWriterSize(conn, serverWriteBufferSize),
|
||||
@@ -148,6 +169,29 @@ func (sc *ServerConn) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// State returns the state.
|
||||
func (sc *ServerConn) State() ServerConnState {
|
||||
return sc.state
|
||||
}
|
||||
|
||||
// TracksProtocol returns the tracks protocol.
|
||||
func (sc *ServerConn) TracksProtocol() *StreamProtocol {
|
||||
return sc.tracksProtocol
|
||||
}
|
||||
|
||||
func (sc *ServerConn) checkState(allowed map[ServerConnState]struct{}) error {
|
||||
if _, ok := allowed[sc.state]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var allowedList []ServerConnState
|
||||
for a := range allowed {
|
||||
allowedList = append(allowedList, a)
|
||||
}
|
||||
return fmt.Errorf("must be in state %v, while is in state %v",
|
||||
allowedList, sc.state)
|
||||
}
|
||||
|
||||
// NetConn returns the underlying net.Conn.
|
||||
func (sc *ServerConn) NetConn() net.Conn {
|
||||
return sc.nconn
|
||||
@@ -163,20 +207,20 @@ func (sc *ServerConn) zone() string {
|
||||
|
||||
func (sc *ServerConn) frameModeEnable() {
|
||||
switch sc.state {
|
||||
case serverConnStatePlay:
|
||||
if *sc.tracksProto == StreamProtocolTCP {
|
||||
case ServerConnStatePlay:
|
||||
if *sc.tracksProtocol == StreamProtocolTCP {
|
||||
sc.nextFramesEnabled = true
|
||||
}
|
||||
|
||||
case serverConnStateRecord:
|
||||
if *sc.tracksProto == StreamProtocolTCP {
|
||||
case ServerConnStateRecord:
|
||||
if *sc.tracksProtocol == StreamProtocolTCP {
|
||||
sc.nextFramesEnabled = true
|
||||
sc.readTimeoutEnabled = true
|
||||
|
||||
} else {
|
||||
for trackID, track := range sc.tracks {
|
||||
sc.s.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
|
||||
sc.s.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
|
||||
sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
|
||||
sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -184,17 +228,17 @@ func (sc *ServerConn) frameModeEnable() {
|
||||
|
||||
func (sc *ServerConn) frameModeDisable() {
|
||||
switch sc.state {
|
||||
case serverConnStatePlay:
|
||||
case ServerConnStatePlay:
|
||||
sc.nextFramesEnabled = false
|
||||
|
||||
case serverConnStateRecord:
|
||||
case ServerConnStateRecord:
|
||||
sc.nextFramesEnabled = false
|
||||
sc.readTimeoutEnabled = false
|
||||
|
||||
for _, track := range sc.tracks {
|
||||
if track.proto == StreamProtocolUDP {
|
||||
sc.s.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
|
||||
sc.s.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
|
||||
sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
|
||||
sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,11 +289,29 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
case base.Describe:
|
||||
if sc.readHandlers.OnDescribe != nil {
|
||||
err := sc.checkState(map[ServerConnState]struct{}{
|
||||
ServerConnStateInitial: {},
|
||||
})
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, err
|
||||
}
|
||||
|
||||
return sc.readHandlers.OnDescribe(req)
|
||||
}
|
||||
|
||||
case base.Announce:
|
||||
if sc.readHandlers.OnAnnounce != nil {
|
||||
err := sc.checkState(map[ServerConnState]struct{}{
|
||||
ServerConnStateInitial: {},
|
||||
})
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, err
|
||||
}
|
||||
|
||||
ct, ok := req.Header["Content-Type"]
|
||||
if !ok || len(ct) != 1 {
|
||||
return &base.Response{
|
||||
@@ -277,11 +339,27 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}
|
||||
|
||||
res, err := sc.readHandlers.OnAnnounce(req, tracks)
|
||||
|
||||
if res.StatusCode == 200 {
|
||||
sc.state = ServerConnStatePreRecord
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
case base.Setup:
|
||||
if sc.readHandlers.OnSetup != nil {
|
||||
err := sc.checkState(map[ServerConnState]struct{}{
|
||||
ServerConnStateInitial: {},
|
||||
ServerConnStatePrePlay: {},
|
||||
ServerConnStatePreRecord: {},
|
||||
})
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, err
|
||||
}
|
||||
|
||||
_, controlPath, ok := req.URL.BasePathControlAttr()
|
||||
if !ok {
|
||||
return &base.Response{
|
||||
@@ -309,14 +387,14 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}, fmt.Errorf("track %d has already been setup", trackID)
|
||||
}
|
||||
|
||||
if sc.tracksProto != nil && *sc.tracksProto != th.Protocol {
|
||||
if sc.tracksProtocol != nil && *sc.tracksProtocol != th.Protocol {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("can't receive tracks with different protocols")
|
||||
}, fmt.Errorf("can't setup tracks with different protocols")
|
||||
}
|
||||
|
||||
if th.Protocol == StreamProtocolUDP {
|
||||
if sc.s.conf.UDPRTPListener == nil {
|
||||
if sc.conf.UDPRTPListener == nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusUnsupportedTransport,
|
||||
}, nil
|
||||
@@ -347,9 +425,15 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
res, err := sc.readHandlers.OnSetup(req, th)
|
||||
|
||||
if res.StatusCode == 200 {
|
||||
sc.tracksProto = &th.Protocol
|
||||
sc.tracksProtocol = &th.Protocol
|
||||
|
||||
if th.Protocol == StreamProtocolUDP {
|
||||
sc.tracks[trackID] = serverConnTrack{
|
||||
proto: StreamProtocolUDP,
|
||||
rtpPort: th.ClientPorts[0],
|
||||
rtcpPort: th.ClientPorts[1],
|
||||
}
|
||||
|
||||
res.Header["Transport"] = headers.Transport{
|
||||
Protocol: StreamProtocolUDP,
|
||||
Delivery: func() *base.StreamDelivery {
|
||||
@@ -357,25 +441,24 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
return &v
|
||||
}(),
|
||||
ClientPorts: th.ClientPorts,
|
||||
ServerPorts: &[2]int{sc.s.conf.UDPRTPListener.port(), sc.s.conf.UDPRTCPListener.port()},
|
||||
ServerPorts: &[2]int{sc.conf.UDPRTPListener.port(), sc.conf.UDPRTCPListener.port()},
|
||||
}.Write()
|
||||
|
||||
} else {
|
||||
sc.tracks[trackID] = serverConnTrack{
|
||||
proto: StreamProtocolUDP,
|
||||
rtpPort: th.ClientPorts[0],
|
||||
rtcpPort: th.ClientPorts[1],
|
||||
proto: StreamProtocolTCP,
|
||||
}
|
||||
|
||||
} else {
|
||||
res.Header["Transport"] = headers.Transport{
|
||||
Protocol: StreamProtocolTCP,
|
||||
InterleavedIds: th.InterleavedIds,
|
||||
}.Write()
|
||||
}
|
||||
}
|
||||
|
||||
sc.tracks[trackID] = serverConnTrack{
|
||||
proto: StreamProtocolTCP,
|
||||
}
|
||||
}
|
||||
switch sc.state {
|
||||
case ServerConnStateInitial:
|
||||
sc.state = ServerConnStatePrePlay
|
||||
}
|
||||
|
||||
// workaround to prevent a bug in rtspclientsink
|
||||
@@ -397,10 +480,21 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
case base.Play:
|
||||
if sc.readHandlers.OnPlay != nil {
|
||||
// play can be sent twice, allow calling it even if we're already playing
|
||||
err := sc.checkState(map[ServerConnState]struct{}{
|
||||
ServerConnStatePrePlay: {},
|
||||
ServerConnStatePlay: {},
|
||||
})
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, err
|
||||
}
|
||||
|
||||
res, err := sc.readHandlers.OnPlay(req)
|
||||
|
||||
if res.StatusCode == 200 {
|
||||
sc.state = serverConnStatePlay
|
||||
if res.StatusCode == 200 && sc.state != ServerConnStatePlay {
|
||||
sc.state = ServerConnStatePlay
|
||||
sc.frameModeEnable()
|
||||
}
|
||||
|
||||
@@ -409,10 +503,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
case base.Record:
|
||||
if sc.readHandlers.OnRecord != nil {
|
||||
err := sc.checkState(map[ServerConnState]struct{}{
|
||||
ServerConnStatePreRecord: {},
|
||||
})
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, err
|
||||
}
|
||||
|
||||
res, err := sc.readHandlers.OnRecord(req)
|
||||
|
||||
if res.StatusCode == 200 {
|
||||
sc.state = serverConnStateRecord
|
||||
sc.state = ServerConnStateRecord
|
||||
sc.frameModeEnable()
|
||||
}
|
||||
|
||||
@@ -421,11 +524,30 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
case base.Pause:
|
||||
if sc.readHandlers.OnPause != nil {
|
||||
err := sc.checkState(map[ServerConnState]struct{}{
|
||||
ServerConnStatePrePlay: {},
|
||||
ServerConnStatePlay: {},
|
||||
ServerConnStatePreRecord: {},
|
||||
ServerConnStateRecord: {},
|
||||
})
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, err
|
||||
}
|
||||
|
||||
res, err := sc.readHandlers.OnPause(req)
|
||||
|
||||
if res.StatusCode == 200 {
|
||||
switch sc.state {
|
||||
case ServerConnStatePlay:
|
||||
sc.frameModeDisable()
|
||||
sc.state = serverConnStateInitial
|
||||
sc.state = ServerConnStatePrePlay
|
||||
|
||||
case ServerConnStateRecord:
|
||||
sc.frameModeDisable()
|
||||
sc.state = ServerConnStatePreRecord
|
||||
}
|
||||
}
|
||||
|
||||
return res, err
|
||||
@@ -471,7 +593,7 @@ func (sc *ServerConn) backgroundRead() error {
|
||||
cseq, ok := req.Header["CSeq"]
|
||||
if !ok || len(cseq) != 1 {
|
||||
sc.writeMutex.Lock()
|
||||
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
|
||||
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
|
||||
base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
Header: base.Header{},
|
||||
@@ -498,7 +620,7 @@ func (sc *ServerConn) backgroundRead() error {
|
||||
|
||||
sc.writeMutex.Lock()
|
||||
|
||||
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
|
||||
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
|
||||
res.Write(sc.bw)
|
||||
|
||||
// set framesEnabled after sending the response
|
||||
@@ -514,13 +636,13 @@ func (sc *ServerConn) backgroundRead() error {
|
||||
|
||||
var req base.Request
|
||||
var frame base.InterleavedFrame
|
||||
tcpFrameBuffer := multibuffer.New(sc.s.conf.ReadBufferCount, clientTCPFrameReadBufferSize)
|
||||
tcpFrameBuffer := multibuffer.New(sc.conf.ReadBufferCount, clientTCPFrameReadBufferSize)
|
||||
var errRet error
|
||||
|
||||
outer:
|
||||
for {
|
||||
if sc.readTimeoutEnabled {
|
||||
sc.nconn.SetReadDeadline(time.Now().Add(sc.s.conf.ReadTimeout))
|
||||
sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout))
|
||||
} else {
|
||||
sc.nconn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
@@ -589,14 +711,14 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b
|
||||
|
||||
if track.proto == StreamProtocolUDP {
|
||||
if streamType == StreamTypeRTP {
|
||||
return sc.s.conf.UDPRTPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{
|
||||
return sc.conf.UDPRTPListener.write(sc.conf.WriteTimeout, payload, &net.UDPAddr{
|
||||
IP: sc.ip(),
|
||||
Zone: sc.zone(),
|
||||
Port: track.rtpPort,
|
||||
})
|
||||
}
|
||||
|
||||
return sc.s.conf.UDPRTCPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{
|
||||
return sc.conf.UDPRTCPListener.write(sc.conf.WriteTimeout, payload, &net.UDPAddr{
|
||||
IP: sc.ip(),
|
||||
Zone: sc.zone(),
|
||||
Port: track.rtcpPort,
|
||||
@@ -609,7 +731,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b
|
||||
return errors.New("frames are disabled")
|
||||
}
|
||||
|
||||
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
|
||||
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
|
||||
frame := base.InterleavedFrame{
|
||||
TrackID: trackID,
|
||||
StreamType: streamType,
|
||||
|
Reference in New Issue
Block a user