diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index 24e0481e..4041d37f 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -257,3 +257,11 @@ type ErrServerPathHasChanged struct { func (e ErrServerPathHasChanged) Error() string { return fmt.Sprintf("path has changed, was '%s', now is '%s'", e.Prev, e.Cur) } + +// ErrServerCannotSetupFromDifferentIPs is an error that can be returned by a server. +type ErrServerCannotSetupFromDifferentIPs struct{} + +// Error implements the error interface. +func (e ErrServerCannotSetupFromDifferentIPs) Error() string { + return "cannot setup tracks from different IPs" +} diff --git a/serversession.go b/serversession.go index 21240e1c..0acafc51 100644 --- a/serversession.go +++ b/serversession.go @@ -587,6 +587,14 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base StatusCode: base.StatusBadRequest, }, liberrors.ErrServerTransportHeaderNoClientPorts{} } + + if ss.setuppedTracks != nil && + (!ss.udpIP.Equal(sc.ip()) || ss.udpZone != sc.zone()) { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerCannotSetupFromDifferentIPs{} + } + } else if ss.s.MulticastIPRange == "" { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, @@ -672,6 +680,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base switch { case delivery == base.StreamDeliveryMulticast: + ss.udpIP = sc.ip() + ss.udpZone = sc.zone() + th.Protocol = base.StreamProtocolUDP de := base.StreamDeliveryMulticast th.Delivery = &de @@ -685,6 +696,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } case inTH.Protocol == base.StreamProtocolUDP: + ss.udpIP = sc.ip() + ss.udpZone = sc.zone() + sst.udpRTPPort = inTH.ClientPorts[0] sst.udpRTCPPort = inTH.ClientPorts[1] @@ -781,10 +795,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if res.StatusCode == base.StatusOK { ss.state = ServerSessionStateRead - if *ss.setuppedProtocol == base.StreamProtocolUDP { - ss.udpIP = sc.ip() - ss.udpZone = sc.zone() - } else { + if *ss.setuppedProtocol == base.StreamProtocolTCP { ss.tcpConn = sc } @@ -877,10 +888,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base path, query := base.PathSplitQuery(pathAndQuery) // allow to use WriteFrame() before response - if *ss.setuppedProtocol == base.StreamProtocolUDP { - ss.udpIP = sc.ip() - ss.udpZone = sc.zone() - } else { + if *ss.setuppedProtocol == base.StreamProtocolTCP { ss.tcpConn = sc } @@ -919,8 +927,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base return res, liberrors.ErrServerTCPFramesEnable{} } - ss.udpIP = nil - ss.udpZone = "" ss.tcpConn = nil return res, err @@ -964,8 +970,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.setuppedStream.readerSetInactive(ss) ss.state = ServerSessionStatePreRead - ss.udpIP = nil - ss.udpZone = "" ss.tcpConn = nil if *ss.setuppedProtocol == base.StreamProtocolUDP { @@ -976,8 +980,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base case ServerSessionStatePublish: ss.state = ServerSessionStatePrePublish - ss.udpIP = nil - ss.udpZone = "" ss.tcpConn = nil if *ss.setuppedProtocol == base.StreamProtocolUDP {