mirror of
https://github.com/datarhei/core.git
synced 2025-10-05 16:07:07 +08:00
Set cluster node ips on session block list
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
golog "log"
|
||||
gonet "net"
|
||||
gohttp "net/http"
|
||||
@@ -62,23 +61,24 @@ type API interface {
|
||||
}
|
||||
|
||||
type api struct {
|
||||
restream restream.Restreamer
|
||||
ffmpeg ffmpeg.FFmpeg
|
||||
diskfs fs.Filesystem
|
||||
memfs fs.Filesystem
|
||||
rtmpserver rtmp.Server
|
||||
srtserver srt.Server
|
||||
metrics monitor.HistoryMonitor
|
||||
prom prometheus.Metrics
|
||||
service service.Service
|
||||
sessions session.Registry
|
||||
cache cache.Cacher
|
||||
mainserver *gohttp.Server
|
||||
sidecarserver *gohttp.Server
|
||||
httpjwt jwt.JWT
|
||||
update update.Checker
|
||||
replacer replace.Replacer
|
||||
cluster cluster.Cluster
|
||||
restream restream.Restreamer
|
||||
ffmpeg ffmpeg.FFmpeg
|
||||
diskfs fs.Filesystem
|
||||
memfs fs.Filesystem
|
||||
rtmpserver rtmp.Server
|
||||
srtserver srt.Server
|
||||
metrics monitor.HistoryMonitor
|
||||
prom prometheus.Metrics
|
||||
service service.Service
|
||||
sessions session.Registry
|
||||
sessionsLimiter net.IPLimiter
|
||||
cache cache.Cacher
|
||||
mainserver *gohttp.Server
|
||||
sidecarserver *gohttp.Server
|
||||
httpjwt jwt.JWT
|
||||
update update.Checker
|
||||
replacer replace.Replacer
|
||||
cluster cluster.Cluster
|
||||
|
||||
errorChan chan error
|
||||
|
||||
@@ -122,7 +122,7 @@ func New(configpath string, logwriter io.Writer) (API, error) {
|
||||
a.log.writer = logwriter
|
||||
|
||||
if a.log.writer == nil {
|
||||
a.log.writer = ioutil.Discard
|
||||
a.log.writer = io.Discard
|
||||
}
|
||||
|
||||
a.errorChan = make(chan error, 1)
|
||||
@@ -301,6 +301,8 @@ func (a *api) start() error {
|
||||
return fmt.Errorf("incorret IP ranges for the statistics provided: %w", err)
|
||||
}
|
||||
|
||||
a.sessionsLimiter = iplimiter
|
||||
|
||||
config := session.CollectorConfig{
|
||||
MaxTxBitrate: cfg.Sessions.MaxBitrate * 1024 * 1024,
|
||||
MaxSessions: cfg.Sessions.MaxSessions,
|
||||
@@ -498,7 +500,8 @@ func (a *api) start() error {
|
||||
a.restream = restream
|
||||
|
||||
if cluster, err := cluster.New(cluster.ClusterConfig{
|
||||
Logger: a.log.logger.core.WithComponent("Cluster"),
|
||||
IPLimiter: a.sessionsLimiter,
|
||||
Logger: a.log.logger.core.WithComponent("Cluster"),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("unable to create cluster: %w", err)
|
||||
} else {
|
||||
@@ -778,7 +781,7 @@ func (a *api) start() error {
|
||||
logcontext = "HTTPS"
|
||||
}
|
||||
|
||||
var iplimiter net.IPLimiter
|
||||
var iplimiter net.IPLimitValidator
|
||||
|
||||
if cfg.TLS.Enable {
|
||||
limiter, err := net.NewIPLimiter(cfg.API.Access.HTTPS.Block, cfg.API.Access.HTTPS.Allow)
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/datarhei/core/v16/log"
|
||||
"github.com/datarhei/core/v16/net"
|
||||
)
|
||||
|
||||
type ClusterReader interface {
|
||||
@@ -33,7 +34,8 @@ type Cluster interface {
|
||||
}
|
||||
|
||||
type ClusterConfig struct {
|
||||
Logger log.Logger
|
||||
IPLimiter net.IPLimiter
|
||||
Logger log.Logger
|
||||
}
|
||||
|
||||
type cluster struct {
|
||||
@@ -42,6 +44,8 @@ type cluster struct {
|
||||
idupdate map[string]time.Time
|
||||
fileid map[string]string
|
||||
|
||||
limiter net.IPLimiter
|
||||
|
||||
updates chan NodeState
|
||||
|
||||
lock sync.RWMutex
|
||||
@@ -57,10 +61,15 @@ func New(config ClusterConfig) (Cluster, error) {
|
||||
idfiles: map[string][]string{},
|
||||
idupdate: map[string]time.Time{},
|
||||
fileid: map[string]string{},
|
||||
limiter: config.IPLimiter,
|
||||
updates: make(chan NodeState, 64),
|
||||
logger: config.Logger,
|
||||
}
|
||||
|
||||
if c.limiter == nil {
|
||||
c.limiter = net.NewNullIPLimiter()
|
||||
}
|
||||
|
||||
if c.logger == nil {
|
||||
c.logger = log.New("")
|
||||
}
|
||||
@@ -134,9 +143,15 @@ func (c *cluster) AddNode(address, username, password string) (string, error) {
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if _, ok := c.nodes[id]; ok {
|
||||
node.stop()
|
||||
return id, nil
|
||||
}
|
||||
|
||||
ips := node.IPs()
|
||||
for _, ip := range ips {
|
||||
c.limiter.AddBlock(ip)
|
||||
}
|
||||
|
||||
c.nodes[id] = node
|
||||
|
||||
c.logger.Info().WithFields(log.Fields{
|
||||
@@ -160,6 +175,12 @@ func (c *cluster) RemoveNode(id string) error {
|
||||
|
||||
delete(c.nodes, id)
|
||||
|
||||
ips := node.IPs()
|
||||
|
||||
for _, ip := range ips {
|
||||
c.limiter.RemoveBlock(ip)
|
||||
}
|
||||
|
||||
c.logger.Info().WithFields(log.Fields{
|
||||
"id": id,
|
||||
}).Log("Removed node")
|
||||
|
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
type NodeReader interface {
|
||||
Address() string
|
||||
IPs() []string
|
||||
State() NodeState
|
||||
}
|
||||
|
||||
@@ -40,6 +41,7 @@ const (
|
||||
|
||||
type node struct {
|
||||
address string
|
||||
ips []string
|
||||
state nodeState
|
||||
username string
|
||||
password string
|
||||
@@ -75,8 +77,14 @@ func newNode(address, username, password string, updates chan<- NodeState) (*nod
|
||||
return nil, fmt.Errorf("invalid address: %w", err)
|
||||
}
|
||||
|
||||
addrs, err := net.LookupHost(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup failed: %w", err)
|
||||
}
|
||||
|
||||
n := &node{
|
||||
address: address,
|
||||
ips: addrs,
|
||||
username: username,
|
||||
password: password,
|
||||
state: stateDisconnected,
|
||||
@@ -172,6 +180,10 @@ func (n *node) Address() string {
|
||||
return n.address
|
||||
}
|
||||
|
||||
func (n *node) IPs() []string {
|
||||
return n.ips
|
||||
}
|
||||
|
||||
func (n *node) ID() string {
|
||||
return n.peer.ID()
|
||||
}
|
||||
@@ -239,8 +251,6 @@ func (n *node) files() {
|
||||
n.fileList[nfiles] = "srt:" + file.Name
|
||||
nfiles++
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (n *node) GetURL(path string) (string, error) {
|
||||
|
@@ -12,7 +12,7 @@ import (
|
||||
type Config struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper middleware.Skipper
|
||||
Limiter net.IPLimiter
|
||||
Limiter net.IPLimitValidator
|
||||
}
|
||||
|
||||
var DefaultConfig = Config{
|
||||
|
@@ -83,7 +83,7 @@ type Config struct {
|
||||
MimeTypesFile string
|
||||
DiskFS fs.Filesystem
|
||||
MemFS MemFSConfig
|
||||
IPLimiter net.IPLimiter
|
||||
IPLimiter net.IPLimitValidator
|
||||
Profiling bool
|
||||
Cors CorsConfig
|
||||
RTMP rtmp.Server
|
||||
|
162
net/iplimit.go
162
net/iplimit.go
@@ -3,71 +3,169 @@ package net
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// The IPLimiter interface allows to check whether a certain IP
|
||||
// is allowed.
|
||||
type IPLimiter interface {
|
||||
// The IPLimitValidator interface allows to check whether a certain IP is allowed.
|
||||
type IPLimitValidator interface {
|
||||
// Tests whether the IP is allowed in respect to the underlying implementation
|
||||
IsAllowed(ip string) bool
|
||||
}
|
||||
|
||||
type IPLimiter interface {
|
||||
// AddAllow adds a CIDR block to the allow list. If only an IP is provided
|
||||
// a CIDR will be generated.
|
||||
AddAllow(cidr string) error
|
||||
|
||||
// RemoveAllow removes a CIDR block from the allow list. If only an IP is provided
|
||||
// a CIDR will be generated.
|
||||
RemoveAllow(cidr string) error
|
||||
|
||||
// AddBlock adds a CIDR block to the block list. If only an IP is provided
|
||||
// a CIDR will be generated.
|
||||
AddBlock(cidr string) error
|
||||
|
||||
// RemoveBlock removes a CIDR block from the block list. If only an IP is provided
|
||||
// a CIDR will be generated.
|
||||
RemoveBlock(cidr string) error
|
||||
|
||||
IPLimitValidator
|
||||
}
|
||||
|
||||
// IPLimit implements the IPLimiter interface by having an allow and block list
|
||||
// of CIDR ranges.
|
||||
type iplimit struct {
|
||||
// Array of allowed IP ranges
|
||||
allowlist []*net.IPNet
|
||||
// allowList is an array of allowed IP ranges
|
||||
allowlist map[string]*net.IPNet
|
||||
|
||||
// Array of blocked IP ranges
|
||||
blocklist []*net.IPNet
|
||||
// blocklist is an array of blocked IP ranges
|
||||
blocklist map[string]*net.IPNet
|
||||
|
||||
// lock is synchronizing the acces to the allow and block lists
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewIPLimiter creates a new IPLimiter with the given IP ranges for the
|
||||
// allowed and blocked IPs. Empty strings are ignored. Returns an error
|
||||
// if an invalid IP range has been found.
|
||||
func NewIPLimiter(blocklist, allowlist []string) (IPLimiter, error) {
|
||||
ipl := &iplimit{}
|
||||
ipl := &iplimit{
|
||||
allowlist: make(map[string]*net.IPNet),
|
||||
blocklist: make(map[string]*net.IPNet),
|
||||
}
|
||||
|
||||
for _, ipblock := range blocklist {
|
||||
ipblock = strings.TrimSpace(ipblock)
|
||||
if len(ipblock) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
_, cidr, err := net.ParseCIDR(ipblock)
|
||||
err := ipl.AddBlock(ipblock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("the IP block %s in the block list is invalid", ipblock)
|
||||
return nil, fmt.Errorf("block list: %w", err)
|
||||
}
|
||||
|
||||
ipl.blocklist = append(ipl.blocklist, cidr)
|
||||
}
|
||||
|
||||
for _, ipblock := range allowlist {
|
||||
ipblock = strings.TrimSpace(ipblock)
|
||||
if len(ipblock) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
_, cidr, err := net.ParseCIDR(ipblock)
|
||||
err := ipl.AddAllow(ipblock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("the IP block %s in the allow list is invalid", ipblock)
|
||||
return nil, fmt.Errorf("allow list: %w", err)
|
||||
}
|
||||
|
||||
ipl.allowlist = append(ipl.allowlist, cidr)
|
||||
}
|
||||
|
||||
return ipl, nil
|
||||
}
|
||||
|
||||
// IsAllowed checks whether the provided IP is allowed according
|
||||
// to the IP ranges in the allow- and blocklists.
|
||||
func (ipl *iplimit) validate(ipblock string) (*net.IPNet, error) {
|
||||
ipblock = strings.TrimSpace(ipblock)
|
||||
if len(ipblock) == 0 {
|
||||
return nil, fmt.Errorf("invalid IP block")
|
||||
}
|
||||
|
||||
_, cidr, err := net.ParseCIDR(ipblock)
|
||||
if err != nil {
|
||||
addr, err := netip.ParseAddr(ipblock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid IP block: %w", err)
|
||||
}
|
||||
|
||||
if addr.Is4() {
|
||||
ipblock = addr.String() + "/32"
|
||||
} else {
|
||||
ipblock = addr.String() + "/128"
|
||||
}
|
||||
|
||||
_, cidr, err = net.ParseCIDR(ipblock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid IP block: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return cidr, nil
|
||||
}
|
||||
|
||||
func (ipl *iplimit) AddAllow(ipblock string) error {
|
||||
cidr, err := ipl.validate(ipblock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ipl.lock.Lock()
|
||||
defer ipl.lock.Unlock()
|
||||
|
||||
ipl.allowlist[cidr.String()] = cidr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ipl *iplimit) RemoveAllow(ipblock string) error {
|
||||
cidr, err := ipl.validate(ipblock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ipl.lock.Lock()
|
||||
defer ipl.lock.Unlock()
|
||||
|
||||
delete(ipl.allowlist, cidr.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ipl *iplimit) AddBlock(ipblock string) error {
|
||||
cidr, err := ipl.validate(ipblock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ipl.lock.Lock()
|
||||
defer ipl.lock.Unlock()
|
||||
|
||||
ipl.blocklist[cidr.String()] = cidr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ipl *iplimit) RemoveBlock(ipblock string) error {
|
||||
cidr, err := ipl.validate(ipblock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ipl.lock.Lock()
|
||||
defer ipl.lock.Unlock()
|
||||
|
||||
delete(ipl.blocklist, cidr.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ipl *iplimit) IsAllowed(ip string) bool {
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ipl.lock.RLock()
|
||||
defer ipl.lock.RUnlock()
|
||||
|
||||
for _, r := range ipl.blocklist {
|
||||
if r.Contains(parsedIP) {
|
||||
return false
|
||||
@@ -87,12 +185,8 @@ func (ipl *iplimit) IsAllowed(ip string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type nulliplimiter struct{}
|
||||
|
||||
func NewNullIPLimiter() IPLimiter {
|
||||
return &nulliplimiter{}
|
||||
}
|
||||
ipl, _ := NewIPLimiter(nil, nil)
|
||||
|
||||
func (ipl *nulliplimiter) IsAllowed(ip string) bool {
|
||||
return true
|
||||
return ipl
|
||||
}
|
||||
|
@@ -188,7 +188,7 @@ type CollectorConfig struct {
|
||||
|
||||
// Limiter is an IPLimiter. It is used to query whether a session for an IP
|
||||
// should be created.
|
||||
Limiter net.IPLimiter
|
||||
Limiter net.IPLimitValidator
|
||||
|
||||
// InactiveTimeout is the duration of how long a not yet activated session is kept.
|
||||
// A session gets activated with the first ingress or egress bytes.
|
||||
@@ -253,7 +253,7 @@ type collector struct {
|
||||
inactiveTimeout time.Duration
|
||||
sessionTimeout time.Duration
|
||||
|
||||
limiter net.IPLimiter
|
||||
limiter net.IPLimitValidator
|
||||
|
||||
companions []Collector
|
||||
|
||||
|
Reference in New Issue
Block a user