server: do not allow a client to control a session created with a different IP

This commit is contained in:
aler9
2021-09-23 19:52:57 +02:00
parent 0454e5407f
commit 239b71d975
9 changed files with 42 additions and 41 deletions

View File

@@ -21,7 +21,7 @@ const (
func randUint32() uint32 { func randUint32() uint32 {
var b [4]byte var b [4]byte
rand.Read(b[:]) rand.Read(b[:])
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3]) return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
} }
func randIntn(n int) int { func randIntn(n int) int {

View File

@@ -258,12 +258,12 @@ func (e ErrServerPathHasChanged) Error() string {
return fmt.Sprintf("path has changed, was '%s', now is '%s'", e.Prev, e.Cur) return fmt.Sprintf("path has changed, was '%s', now is '%s'", e.Prev, e.Cur)
} }
// ErrServerCannotSetupFromDifferentIPs is an error that can be returned by a server. // ErrServerCannotUseSessionCreatedByOtherIP is an error that can be returned by a server.
type ErrServerCannotSetupFromDifferentIPs struct{} type ErrServerCannotUseSessionCreatedByOtherIP struct{}
// Error implements the error interface. // Error implements the error interface.
func (e ErrServerCannotSetupFromDifferentIPs) Error() string { func (e ErrServerCannotUseSessionCreatedByOtherIP) Error() string {
return "cannot setup tracks from different IPs" return "cannot use a session created with a different IP"
} }
// ErrServerUDPPortsAlreadyInUse is an error that can be returned by a server. // ErrServerUDPPortsAlreadyInUse is an error that can be returned by a server.

View File

@@ -14,7 +14,7 @@ import (
func randUint32() uint32 { func randUint32() uint32 {
var b [4]byte var b [4]byte
rand.Read(b[:]) rand.Read(b[:])
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3]) return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
} }
// RTCPReceiver is a utility to generate RTCP receiver reports. // RTCPReceiver is a utility to generate RTCP receiver reports.

View File

@@ -16,7 +16,7 @@ const (
func randUint32() uint32 { func randUint32() uint32 {
var b [4]byte var b [4]byte
rand.Read(b[:]) rand.Read(b[:])
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3]) return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
} }
// Encoder is a RTP/AAC encoder. // Encoder is a RTP/AAC encoder.

View File

@@ -17,7 +17,7 @@ const (
func randUint32() uint32 { func randUint32() uint32 {
var b [4]byte var b [4]byte
rand.Read(b[:]) rand.Read(b[:])
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3]) return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
} }
// Encoder is a RTP/H264 encoder. // Encoder is a RTP/H264 encoder.

View File

