diff --git a/serversession.go b/serversession.go index 59850279..37a81453 100644 --- a/serversession.go +++ b/serversession.go @@ -595,8 +595,14 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base path, query := base.PathSplitQuery(pathAndQuery) - if ss.state != ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolTCP { - ss.linkedConn = sc + // allow to use WriteFrame() before response + if ss.state != ServerSessionStatePlay { + if *ss.setupProtocol == StreamProtocolUDP { + ss.udpIP = sc.ip() + ss.udpZone = sc.zone() + } else { + ss.linkedConn = sc + } } res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{ @@ -607,14 +613,11 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base Query: query, }) - if res.StatusCode == base.StatusOK { - if ss.state != ServerSessionStatePlay { + if ss.state != ServerSessionStatePlay { + if res.StatusCode == base.StatusOK { ss.state = ServerSessionStatePlay if *ss.setupProtocol == StreamProtocolUDP { - ss.udpIP = sc.ip() - ss.udpZone = sc.zone() - // readers can send RTCP frames, they cannot sent RTP frames for trackID, track := range ss.setuppedTracks { sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false) @@ -625,7 +628,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base return res, liberrors.ErrServerTCPFramesEnable{} } - } else { + + ss.udpIP = nil + ss.udpZone = "" ss.linkedConn = nil } @@ -659,6 +664,14 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base path, query := base.PathSplitQuery(pathAndQuery) + // allow to use WriteFrame() before response + if *ss.setupProtocol == StreamProtocolUDP { + ss.udpIP = sc.ip() + ss.udpZone = sc.zone() + } else { + ss.linkedConn = sc + } + res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ Session: ss, Conn: sc, @@ -671,9 +684,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStateRecord if *ss.setupProtocol == StreamProtocolUDP { - ss.udpIP = sc.ip() - ss.udpZone = sc.zone() - for trackID, track := range ss.setuppedTracks { ss.s.udpRTPListener.addClient(ss.udpIP, track.udpRTPPort, ss, trackID, true) ss.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, true) @@ -688,10 +698,13 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base return res, err } - ss.linkedConn = sc return res, liberrors.ErrServerTCPFramesEnable{} } + ss.udpIP = nil + ss.udpZone = "" + ss.linkedConn = nil + return res, err case base.Pause: @@ -731,6 +744,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base switch ss.state { case ServerSessionStatePlay: ss.state = ServerSessionStatePrePlay + ss.udpIP = nil + ss.udpZone = "" ss.linkedConn = nil if *ss.setupProtocol == StreamProtocolUDP { @@ -741,6 +756,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base case ServerSessionStateRecord: ss.state = ServerSessionStatePreRecord + ss.udpIP = nil + ss.udpZone = "" ss.linkedConn = nil if *ss.setupProtocol == StreamProtocolUDP {