diff --git a/app/api/api.go b/app/api/api.go index e62c2f3c..00c0e231 100644 --- a/app/api/api.go +++ b/app/api/api.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "fmt" "io" - "io/ioutil" golog "log" gonet "net" gohttp "net/http" @@ -62,23 +61,24 @@ type API interface { } type api struct { - restream restream.Restreamer - ffmpeg ffmpeg.FFmpeg - diskfs fs.Filesystem - memfs fs.Filesystem - rtmpserver rtmp.Server - srtserver srt.Server - metrics monitor.HistoryMonitor - prom prometheus.Metrics - service service.Service - sessions session.Registry - cache cache.Cacher - mainserver *gohttp.Server - sidecarserver *gohttp.Server - httpjwt jwt.JWT - update update.Checker - replacer replace.Replacer - cluster cluster.Cluster + restream restream.Restreamer + ffmpeg ffmpeg.FFmpeg + diskfs fs.Filesystem + memfs fs.Filesystem + rtmpserver rtmp.Server + srtserver srt.Server + metrics monitor.HistoryMonitor + prom prometheus.Metrics + service service.Service + sessions session.Registry + sessionsLimiter net.IPLimiter + cache cache.Cacher + mainserver *gohttp.Server + sidecarserver *gohttp.Server + httpjwt jwt.JWT + update update.Checker + replacer replace.Replacer + cluster cluster.Cluster errorChan chan error @@ -122,7 +122,7 @@ func New(configpath string, logwriter io.Writer) (API, error) { a.log.writer = logwriter if a.log.writer == nil { - a.log.writer = ioutil.Discard + a.log.writer = io.Discard } a.errorChan = make(chan error, 1) @@ -301,6 +301,8 @@ func (a *api) start() error { return fmt.Errorf("incorret IP ranges for the statistics provided: %w", err) } + a.sessionsLimiter = iplimiter + config := session.CollectorConfig{ MaxTxBitrate: cfg.Sessions.MaxBitrate * 1024 * 1024, MaxSessions: cfg.Sessions.MaxSessions, @@ -498,7 +500,8 @@ func (a *api) start() error { a.restream = restream if cluster, err := cluster.New(cluster.ClusterConfig{ - Logger: a.log.logger.core.WithComponent("Cluster"), + IPLimiter: a.sessionsLimiter, + Logger: a.log.logger.core.WithComponent("Cluster"), }); err != nil { return fmt.Errorf("unable to create cluster: %w", err) } else { @@ -778,7 +781,7 @@ func (a *api) start() error { logcontext = "HTTPS" } - var iplimiter net.IPLimiter + var iplimiter net.IPLimitValidator if cfg.TLS.Enable { limiter, err := net.NewIPLimiter(cfg.API.Access.HTTPS.Block, cfg.API.Access.HTTPS.Allow) diff --git a/cluster/cluster.go b/cluster/cluster.go index 7883d00a..e19cfc93 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -7,6 +7,7 @@ import ( "time" "github.com/datarhei/core/v16/log" + "github.com/datarhei/core/v16/net" ) type ClusterReader interface { @@ -33,7 +34,8 @@ type Cluster interface { } type ClusterConfig struct { - Logger log.Logger + IPLimiter net.IPLimiter + Logger log.Logger } type cluster struct { @@ -42,6 +44,8 @@ type cluster struct { idupdate map[string]time.Time fileid map[string]string + limiter net.IPLimiter + updates chan NodeState lock sync.RWMutex @@ -57,10 +61,15 @@ func New(config ClusterConfig) (Cluster, error) { idfiles: map[string][]string{}, idupdate: map[string]time.Time{}, fileid: map[string]string{}, + limiter: config.IPLimiter, updates: make(chan NodeState, 64), logger: config.Logger, } + if c.limiter == nil { + c.limiter = net.NewNullIPLimiter() + } + if c.logger == nil { c.logger = log.New("") } @@ -134,9 +143,15 @@ func (c *cluster) AddNode(address, username, password string) (string, error) { defer c.lock.Unlock() if _, ok := c.nodes[id]; ok { + node.stop() return id, nil } + ips := node.IPs() + for _, ip := range ips { + c.limiter.AddBlock(ip) + } + c.nodes[id] = node c.logger.Info().WithFields(log.Fields{ @@ -160,6 +175,12 @@ func (c *cluster) RemoveNode(id string) error { delete(c.nodes, id) + ips := node.IPs() + + for _, ip := range ips { + c.limiter.RemoveBlock(ip) + } + c.logger.Info().WithFields(log.Fields{ "id": id, }).Log("Removed node") diff --git a/cluster/node.go b/cluster/node.go index 7ce6836b..2a14cf14 100644 --- a/cluster/node.go +++ b/cluster/node.go @@ -17,6 +17,7 @@ import ( type NodeReader interface { Address() string + IPs() []string State() NodeState } @@ -40,6 +41,7 @@ const ( type node struct { address string + ips []string state nodeState username string password string @@ -75,8 +77,14 @@ func newNode(address, username, password string, updates chan<- NodeState) (*nod return nil, fmt.Errorf("invalid address: %w", err) } + addrs, err := net.LookupHost(host) + if err != nil { + return nil, fmt.Errorf("lookup failed: %w", err) + } + n := &node{ address: address, + ips: addrs, username: username, password: password, state: stateDisconnected, @@ -172,6 +180,10 @@ func (n *node) Address() string { return n.address } +func (n *node) IPs() []string { + return n.ips +} + func (n *node) ID() string { return n.peer.ID() } @@ -239,8 +251,6 @@ func (n *node) files() { n.fileList[nfiles] = "srt:" + file.Name nfiles++ } - - return } func (n *node) GetURL(path string) (string, error) { diff --git a/http/middleware/iplimit/iplimit.go b/http/middleware/iplimit/iplimit.go index 369551d4..75652e4c 100644 --- a/http/middleware/iplimit/iplimit.go +++ b/http/middleware/iplimit/iplimit.go @@ -12,7 +12,7 @@ import ( type Config struct { // Skipper defines a function to skip middleware. Skipper middleware.Skipper - Limiter net.IPLimiter + Limiter net.IPLimitValidator } var DefaultConfig = Config{ diff --git a/http/server.go b/http/server.go index f8bd7317..332302d2 100644 --- a/http/server.go +++ b/http/server.go @@ -83,7 +83,7 @@ type Config struct { MimeTypesFile string DiskFS fs.Filesystem MemFS MemFSConfig - IPLimiter net.IPLimiter + IPLimiter net.IPLimitValidator Profiling bool Cors CorsConfig RTMP rtmp.Server diff --git a/net/iplimit.go b/net/iplimit.go index 534cf18d..7e2662f6 100644 --- a/net/iplimit.go +++ b/net/iplimit.go @@ -3,71 +3,169 @@ package net import ( "fmt" "net" + "net/netip" "strings" + "sync" ) -// The IPLimiter interface allows to check whether a certain IP -// is allowed. -type IPLimiter interface { +// The IPLimitValidator interface allows to check whether a certain IP is allowed. +type IPLimitValidator interface { // Tests whether the IP is allowed in respect to the underlying implementation IsAllowed(ip string) bool } +type IPLimiter interface { + // AddAllow adds a CIDR block to the allow list. If only an IP is provided + // a CIDR will be generated. + AddAllow(cidr string) error + + // RemoveAllow removes a CIDR block from the allow list. If only an IP is provided + // a CIDR will be generated. + RemoveAllow(cidr string) error + + // AddBlock adds a CIDR block to the block list. If only an IP is provided + // a CIDR will be generated. + AddBlock(cidr string) error + + // RemoveBlock removes a CIDR block from the block list. If only an IP is provided + // a CIDR will be generated. + RemoveBlock(cidr string) error + + IPLimitValidator +} + // IPLimit implements the IPLimiter interface by having an allow and block list // of CIDR ranges. type iplimit struct { - // Array of allowed IP ranges - allowlist []*net.IPNet + // allowList is an array of allowed IP ranges + allowlist map[string]*net.IPNet - // Array of blocked IP ranges - blocklist []*net.IPNet + // blocklist is an array of blocked IP ranges + blocklist map[string]*net.IPNet + + // lock is synchronizing the acces to the allow and block lists + lock sync.RWMutex } // NewIPLimiter creates a new IPLimiter with the given IP ranges for the // allowed and blocked IPs. Empty strings are ignored. Returns an error // if an invalid IP range has been found. func NewIPLimiter(blocklist, allowlist []string) (IPLimiter, error) { - ipl := &iplimit{} + ipl := &iplimit{ + allowlist: make(map[string]*net.IPNet), + blocklist: make(map[string]*net.IPNet), + } for _, ipblock := range blocklist { - ipblock = strings.TrimSpace(ipblock) - if len(ipblock) == 0 { - continue - } - - _, cidr, err := net.ParseCIDR(ipblock) + err := ipl.AddBlock(ipblock) if err != nil { - return nil, fmt.Errorf("the IP block %s in the block list is invalid", ipblock) + return nil, fmt.Errorf("block list: %w", err) } - - ipl.blocklist = append(ipl.blocklist, cidr) } for _, ipblock := range allowlist { - ipblock = strings.TrimSpace(ipblock) - if len(ipblock) == 0 { - continue - } - - _, cidr, err := net.ParseCIDR(ipblock) + err := ipl.AddAllow(ipblock) if err != nil { - return nil, fmt.Errorf("the IP block %s in the allow list is invalid", ipblock) + return nil, fmt.Errorf("allow list: %w", err) } - - ipl.allowlist = append(ipl.allowlist, cidr) } return ipl, nil } -// IsAllowed checks whether the provided IP is allowed according -// to the IP ranges in the allow- and blocklists. +func (ipl *iplimit) validate(ipblock string) (*net.IPNet, error) { + ipblock = strings.TrimSpace(ipblock) + if len(ipblock) == 0 { + return nil, fmt.Errorf("invalid IP block") + } + + _, cidr, err := net.ParseCIDR(ipblock) + if err != nil { + addr, err := netip.ParseAddr(ipblock) + if err != nil { + return nil, fmt.Errorf("invalid IP block: %w", err) + } + + if addr.Is4() { + ipblock = addr.String() + "/32" + } else { + ipblock = addr.String() + "/128" + } + + _, cidr, err = net.ParseCIDR(ipblock) + if err != nil { + return nil, fmt.Errorf("invalid IP block: %w", err) + } + } + + return cidr, nil +} + +func (ipl *iplimit) AddAllow(ipblock string) error { + cidr, err := ipl.validate(ipblock) + if err != nil { + return err + } + + ipl.lock.Lock() + defer ipl.lock.Unlock() + + ipl.allowlist[cidr.String()] = cidr + + return nil +} + +func (ipl *iplimit) RemoveAllow(ipblock string) error { + cidr, err := ipl.validate(ipblock) + if err != nil { + return err + } + + ipl.lock.Lock() + defer ipl.lock.Unlock() + + delete(ipl.allowlist, cidr.String()) + + return nil +} + +func (ipl *iplimit) AddBlock(ipblock string) error { + cidr, err := ipl.validate(ipblock) + if err != nil { + return err + } + + ipl.lock.Lock() + defer ipl.lock.Unlock() + + ipl.blocklist[cidr.String()] = cidr + + return nil +} + +func (ipl *iplimit) RemoveBlock(ipblock string) error { + cidr, err := ipl.validate(ipblock) + if err != nil { + return err + } + + ipl.lock.Lock() + defer ipl.lock.Unlock() + + delete(ipl.blocklist, cidr.String()) + + return nil +} + func (ipl *iplimit) IsAllowed(ip string) bool { parsedIP := net.ParseIP(ip) if parsedIP == nil { return false } + ipl.lock.RLock() + defer ipl.lock.RUnlock() + for _, r := range ipl.blocklist { if r.Contains(parsedIP) { return false @@ -87,12 +185,8 @@ func (ipl *iplimit) IsAllowed(ip string) bool { return false } -type nulliplimiter struct{} - func NewNullIPLimiter() IPLimiter { - return &nulliplimiter{} -} + ipl, _ := NewIPLimiter(nil, nil) -func (ipl *nulliplimiter) IsAllowed(ip string) bool { - return true + return ipl } diff --git a/session/collector.go b/session/collector.go index e5f1aa7e..d46953f8 100644 --- a/session/collector.go +++ b/session/collector.go @@ -188,7 +188,7 @@ type CollectorConfig struct { // Limiter is an IPLimiter. It is used to query whether a session for an IP // should be created. - Limiter net.IPLimiter + Limiter net.IPLimitValidator // InactiveTimeout is the duration of how long a not yet activated session is kept. // A session gets activated with the first ingress or egress bytes. @@ -253,7 +253,7 @@ type collector struct { inactiveTimeout time.Duration sessionTimeout time.Duration - limiter net.IPLimiter + limiter net.IPLimitValidator companions []Collector