mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 07:06:58 +08:00
server: do not allow a client to control a session created with a different IP
This commit is contained in:
@@ -21,7 +21,7 @@ const (
|
|||||||
func randUint32() uint32 {
|
func randUint32() uint32 {
|
||||||
var b [4]byte
|
var b [4]byte
|
||||||
rand.Read(b[:])
|
rand.Read(b[:])
|
||||||
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3])
|
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||||
}
|
}
|
||||||
|
|
||||||
func randIntn(n int) int {
|
func randIntn(n int) int {
|
||||||
|
@@ -258,12 +258,12 @@ func (e ErrServerPathHasChanged) Error() string {
|
|||||||
return fmt.Sprintf("path has changed, was '%s', now is '%s'", e.Prev, e.Cur)
|
return fmt.Sprintf("path has changed, was '%s', now is '%s'", e.Prev, e.Cur)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrServerCannotSetupFromDifferentIPs is an error that can be returned by a server.
|
// ErrServerCannotUseSessionCreatedByOtherIP is an error that can be returned by a server.
|
||||||
type ErrServerCannotSetupFromDifferentIPs struct{}
|
type ErrServerCannotUseSessionCreatedByOtherIP struct{}
|
||||||
|
|
||||||
// Error implements the error interface.
|
// Error implements the error interface.
|
||||||
func (e ErrServerCannotSetupFromDifferentIPs) Error() string {
|
func (e ErrServerCannotUseSessionCreatedByOtherIP) Error() string {
|
||||||
return "cannot setup tracks from different IPs"
|
return "cannot use a session created with a different IP"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrServerUDPPortsAlreadyInUse is an error that can be returned by a server.
|
// ErrServerUDPPortsAlreadyInUse is an error that can be returned by a server.
|
||||||
|
@@ -14,7 +14,7 @@ import (
|
|||||||
func randUint32() uint32 {
|
func randUint32() uint32 {
|
||||||
var b [4]byte
|
var b [4]byte
|
||||||
rand.Read(b[:])
|
rand.Read(b[:])
|
||||||
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3])
|
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||||
}
|
}
|
||||||
|
|
||||||
// RTCPReceiver is a utility to generate RTCP receiver reports.
|
// RTCPReceiver is a utility to generate RTCP receiver reports.
|
||||||
|
@@ -16,7 +16,7 @@ const (
|
|||||||
func randUint32() uint32 {
|
func randUint32() uint32 {
|
||||||
var b [4]byte
|
var b [4]byte
|
||||||
rand.Read(b[:])
|
rand.Read(b[:])
|
||||||
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3])
|
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encoder is a RTP/AAC encoder.
|
// Encoder is a RTP/AAC encoder.
|
||||||
|
@@ -17,7 +17,7 @@ const (
|
|||||||
func randUint32() uint32 {
|
func randUint32() uint32 {
|
||||||
var b [4]byte
|
var b [4]byte
|
||||||
rand.Read(b[:])
|
rand.Read(b[:])
|
||||||
return uint32(b[0]<<24) | uint32(b[1]<<16) | uint32(b[2]<<8) | uint32(b[3])
|
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encoder is a RTP/H264 encoder.
|
// Encoder is a RTP/H264 encoder.
|
||||||
|
13
server.go
13
server.go
@@ -32,7 +32,7 @@ func extractPort(address string) (int, error) {
|
|||||||
func newSessionSecretID(sessions map[string]*ServerSession) (string, error) {
|
func newSessionSecretID(sessions map[string]*ServerSession) (string, error) {
|
||||||
for {
|
for {
|
||||||
b := make([]byte, 4)
|
b := make([]byte, 4)
|
||||||
_, err := rand.Read(b[:])
|
_, err := rand.Read(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -363,6 +363,17 @@ outer:
|
|||||||
|
|
||||||
case req := <-s.sessionRequest:
|
case req := <-s.sessionRequest:
|
||||||
if ss, ok := s.sessions[req.id]; ok {
|
if ss, ok := s.sessions[req.id]; ok {
|
||||||
|
if !req.sc.ip().Equal(ss.ip()) ||
|
||||||
|
req.sc.zone() != ss.zone() {
|
||||||
|
req.res <- sessionRequestRes{
|
||||||
|
res: &base.Response{
|
||||||
|
StatusCode: base.StatusBadRequest,
|
||||||
|
},
|
||||||
|
err: liberrors.ErrServerCannotUseSessionCreatedByOtherIP{},
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
ss.request <- req
|
ss.request <- req
|
||||||
} else {
|
} else {
|
||||||
if !req.create {
|
if !req.create {
|
||||||
|
@@ -47,6 +47,7 @@ type ServerConn struct {
|
|||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel func()
|
ctxCancel func()
|
||||||
|
remoteAddr *net.TCPAddr // to improve speed
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
bw *bufio.Writer
|
bw *bufio.Writer
|
||||||
sessions map[string]*ServerSession
|
sessions map[string]*ServerSession
|
||||||
@@ -76,6 +77,7 @@ func newServerConn(
|
|||||||
nconn: nconn,
|
nconn: nconn,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: ctxCancel,
|
ctxCancel: ctxCancel,
|
||||||
|
remoteAddr: nconn.RemoteAddr().(*net.TCPAddr),
|
||||||
sessionRemove: make(chan *ServerSession),
|
sessionRemove: make(chan *ServerSession),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
@@ -98,11 +100,11 @@ func (sc *ServerConn) NetConn() net.Conn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (sc *ServerConn) ip() net.IP {
|
func (sc *ServerConn) ip() net.IP {
|
||||||
return sc.nconn.RemoteAddr().(*net.TCPAddr).IP
|
return sc.remoteAddr.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *ServerConn) zone() string {
|
func (sc *ServerConn) zone() string {
|
||||||
return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone
|
return sc.remoteAddr.Zone
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *ServerConn) run() {
|
func (sc *ServerConn) run() {
|
||||||
|
@@ -137,8 +137,6 @@ type ServerSession struct {
|
|||||||
setuppedQuery *string
|
setuppedQuery *string
|
||||||
lastRequestTime time.Time
|
lastRequestTime time.Time
|
||||||
tcpConn *ServerConn // tcp
|
tcpConn *ServerConn // tcp
|
||||||
udpIP net.IP // udp
|
|
||||||
udpZone string // udp
|
|
||||||
announcedTracks []ServerSessionAnnouncedTrack // publish
|
announcedTracks []ServerSessionAnnouncedTrack // publish
|
||||||
udpLastFrameTime *int64 // publish, udp
|
udpLastFrameTime *int64 // publish, udp
|
||||||
|
|
||||||
@@ -203,6 +201,14 @@ func (ss *ServerSession) AnnouncedTracks() []ServerSessionAnnouncedTrack {
|
|||||||
return ss.announcedTracks
|
return ss.announcedTracks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ss *ServerSession) ip() net.IP {
|
||||||
|
return ss.author.ip()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *ServerSession) zone() string {
|
||||||
|
return ss.author.zone()
|
||||||
|
}
|
||||||
|
|
||||||
func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error {
|
func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error {
|
||||||
if _, ok := allowed[ss.state]; ok {
|
if _, ok := allowed[ss.state]; ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -225,7 +231,6 @@ func (ss *ServerSession) run() {
|
|||||||
Session: ss,
|
Session: ss,
|
||||||
Conn: ss.author,
|
Conn: ss.author,
|
||||||
})
|
})
|
||||||
ss.author = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := func() error {
|
err := func() error {
|
||||||
@@ -587,13 +592,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
|||||||
StatusCode: base.StatusBadRequest,
|
StatusCode: base.StatusBadRequest,
|
||||||
}, liberrors.ErrServerTransportHeaderNoClientPorts{}
|
}, liberrors.ErrServerTransportHeaderNoClientPorts{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ss.setuppedTracks != nil &&
|
|
||||||
(!ss.udpIP.Equal(sc.ip()) || ss.udpZone != sc.zone()) {
|
|
||||||
return &base.Response{
|
|
||||||
StatusCode: base.StatusBadRequest,
|
|
||||||
}, liberrors.ErrServerCannotSetupFromDifferentIPs{}
|
|
||||||
}
|
|
||||||
} else if ss.s.MulticastIPRange == "" {
|
} else if ss.s.MulticastIPRange == "" {
|
||||||
return &base.Response{
|
return &base.Response{
|
||||||
StatusCode: base.StatusUnsupportedTransport,
|
StatusCode: base.StatusUnsupportedTransport,
|
||||||
@@ -649,8 +647,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
|||||||
err := stream.readerAdd(ss,
|
err := stream.readerAdd(ss,
|
||||||
inTH.Protocol,
|
inTH.Protocol,
|
||||||
delivery,
|
delivery,
|
||||||
sc.ip(),
|
|
||||||
sc.zone(),
|
|
||||||
inTH.ClientPorts,
|
inTH.ClientPorts,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -685,9 +681,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case delivery == base.StreamDeliveryMulticast:
|
case delivery == base.StreamDeliveryMulticast:
|
||||||
ss.udpIP = sc.ip()
|
|
||||||
ss.udpZone = sc.zone()
|
|
||||||
|
|
||||||
th.Protocol = base.StreamProtocolUDP
|
th.Protocol = base.StreamProtocolUDP
|
||||||
de := base.StreamDeliveryMulticast
|
de := base.StreamDeliveryMulticast
|
||||||
th.Delivery = &de
|
th.Delivery = &de
|
||||||
@@ -701,9 +694,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
|||||||
}
|
}
|
||||||
|
|
||||||
case inTH.Protocol == base.StreamProtocolUDP:
|
case inTH.Protocol == base.StreamProtocolUDP:
|
||||||
ss.udpIP = sc.ip()
|
|
||||||
ss.udpZone = sc.zone()
|
|
||||||
|
|
||||||
sst.udpRTPPort = inTH.ClientPorts[0]
|
sst.udpRTPPort = inTH.ClientPorts[0]
|
||||||
sst.udpRTCPPort = inTH.ClientPorts[1]
|
sst.udpRTCPPort = inTH.ClientPorts[1]
|
||||||
|
|
||||||
@@ -847,7 +837,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
|||||||
if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
|
if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
|
||||||
for trackID, track := range ss.setuppedTracks {
|
for trackID, track := range ss.setuppedTracks {
|
||||||
// readers can send RTCP frames
|
// readers can send RTCP frames
|
||||||
sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false)
|
sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false)
|
||||||
|
|
||||||
// open the firewall by sending packets to the counterpart
|
// open the firewall by sending packets to the counterpart
|
||||||
ss.WriteFrame(trackID, StreamTypeRTCP,
|
ss.WriteFrame(trackID, StreamTypeRTCP,
|
||||||
@@ -916,8 +906,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
|||||||
|
|
||||||
if *ss.setuppedProtocol == base.StreamProtocolUDP {
|
if *ss.setuppedProtocol == base.StreamProtocolUDP {
|
||||||
for trackID, track := range ss.setuppedTracks {
|
for trackID, track := range ss.setuppedTracks {
|
||||||
ss.s.udpRTPListener.addClient(ss.udpIP, track.udpRTPPort, ss, trackID, true)
|
ss.s.udpRTPListener.addClient(ss.ip(), track.udpRTPPort, ss, trackID, true)
|
||||||
ss.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, true)
|
ss.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, true)
|
||||||
|
|
||||||
// open the firewall by sending packets to the counterpart
|
// open the firewall by sending packets to the counterpart
|
||||||
ss.WriteFrame(trackID, StreamTypeRTP,
|
ss.WriteFrame(trackID, StreamTypeRTP,
|
||||||
@@ -1050,14 +1040,14 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
|
|||||||
|
|
||||||
if streamType == StreamTypeRTP {
|
if streamType == StreamTypeRTP {
|
||||||
ss.s.udpRTPListener.write(payload, &net.UDPAddr{
|
ss.s.udpRTPListener.write(payload, &net.UDPAddr{
|
||||||
IP: ss.udpIP,
|
IP: ss.ip(),
|
||||||
Zone: ss.udpZone,
|
Zone: ss.zone(),
|
||||||
Port: track.udpRTPPort,
|
Port: track.udpRTPPort,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
ss.s.udpRTCPListener.write(payload, &net.UDPAddr{
|
ss.s.udpRTCPListener.write(payload, &net.UDPAddr{
|
||||||
IP: ss.udpIP,
|
IP: ss.ip(),
|
||||||
Zone: ss.udpZone,
|
Zone: ss.zone(),
|
||||||
Port: track.udpRTCPPort,
|
Port: track.udpRTCPPort,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@@ -116,8 +116,6 @@ func (st *ServerStream) readerAdd(
|
|||||||
ss *ServerSession,
|
ss *ServerSession,
|
||||||
protocol base.StreamProtocol,
|
protocol base.StreamProtocol,
|
||||||
delivery base.StreamDelivery,
|
delivery base.StreamDelivery,
|
||||||
ip net.IP,
|
|
||||||
zone string,
|
|
||||||
clientPorts *[2]int,
|
clientPorts *[2]int,
|
||||||
) error {
|
) error {
|
||||||
st.mutex.Lock()
|
st.mutex.Lock()
|
||||||
@@ -137,8 +135,8 @@ func (st *ServerStream) readerAdd(
|
|||||||
for r := range st.readersUnicast {
|
for r := range st.readersUnicast {
|
||||||
if *r.setuppedProtocol == base.StreamProtocolUDP &&
|
if *r.setuppedProtocol == base.StreamProtocolUDP &&
|
||||||
*r.setuppedDelivery == base.StreamDeliveryUnicast &&
|
*r.setuppedDelivery == base.StreamDeliveryUnicast &&
|
||||||
r.udpIP.Equal(ip) &&
|
r.ip().Equal(ss.ip()) &&
|
||||||
r.udpZone == zone {
|
r.zone() == ss.zone() {
|
||||||
for _, rt := range r.setuppedTracks {
|
for _, rt := range r.setuppedTracks {
|
||||||
if rt.udpRTPPort == clientPorts[0] {
|
if rt.udpRTPPort == clientPorts[0] {
|
||||||
return liberrors.ErrServerUDPPortsAlreadyInUse{Port: rt.udpRTPPort}
|
return liberrors.ErrServerUDPPortsAlreadyInUse{Port: rt.udpRTPPort}
|
||||||
@@ -203,7 +201,7 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) {
|
|||||||
} else {
|
} else {
|
||||||
for trackID := range ss.setuppedTracks {
|
for trackID := range ss.setuppedTracks {
|
||||||
st.multicastListeners[trackID].rtcpListener.addClient(
|
st.multicastListeners[trackID].rtcpListener.addClient(
|
||||||
ss.udpIP, st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false)
|
ss.ip(), st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user