Set cluster node ips on session block list

This commit is contained in:
Ingo Oppermann
2022-08-15 13:52:38 +03:00
parent 2b05d5fb31
commit fe2d4e247e
7 changed files with 190 additions and 62 deletions

View File

@@ -5,7 +5,6 @@ import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
golog "log"
gonet "net"
gohttp "net/http"
@@ -72,6 +71,7 @@ type api struct {
prom prometheus.Metrics
service service.Service
sessions session.Registry
sessionsLimiter net.IPLimiter
cache cache.Cacher
mainserver *gohttp.Server
sidecarserver *gohttp.Server
@@ -122,7 +122,7 @@ func New(configpath string, logwriter io.Writer) (API, error) {
a.log.writer = logwriter
if a.log.writer == nil {
a.log.writer = ioutil.Discard
a.log.writer = io.Discard
}
a.errorChan = make(chan error, 1)
@@ -301,6 +301,8 @@ func (a *api) start() error {
return fmt.Errorf("incorret IP ranges for the statistics provided: %w", err)
}
a.sessionsLimiter = iplimiter
config := session.CollectorConfig{
MaxTxBitrate: cfg.Sessions.MaxBitrate * 1024 * 1024,
MaxSessions: cfg.Sessions.MaxSessions,
@@ -498,6 +500,7 @@ func (a *api) start() error {
a.restream = restream
if cluster, err := cluster.New(cluster.ClusterConfig{
IPLimiter: a.sessionsLimiter,
Logger: a.log.logger.core.WithComponent("Cluster"),
}); err != nil {
return fmt.Errorf("unable to create cluster: %w", err)
@@ -778,7 +781,7 @@ func (a *api) start() error {
logcontext = "HTTPS"
}
var iplimiter net.IPLimiter
var iplimiter net.IPLimitValidator
if cfg.TLS.Enable {
limiter, err := net.NewIPLimiter(cfg.API.Access.HTTPS.Block, cfg.API.Access.HTTPS.Allow)

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/datarhei/core/v16/log"
"github.com/datarhei/core/v16/net"
)
type ClusterReader interface {
@@ -33,6 +34,7 @@ type Cluster interface {
}
type ClusterConfig struct {
IPLimiter net.IPLimiter
Logger log.Logger
}
@@ -42,6 +44,8 @@ type cluster struct {
idupdate map[string]time.Time
fileid map[string]string
limiter net.IPLimiter
updates chan NodeState
lock sync.RWMutex
@@ -57,10 +61,15 @@ func New(config ClusterConfig) (Cluster, error) {
idfiles: map[string][]string{},
idupdate: map[string]time.Time{},
fileid: map[string]string{},
limiter: config.IPLimiter,
updates: make(chan NodeState, 64),
logger: config.Logger,
}
if c.limiter == nil {
c.limiter = net.NewNullIPLimiter()
}
if c.logger == nil {
c.logger = log.New("")
}
@@ -134,9 +143,15 @@ func (c *cluster) AddNode(address, username, password string) (string, error) {
defer c.lock.Unlock()
if _, ok := c.nodes[id]; ok {
node.stop()
return id, nil
}
ips := node.IPs()
for _, ip := range ips {
c.limiter.AddBlock(ip)
}
c.nodes[id] = node
c.logger.Info().WithFields(log.Fields{
@@ -160,6 +175,12 @@ func (c *cluster) RemoveNode(id string) error {
delete(c.nodes, id)
ips := node.IPs()
for _, ip := range ips {
c.limiter.RemoveBlock(ip)
}
c.logger.Info().WithFields(log.Fields{
"id": id,
}).Log("Removed node")

View File

@@ -17,6 +17,7 @@ import (
type NodeReader interface {
Address() string
IPs() []string
State() NodeState
}
@@ -40,6 +41,7 @@ const (
type node struct {
address string
ips []string
state nodeState
username string
password string
@@ -75,8 +77,14 @@ func newNode(address, username, password string, updates chan<- NodeState) (*nod
return nil, fmt.Errorf("invalid address: %w", err)
}
addrs, err := net.LookupHost(host)
if err != nil {
return nil, fmt.Errorf("lookup failed: %w", err)
}
n := &node{
address: address,
ips: addrs,
username: username,
password: password,
state: stateDisconnected,
@@ -172,6 +180,10 @@ func (n *node) Address() string {
return n.address
}
func (n *node) IPs() []string {
return n.ips
}
func (n *node) ID() string {
return n.peer.ID()
}
@@ -239,8 +251,6 @@ func (n *node) files() {
n.fileList[nfiles] = "srt:" + file.Name
nfiles++
}
return
}
func (n *node) GetURL(path string) (string, error) {

View File

@@ -12,7 +12,7 @@ import (
type Config struct {
// Skipper defines a function to skip middleware.
Skipper middleware.Skipper
Limiter net.IPLimiter
Limiter net.IPLimitValidator
}
var DefaultConfig = Config{

View File

@@ -83,7 +83,7 @@ type Config struct {
MimeTypesFile string
DiskFS fs.Filesystem
MemFS MemFSConfig
IPLimiter net.IPLimiter
IPLimiter net.IPLimitValidator
Profiling bool
Cors CorsConfig
RTMP rtmp.Server

View File

@@ -3,71 +3,169 @@ package net
import (
"fmt"
"net"
"net/netip"
"strings"
"sync"
)
// The IPLimiter interface allows to check whether a certain IP
// is allowed.
type IPLimiter interface {
// The IPLimitValidator interface allows to check whether a certain IP is allowed.
type IPLimitValidator interface {
// Tests whether the IP is allowed in respect to the underlying implementation
IsAllowed(ip string) bool
}
type IPLimiter interface {
// AddAllow adds a CIDR block to the allow list. If only an IP is provided
// a CIDR will be generated.
AddAllow(cidr string) error
// RemoveAllow removes a CIDR block from the allow list. If only an IP is provided
// a CIDR will be generated.
RemoveAllow(cidr string) error
// AddBlock adds a CIDR block to the block list. If only an IP is provided
// a CIDR will be generated.
AddBlock(cidr string) error
// RemoveBlock removes a CIDR block from the block list. If only an IP is provided
// a CIDR will be generated.
RemoveBlock(cidr string) error
IPLimitValidator
}
// IPLimit implements the IPLimiter interface by having an allow and block list
// of CIDR ranges.
type iplimit struct {
// Array of allowed IP ranges
allowlist []*net.IPNet
// allowList is an array of allowed IP ranges
allowlist map[string]*net.IPNet
// Array of blocked IP ranges
blocklist []*net.IPNet
// blocklist is an array of blocked IP ranges
blocklist map[string]*net.IPNet
// lock is synchronizing the acces to the allow and block lists
lock sync.RWMutex
}
// NewIPLimiter creates a new IPLimiter with the given IP ranges for the
// allowed and blocked IPs. Empty strings are ignored. Returns an error
// if an invalid IP range has been found.
func NewIPLimiter(blocklist, allowlist []string) (IPLimiter, error) {
ipl := &iplimit{}
ipl := &iplimit{
allowlist: make(map[string]*net.IPNet),
blocklist: make(map[string]*net.IPNet),
}
for _, ipblock := range blocklist {
ipblock = strings.TrimSpace(ipblock)
if len(ipblock) == 0 {
continue
}
_, cidr, err := net.ParseCIDR(ipblock)
err := ipl.AddBlock(ipblock)
if err != nil {
return nil, fmt.Errorf("the IP block %s in the block list is invalid", ipblock)
return nil, fmt.Errorf("block list: %w", err)
}
ipl.blocklist = append(ipl.blocklist, cidr)
}
for _, ipblock := range allowlist {
ipblock = strings.TrimSpace(ipblock)
if len(ipblock) == 0 {
continue
}
_, cidr, err := net.ParseCIDR(ipblock)
err := ipl.AddAllow(ipblock)
if err != nil {
return nil, fmt.Errorf("the IP block %s in the allow list is invalid", ipblock)
return nil, fmt.Errorf("allow list: %w", err)
}
ipl.allowlist = append(ipl.allowlist, cidr)
}
return ipl, nil
}
// IsAllowed checks whether the provided IP is allowed according
// to the IP ranges in the allow- and blocklists.
func (ipl *iplimit) validate(ipblock string) (*net.IPNet, error) {
ipblock = strings.TrimSpace(ipblock)
if len(ipblock) == 0 {
return nil, fmt.Errorf("invalid IP block")
}
_, cidr, err := net.ParseCIDR(ipblock)
if err != nil {
addr, err := netip.ParseAddr(ipblock)
if err != nil {
return nil, fmt.Errorf("invalid IP block: %w", err)
}
if addr.Is4() {
ipblock = addr.String() + "/32"
} else {
ipblock = addr.String() + "/128"
}
_, cidr, err = net.ParseCIDR(ipblock)
if err != nil {
return nil, fmt.Errorf("invalid IP block: %w", err)
}
}
return cidr, nil
}
func (ipl *iplimit) AddAllow(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
ipl.allowlist[cidr.String()] = cidr
return nil
}
func (ipl *iplimit) RemoveAllow(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
delete(ipl.allowlist, cidr.String())
return nil
}
func (ipl *iplimit) AddBlock(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
ipl.blocklist[cidr.String()] = cidr
return nil
}
func (ipl *iplimit) RemoveBlock(ipblock string) error {
cidr, err := ipl.validate(ipblock)
if err != nil {
return err
}
ipl.lock.Lock()
defer ipl.lock.Unlock()
delete(ipl.blocklist, cidr.String())
return nil
}
func (ipl *iplimit) IsAllowed(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
ipl.lock.RLock()
defer ipl.lock.RUnlock()
for _, r := range ipl.blocklist {
if r.Contains(parsedIP) {
return false
@@ -87,12 +185,8 @@ func (ipl *iplimit) IsAllowed(ip string) bool {
return false
}
type nulliplimiter struct{}
func NewNullIPLimiter() IPLimiter {
return &nulliplimiter{}
}
ipl, _ := NewIPLimiter(nil, nil)
func (ipl *nulliplimiter) IsAllowed(ip string) bool {
return true
return ipl
}

View File

@@ -188,7 +188,7 @@ type CollectorConfig struct {
// Limiter is an IPLimiter. It is used to query whether a session for an IP
// should be created.
Limiter net.IPLimiter
Limiter net.IPLimitValidator
// InactiveTimeout is the duration of how long a not yet activated session is kept.
// A session gets activated with the first ingress or egress bytes.
@@ -253,7 +253,7 @@ type collector struct {
inactiveTimeout time.Duration
sessionTimeout time.Duration
limiter net.IPLimiter
limiter net.IPLimitValidator
companions []Collector