mirror of
https://github.com/datarhei/core.git
synced 2025-09-27 04:16:25 +08:00
444 lines
9.7 KiB
Go
444 lines
9.7 KiB
Go
// Package session is a HLS session middleware for Gin
|
|
package session
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
urlpath "path"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/datarhei/core/v16/net"
|
|
"github.com/datarhei/core/v16/session"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/labstack/echo/v4/middleware"
|
|
"github.com/lithammer/shortuuid/v4"
|
|
)
|
|
|
|
type HLSConfig struct {
|
|
// Skipper defines a function to skip middleware.
|
|
Skipper middleware.Skipper
|
|
EgressCollector session.Collector
|
|
IngressCollector session.Collector
|
|
}
|
|
|
|
var DefaultHLSConfig = HLSConfig{
|
|
Skipper: middleware.DefaultSkipper,
|
|
EgressCollector: session.NewNullCollector(),
|
|
IngressCollector: session.NewNullCollector(),
|
|
}
|
|
|
|
// NewHTTP returns a new HTTP session middleware with default config
|
|
func NewHLS() echo.MiddlewareFunc {
|
|
return NewHLSWithConfig(DefaultHLSConfig)
|
|
}
|
|
|
|
type hls struct {
|
|
egressCollector session.Collector
|
|
ingressCollector session.Collector
|
|
reSessionID *regexp.Regexp
|
|
|
|
rxsegments map[string]int64
|
|
lock sync.Mutex
|
|
}
|
|
|
|
// NewHLS returns a new HLS session middleware
|
|
func NewHLSWithConfig(config HLSConfig) echo.MiddlewareFunc {
|
|
if config.Skipper == nil {
|
|
config.Skipper = DefaultHLSConfig.Skipper
|
|
}
|
|
|
|
if config.EgressCollector == nil {
|
|
config.EgressCollector = DefaultHLSConfig.EgressCollector
|
|
}
|
|
|
|
if config.IngressCollector == nil {
|
|
config.IngressCollector = DefaultHLSConfig.IngressCollector
|
|
}
|
|
|
|
hls := hls{
|
|
egressCollector: config.EgressCollector,
|
|
ingressCollector: config.IngressCollector,
|
|
reSessionID: regexp.MustCompile(`^[` + regexp.QuoteMeta(shortuuid.DefaultAlphabet) + `]{22}$`),
|
|
rxsegments: make(map[string]int64),
|
|
}
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if config.Skipper(c) {
|
|
return next(c)
|
|
}
|
|
|
|
req := c.Request()
|
|
|
|
if req.Method == "PUT" || req.Method == "POST" {
|
|
return hls.handleIngress(c, next)
|
|
} else if req.Method == "GET" || req.Method == "HEAD" {
|
|
return hls.handleEgress(c, next)
|
|
}
|
|
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *hls) handleIngress(c echo.Context, next echo.HandlerFunc) error {
|
|
req := c.Request()
|
|
|
|
path := req.URL.Path
|
|
|
|
if strings.HasSuffix(path, ".m3u8") {
|
|
// Read out the path of the .ts files and look them up in the ts-map.
|
|
// Add it as ingress for the respective "sessionId". The "sessionId" is the .m3u8 file name.
|
|
reader := req.Body
|
|
r := &bodyReader{
|
|
reader: req.Body,
|
|
}
|
|
req.Body = r
|
|
|
|
defer func() {
|
|
req.Body = reader
|
|
|
|
if r.size == 0 {
|
|
return
|
|
}
|
|
|
|
if !h.ingressCollector.IsKnownSession(path) {
|
|
// Register a new session
|
|
reference := strings.TrimSuffix(filepath.Base(path), filepath.Ext(path))
|
|
h.ingressCollector.RegisterAndActivate(path, reference, path, "")
|
|
h.ingressCollector.Extra(path, req.Header.Get("User-Agent"))
|
|
}
|
|
|
|
h.ingressCollector.Ingress(path, headerSize(req.Header))
|
|
h.ingressCollector.Ingress(path, r.size)
|
|
|
|
segments := r.getSegments(urlpath.Dir(path))
|
|
|
|
if len(segments) != 0 {
|
|
h.lock.Lock()
|
|
for _, s := range segments {
|
|
if size, ok := h.rxsegments[s]; ok {
|
|
// Update ingress value
|
|
h.ingressCollector.Ingress(path, size)
|
|
delete(h.rxsegments, s)
|
|
}
|
|
}
|
|
h.lock.Unlock()
|
|
}
|
|
}()
|
|
} else if strings.HasSuffix(path, ".ts") {
|
|
// Get the size of the .ts file and store it in the ts-map for later use.
|
|
reader := req.Body
|
|
r := &bodysizeReader{
|
|
reader: req.Body,
|
|
}
|
|
req.Body = r
|
|
|
|
defer func() {
|
|
req.Body = reader
|
|
|
|
if r.size != 0 {
|
|
h.lock.Lock()
|
|
h.rxsegments[path] = r.size + headerSize(req.Header)
|
|
h.lock.Unlock()
|
|
}
|
|
}()
|
|
}
|
|
|
|
return next(c)
|
|
}
|
|
|
|
func (h *hls) handleEgress(c echo.Context, next echo.HandlerFunc) error {
|
|
req := c.Request()
|
|
res := c.Response()
|
|
|
|
if !h.egressCollector.IsCollectableIP(c.RealIP()) {
|
|
return next(c)
|
|
}
|
|
|
|
path := req.URL.Path
|
|
sessionID := c.QueryParam("session")
|
|
|
|
isM3U8 := strings.HasSuffix(path, ".m3u8")
|
|
isTS := strings.HasSuffix(path, ".ts")
|
|
|
|
rewrite := false
|
|
|
|
if isM3U8 {
|
|
if !h.egressCollector.IsKnownSession(sessionID) {
|
|
if h.egressCollector.IsSessionsExceeded() {
|
|
return echo.NewHTTPError(509, "Number of sessions exceeded")
|
|
}
|
|
|
|
streamBitrate := h.ingressCollector.SessionTopIngressBitrate(path) * 2.0 // Multiply by 2 to cover the initial peak
|
|
maxBitrate := h.egressCollector.MaxEgressBitrate()
|
|
|
|
if maxBitrate > 0.0 {
|
|
currentBitrate := h.egressCollector.CompanionTopEgressBitrate() * 1.15
|
|
|
|
// Add the new session's top bitrate to the ingress top bitrate
|
|
resultingBitrate := currentBitrate + streamBitrate
|
|
|
|
if resultingBitrate <= 0.5 || resultingBitrate >= maxBitrate {
|
|
return echo.NewHTTPError(509, "Bitrate limit exceeded")
|
|
}
|
|
}
|
|
|
|
if len(sessionID) != 0 {
|
|
if !h.reSessionID.MatchString(sessionID) {
|
|
return echo.NewHTTPError(http.StatusForbidden)
|
|
}
|
|
|
|
referrer := req.Header.Get("Referer")
|
|
if u, err := url.Parse(referrer); err == nil {
|
|
referrer = u.Host
|
|
}
|
|
|
|
ip, _ := net.AnonymizeIPString(c.RealIP())
|
|
extra := "[" + ip + "] " + req.Header.Get("User-Agent")
|
|
|
|
reference := strings.TrimSuffix(filepath.Base(path), filepath.Ext(path))
|
|
|
|
// Register a new session
|
|
h.egressCollector.Register(sessionID, reference, path, referrer)
|
|
h.egressCollector.Extra(sessionID, extra)
|
|
|
|
// Give the new session an initial top bitrate
|
|
h.egressCollector.SessionSetTopEgressBitrate(sessionID, streamBitrate)
|
|
}
|
|
}
|
|
|
|
rewrite = true
|
|
}
|
|
|
|
var rewriter *sessionRewriter
|
|
|
|
// Keep the current writer for later
|
|
writer := res.Writer
|
|
|
|
if rewrite {
|
|
// Put the session rewriter in the middle. This will collect
|
|
// the data that we need to rewrite.
|
|
rewriter = &sessionRewriter{
|
|
ResponseWriter: res.Writer,
|
|
}
|
|
|
|
res.Writer = rewriter
|
|
}
|
|
|
|
if err := next(c); err != nil {
|
|
c.Error(err)
|
|
}
|
|
|
|
// Restore the original writer
|
|
res.Writer = writer
|
|
|
|
if rewrite {
|
|
if res.Status != 200 {
|
|
res.Write(rewriter.buffer.Bytes())
|
|
return nil
|
|
}
|
|
|
|
// Rewrite the data befor sending it to the client
|
|
rewriter.rewriteHLS(sessionID, c.Request().URL)
|
|
|
|
res.Header().Set("Cache-Control", "private")
|
|
res.Write(rewriter.buffer.Bytes())
|
|
}
|
|
|
|
if isM3U8 || isTS {
|
|
// Collect how many bytes we've written in this session
|
|
h.egressCollector.Egress(sessionID, headerSize(res.Header()))
|
|
h.egressCollector.Egress(sessionID, res.Size)
|
|
|
|
if isTS {
|
|
// Activate the session. If the session is already active, this is a noop
|
|
h.egressCollector.Activate(sessionID)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func headerSize(header http.Header) int64 {
|
|
var buffer bytes.Buffer
|
|
|
|
header.Write(&buffer)
|
|
|
|
return int64(buffer.Len())
|
|
}
|
|
|
|
type bodyReader struct {
|
|
reader io.ReadCloser
|
|
buffer bytes.Buffer
|
|
size int64
|
|
}
|
|
|
|
func (r *bodyReader) Read(b []byte) (int, error) {
|
|
n, err := r.reader.Read(b)
|
|
if n > 0 {
|
|
r.buffer.Write(b[:n])
|
|
}
|
|
r.size += int64(n)
|
|
|
|
return n, err
|
|
}
|
|
|
|
func (r *bodyReader) Close() error {
|
|
return r.reader.Close()
|
|
}
|
|
|
|
func (r *bodyReader) getSegments(dir string) []string {
|
|
segments := []string{}
|
|
|
|
// Find all segments URLS in the .m3u8
|
|
scanner := bufio.NewScanner(&r.buffer)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
// Ignore empty lines
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Ignore comments
|
|
if strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
u, err := url.Parse(line)
|
|
if err != nil {
|
|
// Invalid URL
|
|
continue
|
|
}
|
|
|
|
if u.Scheme != "" {
|
|
// Ignore full URLs
|
|
continue
|
|
}
|
|
|
|
// Ignore anything that doesn't end in .ts
|
|
if !strings.HasSuffix(u.Path, ".ts") {
|
|
continue
|
|
}
|
|
|
|
path := u.Path
|
|
|
|
if !strings.HasPrefix(u.Path, "/") {
|
|
path = urlpath.Join(dir, u.Path)
|
|
}
|
|
|
|
segments = append(segments, path)
|
|
}
|
|
|
|
return segments
|
|
}
|
|
|
|
type bodysizeReader struct {
|
|
reader io.ReadCloser
|
|
size int64
|
|
}
|
|
|
|
func (r *bodysizeReader) Read(b []byte) (int, error) {
|
|
n, err := r.reader.Read(b)
|
|
r.size += int64(n)
|
|
|
|
return n, err
|
|
}
|
|
|
|
func (r *bodysizeReader) Close() error {
|
|
return r.reader.Close()
|
|
}
|
|
|
|
type sessionRewriter struct {
|
|
http.ResponseWriter
|
|
buffer bytes.Buffer
|
|
}
|
|
|
|
func (g *sessionRewriter) Write(data []byte) (int, error) {
|
|
// Write the data into internal buffer for later rewrite
|
|
w, err := g.buffer.Write(data)
|
|
|
|
return w, err
|
|
}
|
|
|
|
func (g *sessionRewriter) rewriteHLS(sessionID string, requestURL *url.URL) {
|
|
var buffer bytes.Buffer
|
|
|
|
isMaster := false
|
|
|
|
// Find all URLS in the .m3u8 and add the session ID to the query string
|
|
scanner := bufio.NewScanner(&g.buffer)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
// Write empty lines unmodified
|
|
if len(line) == 0 {
|
|
buffer.WriteString(line + "\n")
|
|
continue
|
|
}
|
|
|
|
// Write comments unmodified
|
|
if strings.HasPrefix(line, "#") {
|
|
buffer.WriteString(line + "\n")
|
|
continue
|
|
}
|
|
|
|
u, err := url.Parse(line)
|
|
if err != nil {
|
|
buffer.WriteString(line + "\n")
|
|
continue
|
|
}
|
|
|
|
// Write anything that doesn't end in .m3u8 or .ts unmodified
|
|
if !strings.HasSuffix(u.Path, ".m3u8") && !strings.HasSuffix(u.Path, ".ts") {
|
|
buffer.WriteString(line + "\n")
|
|
continue
|
|
}
|
|
|
|
q := u.Query()
|
|
|
|
// If this is a master manifest (i.e. an m3u8 which contains references to other m3u8), then
|
|
// we give each substream an own session ID if they don't have already.
|
|
if strings.HasSuffix(u.Path, ".m3u8") {
|
|
q.Set("session", shortuuid.New())
|
|
|
|
isMaster = true
|
|
} else {
|
|
q.Set("session", sessionID)
|
|
}
|
|
|
|
u.RawQuery = q.Encode()
|
|
|
|
buffer.WriteString(u.String() + "\n")
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return
|
|
}
|
|
|
|
// If this is not a master manifest and there isn't a session ID, we add a new session ID.
|
|
if !isMaster && len(sessionID) == 0 {
|
|
sessionID = shortuuid.New()
|
|
|
|
buffer.Reset()
|
|
|
|
buffer.WriteString("#EXTM3U\n#EXT-X-VERSION:3\n#EXT-X-STREAM-INF:BANDWIDTH=1024\n")
|
|
|
|
// Add the session ID to the query string
|
|
q := requestURL.Query()
|
|
q.Set("session", sessionID)
|
|
|
|
buffer.WriteString(urlpath.Base(requestURL.Path) + "?" + q.Encode())
|
|
}
|
|
|
|
g.buffer = buffer
|
|
}
|