@@ -32,7 +32,7 @@ func extractPort(address string) (int, error) {
func newSessionSecretID(sessions map[string]*ServerSession) (string, error) { func newSessionSecretID(sessions map[string]*ServerSession) (string, error) {
for { for {
b := make([]byte, 4) b := make([]byte, 4)
_, err := rand.Read(b[:]) _, err := rand.Read(b)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -363,6 +363,17 @@ outer:
case req := <-s.sessionRequest: case req := <-s.sessionRequest:
if ss, ok := s.sessions[req.id]; ok { if ss, ok := s.sessions[req.id]; ok {
if !req.sc.ip().Equal(ss.ip()) ||
req.sc.zone() != ss.zone() {
req.res <- sessionRequestRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerCannotUseSessionCreatedByOtherIP{},
}
continue
}
ss.request <- req ss.request <- req
} else { } else {
if !req.create { if !req.create {

View File

@@ -47,6 +47,7 @@ type ServerConn struct {
ctx context.Context ctx context.Context
ctxCancel func() ctxCancel func()
remoteAddr *net.TCPAddr // to improve speed
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
sessions map[string]*ServerSession sessions map[string]*ServerSession
@@ -76,6 +77,7 @@ func newServerConn(
nconn: nconn, nconn: nconn,
ctx: ctx, ctx: ctx,
ctxCancel: ctxCancel, ctxCancel: ctxCancel,
remoteAddr: nconn.RemoteAddr().(*net.TCPAddr),
sessionRemove: make(chan *ServerSession), sessionRemove: make(chan *ServerSession),
done: make(chan struct{}), done: make(chan struct{}),
} }
@@ -98,11 +100,11 @@ func (sc *ServerConn) NetConn() net.Conn {
} }
func (sc *ServerConn) ip() net.IP { func (sc *ServerConn) ip() net.IP {
return sc.nconn.RemoteAddr().(*net.TCPAddr).IP return sc.remoteAddr.IP
} }
func (sc *ServerConn) zone() string { func (sc *ServerConn) zone() string {
return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone return sc.remoteAddr.Zone
} }
func (sc *ServerConn) run() { func (sc *ServerConn) run() {

View File

@@ -137,8 +137,6 @@ type ServerSession struct {
setuppedQuery *string setuppedQuery *string
lastRequestTime time.Time lastRequestTime time.Time
tcpConn *ServerConn // tcp tcpConn *ServerConn // tcp
udpIP net.IP // udp
udpZone string // udp
announcedTracks []ServerSessionAnnouncedTrack // publish announcedTracks []ServerSessionAnnouncedTrack // publish
udpLastFrameTime *int64 // publish, udp udpLastFrameTime *int64 // publish, udp
@@ -203,6 +201,14 @@ func (ss *ServerSession) AnnouncedTracks() []ServerSessionAnnouncedTrack {
return ss.announcedTracks return ss.announcedTracks
} }
func (ss *ServerSession) ip() net.IP {
return ss.author.ip()
}
func (ss *ServerSession) zone() string {
return ss.author.zone()
}
func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error { func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error {
if _, ok := allowed[ss.state]; ok { if _, ok := allowed[ss.state]; ok {
return nil return nil
@@ -225,7 +231,6 @@ func (ss *ServerSession) run() {
Session: ss, Session: ss,
Conn: ss.author, Conn: ss.author,
}) })
ss.author = nil
} }
err := func() error { err := func() error {
@@ -587,13 +592,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoClientPorts{} }, liberrors.ErrServerTransportHeaderNoClientPorts{}
} }
if ss.setuppedTracks != nil &&
(!ss.udpIP.Equal(sc.ip()) || ss.udpZone != sc.zone()) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerCannotSetupFromDifferentIPs{}
}
} else if ss.s.MulticastIPRange == "" { } else if ss.s.MulticastIPRange == "" {
return &base.Response{ return &base.Response{
StatusCode: base.StatusUnsupportedTransport, StatusCode: base.StatusUnsupportedTransport,
@@ -649,8 +647,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
err := stream.readerAdd(ss, err := stream.readerAdd(ss,
inTH.Protocol, inTH.Protocol,
delivery, delivery,
sc.ip(),
sc.zone(),
inTH.ClientPorts, inTH.ClientPorts,
) )
if err != nil { if err != nil {
@@ -685,9 +681,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
switch { switch {
case delivery == base.StreamDeliveryMulticast: case delivery == base.StreamDeliveryMulticast:
ss.udpIP = sc.ip()
ss.udpZone = sc.zone()
th.Protocol = base.StreamProtocolUDP th.Protocol = base.StreamProtocolUDP
de := base.StreamDeliveryMulticast de := base.StreamDeliveryMulticast
th.Delivery = &de th.Delivery = &de
@@ -701,9 +694,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
} }
case inTH.Protocol == base.StreamProtocolUDP: case inTH.Protocol == base.StreamProtocolUDP:
ss.udpIP = sc.ip()
ss.udpZone = sc.zone()
sst.udpRTPPort = inTH.ClientPorts[0] sst.udpRTPPort = inTH.ClientPorts[0]
sst.udpRTCPPort = inTH.ClientPorts[1] sst.udpRTCPPort = inTH.ClientPorts[1]
@@ -847,7 +837,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if *ss.setuppedDelivery == base.StreamDeliveryUnicast { if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
for trackID, track := range ss.setuppedTracks { for trackID, track := range ss.setuppedTracks {
// readers can send RTCP frames // readers can send RTCP frames
sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false) sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false)
// open the firewall by sending packets to the counterpart // open the firewall by sending packets to the counterpart
ss.WriteFrame(trackID, StreamTypeRTCP, ss.WriteFrame(trackID, StreamTypeRTCP,
@@ -916,8 +906,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if *ss.setuppedProtocol == base.StreamProtocolUDP { if *ss.setuppedProtocol == base.StreamProtocolUDP {
for trackID, track := range ss.setuppedTracks { for trackID, track := range ss.setuppedTracks {
ss.s.udpRTPListener.addClient(ss.udpIP, track.udpRTPPort, ss, trackID, true) ss.s.udpRTPListener.addClient(ss.ip(), track.udpRTPPort, ss, trackID, true)
ss.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, true) ss.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, true)
// open the firewall by sending packets to the counterpart // open the firewall by sending packets to the counterpart
ss.WriteFrame(trackID, StreamTypeRTP, ss.WriteFrame(trackID, StreamTypeRTP,
@@ -1050,14 +1040,14 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
if streamType == StreamTypeRTP { if streamType == StreamTypeRTP {
ss.s.udpRTPListener.write(payload, &net.UDPAddr{ ss.s.udpRTPListener.write(payload, &net.UDPAddr{
IP: ss.udpIP, IP: ss.ip(),
Zone: ss.udpZone, Zone: ss.zone(),
Port: track.udpRTPPort, Port: track.udpRTPPort,
}) })
} else { } else {
ss.s.udpRTCPListener.write(payload, &net.UDPAddr{ ss.s.udpRTCPListener.write(payload, &net.UDPAddr{
IP: ss.udpIP, IP: ss.ip(),
Zone: ss.udpZone, Zone: ss.zone(),
Port: track.udpRTCPPort, Port: track.udpRTCPPort,
}) })
} }

View File

@@ -116,8 +116,6 @@ func (st *ServerStream) readerAdd(
ss *ServerSession, ss *ServerSession,
protocol base.StreamProtocol, protocol base.StreamProtocol,
delivery base.StreamDelivery, delivery base.StreamDelivery,
ip net.IP,
zone string,
clientPorts *[2]int, clientPorts *[2]int,
) error { ) error {
st.mutex.Lock() st.mutex.Lock()
@@ -137,8 +135,8 @@ func (st *ServerStream) readerAdd(
for r := range st.readersUnicast { for r := range st.readersUnicast {
if *r.setuppedProtocol == base.StreamProtocolUDP && if *r.setuppedProtocol == base.StreamProtocolUDP &&
*r.setuppedDelivery == base.StreamDeliveryUnicast && *r.setuppedDelivery == base.StreamDeliveryUnicast &&
r.udpIP.Equal(ip) && r.ip().Equal(ss.ip()) &&
r.udpZone == zone { r.zone() == ss.zone() {
for _, rt := range r.setuppedTracks { for _, rt := range r.setuppedTracks {
if rt.udpRTPPort == clientPorts[0] { if rt.udpRTPPort == clientPorts[0] {
return liberrors.ErrServerUDPPortsAlreadyInUse{Port: rt.udpRTPPort} return liberrors.ErrServerUDPPortsAlreadyInUse{Port: rt.udpRTPPort}
@@ -203,7 +201,7 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) {
} else { } else {
for trackID := range ss.setuppedTracks { for trackID := range ss.setuppedTracks {
st.multicastListeners[trackID].rtcpListener.addClient( st.multicastListeners[trackID].rtcpListener.addClient(
ss.udpIP, st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false) ss.ip(), st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false)
} }
} }
} }