mirror of
				https://github.com/datarhei/core.git
				synced 2025-10-31 11:26:52 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			444 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			444 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Package session is a HLS session middleware for Gin
 | |
| package session
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"bytes"
 | |
| 	"io"
 | |
| 	"net/http"
 | |
| 	"net/url"
 | |
| 	urlpath "path"
 | |
| 	"path/filepath"
 | |
| 	"regexp"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 
 | |
| 	"github.com/datarhei/core/v16/net"
 | |
| 	"github.com/datarhei/core/v16/session"
 | |
| 
 | |
| 	"github.com/labstack/echo/v4"
 | |
| 	"github.com/labstack/echo/v4/middleware"
 | |
| 	"github.com/lithammer/shortuuid/v4"
 | |
| )
 | |
| 
 | |
| type HLSConfig struct {
 | |
| 	// Skipper defines a function to skip middleware.
 | |
| 	Skipper          middleware.Skipper
 | |
| 	EgressCollector  session.Collector
 | |
| 	IngressCollector session.Collector
 | |
| }
 | |
| 
 | |
| var DefaultHLSConfig = HLSConfig{
 | |
| 	Skipper:          middleware.DefaultSkipper,
 | |
| 	EgressCollector:  session.NewNullCollector(),
 | |
| 	IngressCollector: session.NewNullCollector(),
 | |
| }
 | |
| 
 | |
| // NewHTTP returns a new HTTP session middleware with default config
 | |
| func NewHLS() echo.MiddlewareFunc {
 | |
| 	return NewHLSWithConfig(DefaultHLSConfig)
 | |
| }
 | |
| 
 | |
| type hls struct {
 | |
| 	egressCollector  session.Collector
 | |
| 	ingressCollector session.Collector
 | |
| 	reSessionID      *regexp.Regexp
 | |
| 
 | |
| 	rxsegments map[string]int64
 | |
| 	lock       sync.Mutex
 | |
| }
 | |
| 
 | |
| // NewHLS returns a new HLS session middleware
 | |
