Files
streamctl/pkg/streamserver/streamforward/stream_forwards.go
2025-02-15 22:41:29 +00:00

884 lines
22 KiB
Go

package streamforward
import (
"context"
"fmt"
"net/url"
"sort"
"time"
"github.com/facebookincubator/go-belt"
"github.com/facebookincubator/go-belt/tool/logger"
"github.com/xaionaro-go/lockmap"
"github.com/xaionaro-go/observability"
"github.com/xaionaro-go/recoder"
"github.com/xaionaro-go/streamctl/pkg/secret"
"github.com/xaionaro-go/streamctl/pkg/streamcontrol/youtube"
"github.com/xaionaro-go/streamctl/pkg/streamd/memoize"
"github.com/xaionaro-go/streamctl/pkg/streamserver/types"
"github.com/xaionaro-go/streamctl/pkg/streamtypes"
"github.com/xaionaro-go/typing/ordered"
"github.com/xaionaro-go/xsync"
)
type ForwardingKey struct {
StreamID types.StreamID
DestinationID types.DestinationID
}
type StreamServer interface {
types.WithConfiger
types.WaitPublisherChaner
types.PubsubNameser
types.GetPortServerser
}
type StreamForwards struct {
StreamServer
types.PlatformsController
Mutex xsync.Mutex
DestinationStreamingLocker *lockmap.LockMap
ActiveStreamForwardings map[ForwardingKey]*ActiveStreamForwarding
StreamDestinations []types.StreamDestination
RecoderFactory recoder.Factory
}
func NewStreamForwards(
s StreamServer,
recoderFactory recoder.Factory,
pc types.PlatformsController,
) *StreamForwards {
return &StreamForwards{
StreamServer: s,
RecoderFactory: recoderFactory,
PlatformsController: pc,
DestinationStreamingLocker: lockmap.NewLockMap(),
ActiveStreamForwardings: map[ForwardingKey]*ActiveStreamForwarding{},
}
}
func (s *StreamForwards) Init(
ctx context.Context,
opts ...types.InitOption,
) error {
return xsync.DoR1(ctx, &s.Mutex, func() error {
return s.init(ctx, opts...)
})
}
func (s *StreamForwards) init(
ctx context.Context,
_ ...types.InitOption,
) (_ret error) {
s.WithConfig(ctx, func(ctx context.Context, cfg *types.Config) {
for dstID, dstCfg := range cfg.Destinations {
err := s.addActiveStreamDestination(ctx, dstID, dstCfg.URL, dstCfg.StreamKey)
if err != nil {
_ret = fmt.Errorf(
"unable to initialize stream destination '%s' to %#+v: %w",
dstID,
dstCfg,
err,
)
return
}
}
for streamID, streamCfg := range cfg.Streams {
for dstID, fwd := range streamCfg.Forwardings {
if fwd.Disabled {
continue
}
_, err := s.newActiveStreamForward(ctx, streamID, dstID, fwd.Quirks)
if err != nil {
_ret = fmt.Errorf(
"unable to launch stream forward from '%s' to '%s': %w",
streamID,
dstID,
err,
)
return
}
}
}
})
return
}
func (s *StreamForwards) AddStreamForward(
ctx context.Context,
streamID types.StreamID,
destinationID types.DestinationID,
enabled bool,
quirks types.ForwardingQuirks,
) (*StreamForward, error) {
return xsync.DoR2(ctx, &s.Mutex, func() (*StreamForward, error) {
return s.addStreamForward(ctx, streamID, destinationID, enabled, quirks)
})
}
func (s *StreamForwards) addStreamForward(
ctx context.Context,
streamID types.StreamID,
destinationID types.DestinationID,
enabled bool,
quirks types.ForwardingQuirks,
) (*StreamForward, error) {
ctx = belt.WithField(ctx, "module", "StreamServer")
var (
streamConfig *types.StreamConfig
err error
)
s.WithConfig(ctx, func(ctx context.Context, cfg *types.Config) {
streamConfig = cfg.Streams[streamID]
if streamConfig == nil {
err = fmt.Errorf("stream '%s' does not exist", streamID)
return
}
if streamConfig.Forwardings == nil {
streamConfig.Forwardings = make(map[streamtypes.DestinationID]types.ForwardingConfig)
}
if _, ok := streamConfig.Forwardings[destinationID]; ok {
err = fmt.Errorf("the forwarding %s->%s already exists", streamID, destinationID)
return
}
streamConfig.Forwardings[destinationID] = types.ForwardingConfig{
Disabled: !enabled,
Quirks: quirks,
}
})
if err != nil {
return nil, err
}
if enabled {
fwd, err := s.newActiveStreamForward(ctx, streamID, destinationID, quirks)
if err != nil {
return nil, err
}
return fwd, nil
}
return &StreamForward{
StreamID: streamID,
DestinationID: destinationID,
Enabled: enabled,
Quirks: quirks,
}, nil
}
func (s *StreamForwards) getLocalhostRTMP(ctx context.Context) (*url.URL, error) {
portSrvs, err := s.StreamServer.GetPortServers(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get port servers info: %w", err)
}
portSrv := portSrvs[0]
urlString := fmt.Sprintf("%s://%s", portSrv.Type, portSrv.ListenAddr)
urlParsed, err := url.Parse(urlString)
if err != nil {
return nil, fmt.Errorf("unable to parse '%s': %w", urlString, err)
}
return urlParsed, nil
}
func (s *StreamForwards) newActiveStreamForward(
ctx context.Context,
streamID types.StreamID,
destinationID types.DestinationID,
quirks types.ForwardingQuirks,
opts ...Option,
) (*StreamForward, error) {
ctx = belt.WithField(ctx, "stream_forward", fmt.Sprintf("%s->%s", streamID, destinationID))
key := ForwardingKey{
StreamID: streamID,
DestinationID: destinationID,
}
if _, ok := s.ActiveStreamForwardings[key]; ok {
return nil, fmt.Errorf(
"there is already an active stream forwarding to '%s'",
destinationID,
)
}
dst, err := s.findStreamDestinationByID(ctx, destinationID)
if err != nil {
return nil, fmt.Errorf("unable to find stream destination '%s': %w", destinationID, err)
}
urlParsed, err := url.Parse(dst.URL)
if err != nil {
return nil, fmt.Errorf("unable to parse URL '%s': %w", dst.URL, err)
}
if urlParsed.Host == "" {
urlParsed, err = s.getLocalhostRTMP(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get the URL of the output endpoint: %w", err)
}
}
result := &StreamForward{
StreamID: streamID,
DestinationID: destinationID,
Enabled: true,
Quirks: quirks,
NumBytesWrote: 0,
NumBytesRead: 0,
}
fwd, err := s.NewActiveStreamForward(
ctx,
streamID,
urlParsed.String(),
dst.StreamKey.Get(),
func(
ctx context.Context,
fwd *ActiveStreamForwarding,
) {
if quirks.StartAfterYoutubeRecognizedStream.Enabled {
if quirks.RestartUntilYoutubeRecognizesStream.Enabled {
logger.Errorf(
ctx,
"StartAfterYoutubeRecognizedStream should not be used together with RestartUntilYoutubeRecognizesStream",
)
} else {
logger.Debugf(ctx, "fwd %s->%s is waiting for YouTube to recognize the stream", streamID, destinationID)
started, err := s.PlatformsController.CheckStreamStartedByPlatformID(
memoize.SetNoCache(ctx, true),
youtube.ID,
)
logger.Debugf(ctx, "youtube status check: %v %v", started, err)
if started {
return
}
t := time.NewTicker(time.Second)
for {
select {
case <-ctx.Done():
return
case <-t.C:
}
started, err := s.PlatformsController.CheckStreamStartedByPlatformID(
ctx,
youtube.ID,
)
logger.Debugf(ctx, "youtube status check: %v %v", started, err)
if started {
return
}
}
}
}
},
opts...,
)
if err != nil {
return nil, fmt.Errorf("unable to run the stream forwarding: %w", err)
}
s.ActiveStreamForwardings[key] = fwd
result.ActiveForwarding = fwd
if quirks.RestartUntilYoutubeRecognizesStream.Enabled {
observability.Go(ctx, func() {
s.restartUntilYoutubeRecognizesStream(
ctx,
result,
quirks.RestartUntilYoutubeRecognizesStream,
)
})
}
return result, nil
}
func (s *StreamForwards) restartUntilYoutubeRecognizesStream(
ctx context.Context,
fwd *StreamForward,
cfg types.RestartUntilYoutubeRecognizesStream,
) {
ctx = belt.WithField(ctx, "module", "restartUntilYoutubeRecognizesStream")
ctx = belt.WithField(
ctx,
"stream_forward",
fmt.Sprintf("%s->%s", fwd.StreamID, fwd.DestinationID),
)
logger.Debugf(ctx, "restartUntilYoutubeRecognizesStream(ctx, %#+v, %#+v)", fwd, cfg)
defer func() { logger.Debugf(ctx, "restartUntilYoutubeRecognizesStream(ctx, %#+v, %#+v)", fwd, cfg) }()
if !cfg.Enabled {
logger.Errorf(
ctx,
"an attempt to start restartUntilYoutubeRecognizesStream when the hack is disabled for this stream forwarder: %#+v",
cfg,
)
return
}
if s.PlatformsController == nil {
logger.Errorf(ctx, "PlatformsController is nil")
return
}
if fwd.ActiveForwarding == nil {
logger.Error(ctx, "ActiveForwarding is nil")
return
}
_, err := fwd.ActiveForwarding.WaitForPublisher(ctx)
if err != nil {
logger.Error(ctx, err)
return
}
for {
select {
case <-ctx.Done():
return
case <-time.After(cfg.StartTimeout):
}
logger.Debugf(
ctx,
"waited %v, checking if the remote platform accepted the stream",
cfg.StartTimeout,
)
for {
streamOK, err := s.PlatformsController.CheckStreamStartedByPlatformID(
memoize.SetNoCache(ctx, true),
youtube.ID,
)
logger.Debugf(
ctx,
"the result of checking the stream on the remote platform: %v %v",
streamOK,
err,
)
if err != nil {
logger.Errorf(
ctx,
"unable to check if the stream with URL '%s' is started: %v",
fwd.ActiveForwarding.DestinationURL,
err,
)
time.Sleep(time.Second)
continue
}
if streamOK {
logger.Debugf(
ctx,
"waiting %v to recheck if the stream will be still OK",
cfg.StopStartDelay,
)
select {
case <-ctx.Done():
return
case <-time.After(cfg.StopStartDelay):
}
streamOK, err := s.PlatformsController.CheckStreamStartedByPlatformID(
memoize.SetNoCache(ctx, true),
youtube.ID,
)
logger.Debugf(
ctx,
"the result of checking the stream on the remote platform: %v %v",
streamOK,
err,
)
if err != nil {
logger.Errorf(
ctx,
"unable to check if the stream with URL '%s' is started: %v",
fwd.ActiveForwarding.DestinationURL,
err,
)
time.Sleep(time.Second)
continue
}
if streamOK {
return
}
}
break
}
logger.Infof(
ctx,
"the remote platform still does not see the stream, restarting the stream forwarding: stopping...",
)
err := fwd.ActiveForwarding.Stop()
if err != nil {
logger.Errorf(ctx, "unable to stop stream forwarding: %v", err)
}
select {
case <-ctx.Done():
return
case <-time.After(cfg.StopStartDelay):
}
logger.Infof(
ctx,
"the remote platform still does not see the stream, restarting the stream forwarding: starting...",
)
err = fwd.ActiveForwarding.Start(ctx)
if err != nil {
logger.Errorf(ctx, "unable to start stream forwarding: %v", err)
}
}
}
func (s *StreamForwards) UpdateStreamForward(
ctx context.Context,
streamID types.StreamID,
destinationID types.DestinationID,
enabled bool,
quirks types.ForwardingQuirks,
) (*StreamForward, error) {
return xsync.DoR2(ctx, &s.Mutex, func() (*StreamForward, error) {
return s.updateStreamForward(ctx, streamID, destinationID, enabled, quirks)
})
}
func (s *StreamForwards) updateStreamForward(
ctx context.Context,
streamID types.StreamID,
destinationID types.DestinationID,
enabled bool,
quirks types.ForwardingQuirks,
) (_ret *StreamForward, _err error) {
s.WithConfig(ctx, func(ctx context.Context, cfg *types.Config) {
streamConfig := cfg.Streams[streamID]
fwdCfg, ok := streamConfig.Forwardings[destinationID]
if !ok {
_err = fmt.Errorf("the forwarding %s->%s does not exist", streamID, destinationID)
return
}
var fwd *StreamForward
if fwdCfg.Disabled && enabled {
var err error
fwd, err = s.newActiveStreamForward(ctx, streamID, destinationID, quirks)
if err != nil {
_err = fmt.Errorf("unable to active the stream: %w", err)
return
}
}
if !fwdCfg.Disabled && !enabled {
err := s.removeActiveStreamForward(ctx, streamID, destinationID)
if err != nil {
_err = fmt.Errorf("unable to deactivate the stream: %w", err)
return
}
}
streamConfig.Forwardings[destinationID] = types.ForwardingConfig{
Disabled: !enabled,
Quirks: quirks,
}
r := &StreamForward{
StreamID: streamID,
DestinationID: destinationID,
Enabled: enabled,
Quirks: quirks,
NumBytesWrote: 0,
NumBytesRead: 0,
}
if fwd != nil {
r.ActiveForwarding = fwd.ActiveForwarding
}
_ret = r
})
return
}
func (s *StreamForwards) ListStreamForwards(
ctx context.Context,
) (_ret []StreamForward, _err error) {
defer func() {
logger.Tracef(ctx, "/ListStreamForwards(): %#+v %v", _ret, _err)
}()
return xsync.DoR2(ctx, &s.Mutex, func() ([]StreamForward, error) {
return s.getStreamForwards(
ctx,
func(si types.StreamID, di ordered.Optional[types.DestinationID]) bool {
return true
},
)
})
}
func (s *StreamForwards) getStreamForwards(
ctx context.Context,
filterFunc func(types.StreamID, ordered.Optional[types.DestinationID]) bool,
) (_ret []StreamForward, _err error) {
activeStreamForwards, err := s.listActiveStreamForwards(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get the list of active stream forwardings: %w", err)
}
logger.Tracef(ctx, "len(activeStreamForwards) == %d", len(activeStreamForwards))
type fwdID struct {
StreamID types.StreamID
DestID types.DestinationID
}
m := map[fwdID]*StreamForward{}
for idx := range activeStreamForwards {
fwd := &activeStreamForwards[idx]
if !filterFunc(fwd.StreamID, ordered.Opt(fwd.DestinationID)) {
continue
}
m[fwdID{
StreamID: fwd.StreamID,
DestID: fwd.DestinationID,
}] = fwd
}
var result []StreamForward
s.WithConfig(ctx, func(ctx context.Context, cfg *types.Config) {
logger.Tracef(ctx, "len(s.Config.Streams) == %d", len(cfg.Streams))
for streamID, stream := range cfg.Streams {
if !filterFunc(streamID, ordered.Optional[types.DestinationID]{}) {
continue
}
logger.Tracef(
ctx,
"len(s.Config.Streams[%s].Forwardings) == %d",
streamID,
len(stream.Forwardings),
)
for dstID, cfg := range stream.Forwardings {
if !filterFunc(streamID, ordered.Opt(dstID)) {
continue
}
item := StreamForward{
StreamID: streamID,
DestinationID: dstID,
Enabled: !cfg.Disabled,
Quirks: cfg.Quirks,
}
if activeFwd, ok := m[fwdID{
StreamID: streamID,
DestID: dstID,
}]; ok {
item.NumBytesWrote = activeFwd.NumBytesWrote
item.NumBytesRead = activeFwd.NumBytesRead
}
logger.Tracef(ctx, "stream forwarding '%s->%s': %#+v", streamID, dstID, cfg)
result = append(result, item)
}
}
})
return result, nil
}
func (s *StreamForwards) listActiveStreamForwards(
_ context.Context,
) ([]StreamForward, error) {
var result []StreamForward
for key, fwd := range s.ActiveStreamForwardings {
result = append(result, StreamForward{
StreamID: key.StreamID,
DestinationID: key.DestinationID,
Enabled: true,
NumBytesWrote: fwd.WriteCount.Load(),
NumBytesRead: fwd.ReadCount.Load(),
})
}
return result, nil
}
func (s *StreamForwards) RemoveStreamForward(
ctx context.Context,
streamID types.StreamID,
dstID types.DestinationID,
) error {
ctx = belt.WithField(ctx, "module", "StreamServer")
return xsync.DoA3R1(ctx, &s.Mutex, s.removeStreamForward, ctx, streamID, dstID)
}
func (s *StreamForwards) removeStreamForward(
ctx context.Context,
streamID types.StreamID,
dstID types.DestinationID,
) (err error) {
s.WithConfig(ctx, func(ctx context.Context, cfg *types.Config) {
streamCfg := cfg.Streams[streamID]
if _, ok := streamCfg.Forwardings[dstID]; !ok {
err = fmt.Errorf("the forwarding %s->%s does not exist", streamID, dstID)
return
}
delete(streamCfg.Forwardings, dstID)
err = s.removeActiveStreamForward(ctx, streamID, dstID)
})
return
}
func (s *StreamForwards) removeActiveStreamForward(
_ context.Context,
streamID types.StreamID,
dstID types.DestinationID,
) error {
key := ForwardingKey{
StreamID: streamID,
DestinationID: dstID,
}
fwd := s.ActiveStreamForwardings[key]
if fwd == nil {
return nil
}
delete(s.ActiveStreamForwardings, key)
err := fwd.Close()
if err != nil {
return fmt.Errorf("unable to close stream forwarding: %w", err)
}
return nil
}
func (s *StreamForwards) GetStreamForwardsByDestination(
ctx context.Context,
destID types.DestinationID,
) (_ret []StreamForward, _err error) {
ctx = belt.WithField(ctx, "module", "StreamServer")
logger.Debugf(ctx, "GetStreamForwardsByDestination()")
defer func() {
logger.Debugf(ctx, "/GetStreamForwardsByDestination(): %#+v %v", _ret, _err)
}()
return xsync.DoR2(ctx, &s.Mutex, func() ([]StreamForward, error) {
return s.getStreamForwards(
ctx,
func(streamID types.StreamID, dstID ordered.Optional[types.DestinationID]) bool {
return !dstID.IsSet() || dstID.Get() == destID
},
)
})
}
func (s *StreamForwards) ListStreamDestinations(
ctx context.Context,
) ([]types.StreamDestination, error) {
ctx = belt.WithField(ctx, "module", "StreamServer")
return xsync.DoA1R2(ctx, &s.Mutex, s.listStreamDestinations, ctx)
}
func (s *StreamForwards) listStreamDestinations(
_ context.Context,
) ([]types.StreamDestination, error) {
c := make([]types.StreamDestination, len(s.StreamDestinations))
copy(c, s.StreamDestinations)
return c, nil
}
func (s *StreamForwards) AddStreamDestination(
ctx context.Context,
destinationID types.DestinationID,
url string,
streamKey string,
) error {
ctx = belt.WithField(ctx, "module", "StreamServer")
return xsync.DoA4R1(ctx, &s.Mutex, s.addStreamDestination, ctx, destinationID, url, streamKey)
}
func (s *StreamForwards) addStreamDestination(
ctx context.Context,
destinationID types.DestinationID,
url string,
streamKey string,
) (_ret error) {
s.WithConfig(ctx, func(ctx context.Context, cfg *types.Config) {
err := s.addActiveStreamDestination(ctx, destinationID, url, streamKey)
if err != nil {
_ret = fmt.Errorf("unable to add an active stream destination: %w", err)
return
}
cfg.Destinations[destinationID] = &types.DestinationConfig{
URL: url,
StreamKey: streamKey,
}
})
return
}
func (s *StreamForwards) UpdateStreamDestination(
ctx context.Context,
destinationID types.DestinationID,
url string,
streamKey string,
) error {
ctx = belt.WithField(ctx, "module", "StreamServer")
return xsync.DoA4R1(
ctx,
&s.Mutex,
s.updateStreamDestination,
ctx,
destinationID,
url,
streamKey,
)
}
func (s *StreamForwards) updateStreamDestination(
ctx context.Context,
destinationID types.DestinationID,
url string,
streamKey string,
) (_ret error) {
s.WithConfig(ctx, func(ctx context.Context, cfg *types.Config) {
for key := range s.ActiveStreamForwardings {
if key.DestinationID == destinationID {
_ret = fmt.Errorf(
"there is already an active stream forwarding to '%s'",
destinationID,
)
return
}
}
err := s.removeActiveStreamDestination(ctx, destinationID)
if err != nil {
_ret = fmt.Errorf(
"unable to remove (to then re-add) the active stream destination: %w",
err,
)
return
}
err = s.addActiveStreamDestination(ctx, destinationID, url, streamKey)
if err != nil {
_ret = fmt.Errorf("unable to re-add the active stream destination: %w", err)
return
}
cfg.Destinations[destinationID] = &types.DestinationConfig{
URL: url,
StreamKey: streamKey,
}
})
return
}
// TODO: delete this function, we already store the exact same information in the config
func (s *StreamForwards) addActiveStreamDestination(
_ context.Context,
destinationID types.DestinationID,
url string,
streamKey string,
) error {
s.StreamDestinations = append(s.StreamDestinations, types.StreamDestination{
ID: destinationID,
URL: url,
StreamKey: secret.New(streamKey),
})
return nil
}
func (s *StreamForwards) RemoveStreamDestination(
ctx context.Context,
destinationID types.DestinationID,
) error {
ctx = belt.WithField(ctx, "module", "StreamServer")
return xsync.DoA2R1(ctx, &s.Mutex, s.removeStreamDestination, ctx, destinationID)
}
func (s *StreamForwards) removeStreamDestination(
ctx context.Context,
destinationID types.DestinationID,
) (err error) {
s.WithConfig(ctx, func(ctx context.Context, cfg *types.Config) {
for _, streamCfg := range cfg.Streams {
delete(streamCfg.Forwardings, destinationID)
}
delete(cfg.Destinations, destinationID)
err = s.removeActiveStreamDestination(ctx, destinationID)
})
return
}
func (s *StreamForwards) removeActiveStreamDestination(
ctx context.Context,
destinationID types.DestinationID,
) error {
streamForwards, err := s.getStreamForwards(
ctx,
func(si types.StreamID, di ordered.Optional[types.DestinationID]) bool {
return true
},
)
if err != nil {
return fmt.Errorf("unable to list stream forwardings: %w", err)
}
for _, fwd := range streamForwards {
if fwd.DestinationID == destinationID {
s.removeStreamForward(ctx, fwd.StreamID, fwd.DestinationID)
}
}
for i := range s.StreamDestinations {
if s.StreamDestinations[i].ID == destinationID {
s.StreamDestinations = append(s.StreamDestinations[:i], s.StreamDestinations[i+1:]...)
return nil
}
}
return fmt.Errorf("have not found stream destination with id %s", destinationID)
}
func (s *StreamForwards) findStreamDestinationByID(
_ context.Context,
destinationID types.DestinationID,
) (types.StreamDestination, error) {
for _, dst := range s.StreamDestinations {
if dst.ID == destinationID {
return dst, nil
}
}
return types.StreamDestination{}, fmt.Errorf(
"unable to find a stream destination by StreamID '%s'",
destinationID,
)
}
func (s *StreamForwards) GetLocalhostEndpoint(ctx context.Context) (_ret *url.URL, _err error) {
defer func() { logger.Debugf(ctx, "GetLocalhostEndpoint result: %v %v", _ret, _err) }()
portSrvs, err := s.StreamServer.GetPortServers(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get port servers info: %w", err)
}
sort.Slice(portSrvs, func(i, j int) bool {
a := &portSrvs[i]
b := &portSrvs[j]
if a.IsTLS != b.IsTLS {
return b.IsTLS
}
return false
})
portSrv := portSrvs[0]
logger.Debugf(ctx, "getLocalhostEndpoint: chosen portSrv == %#+v", portSrv)
protoString := portSrv.Type.String()
if portSrv.IsTLS {
protoString += "s"
}
urlString := fmt.Sprintf("%s://%s", protoString, portSrv.ListenAddr)
urlParsed, err := url.Parse(urlString)
if err != nil {
return nil, fmt.Errorf("unable to parse '%s': %w", urlString, err)
}
return urlParsed, nil
}