Files
core/http/middleware/session/session.go
2024-10-14 10:51:35 +02:00

225 lines
4.8 KiB
Go

package session
import (
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"github.com/datarhei/core/v16/glob"
"github.com/datarhei/core/v16/http/api"
"github.com/datarhei/core/v16/http/handler/util"
"github.com/datarhei/core/v16/mem"
"github.com/datarhei/core/v16/net"
"github.com/datarhei/core/v16/session"
"github.com/lithammer/shortuuid/v4"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
type Config struct {
Skipper middleware.Skipper
HTTPCollector session.Collector
HLSEgressCollector session.Collector
HLSIngressCollector session.Collector
}
var DefaultConfig = Config{
Skipper: middleware.DefaultSkipper,
HTTPCollector: session.NewNullCollector(),
HLSEgressCollector: session.NewNullCollector(),
HLSIngressCollector: session.NewNullCollector(),
}
type handler struct {
reSessionID *regexp.Regexp
httpCollector session.Collector
hlsEgressCollector session.Collector
hlsIngressCollector session.Collector
rxsegments map[string]int64
lock sync.Mutex
}
// New returns a new session middleware with default config
func New() echo.MiddlewareFunc {
return NewWithConfig(DefaultConfig)
}
// New returns a new HLS session middleware
func NewWithConfig(config Config) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultConfig.Skipper
}
if config.HTTPCollector == nil {
config.HTTPCollector = DefaultConfig.HTTPCollector
}
if config.HLSEgressCollector == nil {
config.HLSEgressCollector = DefaultConfig.HLSEgressCollector
}
if config.HLSIngressCollector == nil {
config.HLSIngressCollector = DefaultConfig.HLSIngressCollector
}
h := handler{
httpCollector: config.HTTPCollector,
hlsEgressCollector: config.HLSEgressCollector,
hlsIngressCollector: config.HLSIngressCollector,
reSessionID: regexp.MustCompile(`^[` + regexp.QuoteMeta(shortuuid.DefaultAlphabet) + `]{22}$`),
rxsegments: make(map[string]int64),
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
ctxuser := util.DefaultContext(c, "user", "")
req := c.Request()
path := req.URL.Path
referrer := req.Header.Get("Referer")
data, err := verifySession(util.DefaultContext[interface{}](c, "session", nil), path, referrer)
if err != nil {
return api.Err(http.StatusForbidden, "", "verifying session failed: %s", err.Error())
}
data["name"] = ctxuser
data["method"] = req.Method
data["user_agent"] = req.Header.Get("User-Agent")
ip, _ := net.AnonymizeIPString(c.RealIP())
data["ip"] = ip
isM3U8 := strings.HasSuffix(path, ".m3u8")
isTS := strings.HasSuffix(path, ".ts")
if isM3U8 || isTS {
return h.handleHLS(c, ctxuser, data, next)
}
return h.handleHTTP(c, ctxuser, data, next)
}
}
}
func verifySession(raw interface{}, path, referrer string) (map[string]interface{}, error) {
data := map[string]interface{}{}
if raw == nil {
return data, nil
}
var ok bool
data, ok = raw.(map[string]interface{})
if !ok {
return data, fmt.Errorf("invalid session data")
}
if match, ok := data["match"].(string); ok {
if ok, err := glob.Match(match, path, '/'); !ok {
if err != nil {
return data, fmt.Errorf("no match for '%s' in %s: %s", match, path, err.Error())
}
return data, fmt.Errorf("no match for '%s' in %s", match, path)
}
}
if u, err := url.Parse(referrer); err == nil {
referrer = u.Host
}
if remote, ok := data["remote"].([]interface{}); ok && len(remote) != 0 {
if len(referrer) == 0 {
return data, fmt.Errorf("remote not allowed")
}
remotes := []string{}
for _, r := range remote {
v, ok := r.(string)
if !ok {
continue
}
remotes = append(remotes, v)
}
match := false
for _, r := range remotes {
if ok, _ := glob.Match(r, referrer, '.'); ok {
match = true
break
}
}
if !match {
return data, fmt.Errorf("remote not allowed")
}
}
return data, nil
}
func headerSize(header http.Header, buffer *mem.Buffer) int64 {
buffer.Reset()
header.Write(buffer)
return int64(buffer.Len())
}
type bodysizeReader struct {
reader io.ReadCloser
size int64
}
func (r *bodysizeReader) Read(b []byte) (int, error) {
n, err := r.reader.Read(b)
r.size += int64(n)
return n, err
}
func (r *bodysizeReader) Close() error {
return r.reader.Close()
}
type bodysizeWriter struct {
http.ResponseWriter
size int64
code int
}
func (w *bodysizeWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
w.code = statusCode
}
func (w *bodysizeWriter) Write(body []byte) (int, error) {
n, err := w.ResponseWriter.Write(body)
w.size += int64(n)
return n, err
}
func (w *bodysizeWriter) Flush() {
flusher, ok := w.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}