From 239b71d9754c552e0e77450cf5b72f7d0d3499f2 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Thu, 23 Sep 2021 19:52:57 +0200 Subject: [PATCH] server: do not allow a client to control a session created with a different IP --- clientconnudpl.go | 2 +- pkg/liberrors/server.go | 8 +++---- pkg/rtcpreceiver/rtcpreceiver.go | 2 +- pkg/rtpaac/encoder.go | 2 +- pkg/rtph264/encoder.go | 2 +- server.go | 13 ++++++++++- serverconn.go | 6 +++-- serversession.go | 40 ++++++++++++-------------------- serverstream.go | 8 +++---- 9 files changed, 42 insertions(+), 41 deletions(-) diff --git a/clientconnudpl.go b/clientconnudpl.go index 4c3210ce..30c85e45 100644 --- a/clientconnudpl.go +++ b/clientconnudpl.go @@ -21,7 +21,7 @@ const ( func randUint32() uint32 { var b [4]byte rand.Read(b[:]) - return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3]) + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) } func randIntn(n int) int { diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index 3775a058..5caef787 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -258,12 +258,12 @@ func (e ErrServerPathHasChanged) Error() string { return fmt.Sprintf("path has changed, was '%s', now is '%s'", e.Prev, e.Cur) } -// ErrServerCannotSetupFromDifferentIPs is an error that can be returned by a server. -type ErrServerCannotSetupFromDifferentIPs struct{} +// ErrServerCannotUseSessionCreatedByOtherIP is an error that can be returned by a server. +type ErrServerCannotUseSessionCreatedByOtherIP struct{} // Error implements the error interface. -func (e ErrServerCannotSetupFromDifferentIPs) Error() string { - return "cannot setup tracks from different IPs" +func (e ErrServerCannotUseSessionCreatedByOtherIP) Error() string { + return "cannot use a session created with a different IP" } // ErrServerUDPPortsAlreadyInUse is an error that can be returned by a server. diff --git a/pkg/rtcpreceiver/rtcpreceiver.go b/pkg/rtcpreceiver/rtcpreceiver.go index 2f5af518..1650fffa 100644 --- a/pkg/rtcpreceiver/rtcpreceiver.go +++ b/pkg/rtcpreceiver/rtcpreceiver.go @@ -14,7 +14,7 @@ import ( func randUint32() uint32 { var b [4]byte rand.Read(b[:]) - return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3]) + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) } // RTCPReceiver is a utility to generate RTCP receiver reports. diff --git a/pkg/rtpaac/encoder.go b/pkg/rtpaac/encoder.go index 6cd5d1f2..fc79ca9b 100644 --- a/pkg/rtpaac/encoder.go +++ b/pkg/rtpaac/encoder.go @@ -16,7 +16,7 @@ const ( func randUint32() uint32 { var b [4]byte rand.Read(b[:]) - return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3]) + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) } // Encoder is a RTP/AAC encoder. diff --git a/pkg/rtph264/encoder.go b/pkg/rtph264/encoder.go index 47119b7a..67e1f18f 100644 --- a/pkg/rtph264/encoder.go +++ b/pkg/rtph264/encoder.go @@ -17,7 +17,7 @@ const ( func randUint32() uint32 { var b [4]byte rand.Read(b[:]) - return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3]) + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) } // Encoder is a RTP/H264 encoder. diff --git a/server.go b/server.go index 53f34c38..73bd9dac 100644 --- a/server.go +++ b/server.go @@ -32,7 +32,7 @@ func extractPort(address string) (int, error) { func newSessionSecretID(sessions map[string]*ServerSession) (string, error) { for { b := make([]byte, 4) - _, err := rand.Read(b[:]) + _, err := rand.Read(b) if err != nil { return "", err } @@ -363,6 +363,17 @@ outer: case req := <-s.sessionRequest: if ss, ok := s.sessions[req.id]; ok { + if !req.sc.ip().Equal(ss.ip()) || + req.sc.zone() != ss.zone() { + req.res <- sessionRequestRes{ + res: &base.Response{ + StatusCode: base.StatusBadRequest, + }, + err: liberrors.ErrServerCannotUseSessionCreatedByOtherIP{}, + } + continue + } + ss.request <- req } else { if !req.create { diff --git a/serverconn.go b/serverconn.go index 76642658..c837ff8d 100644 --- a/serverconn.go +++ b/serverconn.go @@ -47,6 +47,7 @@ type ServerConn struct { ctx context.Context ctxCancel func() + remoteAddr *net.TCPAddr // to improve speed br *bufio.Reader bw *bufio.Writer sessions map[string]*ServerSession @@ -76,6 +77,7 @@ func newServerConn( nconn: nconn, ctx: ctx, ctxCancel: ctxCancel, + remoteAddr: nconn.RemoteAddr().(*net.TCPAddr), sessionRemove: make(chan *ServerSession), done: make(chan struct{}), } @@ -98,11 +100,11 @@ func (sc *ServerConn) NetConn() net.Conn { } func (sc *ServerConn) ip() net.IP { - return sc.nconn.RemoteAddr().(*net.TCPAddr).IP + return sc.remoteAddr.IP } func (sc *ServerConn) zone() string { - return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone + return sc.remoteAddr.Zone } func (sc *ServerConn) run() { diff --git a/serversession.go b/serversession.go index 39805726..5c9075c5 100644 --- a/serversession.go +++ b/serversession.go @@ -137,8 +137,6 @@ type ServerSession struct { setuppedQuery *string lastRequestTime time.Time tcpConn *ServerConn // tcp - udpIP net.IP // udp - udpZone string // udp announcedTracks []ServerSessionAnnouncedTrack // publish udpLastFrameTime *int64 // publish, udp @@ -203,6 +201,14 @@ func (ss *ServerSession) AnnouncedTracks() []ServerSessionAnnouncedTrack { return ss.announcedTracks } +func (ss *ServerSession) ip() net.IP { + return ss.author.ip() +} + +func (ss *ServerSession) zone() string { + return ss.author.zone() +} + func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error { if _, ok := allowed[ss.state]; ok { return nil @@ -225,7 +231,6 @@ func (ss *ServerSession) run() { Session: ss, Conn: ss.author, }) - ss.author = nil } err := func() error { @@ -587,13 +592,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base StatusCode: base.StatusBadRequest, }, liberrors.ErrServerTransportHeaderNoClientPorts{} } - - if ss.setuppedTracks != nil && - (!ss.udpIP.Equal(sc.ip()) || ss.udpZone != sc.zone()) { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerCannotSetupFromDifferentIPs{} - } } else if ss.s.MulticastIPRange == "" { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, @@ -649,8 +647,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base err := stream.readerAdd(ss, inTH.Protocol, delivery, - sc.ip(), - sc.zone(), inTH.ClientPorts, ) if err != nil { @@ -685,9 +681,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base switch { case delivery == base.StreamDeliveryMulticast: - ss.udpIP = sc.ip() - ss.udpZone = sc.zone() - th.Protocol = base.StreamProtocolUDP de := base.StreamDeliveryMulticast th.Delivery = &de @@ -701,9 +694,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } case inTH.Protocol == base.StreamProtocolUDP: - ss.udpIP = sc.ip() - ss.udpZone = sc.zone() - sst.udpRTPPort = inTH.ClientPorts[0] sst.udpRTCPPort = inTH.ClientPorts[1] @@ -847,7 +837,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if *ss.setuppedDelivery == base.StreamDeliveryUnicast { for trackID, track := range ss.setuppedTracks { // readers can send RTCP frames - sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false) + sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false) // open the firewall by sending packets to the counterpart ss.WriteFrame(trackID, StreamTypeRTCP, @@ -916,8 +906,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if *ss.setuppedProtocol == base.StreamProtocolUDP { for trackID, track := range ss.setuppedTracks { - ss.s.udpRTPListener.addClient(ss.udpIP, track.udpRTPPort, ss, trackID, true) - ss.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, true) + ss.s.udpRTPListener.addClient(ss.ip(), track.udpRTPPort, ss, trackID, true) + ss.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, true) // open the firewall by sending packets to the counterpart ss.WriteFrame(trackID, StreamTypeRTP, @@ -1050,14 +1040,14 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload if streamType == StreamTypeRTP { ss.s.udpRTPListener.write(payload, &net.UDPAddr{ - IP: ss.udpIP, - Zone: ss.udpZone, + IP: ss.ip(), + Zone: ss.zone(), Port: track.udpRTPPort, }) } else { ss.s.udpRTCPListener.write(payload, &net.UDPAddr{ - IP: ss.udpIP, - Zone: ss.udpZone, + IP: ss.ip(), + Zone: ss.zone(), Port: track.udpRTCPPort, }) } diff --git a/serverstream.go b/serverstream.go index 33654af4..b5a5f699 100644 --- a/serverstream.go +++ b/serverstream.go @@ -116,8 +116,6 @@ func (st *ServerStream) readerAdd( ss *ServerSession, protocol base.StreamProtocol, delivery base.StreamDelivery, - ip net.IP, - zone string, clientPorts *[2]int, ) error { st.mutex.Lock() @@ -137,8 +135,8 @@ func (st *ServerStream) readerAdd( for r := range st.readersUnicast { if *r.setuppedProtocol == base.StreamProtocolUDP && *r.setuppedDelivery == base.StreamDeliveryUnicast && - r.udpIP.Equal(ip) && - r.udpZone == zone { + r.ip().Equal(ss.ip()) && + r.zone() == ss.zone() { for _, rt := range r.setuppedTracks { if rt.udpRTPPort == clientPorts[0] { return liberrors.ErrServerUDPPortsAlreadyInUse{Port: rt.udpRTPPort} @@ -203,7 +201,7 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) { } else { for trackID := range ss.setuppedTracks { st.multicastListeners[trackID].rtcpListener.addClient( - ss.udpIP, st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false) + ss.ip(), st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false) } } }