Set cluster node ips on session block list

This commit is contained in:
Ingo Oppermann
2022-08-15 13:52:38 +03:00
parent 2b05d5fb31
commit fe2d4e247e
7 changed files with 190 additions and 62 deletions

View File

@@ -5,7 +5,6 @@ import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
golog "log"
gonet "net"
gohttp "net/http"
@@ -62,23 +61,24 @@ type API interface {
}
type api struct {
restream restream.Restreamer
ffmpeg ffmpeg.FFmpeg
diskfs fs.Filesystem
memfs fs.Filesystem
rtmpserver rtmp.Server
srtserver srt.Server
metrics monitor.HistoryMonitor
prom prometheus.Metrics
service service.Service
sessions session.Registry
cache cache.Cacher
mainserver *gohttp.Server
sidecarserver *gohttp.Server
httpjwt jwt.JWT
update update.Checker
replacer replace.Replacer
cluster cluster.Cluster
restream restream.Restreamer
ffmpeg ffmpeg.FFmpeg
diskfs fs.Filesystem
memfs fs.Filesystem
rtmpserver rtmp.Server
srtserver srt.Server
metrics monitor.HistoryMonitor
prom prometheus.Metrics
service service.Service
sessions session.Registry
sessionsLimiter net.IPLimiter
cache cache.Cacher
mainserver *gohttp.Server
sidecarserver *gohttp.Server
httpjwt jwt.JWT
update update.Checker
replacer replace.Replacer
cluster cluster.Cluster
errorChan chan error
@@ -122,7 +122,7 @@ func New(configpath string, logwriter io.Writer) (API, error) {
a.log.writer = logwriter
if a.log.writer == nil {
a.log.writer = ioutil.Discard
a.log.writer = io.Discard
}
a.errorChan = make(chan error, 1)
@@ -301,6 +301,8 @@ func (a *api) start() error {
return fmt.Errorf("incorret IP ranges for the statistics provided: %w", err)
}
a.sessionsLimiter = iplimiter
config := session.CollectorConfig{
MaxTxBitrate: cfg.Sessions.MaxBitrate * 1024 * 1024,
MaxSessions: cfg.Sessions.MaxSessions,
@@ -498,7 +500,8 @@ func (a *api) start() error {
a.restream = restream
if cluster, err := cluster.New(cluster.ClusterConfig{
Logger: a.log.logger.core.WithComponent("Cluster"),
IPLimiter: a.sessionsLimiter,
Logger: a.log.logger.core.WithComponent("Cluster"),
}); err != nil {
return fmt.Errorf("unable to create cluster: %w", err)
} else {
@@ -778,7 +781,7 @@ func (a *api) start() error {
logcontext = "HTTPS"
}
var iplimiter net.IPLimiter
var iplimiter net.IPLimitValidator
if cfg.TLS.Enable {
limiter, err := net.NewIPLimiter(cfg.API.Access.HTTPS.Block, cfg.API.Access.HTTPS.Allow)

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/datarhei/core/v16/log"
"github.com/datarhei/core/v16/net"
)
type ClusterReader interface {
@@ -33,7 +34,8 @@ type Cluster interface {
}
type ClusterConfig struct {
Logger log.Logger
IPLimiter net.IPLimiter
Logger log.Logger
}
type cluster struct {
@@ -42,6 +44,8 @@ type cluster struct {
idupdate map[string]time.Time
fileid map[string]string
limiter net.IPLimiter
updates chan NodeState
lock sync.RWMutex
@@ -57,10 +61,15 @@ func New(config ClusterConfig) (Cluster, error) {
idfiles: map[string][]string{},
idupdate: map[string]time.Time{},
fileid: map[string]string{},
limiter: config.IPLimiter,
updates: make(chan NodeState, 64),
logger: config.Logger,
}
if c.limiter == nil {
c.limiter = net.NewNullIPLimiter()
}
if c.logger == nil {
c.logger = log.New("")
}
@@ -134,9 +143,15 @@ func (c *cluster) AddNode(address, username, password string) (string, error) {
defer c.lock.Unlock()
if _, ok := c.nodes[id]; ok {
node.stop()
return id, nil
}
ips := node.IPs()
for _, ip := range ips {
c.limiter.AddBlock(ip)
}
c.nodes[id] = node
c.logger.Info().WithFields(log.Fields{
@@ -160,6 +175,12 @@ func (c *cluster) RemoveNode(id string) error {
delete(c.nodes, id)
ips := node.IPs()
for _, ip := range ips {
c.limiter.RemoveBlock(ip)
}
c.logger.Info().WithFields(log.Fields{
"id": id,
}).Log("Removed node")

View File

@@ -17,6 +17,7 @@ import (
type NodeReader interface {
Address() string
IPs() []string
State() NodeState
}
@@ -40,6 +41,7 @@ const (
type node struct {
address string
ips []string
state nodeState
username string
password string
@@ -75,8 +77,14 @@ func newNode(address, username, password string, updates chan<- NodeState) (*nod
return nil, fmt.Errorf("invalid address: %w", err)
}
addrs, err := net.LookupHost(host)
if err != nil {
return nil, fmt.Errorf("lookup failed: %w", err)
}
n := &node{
address: address,
ips: addrs,
username: username,
password: password,
state: stateDisconnected,
@@ -172,6 +180,10 @@ func (n *node) Address() string {
return n.address
}
func (n *node) IPs() []string {
return n.ips
}
func (n *node) ID() string {
return n.peer.ID()
}
@@ -239,8 +251,6 @@ func (n *node) files() {
n.fileList[nfiles] = "srt:" + file.Name
nfiles++
}
return
}
func (n *node) GetURL(path string) (string, error) {

View File

@@ -12,7 +12,7 @@ import (
type Config struct {
// Skipper defines a function to skip middleware.
Skipper middleware.Skipper
Limiter net.IPLimiter
Limiter net.IPLimitValidator
}
var DefaultConfig = Config{

View File

@@ -83,7 +83,7 @@ type Config struct {
MimeTypesFile string
DiskFS fs.Filesystem
MemFS MemFSConfig
IPLimiter net.IPLimiter
IPLimiter net.IPLimitValidator
Profiling bool
Cors CorsConfig
RTMP rtmp.Server

View File

@@ -3,71 +3,169 @@ package net
import (
"fmt"
"net"
"net/netip"
"strings"
"sync"
)
// The IPLimiter interface allows to check whether a certain IP
// is allowed.
type IPLimiter interface {
// The IPLimitValidator interface allows to check whether a certain IP is allowed.
type IPLimitValidator interface {
// Tests whether the IP is allowed in respect to the underlying implementation
IsAllowed(ip string) bool
}
type IPLimiter interface {
// AddAllow adds a CIDR block to the allow list. If only an IP is provided
// a CIDR will be generated.
AddAllow(cidr string) error
// RemoveAllow removes a CIDR block from the allow list. If only an IP is provided
// a CIDR will be generated.
RemoveAllow(cidr string) error
// AddBlock adds a CIDR block to the block list. If only an IP is provided
// a CIDR will be generated.
AddBlock(cidr string) error
// RemoveBlock removes a CIDR block from the block list. If only an IP is provided
// a CIDR will be generated.
RemoveBlock(cidr string) error
IPLimitValidator
}
// IPLimit implements the IPLimiter interface by having an allow and block list
// of CIDR ranges.
type iplimit struct {
// Array of allowed IP ranges
allowlist []*net.IPNet
// allowList is an array of allowed IP ranges
allowlist map[string]*net.IPNet
// Array of blocked IP ranges
blocklist []*net.IPNet
// blocklist is an array of blocked IP ranges
blocklist map[string]*net.IPNet
// lock is synchronizing the acces to the allow and block lists
lock sync.RWMutex
}
// NewIPLimiter creates a new IPLimiter with the given IP ranges for the
// allowed and blocked IPs. Empty strings are ignored. Returns an error
// if an invalid IP range has been found.
func NewIPLimiter(blocklist, allowlist []string) (IPLimiter, error) {
ipl := &iplimit{}
ipl := &iplimit{
allowlist: make(map[string]*net.IPNet),
blocklist: make(map[string]*net.IPNet),
}
for _, ipblock := range blocklist {
ipblock = strings.TrimSpace(ipblock)
if len(ipblock) == 0 {
continue
}
_, cidr, err := net.ParseCIDR(ipblock)
err := ipl.AddBlock(ipblock)
if err != nil {
return nil, fmt.Errorf("the IP block %s in the block list is invalid", ipblock)
return nil, fmt.Errorf("block list: %w", err)
}
ipl.blocklist = append(ipl.blocklist, cidr)
}
for _, ipblock := range allowlist {
ipblock = strings.TrimSpace(ipblock)
if len(ipblock) == 0 {
continue
}
_, cidr, err := net.ParseCIDR(ipblock)
err := ipl.AddAllow(ipblock)
if err != nil {
return nil, fmt.Errorf("the IP block %s in the allow list is invalid", ipblock)
return nil, fmt.Errorf("allow list: %w", err)
}
ipl.allowlist = append(ipl.allowlist, cidr)
}
return ipl, nil
}
// IsAllowed checks whether the provided IP is allowed according
// to the IP ranges in the allow- and blocklists.
func (ipl *iplimit) validate(ipblock string) (*net.IPNet, error) {
ipblock = strings.TrimSpace(ipblock)
if len(ipblock) == 0 {
return nil, fmt.Errorf("invalid IP block")
}
_, cidr, err := net.ParseCIDR(ipblock)
if err != nil {
addr, err := netip.ParseAddr(ipblock)
if err != nil {
return nil, fmt.Errorf("invalid IP block: %w", err)
}
if addr.Is4() {
ipblock = addr.String() + "/32"
} else {
ipblock = addr.String() + "/128"
}
_, cidr, err = net.ParseCIDR(ipblock)
if err != nil {
return nil, fmt.Errorf("invalid IP block: %w", err)
}
}
return cidr, nil
}
func (ipl *iplimit) AddAllow(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
ipl.allowlist[cidr.String()] = cidr
return nil
}
func (ipl *iplimit) RemoveAllow(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
delete(ipl.allowlist, cidr.String())
return nil
}
func (ipl *iplimit) AddBlock(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
ipl.blocklist[cidr.String()] = cidr
return nil
}
func (ipl *iplimit) RemoveBlock(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
delete(ipl.blocklist, cidr.String())
return nil
}
func (ipl *iplimit) IsAllowed(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
ipl.lock.RLock()
defer ipl.lock.RUnlock()
for _, r := range ipl.blocklist {
if r.Contains(parsedIP) {
return false
@@ -87,12 +185,8 @@ func (ipl *iplimit) IsAllowed(ip string) bool {
return false
}
type nulliplimiter struct{}
func NewNullIPLimiter() IPLimiter {
return &nulliplimiter{}
}
ipl, _ := NewIPLimiter(nil, nil)
func (ipl *nulliplimiter) IsAllowed(ip string) bool {
return true
return ipl
}

View File

@@ -188,7 +188,7 @@ type CollectorConfig struct {
// Limiter is an IPLimiter. It is used to query whether a session for an IP
// should be created.
Limiter net.IPLimiter
Limiter net.IPLimitValidator
// InactiveTimeout is the duration of how long a not yet activated session is kept.
// A session gets activated with the first ingress or egress bytes.
@@ -253,7 +253,7 @@ type collector struct {
inactiveTimeout time.Duration
sessionTimeout time.Duration
limiter net.IPLimiter
limiter net.IPLimitValidator
companions []Collector