| func NewHLSWithConfig(config HLSConfig) echo.MiddlewareFunc {
 | |
| 	if config.Skipper == nil {
 | |
| 		config.Skipper = DefaultHLSConfig.Skipper
 | |
| 	}
 | |
| 
 | |
| 	if config.EgressCollector == nil {
 | |
| 		config.EgressCollector = DefaultHLSConfig.EgressCollector
 | |
| 	}
 | |
| 
 | |
| 	if config.IngressCollector == nil {
 | |
| 		config.IngressCollector = DefaultHLSConfig.IngressCollector
 | |
| 	}
 | |
| 
 | |
| 	hls := hls{
 | |
| 		egressCollector:  config.EgressCollector,
 | |
| 		ingressCollector: config.IngressCollector,
 | |
| 		reSessionID:      regexp.MustCompile(`^[` + regexp.QuoteMeta(shortuuid.DefaultAlphabet) + `]{22}$`),
 | |
| 		rxsegments:       make(map[string]int64),
 | |
| 	}
 | |
| 
 | |
| 	return func(next echo.HandlerFunc) echo.HandlerFunc {
 | |
| 		return func(c echo.Context) error {
 | |
| 			if config.Skipper(c) {
 | |
| 				return next(c)
 | |
| 			}
 | |
| 
 | |
| 			req := c.Request()
 | |
| 
 | |
| 			if req.Method == "PUT" || req.Method == "POST" {
 | |
| 				return hls.handleIngress(c, next)
 | |
| 			} else if req.Method == "GET" || req.Method == "HEAD" {
 | |
| 				return hls.handleEgress(c, next)
 | |
| 			}
 | |
| 
 | |
| 			return next(c)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (h *hls) handleIngress(c echo.Context, next echo.HandlerFunc) error {
 | |
| 	req := c.Request()
 | |
| 
 | |
| 	path := req.URL.Path
 | |
| 
 | |
| 	if strings.HasSuffix(path, ".m3u8") {
 | |
| 		// Read out the path of the .ts files and look them up in the ts-map.
 | |
| 		// Add it as ingress for the respective "sessionId". The "sessionId" is the .m3u8 file name.
 | |
| 		reader := req.Body
 | |
| 		r := &bodyReader{
 | |
| 			reader: req.Body,
 | |
| 		}
 | |
| 		req.Body = r
 | |
| 
 | |
| 		defer func() {
 | |
| 			req.Body = reader
 | |
| 
 | |
| 			if r.size == 0 {
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 			if !h.ingressCollector.IsKnownSession(path) {
 | |
| 				// Register a new session
 | |
| 				reference := strings.TrimSuffix(filepath.Base(path), filepath.Ext(path))
 | |
| 				h.ingressCollector.RegisterAndActivate(path, reference, path, "")
 | |
| 				h.ingressCollector.Extra(path, req.Header.Get("User-Agent"))
 | |
| 			}
 | |
| 
 | |
| 			h.ingressCollector.Ingress(path, headerSize(req.Header))
 | |
| 			h.ingressCollector.Ingress(path, r.size)
 | |
| 
 | |
| 			segments := r.getSegments(urlpath.Dir(path))
 | |
| 
 | |
| 			if len(segments) != 0 {
 | |
| 				h.lock.Lock()
 | |
| 				for _, s := range segments {
 | |
| 					if size, ok := h.rxsegments[s]; ok {
 | |
| 						// Update ingress value
 | |
| 						h.ingressCollector.Ingress(path, size)
 | |
| 						delete(h.rxsegments, s)
 | |
| 					}
 | |
| 				}
 | |
| 				h.lock.Unlock()
 | |
| 			}
 | |
| 		}()
 | |
| 	} else if strings.HasSuffix(path, ".ts") {
 | |
| 		// Get the size of the .ts file and store it in the ts-map for later use.
 | |
| 		reader := req.Body
 | |
| 		r := &bodysizeReader{
 | |
| 			reader: req.Body,
 | |
| 		}
 | |
| 		req.Body = r
 | |
| 
 | |
| 		defer func() {
 | |
| 			req.Body = reader
 | |
| 
 | |
| 			if r.size != 0 {
 | |
| 				h.lock.Lock()
 | |
| 				h.rxsegments[path] = r.size + headerSize(req.Header)
 | |
| 				h.lock.Unlock()
 | |
| 			}
 | |
| 		}()
 | |
| 	}
 | |
| 
 | |
| 	return next(c)
 | |
| }
 | |
| 
 | |
| func (h *hls) handleEgress(c echo.Context, next echo.HandlerFunc) error {
 | |
| 	req := c.Request()
 | |
| 	res := c.Response()
 | |
| 
 | |
| 	if !h.egressCollector.IsCollectableIP(c.RealIP()) {
 | |
| 		return next(c)
 | |
| 	}
 | |
| 
 | |
| 	path := req.URL.Path
 | |
| 	sessionID := c.QueryParam("session")
 | |
| 
 | |
| 	isM3U8 := strings.HasSuffix(path, ".m3u8")
 | |
| 	isTS := strings.HasSuffix(path, ".ts")
 | |
| 
 | |
| 	rewrite := false
 | |
| 
 | |
| 	if isM3U8 {
 | |
| 		if !h.egressCollector.IsKnownSession(sessionID) {
 | |
| 			if h.egressCollector.IsSessionsExceeded() {
 | |
| 				return echo.NewHTTPError(509, "Number of sessions exceeded")
 | |
| 			}
 | |
| 
 | |
| 			streamBitrate := h.ingressCollector.SessionTopIngressBitrate(path) * 2.0 // Multiply by 2 to cover the initial peak
 | |
| 			maxBitrate := h.egressCollector.MaxEgressBitrate()
 | |
| 
 | |
| 			if maxBitrate > 0.0 {
 | |
| 				currentBitrate := h.egressCollector.CompanionTopEgressBitrate() * 1.15
 | |
| 
 | |
| 				// Add the new session's top bitrate to the ingress top bitrate
 | |
| 				resultingBitrate := currentBitrate + streamBitrate
 | |
| 
 | |
| 				if resultingBitrate <= 0.5 || resultingBitrate >= maxBitrate {
 | |
| 					return echo.NewHTTPError(509, "Bitrate limit exceeded")
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			if len(sessionID) != 0 {
 | |
| 				if !h.reSessionID.MatchString(sessionID) {
 | |
| 					return echo.NewHTTPError(http.StatusForbidden)
 | |
| 				}
 | |
| 
 | |
| 				referrer := req.Header.Get("Referer")
 | |
| 				if u, err := url.Parse(referrer); err == nil {
 | |
| 					referrer = u.Host
 | |
| 				}
 | |
| 
 | |
| 				ip, _ := net.AnonymizeIPString(c.RealIP())
 | |
| 				extra := "[" + ip + "] " + req.Header.Get("User-Agent")
 | |
| 
 | |
| 				reference := strings.TrimSuffix(filepath.Base(path), filepath.Ext(path))
 | |
| 
 | |
| 				// Register a new session
 | |
| 				h.egressCollector.Register(sessionID, reference, path, referrer)
 | |
| 				h.egressCollector.Extra(sessionID, extra)
 | |
| 
 | |
| 				// Give the new session an initial top bitrate
 | |
| 				h.egressCollector.SessionSetTopEgressBitrate(sessionID, streamBitrate)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		rewrite = true
 | |
| 	}
 | |
| 
 | |
| 	var rewriter *sessionRewriter
 | |
| 
 | |
| 	// Keep the current writer for later
 | |
| 	writer := res.Writer
 | |
| 
 | |
| 	if rewrite {
 | |
| 		// Put the session rewriter in the middle. This will collect
 | |
| 		// the data that we need to rewrite.
 | |
| 		rewriter = &sessionRewriter{
 | |
| 			ResponseWriter: res.Writer,
 | |
| 		}
 | |
| 
 | |
| 		res.Writer = rewriter
 | |
| 	}
 | |
| 
 | |
| 	if err := next(c); err != nil {
 | |
| 		c.Error(err)
 | |
| 	}
 | |
| 
 | |
| 	// Restore the original writer
 | |
| 	res.Writer = writer
 | |
| 
 | |
| 	if rewrite {
 | |
| 		if res.Status != 200 {
 | |
| 			res.Write(rewriter.buffer.Bytes())
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		// Rewrite the data befor sending it to the client
 | |
| 		rewriter.rewriteHLS(sessionID, c.Request().URL)
 | |
| 
 | |
| 		res.Header().Set("Cache-Control", "private")
 | |
| 		res.Write(rewriter.buffer.Bytes())
 | |
| 	}
 | |
| 
 | |
| 	if isM3U8 || isTS {
 | |
| 		// Collect how many bytes we've written in this session
 | |
| 		h.egressCollector.Egress(sessionID, headerSize(res.Header()))
 | |
| 		h.egressCollector.Egress(sessionID, res.Size)
 | |
| 
 | |
| 		if isTS {
 | |
| 			// Activate the session. If the session is already active, this is a noop
 | |
| 			h.egressCollector.Activate(sessionID)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func headerSize(header http.Header) int64 {
 | |
| 	var buffer bytes.Buffer
 | |
| 
 | |
| 	header.Write(&buffer)
 | |
| 
 | |
| 	return int64(buffer.Len())
 | |
| }
 | |
| 
 | |
| type bodyReader struct {
 | |
| 	reader io.ReadCloser
 | |
| 	buffer bytes.Buffer
 | |
| 	size   int64
 | |
| }
 | |
| 
 | |
| func (r *bodyReader) Read(b []byte) (int, error) {
 | |
| 	n, err := r.reader.Read(b)
 | |
| 	if n > 0 {
 | |
| 		r.buffer.Write(b[:n])
 | |
| 	}
 | |
| 	r.size += int64(n)
 | |
| 
 | |
| 	return n, err
 | |
| }
 | |
| 
 | |
| func (r *bodyReader) Close() error {
 | |
| 	return r.reader.Close()
 | |
| }
 | |
| 
 | |
| func (r *bodyReader) getSegments(dir string) []string {
 | |
| 	segments := []string{}
 | |
| 
 | |
| 	// Find all segments URLS in the .m3u8
 | |
| 	scanner := bufio.NewScanner(&r.buffer)
 | |
| 	for scanner.Scan() {
 | |
| 		line := scanner.Text()
 | |
| 
 | |
| 		// Ignore empty lines
 | |
| 		if len(line) == 0 {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		// Ignore comments
 | |
| 		if strings.HasPrefix(line, "#") {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		u, err := url.Parse(line)
 | |
| 		if err != nil {
 | |
| 			// Invalid URL
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if u.Scheme != "" {
 | |
| 			// Ignore full URLs
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		// Ignore anything that doesn't end in .ts
 | |
| 		if !strings.HasSuffix(u.Path, ".ts") {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		path := u.Path
 | |
| 
 | |
| 		if !strings.HasPrefix(u.Path, "/") {
 | |
| 			path = urlpath.Join(dir, u.Path)
 | |
| 		}
 | |
| 
 | |
| 		segments = append(segments, path)
 | |
| 	}
 | |
| 
 | |
| 	return segments
 | |
| }
 | |
| 
 | |
| type bodysizeReader struct {
 | |
| 	reader io.ReadCloser
 | |
| 	size   int64
 | |
| }
 | |
| 
 | |
| func (r *bodysizeReader) Read(b []byte) (int, error) {
 | |
| 	n, err := r.reader.Read(b)
 | |
| 	r.size += int64(n)
 | |
| 
 | |
| 	return n, err
 | |
| }
 | |
| 
 | |
| func (r *bodysizeReader) Close() error {
 | |
| 	return r.reader.Close()
 | |
| }
 | |
| 
 | |
| type sessionRewriter struct {
 | |
| 	http.ResponseWriter
 | |
| 	buffer bytes.Buffer
 | |
| }
 | |
| 
 | |
| func (g *sessionRewriter) Write(data []byte) (int, error) {
 | |
| 	// Write the data into internal buffer for later rewrite
 | |
| 	w, err := g.buffer.Write(data)
 | |
| 
 | |
| 	return w, err
 | |
| }
 | |
| 
 | |
| func (g *sessionRewriter) rewriteHLS(sessionID string, requestURL *url.URL) {
 | |
| 	var buffer bytes.Buffer
 | |
| 
 | |
| 	isMaster := false
 | |
| 
 | |
| 	// Find all URLS in the .m3u8 and add the session ID to the query string
 | |
| 	scanner := bufio.NewScanner(&g.buffer)
 | |
| 	for scanner.Scan() {
 | |
| 		line := scanner.Text()
 | |
| 
 | |
| 		// Write empty lines unmodified
 | |
| 		if len(line) == 0 {
 | |
| 			buffer.WriteString(line + "\n")
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		// Write comments unmodified
 | |
| 		if strings.HasPrefix(line, "#") {
 | |
| 			buffer.WriteString(line + "\n")
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		u, err := url.Parse(line)
 | |
| 		if err != nil {
 | |
| 			buffer.WriteString(line + "\n")
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		// Write anything that doesn't end in .m3u8 or .ts unmodified
 | |
| 		if !strings.HasSuffix(u.Path, ".m3u8") && !strings.HasSuffix(u.Path, ".ts") {
 | |
| 			buffer.WriteString(line + "\n")
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		q := u.Query()
 | |
| 
 | |
| 		// If this is a master manifest (i.e. an m3u8 which contains references to other m3u8), then
 | |
| 		// we give each substream an own session ID if they don't have already.
 | |
| 		if strings.HasSuffix(u.Path, ".m3u8") {
 | |
| 			q.Set("session", shortuuid.New())
 | |
| 
 | |
| 			isMaster = true
 | |
| 		} else {
 | |
| 			q.Set("session", sessionID)
 | |
| 		}
 | |
| 
 | |
| 		u.RawQuery = q.Encode()
 | |
| 
 | |
| 		buffer.WriteString(u.String() + "\n")
 | |
| 	}
 | |
| 
 | |
| 	if err := scanner.Err(); err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// If this is not a master manifest and there isn't a session ID, we add a new session ID.
 | |
| 	if !isMaster && len(sessionID) == 0 {
 | |
| 		sessionID = shortuuid.New()
 | |
| 
 | |
| 		buffer.Reset()
 | |
| 
 | |
| 		buffer.WriteString("#EXTM3U\n#EXT-X-VERSION:3\n#EXT-X-STREAM-INF:BANDWIDTH=1024\n")
 | |
| 
 | |
| 		// Add the session ID to the query string
 | |
| 		q := requestURL.Query()
 | |
| 		q.Set("session", sessionID)
 | |
| 
 | |
| 		buffer.WriteString(urlpath.Base(requestURL.Path) + "?" + q.Encode())
 | |
| 	}
 | |
| 
 | |
| 	g.buffer = buffer
 | |
| }
 | 
