Set cluster node ips on session block list

This commit is contained in:
Ingo Oppermann
2022-08-15 13:52:38 +03:00
parent 2b05d5fb31
commit fe2d4e247e
7 changed files with 190 additions and 62 deletions

View File

@@ -5,7 +5,6 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
"io/ioutil"
golog "log" golog "log"
gonet "net" gonet "net"
gohttp "net/http" gohttp "net/http"
@@ -62,23 +61,24 @@ type API interface {
} }
type api struct { type api struct {
restream restream.Restreamer restream restream.Restreamer
ffmpeg ffmpeg.FFmpeg ffmpeg ffmpeg.FFmpeg
diskfs fs.Filesystem diskfs fs.Filesystem
memfs fs.Filesystem memfs fs.Filesystem
rtmpserver rtmp.Server rtmpserver rtmp.Server
srtserver srt.Server srtserver srt.Server
metrics monitor.HistoryMonitor metrics monitor.HistoryMonitor
prom prometheus.Metrics prom prometheus.Metrics
service service.Service service service.Service
sessions session.Registry sessions session.Registry
cache cache.Cacher sessionsLimiter net.IPLimiter
mainserver *gohttp.Server cache cache.Cacher
sidecarserver *gohttp.Server mainserver *gohttp.Server
httpjwt jwt.JWT sidecarserver *gohttp.Server
update update.Checker httpjwt jwt.JWT
replacer replace.Replacer update update.Checker
cluster cluster.Cluster replacer replace.Replacer
cluster cluster.Cluster
errorChan chan error errorChan chan error
@@ -122,7 +122,7 @@ func New(configpath string, logwriter io.Writer) (API, error) {
a.log.writer = logwriter a.log.writer = logwriter
if a.log.writer == nil { if a.log.writer == nil {
a.log.writer = ioutil.Discard a.log.writer = io.Discard
} }
a.errorChan = make(chan error, 1) a.errorChan = make(chan error, 1)
@@ -301,6 +301,8 @@ func (a *api) start() error {
return fmt.Errorf("incorret IP ranges for the statistics provided: %w", err) return fmt.Errorf("incorret IP ranges for the statistics provided: %w", err)
} }
a.sessionsLimiter = iplimiter
config := session.CollectorConfig{ config := session.CollectorConfig{
MaxTxBitrate: cfg.Sessions.MaxBitrate * 1024 * 1024, MaxTxBitrate: cfg.Sessions.MaxBitrate * 1024 * 1024,
MaxSessions: cfg.Sessions.MaxSessions, MaxSessions: cfg.Sessions.MaxSessions,
@@ -498,7 +500,8 @@ func (a *api) start() error {
a.restream = restream a.restream = restream
if cluster, err := cluster.New(cluster.ClusterConfig{ if cluster, err := cluster.New(cluster.ClusterConfig{
Logger: a.log.logger.core.WithComponent("Cluster"), IPLimiter: a.sessionsLimiter,
Logger: a.log.logger.core.WithComponent("Cluster"),
}); err != nil { }); err != nil {
return fmt.Errorf("unable to create cluster: %w", err) return fmt.Errorf("unable to create cluster: %w", err)
} else { } else {
@@ -778,7 +781,7 @@ func (a *api) start() error {
logcontext = "HTTPS" logcontext = "HTTPS"
} }
var iplimiter net.IPLimiter var iplimiter net.IPLimitValidator
if cfg.TLS.Enable { if cfg.TLS.Enable {
limiter, err := net.NewIPLimiter(cfg.API.Access.HTTPS.Block, cfg.API.Access.HTTPS.Allow) limiter, err := net.NewIPLimiter(cfg.API.Access.HTTPS.Block, cfg.API.Access.HTTPS.Allow)

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/log"
"github.com/datarhei/core/v16/net"
) )
type ClusterReader interface { type ClusterReader interface {
@@ -33,7 +34,8 @@ type Cluster interface {
} }
type ClusterConfig struct { type ClusterConfig struct {
Logger log.Logger IPLimiter net.IPLimiter
Logger log.Logger
} }
type cluster struct { type cluster struct {
@@ -42,6 +44,8 @@ type cluster struct {
idupdate map[string]time.Time idupdate map[string]time.Time
fileid map[string]string fileid map[string]string
limiter net.IPLimiter
updates chan NodeState updates chan NodeState
lock sync.RWMutex lock sync.RWMutex
@@ -57,10 +61,15 @@ func New(config ClusterConfig) (Cluster, error) {
idfiles: map[string][]string{}, idfiles: map[string][]string{},
idupdate: map[string]time.Time{}, idupdate: map[string]time.Time{},
fileid: map[string]string{}, fileid: map[string]string{},
limiter: config.IPLimiter,
updates: make(chan NodeState, 64), updates: make(chan NodeState, 64),
logger: config.Logger, logger: config.Logger,
} }
if c.limiter == nil {
c.limiter = net.NewNullIPLimiter()
}
if c.logger == nil { if c.logger == nil {
c.logger = log.New("") c.logger = log.New("")
} }
@@ -134,9 +143,15 @@ func (c *cluster) AddNode(address, username, password string) (string, error) {
defer c.lock.Unlock() defer c.lock.Unlock()
if _, ok := c.nodes[id]; ok { if _, ok := c.nodes[id]; ok {
node.stop()
return id, nil return id, nil
} }
ips := node.IPs()
for _, ip := range ips {
c.limiter.AddBlock(ip)
}
c.nodes[id] = node c.nodes[id] = node
c.logger.Info().WithFields(log.Fields{ c.logger.Info().WithFields(log.Fields{
@@ -160,6 +175,12 @@ func (c *cluster) RemoveNode(id string) error {
delete(c.nodes, id) delete(c.nodes, id)
ips := node.IPs()
for _, ip := range ips {
c.limiter.RemoveBlock(ip)
}
c.logger.Info().WithFields(log.Fields{ c.logger.Info().WithFields(log.Fields{
"id": id, "id": id,
}).Log("Removed node") }).Log("Removed node")

View File

@@ -17,6 +17,7 @@ import (
type NodeReader interface { type NodeReader interface {
Address() string Address() string
IPs() []string
State() NodeState State() NodeState
} }
@@ -40,6 +41,7 @@ const (
type node struct { type node struct {
address string address string
ips []string
state nodeState state nodeState
username string username string
password string password string
@@ -75,8 +77,14 @@ func newNode(address, username, password string, updates chan<- NodeState) (*nod
return nil, fmt.Errorf("invalid address: %w", err) return nil, fmt.Errorf("invalid address: %w", err)
} }
addrs, err := net.LookupHost(host)
if err != nil {
return nil, fmt.Errorf("lookup failed: %w", err)
}
n := &node{ n := &node{
address: address, address: address,
ips: addrs,
username: username, username: username,
password: password, password: password,
state: stateDisconnected, state: stateDisconnected,
@@ -172,6 +180,10 @@ func (n *node) Address() string {
return n.address return n.address
} }
func (n *node) IPs() []string {
return n.ips
}
func (n *node) ID() string { func (n *node) ID() string {
return n.peer.ID() return n.peer.ID()
} }
@@ -239,8 +251,6 @@ func (n *node) files() {
n.fileList[nfiles] = "srt:" + file.Name n.fileList[nfiles] = "srt:" + file.Name
nfiles++ nfiles++
} }
return
} }
func (n *node) GetURL(path string) (string, error) { func (n *node) GetURL(path string) (string, error) {

View File

@@ -12,7 +12,7 @@ import (
type Config struct { type Config struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper middleware.Skipper Skipper middleware.Skipper
Limiter net.IPLimiter Limiter net.IPLimitValidator
} }
var DefaultConfig = Config{ var DefaultConfig = Config{

View File

@@ -83,7 +83,7 @@ type Config struct {
MimeTypesFile string MimeTypesFile string
DiskFS fs.Filesystem DiskFS fs.Filesystem
MemFS MemFSConfig MemFS MemFSConfig
IPLimiter net.IPLimiter IPLimiter net.IPLimitValidator
Profiling bool Profiling bool
Cors CorsConfig Cors CorsConfig
RTMP rtmp.Server RTMP rtmp.Server

View File

@@ -3,71 +3,169 @@ package net
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"strings" "strings"
"sync"
) )
// The IPLimiter interface allows to check whether a certain IP // The IPLimitValidator interface allows to check whether a certain IP is allowed.
// is allowed. type IPLimitValidator interface {
type IPLimiter interface {
// Tests whether the IP is allowed in respect to the underlying implementation // Tests whether the IP is allowed in respect to the underlying implementation
IsAllowed(ip string) bool IsAllowed(ip string) bool
} }
type IPLimiter interface {
// AddAllow adds a CIDR block to the allow list. If only an IP is provided
// a CIDR will be generated.
AddAllow(cidr string) error
// RemoveAllow removes a CIDR block from the allow list. If only an IP is provided
// a CIDR will be generated.
RemoveAllow(cidr string) error
// AddBlock adds a CIDR block to the block list. If only an IP is provided
// a CIDR will be generated.
AddBlock(cidr string) error
// RemoveBlock removes a CIDR block from the block list. If only an IP is provided
// a CIDR will be generated.
RemoveBlock(cidr string) error
IPLimitValidator
}
// IPLimit implements the IPLimiter interface by having an allow and block list // IPLimit implements the IPLimiter interface by having an allow and block list
// of CIDR ranges. // of CIDR ranges.
type iplimit struct { type iplimit struct {
// Array of allowed IP ranges // allowList is an array of allowed IP ranges
allowlist []*net.IPNet allowlist map[string]*net.IPNet
// Array of blocked IP ranges // blocklist is an array of blocked IP ranges
blocklist []*net.IPNet blocklist map[string]*net.IPNet
// lock is synchronizing the acces to the allow and block lists
lock sync.RWMutex
} }
// NewIPLimiter creates a new IPLimiter with the given IP ranges for the // NewIPLimiter creates a new IPLimiter with the given IP ranges for the
// allowed and blocked IPs. Empty strings are ignored. Returns an error // allowed and blocked IPs. Empty strings are ignored. Returns an error
// if an invalid IP range has been found. // if an invalid IP range has been found.
func NewIPLimiter(blocklist, allowlist []string) (IPLimiter, error) { func NewIPLimiter(blocklist, allowlist []string) (IPLimiter, error) {
ipl := &iplimit{} ipl := &iplimit{
allowlist: make(map[string]*net.IPNet),
blocklist: make(map[string]*net.IPNet),
}
for _, ipblock := range blocklist { for _, ipblock := range blocklist {
ipblock = strings.TrimSpace(ipblock) err := ipl.AddBlock(ipblock)
if len(ipblock) == 0 {
continue
}
_, cidr, err := net.ParseCIDR(ipblock)
if err != nil { if err != nil {
return nil, fmt.Errorf("the IP block %s in the block list is invalid", ipblock) return nil, fmt.Errorf("block list: %w", err)
} }
ipl.blocklist = append(ipl.blocklist, cidr)
} }
for _, ipblock := range allowlist { for _, ipblock := range allowlist {
ipblock = strings.TrimSpace(ipblock) err := ipl.AddAllow(ipblock)
if len(ipblock) == 0 {
continue
}
_, cidr, err := net.ParseCIDR(ipblock)
if err != nil { if err != nil {
return nil, fmt.Errorf("the IP block %s in the allow list is invalid", ipblock) return nil, fmt.Errorf("allow list: %w", err)
} }
ipl.allowlist = append(ipl.allowlist, cidr)
} }
return ipl, nil return ipl, nil
} }
// IsAllowed checks whether the provided IP is allowed according func (ipl *iplimit) validate(ipblock string) (*net.IPNet, error) {
// to the IP ranges in the allow- and blocklists. ipblock = strings.TrimSpace(ipblock)
if len(ipblock) == 0 {
return nil, fmt.Errorf("invalid IP block")
}
_, cidr, err := net.ParseCIDR(ipblock)
if err != nil {
addr, err := netip.ParseAddr(ipblock)
if err != nil {
return nil, fmt.Errorf("invalid IP block: %w", err)
}
if addr.Is4() {
ipblock = addr.String() + "/32"
} else {
ipblock = addr.String() + "/128"
}
_, cidr, err = net.ParseCIDR(ipblock)
if err != nil {
return nil, fmt.Errorf("invalid IP block: %w", err)
}
}
return cidr, nil
}
func (ipl *iplimit) AddAllow(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
ipl.allowlist[cidr.String()] = cidr
return nil
}
func (ipl *iplimit) RemoveAllow(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
delete(ipl.allowlist, cidr.String())
return nil
}
func (ipl *iplimit) AddBlock(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
ipl.blocklist[cidr.String()] = cidr
return nil
}
func (ipl *iplimit) RemoveBlock(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
delete(ipl.blocklist, cidr.String())
return nil
}
func (ipl *iplimit) IsAllowed(ip string) bool { func (ipl *iplimit) IsAllowed(ip string) bool {
parsedIP := net.ParseIP(ip) parsedIP := net.ParseIP(ip)
if parsedIP == nil { if parsedIP == nil {
return false return false
} }
ipl.lock.RLock()
defer ipl.lock.RUnlock()
for _, r := range ipl.blocklist { for _, r := range ipl.blocklist {
if r.Contains(parsedIP) { if r.Contains(parsedIP) {
return false return false
@@ -87,12 +185,8 @@ func (ipl *iplimit) IsAllowed(ip string) bool {
return false return false
} }
type nulliplimiter struct{}
func NewNullIPLimiter() IPLimiter { func NewNullIPLimiter() IPLimiter {
return &nulliplimiter{} ipl, _ := NewIPLimiter(nil, nil)
}
func (ipl *nulliplimiter) IsAllowed(ip string) bool { return ipl
return true
} }

View File

@@ -188,7 +188,7 @@ type CollectorConfig struct {
// Limiter is an IPLimiter. It is used to query whether a session for an IP // Limiter is an IPLimiter. It is used to query whether a session for an IP
// should be created. // should be created.
Limiter net.IPLimiter Limiter net.IPLimitValidator
// InactiveTimeout is the duration of how long a not yet activated session is kept. // InactiveTimeout is the duration of how long a not yet activated session is kept.
// A session gets activated with the first ingress or egress bytes. // A session gets activated with the first ingress or egress bytes.
@@ -253,7 +253,7 @@ type collector struct {
inactiveTimeout time.Duration inactiveTimeout time.Duration
sessionTimeout time.Duration sessionTimeout time.Duration
limiter net.IPLimiter limiter net.IPLimitValidator
companions []Collector companions []Collector