server: replace SetuppedProtocol() with SetuppedTransport()

This commit is contained in:
aler9
2021-10-22 17:40:18 +02:00
parent 7a000bed0e
commit e7ab15750c
6 changed files with 233 additions and 194 deletions

View File

@@ -75,6 +75,29 @@ func setupGetTrackIDPathQuery(
return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
}
func setupGetTransport(th headers.Transport) (ClientTransport, bool) {
delivery := func() base.StreamDelivery {
if th.Delivery != nil {
return *th.Delivery
}
return base.StreamDeliveryUnicast
}()
switch th.Protocol {
case base.StreamProtocolUDP:
if delivery == base.StreamDeliveryUnicast {
return ClientTransportUDP, true
}
return ClientTransportUDPMulticast, true
default: // TCP
if delivery != base.StreamDeliveryUnicast {
return 0, false
}
return ClientTransportTCP, true
}
}
// ServerSessionState is a state of a ServerSession.
type ServerSessionState int
@@ -129,8 +152,7 @@ type ServerSession struct {
state ServerSessionState
setuppedTracks map[int]ServerSessionSetuppedTrack
setuppedTracksByChannel map[int]int // tcp
setuppedProtocol *base.StreamProtocol
setuppedDelivery *base.StreamDelivery
setuppedTransport *ClientTransport
setuppedBaseURL *base.URL // publish
setuppedStream *ServerStream // read
setuppedPath *string
@@ -186,14 +208,9 @@ func (ss *ServerSession) SetuppedTracks() map[int]ServerSessionSetuppedTrack {
return ss.setuppedTracks
}
// SetuppedProtocol returns the stream protocol of the setupped tracks.
func (ss *ServerSession) SetuppedProtocol() *base.StreamProtocol {
return ss.setuppedProtocol
}
// SetuppedDelivery returns the delivery method of the setupped tracks.
func (ss *ServerSession) SetuppedDelivery() *base.StreamDelivery {
return ss.setuppedDelivery
// SetuppedTransport returns the transport of the setupped tracks.
func (ss *ServerSession) SetuppedTransport() *ClientTransport {
return ss.setuppedTransport
}
// AnnouncedTracks returns the announced tracks.
@@ -279,10 +296,10 @@ func (ss *ServerSession) run() {
}
}
// if session is not in state RECORD or PLAY, or protocol is TCP
// if session is not in state RECORD or PLAY, or transport is TCP
if (ss.state != ServerSessionStatePublish &&
ss.state != ServerSessionStateRead) ||
*ss.setuppedProtocol == base.StreamProtocolTCP {
*ss.setuppedTransport == ClientTransportTCP {
// close if there are no active connections
if len(ss.conns) == 0 {
@@ -293,7 +310,8 @@ func (ss *ServerSession) run() {
case <-checkTimeoutTicker.C:
switch {
// in case of RECORD and UDP, timeout happens when no frames are being received
case ss.state == ServerSessionStatePublish && *ss.setuppedProtocol == base.StreamProtocolUDP:
case ss.state == ServerSessionStatePublish && (*ss.setuppedTransport == ClientTransportUDP ||
*ss.setuppedTransport == ClientTransportUDPMulticast):
now := time.Now()
lft := atomic.LoadInt64(ss.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
@@ -301,7 +319,8 @@ func (ss *ServerSession) run() {
}
// in case of PLAY and UDP, timeout happens when no request arrives
case ss.state == ServerSessionStateRead && *ss.setuppedProtocol == base.StreamProtocolUDP:
case ss.state == ServerSessionStateRead && (*ss.setuppedTransport == ClientTransportUDP ||
*ss.setuppedTransport == ClientTransportUDPMulticast):
now := time.Now()
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
return liberrors.ErrServerSessionTimedOut{}
@@ -333,13 +352,12 @@ func (ss *ServerSession) run() {
case ServerSessionStateRead:
ss.setuppedStream.readerSetInactive(ss)
if *ss.setuppedProtocol == base.StreamProtocolUDP &&
*ss.setuppedDelivery == base.StreamDeliveryUnicast {
if *ss.setuppedTransport == ClientTransportUDP {
ss.s.udpRTCPListener.removeClient(ss)
}
case ServerSessionStatePublish:
if *ss.setuppedProtocol == base.StreamProtocolUDP {
if *ss.setuppedTransport == ClientTransportUDP {
ss.s.udpRTPListener.removeClient(ss)
ss.s.udpRTCPListener.removeClient(ss)
}
@@ -550,60 +568,35 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, liberrors.ErrServerTrackAlreadySetup{TrackID: trackID}
}
delivery := func() base.StreamDelivery {
if inTH.Delivery != nil {
return *inTH.Delivery
}
return base.StreamDeliveryUnicast
}()
switch ss.state {
case ServerSessionStateInitial, ServerSessionStatePreRead: // play
if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
}
default: // record
if delivery == base.StreamDeliveryMulticast {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
}
transport, ok := setupGetTransport(inTH)
if !ok {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if inTH.Protocol == base.StreamProtocolUDP {
if delivery == base.StreamDeliveryUnicast {
if ss.s.udpRTPListener == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if inTH.ClientPorts == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoClientPorts{}
}
} else if ss.s.MulticastIPRange == "" {
switch transport {
case ClientTransportUDP:
if inTH.ClientPorts == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoClientPorts{}
}
} else {
if delivery == base.StreamDeliveryMulticast {
if ss.s.udpRTPListener == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
case ClientTransportUDPMulticast:
if ss.s.MulticastIPRange == "" {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
default: // TCP
if inTH.InterleavedIDs == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
@@ -624,13 +617,34 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}
}
if ss.setuppedProtocol != nil &&
(*ss.setuppedProtocol != inTH.Protocol || *ss.setuppedDelivery != delivery) {
if ss.setuppedTransport != nil && *ss.setuppedTransport != transport {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTracksDifferentProtocols{}
}
switch ss.state {
case ServerSessionStateInitial, ServerSessionStatePreRead: // play
if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
}
default: // record
if transport == ClientTransportUDPMulticast {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
}
}
res, stream, err := ss.s.Handler.(ServerHandlerOnSetup).OnSetup(&ServerHandlerOnSetupCtx{
Server: ss.s,
Session: ss,
@@ -639,14 +653,13 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
Path: path,
Query: query,
TrackID: trackID,
Transport: &inTH,
Transport: transport,
})
if res.StatusCode == base.StatusOK {
if ss.state == ServerSessionStateInitial {
err := stream.readerAdd(ss,
inTH.Protocol,
delivery,
transport,
inTH.ClientPorts,
)
if err != nil {
@@ -670,8 +683,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}
}
ss.setuppedProtocol = &inTH.Protocol
ss.setuppedDelivery = &delivery
ss.setuppedTransport = &transport
if res.Header == nil {
res.Header = make(base.Header)
@@ -679,8 +691,18 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
sst := ServerSessionSetuppedTrack{}
switch {
case delivery == base.StreamDeliveryMulticast:
switch transport {
case ClientTransportUDP:
sst.udpRTPPort = inTH.ClientPorts[0]
sst.udpRTCPPort = inTH.ClientPorts[1]
th.Protocol = base.StreamProtocolUDP
de := base.StreamDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
case ClientTransportUDPMulticast:
th.Protocol = base.StreamProtocolUDP
de := base.StreamDeliveryMulticast
th.Delivery = &de
@@ -693,16 +715,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
stream.multicastListeners[trackID].rtcpListener.port(),
}
case inTH.Protocol == base.StreamProtocolUDP:
sst.udpRTPPort = inTH.ClientPorts[0]
sst.udpRTCPPort = inTH.ClientPorts[1]
th.Protocol = base.StreamProtocolUDP
de := base.StreamDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
default: // TCP
sst.tcpChannel = inTH.InterleavedIDs[0]
@@ -790,7 +802,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStateRead
if *ss.setuppedProtocol == base.StreamProtocolTCP {
if *ss.setuppedTransport == ClientTransportTCP {
ss.tcpConn = sc
}
@@ -833,22 +845,26 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.setuppedStream.readerSetActive(ss)
if *ss.setuppedProtocol == base.StreamProtocolUDP {
if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
for trackID, track := range ss.setuppedTracks {
// readers can send RTCP packets
sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false)
switch *ss.setuppedTransport {
case ClientTransportUDP:
for trackID, track := range ss.setuppedTracks {
// readers can send RTCP packets
sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false)
// open the firewall by sending packets to the counterpart
ss.WriteFrame(trackID, StreamTypeRTCP,
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
}
// open the firewall by sending packets to the counterpart
ss.WriteFrame(trackID, StreamTypeRTCP,
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
}
return res, err
case ClientTransportUDPMulticast:
default: // TCP
err = liberrors.ErrServerTCPFramesEnable{}
}
return res, liberrors.ErrServerTCPFramesEnable{}
return res, err
}
}
@@ -883,7 +899,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
path, query := base.PathSplitQuery(pathAndQuery)
// allow to use WriteFrame() before response
if *ss.setuppedProtocol == base.StreamProtocolTCP {
if *ss.setuppedTransport == ClientTransportTCP {
ss.tcpConn = sc
}
@@ -904,7 +920,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStatePublish
if *ss.setuppedProtocol == base.StreamProtocolUDP {
switch *ss.setuppedTransport {
case ClientTransportUDP:
for trackID, track := range ss.setuppedTracks {
ss.s.udpRTPListener.addClient(ss.ip(), track.udpRTPPort, ss, trackID, true)
ss.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, true)
@@ -916,10 +933,13 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
}
return res, err
case ClientTransportUDPMulticast:
default: // TCP
err = liberrors.ErrServerTCPFramesEnable{}
}
return res, liberrors.ErrServerTCPFramesEnable{}
return res, err
}
ss.tcpConn = nil
@@ -967,23 +987,29 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.state = ServerSessionStatePreRead
ss.tcpConn = nil
if *ss.setuppedProtocol == base.StreamProtocolUDP {
if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
ss.s.udpRTCPListener.removeClient(ss)
}
} else {
return res, liberrors.ErrServerTCPFramesDisable{}
switch *ss.setuppedTransport {
case ClientTransportUDP:
ss.s.udpRTCPListener.removeClient(ss)
case ClientTransportUDPMulticast:
default: // TCP
err = liberrors.ErrServerTCPFramesDisable{}
}
case ServerSessionStatePublish:
ss.state = ServerSessionStatePrePublish
ss.tcpConn = nil
if *ss.setuppedProtocol == base.StreamProtocolUDP {
switch *ss.setuppedTransport {
case ClientTransportUDP:
ss.s.udpRTPListener.removeClient(ss)
ss.s.udpRTCPListener.removeClient(ss)
} else {
return res, liberrors.ErrServerTCPFramesDisable{}
case ClientTransportUDPMulticast:
default: // TCP
err = liberrors.ErrServerTCPFramesDisable{}
}
}
}
@@ -1037,25 +1063,25 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
return
}
if *ss.setuppedProtocol == base.StreamProtocolUDP {
if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
track := ss.setuppedTracks[trackID]
switch *ss.setuppedTransport {
case ClientTransportUDP:
track := ss.setuppedTracks[trackID]
if streamType == StreamTypeRTP {
ss.s.udpRTPListener.write(payload, &net.UDPAddr{
IP: ss.ip(),
Zone: ss.zone(),
Port: track.udpRTPPort,
})
} else {
ss.s.udpRTCPListener.write(payload, &net.UDPAddr{
IP: ss.ip(),
Zone: ss.zone(),
Port: track.udpRTCPPort,
})
}
if streamType == StreamTypeRTP {
ss.s.udpRTPListener.write(payload, &net.UDPAddr{
IP: ss.ip(),
Zone: ss.zone(),
Port: track.udpRTPPort,
})
} else {
ss.s.udpRTCPListener.write(payload, &net.UDPAddr{
IP: ss.ip(),
Zone: ss.zone(),
Port: track.udpRTCPPort,
})
}
} else {
case ClientTransportTCP:
channel := ss.setuppedTracks[trackID].tcpChannel
if streamType == base.StreamTypeRTCP {
channel++