mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 07:06:58 +08:00
server: do not allow a client to control a session created with a different IP
This commit is contained in:
@@ -21,7 +21,7 @@ const (
|
||||
func randUint32() uint32 {
|
||||
var b [4]byte
|
||||
rand.Read(b[:])
|
||||
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3])
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
|
||||
func randIntn(n int) int {
|
||||
|
@@ -258,12 +258,12 @@ func (e ErrServerPathHasChanged) Error() string {
|
||||
return fmt.Sprintf("path has changed, was '%s', now is '%s'", e.Prev, e.Cur)
|
||||
}
|
||||
|
||||
// ErrServerCannotSetupFromDifferentIPs is an error that can be returned by a server.
|
||||
type ErrServerCannotSetupFromDifferentIPs struct{}
|
||||
// ErrServerCannotUseSessionCreatedByOtherIP is an error that can be returned by a server.
|
||||
type ErrServerCannotUseSessionCreatedByOtherIP struct{}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ErrServerCannotSetupFromDifferentIPs) Error() string {
|
||||
return "cannot setup tracks from different IPs"
|
||||
func (e ErrServerCannotUseSessionCreatedByOtherIP) Error() string {
|
||||
return "cannot use a session created with a different IP"
|
||||
}
|
||||
|
||||
// ErrServerUDPPortsAlreadyInUse is an error that can be returned by a server.
|
||||
|
@@ -14,7 +14,7 @@ import (
|
||||
func randUint32() uint32 {
|
||||
var b [4]byte
|
||||
rand.Read(b[:])
|
||||
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3])
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
|
||||
// RTCPReceiver is a utility to generate RTCP receiver reports.
|
||||
|
@@ -16,7 +16,7 @@ const (
|
||||
func randUint32() uint32 {
|
||||
var b [4]byte
|
||||
rand.Read(b[:])
|
||||
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3])
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
|
||||
// Encoder is a RTP/AAC encoder.
|
||||
|
@@ -17,7 +17,7 @@ const (
|
||||
func randUint32() uint32 {
|
||||
var b [4]byte
|
||||
rand.Read(b[:])
|
||||
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3])
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
|
||||
// Encoder is a RTP/H264 encoder.
|
||||
|
13
server.go
13
server.go
@@ -32,7 +32,7 @@ func extractPort(address string) (int, error) {
|
||||
func newSessionSecretID(sessions map[string]*ServerSession) (string, error) {
|
||||
for {
|
||||
b := make([]byte, 4)
|
||||
_, err := rand.Read(b[:])
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -363,6 +363,17 @@ outer:
|
||||
|
||||
case req := <-s.sessionRequest:
|
||||
if ss, ok := s.sessions[req.id]; ok {
|
||||
if !req.sc.ip().Equal(ss.ip()) ||
|
||||
req.sc.zone() != ss.zone() {
|
||||
req.res <- sessionRequestRes{
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
err: liberrors.ErrServerCannotUseSessionCreatedByOtherIP{},
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
ss.request <- req
|
||||
} else {
|
||||
if !req.create {
|
||||
|
@@ -47,6 +47,7 @@ type ServerConn struct {
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel func()
|
||||
remoteAddr *net.TCPAddr // to improve speed
|
||||
br *bufio.Reader
|
||||
bw *bufio.Writer
|
||||
sessions map[string]*ServerSession
|
||||
@@ -76,6 +77,7 @@ func newServerConn(
|
||||
nconn: nconn,
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
remoteAddr: nconn.RemoteAddr().(*net.TCPAddr),
|
||||
sessionRemove: make(chan *ServerSession),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
@@ -98,11 +100,11 @@ func (sc *ServerConn) NetConn() net.Conn {
|
||||
}
|
||||
|
||||
func (sc *ServerConn) ip() net.IP {
|
||||
return sc.nconn.RemoteAddr().(*net.TCPAddr).IP
|
||||
return sc.remoteAddr.IP
|
||||
}
|
||||
|
||||
func (sc *ServerConn) zone() string {
|
||||
return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone
|
||||
return sc.remoteAddr.Zone
|
||||
}
|
||||
|
||||
func (sc *ServerConn) run() {
|
||||
|
@@ -137,8 +137,6 @@ type ServerSession struct {
|
||||
setuppedQuery *string
|
||||
lastRequestTime time.Time
|
||||
tcpConn *ServerConn // tcp
|
||||
udpIP net.IP // udp
|
||||
udpZone string // udp
|
||||
announcedTracks []ServerSessionAnnouncedTrack // publish
|
||||
udpLastFrameTime *int64 // publish, udp
|
||||
|
||||
@@ -203,6 +201,14 @@ func (ss *ServerSession) AnnouncedTracks() []ServerSessionAnnouncedTrack {
|
||||
return ss.announcedTracks
|
||||
}
|
||||
|
||||
func (ss *ServerSession) ip() net.IP {
|
||||
return ss.author.ip()
|
||||
}
|
||||
|
||||
func (ss *ServerSession) zone() string {
|
||||
return ss.author.zone()
|
||||
}
|
||||
|
||||
func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error {
|
||||
if _, ok := allowed[ss.state]; ok {
|
||||
return nil
|
||||
@@ -225,7 +231,6 @@ func (ss *ServerSession) run() {
|
||||
Session: ss,
|
||||
Conn: ss.author,
|
||||
})
|
||||
ss.author = nil
|
||||
}
|
||||
|
||||
err := func() error {
|
||||
@@ -587,13 +592,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerTransportHeaderNoClientPorts{}
|
||||
}
|
||||
|
||||
if ss.setuppedTracks != nil &&
|
||||
(!ss.udpIP.Equal(sc.ip()) || ss.udpZone != sc.zone()) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerCannotSetupFromDifferentIPs{}
|
||||
}
|
||||
} else if ss.s.MulticastIPRange == "" {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusUnsupportedTransport,
|
||||
@@ -649,8 +647,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
err := stream.readerAdd(ss,
|
||||
inTH.Protocol,
|
||||
delivery,
|
||||
sc.ip(),
|
||||
sc.zone(),
|
||||
inTH.ClientPorts,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -685,9 +681,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
|
||||
switch {
|
||||
case delivery == base.StreamDeliveryMulticast:
|
||||
ss.udpIP = sc.ip()
|
||||
ss.udpZone = sc.zone()
|
||||
|
||||
th.Protocol = base.StreamProtocolUDP
|
||||
de := base.StreamDeliveryMulticast
|
||||
th.Delivery = &de
|
||||
@@ -701,9 +694,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
}
|
||||
|
||||
case inTH.Protocol == base.StreamProtocolUDP:
|
||||
ss.udpIP = sc.ip()
|
||||
ss.udpZone = sc.zone()
|
||||
|
||||
sst.udpRTPPort = inTH.ClientPorts[0]
|
||||
sst.udpRTCPPort = inTH.ClientPorts[1]
|
||||
|
||||
@@ -847,7 +837,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
|
||||
for trackID, track := range ss.setuppedTracks {
|
||||
// readers can send RTCP frames
|
||||
sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false)
|
||||
sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false)
|
||||
|
||||
// open the firewall by sending packets to the counterpart
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP,
|
||||
@@ -916,8 +906,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
|
||||
if *ss.setuppedProtocol == base.StreamProtocolUDP {
|
||||
for trackID, track := range ss.setuppedTracks {
|
||||
ss.s.udpRTPListener.addClient(ss.udpIP, track.udpRTPPort, ss, trackID, true)
|
||||
ss.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, true)
|
||||
ss.s.udpRTPListener.addClient(ss.ip(), track.udpRTPPort, ss, trackID, true)
|
||||
ss.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, true)
|
||||
|
||||
// open the firewall by sending packets to the counterpart
|
||||
ss.WriteFrame(trackID, StreamTypeRTP,
|
||||
@@ -1050,14 +1040,14 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
|
||||
|
||||
if streamType == StreamTypeRTP {
|
||||
ss.s.udpRTPListener.write(payload, &net.UDPAddr{
|
||||
IP: ss.udpIP,
|
||||
Zone: ss.udpZone,
|
||||
IP: ss.ip(),
|
||||
Zone: ss.zone(),
|
||||
Port: track.udpRTPPort,
|
||||
})
|
||||
} else {
|
||||
ss.s.udpRTCPListener.write(payload, &net.UDPAddr{
|
||||
IP: ss.udpIP,
|
||||
Zone: ss.udpZone,
|
||||
IP: ss.ip(),
|
||||
Zone: ss.zone(),
|
||||
Port: track.udpRTCPPort,
|
||||
})
|
||||
}
|
||||
|
@@ -116,8 +116,6 @@ func (st *ServerStream) readerAdd(
|
||||
ss *ServerSession,
|
||||
protocol base.StreamProtocol,
|
||||
delivery base.StreamDelivery,
|
||||
ip net.IP,
|
||||
zone string,
|
||||
clientPorts *[2]int,
|
||||
) error {
|
||||
st.mutex.Lock()
|
||||
@@ -137,8 +135,8 @@ func (st *ServerStream) readerAdd(
|
||||
for r := range st.readersUnicast {
|
||||
if *r.setuppedProtocol == base.StreamProtocolUDP &&
|
||||
*r.setuppedDelivery == base.StreamDeliveryUnicast &&
|
||||
r.udpIP.Equal(ip) &&
|
||||
r.udpZone == zone {
|
||||
r.ip().Equal(ss.ip()) &&
|
||||
r.zone() == ss.zone() {
|
||||
for _, rt := range r.setuppedTracks {
|
||||
if rt.udpRTPPort == clientPorts[0] {
|
||||
return liberrors.ErrServerUDPPortsAlreadyInUse{Port: rt.udpRTPPort}
|
||||
@@ -203,7 +201,7 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) {
|
||||
} else {
|
||||
for trackID := range ss.setuppedTracks {
|
||||
st.multicastListeners[trackID].rtcpListener.addClient(
|
||||
ss.udpIP, st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false)
|
||||
ss.ip(), st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user