Files
streamctl/pkg/streamserver/server/rtmp/rtmp_server.go
2024-07-20 20:52:38 +01:00

271 lines
5.7 KiB
Go

package rtmp
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"github.com/AlexxIT/go2rtc/pkg/core"
"github.com/AlexxIT/go2rtc/pkg/flv"
"github.com/AlexxIT/go2rtc/pkg/rtmp"
"github.com/facebookincubator/go-belt/tool/experimental/errmon"
"github.com/facebookincubator/go-belt/tool/logger"
"github.com/rs/zerolog/log"
"github.com/xaionaro-go/datacounter"
"github.com/xaionaro-go/streamctl/pkg/streamserver/consts"
"github.com/xaionaro-go/streamctl/pkg/streamserver/server"
"github.com/xaionaro-go/streamctl/pkg/streamserver/streams"
"github.com/xaionaro-go/streamctl/pkg/streamserver/types"
)
type RTMPServer struct {
Config Config
StreamHandler *streams.StreamHandler
Listener net.Listener
CancelFn context.CancelFunc
TrafficCounter server.TrafficCounter
}
type Config struct {
Listen string `yaml:"listen" json:"listen"`
}
func New(
ctx context.Context,
cfg Config,
streamHandler *streams.StreamHandler,
) (*RTMPServer, error) {
if cfg.Listen == "" {
cfg.Listen = "127.0.0.1:1935"
}
ln, err := net.Listen("tcp", cfg.Listen)
if err != nil {
return nil, fmt.Errorf("unable to start listening '%s': %w", cfg.Listen, err)
}
ctx, cancelFn := context.WithCancel(ctx)
s := &RTMPServer{
Config: cfg,
StreamHandler: streamHandler,
CancelFn: cancelFn,
Listener: ln,
}
go func() {
<-ctx.Done()
logger.Infof(ctx, "closing %s", cfg.Listen)
err := ln.Close()
errmon.ObserveErrorCtx(ctx, err)
}()
logger.Infof(ctx, "started RTMP server at %s", cfg.Listen)
go func() {
for {
if ctx.Err() != nil {
return
}
conn, err := ln.Accept()
if err != nil {
errmon.ObserveErrorCtx(ctx, err)
return
}
go func() {
if err = s.tcpHandle(conn); err != nil {
errmon.ObserveErrorCtx(ctx, err)
}
}()
}
}()
return s, nil
}
func (s *RTMPServer) Type() types.ServerType {
return types.ServerTypeRTMP
}
func (s *RTMPServer) ListenAddr() string {
return s.Listener.Addr().String()
}
func (s *RTMPServer) Close() error {
logger.Default().Tracef("(*RTMPServer).Close()")
s.CancelFn()
return nil
}
func (s *RTMPServer) tcpHandle(netConn net.Conn) error {
rtmpConn, err := rtmp.NewServer(netConn)
if err != nil {
return err
}
if err = rtmpConn.ReadCommands(); err != nil {
return err
}
switch rtmpConn.Intent {
case rtmp.CommandPlay:
stream := s.StreamHandler.Get(rtmpConn.App)
if stream == nil {
return errors.New("stream not found: " + rtmpConn.App)
}
cons := flv.NewConsumer()
if err = stream.AddConsumer(cons); err != nil {
return err
}
defer stream.RemoveConsumer(cons)
if err = rtmpConn.WriteStart(); err != nil {
return err
}
wc := datacounter.NewWriterCounter(rtmpConn)
s.TrafficCounter.Lock()
s.TrafficCounter.WriterCounter = wc
s.TrafficCounter.Unlock()
_, _ = cons.WriteTo(wc)
return nil
case rtmp.CommandPublish:
stream := s.StreamHandler.Get(rtmpConn.App)
if stream == nil {
return errors.New("stream not found: " + rtmpConn.App)
}
if err = rtmpConn.WriteStart(); err != nil {
return err
}
prod, err := rtmpConn.Producer()
if err != nil {
return err
}
stream.AddProducer(prod)
defer stream.RemoveProducer(prod)
rc := server.NewIntPtrCounter(&prod.Recv)
s.TrafficCounter.Lock()
s.TrafficCounter.ReaderCounter = rc
s.TrafficCounter.Unlock()
err = prod.Start()
if err != nil {
logger.Default().Error(err)
}
return nil
}
return errors.New("rtmp: unknown command: " + rtmpConn.Intent)
}
func (s *RTMPServer) NumBytesConsumerWrote() uint64 {
return s.TrafficCounter.NumBytesWrote()
}
func (s *RTMPServer) NumBytesProducerRead() uint64 {
return s.TrafficCounter.NumBytesRead()
}
func StreamsHandle(url string) (core.Producer, error) {
return rtmp.DialPlay(url)
}
func StreamsConsumerHandle(url string) (core.Consumer, server.NumBytesReaderWroter, func(context.Context) error, error) {
cons := flv.NewConsumer()
trafficCounter := &server.TrafficCounter{}
run := func(ctx context.Context) error {
wr, err := rtmp.DialPublish(url)
if err != nil {
return fmt.Errorf("unable to connect to '%s': %w", url, err)
}
wrc := datacounter.NewWriterCounter(wr)
trafficCounter.Lock()
trafficCounter.WriterCounter = wrc
trafficCounter.Unlock()
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
go func() {
<-ctx.Done()
cancelFn()
err := wr.(io.Closer).Close()
errmon.ObserveErrorCtx(ctx, err)
}()
_, err = cons.WriteTo(wrc)
if err != nil {
return fmt.Errorf("unable to write: %w", err)
}
return nil
}
return cons, trafficCounter, run, nil
}
func (s *RTMPServer) apiHandle(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
s.outputFLV(w, r)
} else {
s.inputFLV(w, r)
}
}
func (s *RTMPServer) outputFLV(w http.ResponseWriter, r *http.Request) {
src := r.URL.Query().Get("src")
stream := s.StreamHandler.Get(src)
if stream == nil {
http.Error(w, consts.StreamNotFound, http.StatusNotFound)
return
}
cons := flv.NewConsumer()
cons.WithRequest(r)
if err := stream.AddConsumer(cons); err != nil {
log.Error().Err(err).Caller().Send()
return
}
h := w.Header()
h.Set("Content-Type", "video/x-flv")
_, _ = cons.WriteTo(w)
stream.RemoveConsumer(cons)
}
func (s *RTMPServer) inputFLV(w http.ResponseWriter, r *http.Request) {
dst := r.URL.Query().Get("dst")
stream := s.StreamHandler.Get(dst)
if stream == nil {
http.Error(w, consts.StreamNotFound, http.StatusNotFound)
return
}
client, err := flv.Open(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
stream.AddProducer(client)
if err = client.Start(); err != nil && err != io.EOF {
log.Warn().Err(err).Caller().Send()
}
stream.RemoveProducer(client)
}