Set cluster node ips on session block list

This commit is contained in:
Ingo Oppermann
2022-08-15 13:52:38 +03:00
parent 2b05d5fb31
commit fe2d4e247e
7 changed files with 190 additions and 62 deletions

View File

@@ -5,7 +5,6 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
"io/ioutil"
golog "log" golog "log"
gonet "net" gonet "net"
gohttp "net/http" gohttp "net/http"
@@ -72,6 +71,7 @@ type api struct {
prom prometheus.Metrics prom prometheus.Metrics
service service.Service service service.Service
sessions session.Registry sessions session.Registry
sessionsLimiter net.IPLimiter
cache cache.Cacher cache cache.Cacher
mainserver *gohttp.Server mainserver *gohttp.Server
sidecarserver *gohttp.Server sidecarserver *gohttp.Server
@@ -122,7 +122,7 @@ func New(configpath string, logwriter io.Writer) (API, error) {
a.log.writer = logwriter a.log.writer = logwriter
if a.log.writer == nil { if a.log.writer == nil {
a.log.writer = ioutil.Discard a.log.writer = io.Discard
} }
a.errorChan = make(chan error, 1) a.errorChan = make(chan error, 1)
@@ -301,6 +301,8 @@ func (a *api) start() error {
return fmt.Errorf("incorret IP ranges for the statistics provided: %w", err) return fmt.Errorf("incorret IP ranges for the statistics provided: %w", err)
} }
a.sessionsLimiter = iplimiter
config := session.CollectorConfig{ config := session.CollectorConfig{
MaxTxBitrate: cfg.Sessions.MaxBitrate * 1024 * 1024, MaxTxBitrate: cfg.Sessions.MaxBitrate * 1024 * 1024,
MaxSessions: cfg.Sessions.MaxSessions, MaxSessions: cfg.Sessions.MaxSessions,
@@ -498,6 +500,7 @@ func (a *api) start() error {
a.restream = restream a.restream = restream
if cluster, err := cluster.New(cluster.ClusterConfig{ if cluster, err := cluster.New(cluster.ClusterConfig{
IPLimiter: a.sessionsLimiter,
Logger: a.log.logger.core.WithComponent("Cluster"), Logger: a.log.logger.core.WithComponent("Cluster"),
}); err != nil { }); err != nil {
return fmt.Errorf("unable to create cluster: %w", err) return fmt.Errorf("unable to create cluster: %w", err)
@@ -778,7 +781,7 @@ func (a *api) start() error {
logcontext = "HTTPS" logcontext = "HTTPS"
} }
var iplimiter net.IPLimiter var iplimiter net.IPLimitValidator
if cfg.TLS.Enable { if cfg.TLS.Enable {
limiter, err := net.NewIPLimiter(cfg.API.Access.HTTPS.Block, cfg.API.Access.HTTPS.Allow) limiter, err := net.NewIPLimiter(cfg.API.Access.HTTPS.Block, cfg.API.Access.HTTPS.Allow)

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/log"
"github.com/datarhei/core/v16/net"
) )
type ClusterReader interface { type ClusterReader interface {
@@ -33,6 +34,7 @@ type Cluster interface {
} }
type ClusterConfig struct { type ClusterConfig struct {
IPLimiter net.IPLimiter
Logger log.Logger Logger log.Logger
} }
@@ -42,6 +44,8 @@ type cluster struct {
idupdate map[string]time.Time idupdate map[string]time.Time
fileid map[string]string fileid map[string]string
limiter net.IPLimiter
updates chan NodeState updates chan NodeState
lock sync.RWMutex lock sync.RWMutex
@@ -57,10 +61,15 @@ func New(config ClusterConfig) (Cluster, error) {
idfiles: map[string][]string{}, idfiles: map[string][]string{},
idupdate: map[string]time.Time{}, idupdate: map[string]time.Time{},
fileid: map[string]string{}, fileid: map[string]string{},
limiter: config.IPLimiter,
updates: make(chan NodeState, 64), updates: make(chan NodeState, 64),
logger: config.Logger, logger: config.Logger,
} }
if c.limiter == nil {
c.limiter = net.NewNullIPLimiter()
}
if c.logger == nil { if c.logger == nil {
c.logger = log.New("") c.logger = log.New("")
} }
@@ -134,9 +143,15 @@ func (c *cluster) AddNode(address, username, password string) (string, error) {
defer c.lock.Unlock() defer c.lock.Unlock()
if _, ok := c.nodes[id]; ok { if _, ok := c.nodes[id]; ok {
node.stop()
return id, nil return id, nil
} }
ips := node.IPs()
for _, ip := range ips {
c.limiter.AddBlock(ip)
}
c.nodes[id] = node c.nodes[id] = node
c.logger.Info().WithFields(log.Fields{ c.logger.Info().WithFields(log.Fields{
@@ -160,6 +175,12 @@ func (c *cluster) RemoveNode(id string) error {
delete(c.nodes, id) delete(c.nodes, id)
ips := node.IPs()
for _, ip := range ips {
c.limiter.RemoveBlock(ip)
}
c.logger.Info().WithFields(log.Fields{ c.logger.Info().WithFields(log.Fields{
"id": id, "id": id,
}).Log("Removed node") }).Log("Removed node")

View File

@@ -17,6 +17,7 @@ import (
type NodeReader interface { type NodeReader interface {
Address() string Address() string
IPs() []string
State() NodeState State() NodeState
} }
@@ -40,6 +41,7 @@ const (
type node struct { type node struct {
address string address string
ips []string
state nodeState state nodeState
username string username string
password string password string
@@ -75,8 +77,14 @@ func newNode(address, username, password string, updates chan<- NodeState) (*nod
return nil, fmt.Errorf("invalid address: %w", err) return nil, fmt.Errorf("invalid address: %w", err)
} }
addrs, err := net.LookupHost(host)
if err != nil {
return nil, fmt.Errorf("lookup failed: %w", err)
}
n := &node{ n := &node{
address: address, address: address,
ips: addrs,
username: username, username: username,
password: password, password: password,
state: stateDisconnected, state: stateDisconnected,
@@ -172,6 +180,10 @@ func (n *node) Address() string {
return n.address return n.address
} }
func (n *node) IPs() []string {
return n.ips
}
func (n *node) ID() string { func (n *node) ID() string {
return n.peer.ID() return n.peer.ID()
} }
@@ -239,8 +251,6 @@ func (n *node) files() {
n.fileList[nfiles] = "srt:" + file.Name n.fileList[nfiles] = "srt:" + file.Name
nfiles++ nfiles++
} }
return
} }
func (n *node) GetURL(path string) (string, error) { func (n *node) GetURL(path string) (string, error) {

View File

@@ -12,7 +12,7 @@ import (
type Config struct { type Config struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper middleware.Skipper Skipper middleware.Skipper
Limiter net.IPLimiter Limiter net.IPLimitValidator
} }
var DefaultConfig = Config{ var DefaultConfig = Config{

View File

@@ -83,7 +83,7 @@ type Config struct {
MimeTypesFile string MimeTypesFile string
DiskFS fs.Filesystem DiskFS fs.Filesystem
MemFS MemFSConfig MemFS MemFSConfig
IPLimiter net.IPLimiter IPLimiter net.IPLimitValidator
Profiling bool Profiling bool
Cors CorsConfig Cors CorsConfig
RTMP rtmp.Server RTMP rtmp.Server

View File

@@ -3,71 +3,169 @@ package net
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"strings" "strings"
"sync"
) )
// The IPLimiter interface allows to check whether a certain IP // The IPLimitValidator interface allows to check whether a certain IP is allowed.
// is allowed. type IPLimitValidator interface {
type IPLimiter interface {
// Tests whether the IP is allowed in respect to the underlying implementation // Tests whether the IP is allowed in respect to the underlying implementation
IsAllowed(ip string) bool IsAllowed(ip string) bool
} }
type IPLimiter interface {
// AddAllow adds a CIDR block to the allow list. If only an IP is provided
// a CIDR will be generated.
AddAllow(cidr string) error
// RemoveAllow removes a CIDR block from the allow list. If only an IP is provided
// a CIDR will be generated.
RemoveAllow(cidr string) error
// AddBlock adds a CIDR block to the block list. If only an IP is provided
// a CIDR will be generated.
AddBlock(cidr string) error
// RemoveBlock removes a CIDR block from the block list. If only an IP is provided
// a CIDR will be generated.
RemoveBlock(cidr string) error
IPLimitValidator
}
// IPLimit implements the IPLimiter interface by having an allow and block list // IPLimit implements the IPLimiter interface by having an allow and block list
// of CIDR ranges. // of CIDR ranges.
type iplimit struct { type iplimit struct {
// Array of allowed IP ranges // allowList is an array of allowed IP ranges
allowlist []*net.IPNet allowlist map[string]*net.IPNet
// Array of blocked IP ranges // blocklist is an array of blocked IP ranges
blocklist []*net.IPNet blocklist map[string]*net.IPNet
// lock is synchronizing the acces to the allow and block lists
lock sync.RWMutex
} }
// NewIPLimiter creates a new IPLimiter with the given IP ranges for the // NewIPLimiter creates a new IPLimiter with the given IP ranges for the
// allowed and blocked IPs. Empty strings are ignored. Returns an error // allowed and blocked IPs. Empty strings are ignored. Returns an error
// if an invalid IP range has been found. // if an invalid IP range has been found.
func NewIPLimiter(blocklist, allowlist []string) (IPLimiter, error) { func NewIPLimiter(blocklist, allowlist []string) (IPLimiter, error) {
ipl := &iplimit{} ipl := &iplimit{
allowlist: make(map[string]*net.IPNet),
blocklist: make(map[string]*net.IPNet),
}
for _, ipblock := range blocklist { for _, ipblock := range blocklist {
ipblock = strings.TrimSpace(ipblock) err := ipl.AddBlock(ipblock)
if len(ipblock) == 0 {
continue
}
_, cidr, err := net.ParseCIDR(ipblock)
if err != nil { if err != nil {
return nil, fmt.Errorf("the IP block %s in the block list is invalid", ipblock) return nil, fmt.Errorf("block list: %w", err)
} }
ipl.blocklist = append(ipl.blocklist, cidr)
} }
for _, ipblock := range allowlist { for _, ipblock := range allowlist {
ipblock = strings.TrimSpace(ipblock) err := ipl.AddAllow(ipblock)
if len(ipblock) == 0 {
continue
}
_, cidr, err := net.ParseCIDR(ipblock)
if err != nil { if err != nil {
return nil, fmt.Errorf("the IP block %s in the allow list is invalid", ipblock) return nil, fmt.Errorf("allow list: %w", err)
} }
ipl.allowlist = append(ipl.allowlist, cidr)
} }
return ipl, nil return ipl, nil
} }
// IsAllowed checks whether the provided IP is allowed according func (ipl *iplimit) validate(ipblock string) (*net.IPNet, error) {
// to the IP ranges in the allow- and blocklists. ipblock = strings.TrimSpace(ipblock)
if len(ipblock) == 0 {
return nil, fmt.Errorf("invalid IP block")
}
_, cidr, err := net.ParseCIDR(ipblock)
if err != nil {
addr, err := netip.ParseAddr(ipblock)
if err != nil {
return nil, fmt.Errorf("invalid IP block: %w", err)
}
if addr.Is4() {
ipblock = addr.String() + "/32"
} else {
ipblock = addr.String() + "/128"
}
_, cidr, err = net.ParseCIDR(ipblock)
if err != nil {
return nil, fmt.Errorf("invalid IP block: %w", err)
}
}
return cidr, nil
}
func (ipl *iplimit) AddAllow(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
ipl.allowlist[cidr.String()] = cidr
return nil
}
func (ipl *iplimit) RemoveAllow(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
delete(ipl.allowlist, cidr.String())
return nil
}
func (ipl *iplimit) AddBlock(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
ipl.blocklist[cidr.String()] = cidr
return nil
}
func (ipl *iplimit) RemoveBlock(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
delete(ipl.blocklist, cidr.String())
return nil
}
func (ipl *iplimit) IsAllowed(ip string) bool { func (ipl *iplimit) IsAllowed(ip string) bool {
parsedIP := net.ParseIP(ip) parsedIP := net.ParseIP(ip)
if parsedIP == nil { if parsedIP == nil {
return false return false
} }
ipl.lock.RLock()
defer ipl.lock.RUnlock()
for _, r := range ipl.blocklist { for _, r := range ipl.blocklist {
if r.Contains(parsedIP) { if r.Contains(parsedIP) {
return false return false
@@ -87,12 +185,8 @@ func (ipl *iplimit) IsAllowed(ip string) bool {
return false return false
} }
type nulliplimiter struct{}
func NewNullIPLimiter() IPLimiter { func NewNullIPLimiter() IPLimiter {
return &nulliplimiter{} ipl, _ := NewIPLimiter(nil, nil)
}
func (ipl *nulliplimiter) IsAllowed(ip string) bool { return ipl
return true
} }

View File

@@ -188,7 +188,7 @@ type CollectorConfig struct {
// Limiter is an IPLimiter. It is used to query whether a session for an IP // Limiter is an IPLimiter. It is used to query whether a session for an IP
// should be created. // should be created.
Limiter net.IPLimiter Limiter net.IPLimitValidator
// InactiveTimeout is the duration of how long a not yet activated session is kept. // InactiveTimeout is the duration of how long a not yet activated session is kept.
// A session gets activated with the first ingress or egress bytes. // A session gets activated with the first ingress or egress bytes.
@@ -253,7 +253,7 @@ type collector struct {
inactiveTimeout time.Duration inactiveTimeout time.Duration
sessionTimeout time.Duration sessionTimeout time.Duration
limiter net.IPLimiter limiter net.IPLimitValidator
companions []Collector companions []Collector