Initial commit, pt. 46

This commit is contained in:
Dmitrii Okunev
2024-07-20 20:52:38 +01:00
parent a37fa68e10
commit 1332d9a855
37 changed files with 3011 additions and 616 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/facebookincubator/go-belt/tool/logger" "github.com/facebookincubator/go-belt/tool/logger"
"github.com/facebookincubator/go-belt/tool/logger/implementation/logrus" "github.com/facebookincubator/go-belt/tool/logger/implementation/logrus"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/kraken-hpc/go-fork"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"github.com/xaionaro-go/streamctl/pkg/observability" "github.com/xaionaro-go/streamctl/pkg/observability"
"github.com/xaionaro-go/streamctl/pkg/streamcontrol" "github.com/xaionaro-go/streamctl/pkg/streamcontrol"
@@ -26,6 +27,15 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
) )
func init() {
fork.RegisterFunc("streamd", streamd)
fork.Init()
}
func streamd(remoteAddr string) {
}
const forceNetPProfOnAndroid = true const forceNetPProfOnAndroid = true
func main() { func main() {
@@ -39,8 +49,13 @@ func main() {
heapProfile := pflag.String("go-profile-heap", "", "file to write memory profile to") heapProfile := pflag.String("go-profile-heap", "", "file to write memory profile to")
sentryDSN := pflag.String("sentry-dsn", "", "DSN of a Sentry instance to send error reports") sentryDSN := pflag.String("sentry-dsn", "", "DSN of a Sentry instance to send error reports")
page := pflag.String("page", string(consts.PageControl), "DSN of a Sentry instance to send error reports") page := pflag.String("page", string(consts.PageControl), "DSN of a Sentry instance to send error reports")
splitProcess := pflag.Bool("split-process", !isMobile(), "split the process into multiple processes for better stability")
pflag.Parse() pflag.Parse()
l := logrus.Default().WithLevel(loggerLevel) l := logrus.Default().WithLevel(loggerLevel)
logger.Default = func() logger.Logger {
return l
}
if *cpuProfile != "" { if *cpuProfile != "" {
f, err := os.Create(*cpuProfile) f, err := os.Create(*cpuProfile)
@@ -81,17 +96,29 @@ func main() {
runtime.GOMAXPROCS(16) runtime.GOMAXPROCS(16)
} }
if *splitProcess && *listenAddr == "" {
listenAddr = ptr("localhost:0")
}
listener, err := net.Listen("tcp", *listenAddr) listener, err := net.Listen("tcp", *listenAddr)
if err != nil { if err != nil {
l.Fatalf("failed to listen: %v", err) l.Fatalf("failed to listen: %v", err)
} }
if *splitProcess {
fork.Fork("streamd", listener.Addr().String())
}
ctx := context.Background() ctx := context.Background()
go func() { go func() {
<-ctx.Done() <-ctx.Done()
listener.Close() listener.Close()
}() }()
if *splitProcess && *remoteAddr == "" {
remoteAddr = ptr(listener.Addr().String())
}
var opts []streampanel.Option var opts []streampanel.Option
if *remoteAddr != "" { if *remoteAddr != "" {
opts = append(opts, streampanel.OptionRemoteStreamDAddr(*remoteAddr)) opts = append(opts, streampanel.OptionRemoteStreamDAddr(*remoteAddr))
@@ -101,7 +128,7 @@ func main() {
} }
panel, panelErr := streampanel.New(*configPath, opts...) panel, panelErr := streampanel.New(*configPath, opts...)
if panel.Config.SentryDSN != "" { if panel != nil && panel.Config.SentryDSN != "" {
l.Infof("setting up Sentry at DSN '%s'", panel.Config.SentryDSN) l.Infof("setting up Sentry at DSN '%s'", panel.Config.SentryDSN)
sentryClient, err := sentry.NewClient(sentry.ClientOptions{ sentryClient, err := sentry.NewClient(sentry.ClientOptions{
Dsn: panel.Config.SentryDSN, Dsn: panel.Config.SentryDSN,

3
go.mod
View File

@@ -210,6 +210,7 @@ require (
github.com/andreykaipov/goobs v1.4.1 github.com/andreykaipov/goobs v1.4.1
github.com/anthonynsimon/bild v0.14.0 github.com/anthonynsimon/bild v0.14.0
github.com/chai2010/webp v1.1.1 github.com/chai2010/webp v1.1.1
github.com/dustin/go-humanize v1.0.1
github.com/getsentry/sentry-go v0.28.1 github.com/getsentry/sentry-go v0.28.1
github.com/go-git/go-git/v5 v5.12.0 github.com/go-git/go-git/v5 v5.12.0
github.com/go-ng/xmath v0.0.0-20230704233441-028f5ea62335 github.com/go-ng/xmath v0.0.0-20230704233441-028f5ea62335
@@ -218,12 +219,14 @@ require (
github.com/hyprspace/hyprspace v0.10.1 github.com/hyprspace/hyprspace v0.10.1
github.com/immune-gmbh/attestation-sdk v0.0.0-20230711173209-f44e4502aeca github.com/immune-gmbh/attestation-sdk v0.0.0-20230711173209-f44e4502aeca
github.com/kbinani/screenshot v0.0.0-20230812210009-b87d31814237 github.com/kbinani/screenshot v0.0.0-20230812210009-b87d31814237
github.com/kraken-hpc/go-fork v0.1.1
github.com/libp2p/go-libp2p v0.33.2 github.com/libp2p/go-libp2p v0.33.2
github.com/libp2p/go-libp2p-kad-dht v0.25.2 github.com/libp2p/go-libp2p-kad-dht v0.25.2
github.com/multiformats/go-multiaddr v0.12.3 github.com/multiformats/go-multiaddr v0.12.3
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_golang v1.18.0
github.com/rs/zerolog v1.33.0 github.com/rs/zerolog v1.33.0
github.com/sethvargo/go-password v0.3.1
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/xaionaro-go/datacounter v1.0.4 github.com/xaionaro-go/datacounter v1.0.4
github.com/xaionaro-go/unsafetools v0.0.0-20210722164218-75ba48cf7b3c github.com/xaionaro-go/unsafetools v0.0.0-20210722164218-75ba48cf7b3c

6
go.sum
View File

@@ -139,6 +139,8 @@ github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDD
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs= github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
github.com/elastic/gosigar v0.14.2 h1:Dg80n8cr90OZ7x+bAax/QjoW/XqTI11RmA79ZwIm9/4= github.com/elastic/gosigar v0.14.2 h1:Dg80n8cr90OZ7x+bAax/QjoW/XqTI11RmA79ZwIm9/4=
github.com/elastic/gosigar v0.14.2/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs= github.com/elastic/gosigar v0.14.2/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
@@ -452,6 +454,8 @@ github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kraken-hpc/go-fork v0.1.1 h1:O3X/ynoNy/eS7UIcZYef8ndFq2RXEIOue9kZqyzF0Sk=
github.com/kraken-hpc/go-fork v0.1.1/go.mod h1:uu0e5h+V4ONH5Qk/xuVlyNXJXy/swhqGIEMK7w+9dNc=
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6cdF0Y8= github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6cdF0Y8=
@@ -650,6 +654,8 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8=
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/sethvargo/go-password v0.3.1 h1:WqrLTjo7X6AcVYfC6R7GtSyuUQR9hGyAj/f1PYQZCJU=
github.com/sethvargo/go-password v0.3.1/go.mod h1:rXofC1zT54N7R8K/h1WDUdkf9BOx5OptoxrMBcrXzvs=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM=
github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0=

119
pkg/mainprocess/client.go Normal file
View File

@@ -0,0 +1,119 @@
package mainprocess
import (
"context"
"encoding/gob"
"fmt"
"net"
"github.com/facebookincubator/go-belt/tool/logger"
)
type Client struct {
Conn net.Conn
Password string
OnReceivedMessage OnReceivedMessageFunc
}
func NewClient(
myName string,
addr string,
password string,
onReceivedMessage OnReceivedMessageFunc,
) (*Client, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, fmt.Errorf("unable to connect to '%s': %w", addr, err)
}
logger.Default().Tracef("connected to '%s' as '%s'", conn.RemoteAddr(), conn.LocalAddr())
msg := RegistrationMessage{
Password: password,
Source: myName,
}
encoder := gob.NewEncoder(conn)
if err := encoder.Encode(msg); err != nil {
return nil, fmt.Errorf("unable to encode&send the registration message %#+v: %w", msg, err)
}
var regResult RegistrationResult
decoder := gob.NewDecoder(conn)
if err := decoder.Decode(&regResult); err != nil {
return nil, fmt.Errorf("unable to decode&receive the registration result: %w", err)
}
if regResult.Error != "" {
return nil, fmt.Errorf("registration error: %s", regResult.Error)
}
logger.Default().Tracef("successfully registered the process '%s'", myName)
return &Client{
Conn: conn,
Password: password,
OnReceivedMessage: onReceivedMessage,
}, nil
}
func (c *Client) SendMessage(
ctx context.Context,
dst string,
content any,
) error {
encoder := gob.NewEncoder(c.Conn)
msg := MessageToMain{
Password: c.Password,
Destination: dst,
Content: content,
}
err := encoder.Encode(msg)
logger.Tracef(ctx, "sending message %#+v: %v", msg, err)
if err != nil {
return fmt.Errorf("unable to encode&send message %#+v: %w", msg, err)
}
return nil
}
func (c *Client) Close() error {
return c.Conn.Close()
}
func (c *Client) Serve(ctx context.Context) error {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
go func() {
<-ctx.Done()
err := c.Close()
if err != nil {
logger.Error(ctx, err)
}
}()
for {
select {
case <-ctx.Done():
return nil
default:
}
var msg MessageFromMain
decoder := gob.NewDecoder(c.Conn)
err := decoder.Decode(&msg)
if err != nil {
return fmt.Errorf("unable to receive&decode message: %w", err)
}
if err := c.onReceivedMessage(ctx, msg); err != nil {
logger.Error(ctx, err)
}
}
}
func (c *Client) onReceivedMessage(
ctx context.Context,
msg MessageFromMain,
) error {
if c.OnReceivedMessage == nil {
return fmt.Errorf("OnReceivedMessage function is not set")
}
return c.OnReceivedMessage(ctx, msg.Source, msg.Content)
}

View File

@@ -0,0 +1,382 @@
package mainprocess
import (
"context"
"encoding/gob"
"fmt"
"net"
"sync"
"github.com/facebookincubator/go-belt"
"github.com/facebookincubator/go-belt/tool/logger"
"github.com/hashicorp/go-multierror"
"github.com/sethvargo/go-password/password"
)
type OnReceivedMessageFunc func(
ctx context.Context,
source string,
content any,
) error
type Manager struct {
listener net.Listener
password string
connsLocker sync.Mutex
conns map[string]net.Conn
connsChanged chan struct{}
allClientProcesses []string
OnReceivedMessage OnReceivedMessageFunc
}
func NewManager(
onReceivedMessage OnReceivedMessageFunc,
expectedClients ...string,
) (*Manager, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("failed to listen: %w", err)
}
password, err := password.Generate(16, 4, 4, false, true)
if err != nil {
return nil, fmt.Errorf("unable to generate a password: %w", err)
}
return &Manager{
OnReceivedMessage: onReceivedMessage,
listener: listener,
password: password,
conns: map[string]net.Conn{},
connsChanged: make(chan struct{}),
allClientProcesses: expectedClients,
}, nil
}
func (m *Manager) Password() string {
return m.password
}
func (m *Manager) Addr() net.Addr {
return m.listener.Addr()
}
func (m *Manager) Close() error {
return m.listener.Close()
}
func (m *Manager) Serve(ctx context.Context) error {
logger.Tracef(ctx, "serving listener at %s", m.listener.Addr())
defer logger.Tracef(ctx, "/serving listener at %s", m.listener.Addr())
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
go func() {
<-ctx.Done()
err := m.Close()
if err != nil {
logger.Error(ctx, err)
}
}()
for {
select {
case <-ctx.Done():
return nil
default:
}
conn, err := m.listener.Accept()
if err != nil {
return fmt.Errorf("unable to accept connection: %w", err)
}
logger.Tracef(ctx, "accepted a connection from '%s'", conn.RemoteAddr())
m.addNewConnection(ctx, conn)
}
}
func (m *Manager) addNewConnection(
ctx context.Context,
conn net.Conn,
) {
go func() {
m.handleConnection(ctx, conn)
}()
}
func (m *Manager) handleConnection(
ctx context.Context,
conn net.Conn,
) {
var regMessage RegistrationMessage
logger.Tracef(ctx, "handleConnection from %s", conn.RemoteAddr())
defer func() { logger.Tracef(ctx, "/handleConnection from %s (%s)", conn.RemoteAddr(), regMessage.Source) }()
ctx, cancelFn := context.WithCancel(ctx)
go func() {
<-ctx.Done()
conn.Close()
}()
defer cancelFn()
encoder := gob.NewEncoder(conn)
decoder := gob.NewDecoder(conn)
err := decoder.Decode(&regMessage)
if err != nil {
err = fmt.Errorf("unable to decode registration message: %w", err)
encoder.Encode(RegistrationResult{Error: err.Error()})
logger.Debug(ctx, err)
return
}
logger.Debugf(ctx, "received registration message: %#+v", regMessage)
if err := m.checkPassword(regMessage.Password); err != nil {
regMessage = RegistrationMessage{}
err = fmt.Errorf("invalid password: %w", err)
encoder.Encode(RegistrationResult{Error: err.Error()})
logger.Warn(ctx, err)
return
}
if err := m.registerConnection(regMessage.Source, conn); err != nil {
err = fmt.Errorf("unable to register process '%s': %w", regMessage.Source, err)
encoder.Encode(RegistrationResult{Error: err.Error()})
logger.Error(ctx, err)
return
}
defer func(sourceName string) {
m.unregisterConnection(sourceName)
}(regMessage.Source)
if err := encoder.Encode(RegistrationResult{}); err != nil {
err = fmt.Errorf("unable to encode&send the registration result to '%s': %w", regMessage.Source, err)
logger.Error(ctx, err)
return
}
ctx = belt.WithField(ctx, "client", regMessage.Source)
for {
select {
case <-ctx.Done():
logger.Tracef(ctx, "context was closed")
return
default:
}
var message MessageToMain
logger.Tracef(ctx, "waiting for a message from '%s'", regMessage.Source)
decoder := gob.NewDecoder(conn)
err := decoder.Decode(&message)
logger.Tracef(ctx, "getting a message from '%s': %#+v %#+v", regMessage.Source, message, err)
if err != nil {
err = fmt.Errorf(
"unable to parse the message from %s (%s): %w",
regMessage.Source,
conn.RemoteAddr().String(),
err,
)
logger.Error(ctx, err)
return
}
if err := m.processMessage(ctx, regMessage.Source, message); err != nil {
logger.Errorf(
ctx,
"unable to process the message %#+v from %s (%s): %w",
message, regMessage.Source, conn.RemoteAddr().String(), err,
)
}
logger.Tracef(ctx, "next iteration")
}
}
func (m *Manager) processMessage(
ctx context.Context,
source string,
message MessageToMain,
) (_ret error) {
logger.Tracef(ctx, "processing message from '%s': %#+v", source, message)
defer func() { logger.Tracef(ctx, "/processing message from '%s': %#+v: %v", source, message, _ret) }()
switch message.Destination {
case "":
logger.Tracef(ctx, "a broadcast message from '%s': %#+v", source, message.Content)
var wg sync.WaitGroup
var err *multierror.Error
err = multierror.Append(err, m.onReceivedMessage(ctx, source, message.Content))
errCh := make(chan error)
go func() {
for e := range errCh {
err = multierror.Append(err, e)
}
}()
for _, dst := range m.allClientProcesses {
if dst == source {
continue
}
wg.Add(1)
go func(dst string) {
defer wg.Done()
errCh <- m.sendMessage(ctx, source, dst, message.Content)
}(dst)
}
wg.Wait()
close(errCh)
return err.ErrorOrNil()
case "main":
logger.Tracef(ctx, "a message to the main process from '%s': %#+v", source, message.Content)
return m.onReceivedMessage(ctx, source, message.Content)
default:
logger.Tracef(ctx, "a message to '%s' from '%s': %#+v", message.Destination, source, message.Content)
return m.sendMessage(ctx, source, message.Destination, message.Content)
}
}
func (m *Manager) onReceivedMessage(
ctx context.Context,
source string,
content any,
) error {
if m.OnReceivedMessage == nil {
err := fmt.Errorf("OnReceivedMessage is not set")
logger.Tracef(ctx, "%v", err)
return err
}
logger.Tracef(ctx, "calling the OnReceivedMessage function")
return m.OnReceivedMessage(ctx, source, content)
}
type MessageFromMain struct {
Source string
Password string
Destination string
Content any
}
func (m *Manager) isExpectedProcess(
name string,
) bool {
for _, p := range m.allClientProcesses {
if name == p {
return true
}
}
return false
}
func (m *Manager) sendMessage(
ctx context.Context,
source string,
destination string,
content any,
) (_ret error) {
logger.Tracef(ctx, "sending message message %#+v from '%s' to '%s'", content, source, destination)
defer func() {
logger.Tracef(ctx, "/sending message message %#+v from '%s' to '%s': %v", content, source, destination, _ret)
}()
if !m.isExpectedProcess(destination) {
return fmt.Errorf("process '%s' is not ever expected", destination)
}
message := MessageFromMain{
Source: source,
Password: m.password,
Destination: destination,
Content: content,
}
conn, err := m.waitForProcess(destination)
if err != nil {
return fmt.Errorf("unable to wait for process '%s': %w", destination, err)
}
encoder := gob.NewEncoder(conn)
err = encoder.Encode(message)
if err != nil {
return fmt.Errorf("unable to encode&send message: %w", err)
}
return nil
}
func (m *Manager) waitForProcess(
name string,
) (net.Conn, error) {
if !m.isExpectedProcess(name) {
return nil, fmt.Errorf("process '%s' is not ever expected", name)
}
for {
m.connsLocker.Lock()
conn := m.conns[name]
ch := m.connsChanged
m.connsLocker.Unlock()
if conn != nil {
return conn, nil
}
<-ch
}
}
func (m *Manager) checkPassword(
password string,
) error {
return checkPassword(m.password, password)
}
func (m *Manager) registerConnection(
sourceName string,
conn net.Conn,
) error {
if !m.isExpectedProcess(sourceName) {
return fmt.Errorf("process '%s' is not ever expected", sourceName)
}
m.connsLocker.Lock()
defer m.connsLocker.Unlock()
if conn, ok := m.conns[sourceName]; ok {
return fmt.Errorf("process '%s' is already registered at %s", sourceName, conn.RemoteAddr().String())
}
m.conns[sourceName] = conn
var oldCh chan struct{}
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
close(oldCh)
return nil
}
func (m *Manager) unregisterConnection(
sourceName string,
) {
m.connsLocker.Lock()
defer m.connsLocker.Unlock()
delete(m.conns, sourceName)
var oldCh chan struct{}
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
close(oldCh)
}
type RegistrationMessage struct {
Password string
Source string
}
type RegistrationResult struct {
Error string
}
type MessageToMain struct {
Password string
Destination string
Content any
}

View File

@@ -0,0 +1,27 @@
package mainprocess
import (
"crypto/sha1"
"fmt"
)
func checkPassword(
a, b string,
) error {
// naive mostly-timing-attack-resistant comparison algo
h0 := sha1.Sum([]byte(a))
h1 := sha1.Sum([]byte(b))
match := true
for idx := range h0 {
charMatches := h0[idx] == h1[idx]
match = match && charMatches
}
if !match {
return fmt.Errorf("the password does not match")
}
return nil
}

View File

@@ -0,0 +1,104 @@
package mainprocess
import (
"context"
"encoding/gob"
"fmt"
"testing"
"github.com/facebookincubator/go-belt"
"github.com/facebookincubator/go-belt/tool/logger"
"github.com/facebookincubator/go-belt/tool/logger/implementation/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test(t *testing.T) {
l := logrus.Default().WithLevel(logger.LevelTrace)
logger.Default = func() logger.Logger {
return l
}
ctx := logger.CtxWithLogger(context.Background(), l)
defer belt.Flush(ctx)
ctx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
type messageContent struct {
Integer int
String string
}
gob.Register(messageContent{})
handleCallHappened := map[string]chan struct{}{
"main": make(chan struct{}),
"child0": make(chan struct{}),
"child1": make(chan struct{}),
}
callCount := map[string]int{}
handleCall := func(procName string, content any) {
logger.Tracef(ctx, "handleCall('%s', %#+v)", procName, content)
count := callCount[procName]
count++
callCount[procName] = count
msg := content.(messageContent)
assert.Equal(t, count, msg.Integer, procName)
assert.Equal(t, fmt.Sprint(count), msg.String, procName)
var oldCh chan struct{}
oldCh, handleCallHappened[procName] = handleCallHappened[procName], make(chan struct{})
close(oldCh)
}
m, err := NewManager(
func(ctx context.Context, source string, content any) error {
handleCall("main", content)
return nil
},
"child0", "child1",
)
require.NoError(t, err)
defer m.Close()
go m.Serve(belt.WithField(ctx, "process", "main"))
c0, err := NewClient("child0", m.Addr().String(), m.Password(), func(ctx context.Context, source string, content any) error {
handleCall("child0", content)
return nil
})
require.NoError(t, err)
defer c0.Close()
go c0.Serve(belt.WithField(ctx, "process", "child0"))
c1, err := NewClient("child1", m.Addr().String(), m.Password(), func(ctx context.Context, source string, content any) error {
handleCall("child1", content)
return nil
})
require.NoError(t, err)
defer c1.Close()
go c1.Serve(belt.WithField(ctx, "process", "child1"))
_, err = NewClient("child2", m.Addr().String(), m.Password(), func(ctx context.Context, source string, content any) error {
return nil
})
require.Error(t, err)
waitCh0 := handleCallHappened["main"]
waitCh1 := handleCallHappened["child1"]
err = c0.SendMessage(ctx, "", messageContent{Integer: 1, String: "1"})
require.NoError(t, err)
<-waitCh0
<-waitCh1
waitCh0 = handleCallHappened["main"]
err = c1.SendMessage(ctx, "main", messageContent{Integer: 2, String: "2"})
require.NoError(t, err)
<-waitCh0
waitCh0 = handleCallHappened["child0"]
err = c1.SendMessage(ctx, "child0", messageContent{Integer: 1, String: "1"})
require.NoError(t, err)
<-waitCh0
require.Equal(t, 2, callCount["main"])
require.Equal(t, 1, callCount["child0"])
require.Equal(t, 1, callCount["child1"])
}

View File

@@ -79,6 +79,14 @@ type StreamD interface {
ctx context.Context, ctx context.Context,
listenAddr string, listenAddr string,
) error ) error
AddIncomingStream(
ctx context.Context,
streamID StreamID,
) error
RemoveIncomingStream(
ctx context.Context,
streamID StreamID,
) error
ListIncomingStreams( ListIncomingStreams(
ctx context.Context, ctx context.Context,
) ([]IncomingStream, error) ) ([]IncomingStream, error)
@@ -99,8 +107,15 @@ type StreamD interface {
) ([]StreamForward, error) ) ([]StreamForward, error)
AddStreamForward( AddStreamForward(
ctx context.Context, ctx context.Context,
streamIDSrc StreamID, streamID StreamID,
destinationID DestinationID, destinationID DestinationID,
enabled bool,
) error
UpdateStreamForward(
ctx context.Context,
streamID StreamID,
destinationID DestinationID,
enabled bool,
) error ) error
RemoveStreamForward( RemoveStreamForward(
ctx context.Context, ctx context.Context,
@@ -153,6 +168,9 @@ func ParseStreamServerType(in string) StreamServerType {
type StreamServer struct { type StreamServer struct {
Type StreamServerType Type StreamServerType
ListenAddr string ListenAddr string
NumBytesConsumerWrote uint64
NumBytesProducerRead uint64
} }
type StreamDestination struct { type StreamDestination struct {
@@ -161,8 +179,11 @@ type StreamDestination struct {
} }
type StreamForward struct { type StreamForward struct {
Enabled bool
StreamID StreamID StreamID StreamID
DestinationID DestinationID DestinationID DestinationID
NumBytesWrote uint64
NumBytesRead uint64
} }
type IncomingStream struct { type IncomingStream struct {

View File

@@ -645,13 +645,15 @@ func (c *Client) ListStreamServers(
} }
var result []api.StreamServer var result []api.StreamServer
for _, server := range reply.GetStreamServers() { for _, server := range reply.GetStreamServers() {
t, err := grpcconv.StreamServerTypeGRPC2Go(server.GetServerType()) t, err := grpcconv.StreamServerTypeGRPC2Go(server.Config.GetServerType())
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to convert the server type value: %w", err) return nil, fmt.Errorf("unable to convert the server type value: %w", err)
} }
result = append(result, api.StreamServer{ result = append(result, api.StreamServer{
Type: t, Type: t,
ListenAddr: server.GetListenAddr(), ListenAddr: server.Config.GetListenAddr(),
NumBytesConsumerWrote: uint64(server.GetStatistics().GetNumBytesConsumerWrote()),
NumBytesProducerRead: uint64(server.GetStatistics().GetNumBytesProducerRead()),
}) })
} }
return result, nil return result, nil
@@ -703,6 +705,44 @@ func (c *Client) StopStreamServer(
return nil return nil
} }
func (c *Client) AddIncomingStream(
ctx context.Context,
streamID api.StreamID,
) error {
client, conn, err := c.grpcClient()
if err != nil {
return err
}
defer conn.Close()
_, err = client.AddIncomingStream(ctx, &streamd_grpc.AddIncomingStreamRequest{
StreamID: string(streamID),
})
if err != nil {
return fmt.Errorf("unable to request to add the incoming stream: %w", err)
}
return nil
}
func (c *Client) RemoveIncomingStream(
ctx context.Context,
streamID api.StreamID,
) error {
client, conn, err := c.grpcClient()
if err != nil {
return err
}
defer conn.Close()
_, err = client.RemoveIncomingStream(ctx, &streamd_grpc.RemoveIncomingStreamRequest{
StreamID: string(streamID),
})
if err != nil {
return fmt.Errorf("unable to request to remove the incoming stream: %w", err)
}
return nil
}
func (c *Client) ListIncomingStreams( func (c *Client) ListIncomingStreams(
ctx context.Context, ctx context.Context,
) ([]api.IncomingStream, error) { ) ([]api.IncomingStream, error) {
@@ -809,8 +849,11 @@ func (c *Client) ListStreamForwards(
var result []api.StreamForward var result []api.StreamForward
for _, forward := range reply.GetStreamForwards() { for _, forward := range reply.GetStreamForwards() {
result = append(result, api.StreamForward{ result = append(result, api.StreamForward{
StreamID: api.StreamID(forward.GetStreamID()), Enabled: forward.Config.Enabled,
DestinationID: api.DestinationID(forward.GetDestinationID()), StreamID: api.StreamID(forward.Config.GetStreamID()),
DestinationID: api.DestinationID(forward.Config.GetDestinationID()),
NumBytesWrote: uint64(forward.Statistics.NumBytesWrote),
NumBytesRead: uint64(forward.Statistics.NumBytesRead),
}) })
} }
return result, nil return result, nil
@@ -820,6 +863,7 @@ func (c *Client) AddStreamForward(
ctx context.Context, ctx context.Context,
streamID api.StreamID, streamID api.StreamID,
destinationID api.DestinationID, destinationID api.DestinationID,
enabled bool,
) error { ) error {
client, conn, err := c.grpcClient() client, conn, err := c.grpcClient()
if err != nil { if err != nil {
@@ -831,6 +875,32 @@ func (c *Client) AddStreamForward(
Config: &streamd_grpc.StreamForward{ Config: &streamd_grpc.StreamForward{
StreamID: string(streamID), StreamID: string(streamID),
DestinationID: string(destinationID), DestinationID: string(destinationID),
Enabled: enabled,
},
})
if err != nil {
return fmt.Errorf("unable to request to add the stream forward: %w", err)
}
return nil
}
func (c *Client) UpdateStreamForward(
ctx context.Context,
streamID api.StreamID,
destinationID api.DestinationID,
enabled bool,
) error {
client, conn, err := c.grpcClient()
if err != nil {
return err
}
defer conn.Close()
_, err = client.UpdateStreamForward(ctx, &streamd_grpc.UpdateStreamForwardRequest{
Config: &streamd_grpc.StreamForward{
StreamID: string(streamID),
DestinationID: string(destinationID),
Enabled: enabled,
}, },
}) })
if err != nil { if err != nil {

View File

@@ -19,12 +19,6 @@ type ProfileMetadata struct {
MaxOrder int MaxOrder int
} }
type GitRepoConfig struct {
Enable *bool
URL string `yaml:"url,omitempty"`
PrivateKey string `yaml:"private_key,omitempty"`
LatestSyncCommit string `yaml:"latest_sync_commit,omitempty"` // TODO: deprecate this field, it's just a non-needed mechanism (better to check against git history).
}
type config struct { type config struct {
CachePath *string `yaml:"cache_path"` CachePath *string `yaml:"cache_path"`

View File

@@ -1,10 +1,12 @@
package config package config
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io" "io"
"github.com/facebookincubator/go-belt/tool/logger"
"github.com/goccy/go-yaml" "github.com/goccy/go-yaml"
"github.com/xaionaro-go/streamctl/pkg/streamcontrol" "github.com/xaionaro-go/streamctl/pkg/streamcontrol"
"github.com/xaionaro-go/streamctl/pkg/streamcontrol/obs" "github.com/xaionaro-go/streamctl/pkg/streamcontrol/obs"
@@ -22,11 +24,32 @@ func (cfg *Config) Read(
return len(b), cfg.UnmarshalYAML(b) return len(b), cfg.UnmarshalYAML(b)
} }
func (cfg *Config) traceDump() {
l := logger.Default()
if l.Level() < logger.LevelTrace {
return
}
if cfg == nil {
l.Tracef("streamd config == nil")
return
}
var buf bytes.Buffer
_, err := cfg.WriteTo(&buf)
if err != nil {
l.Error(err)
return
}
l.Tracef("streamd config == %#+v: %s", *cfg, buf.String())
}
func (cfg *Config) UnmarshalYAML(b []byte) error { func (cfg *Config) UnmarshalYAML(b []byte) error {
logger.Default().Tracef("unparsed streamd config == %s", b)
err := yaml.Unmarshal(b, (*config)(cfg)) err := yaml.Unmarshal(b, (*config)(cfg))
if err != nil { if err != nil {
return fmt.Errorf("unable to unserialize data: %w", err) return fmt.Errorf("unable to unserialize data: %w", err)
} }
cfg.traceDump()
if cfg.Backends == nil { if cfg.Backends == nil {
cfg.Backends = streamcontrol.Config{} cfg.Backends = streamcontrol.Config{}

View File

@@ -45,12 +45,12 @@ func (cfg Config) MarshalYAML() ([]byte, error) {
m := map[string]any{} m := map[string]any{}
err = goyaml.Unmarshal(b, &m) err = goyaml.Unmarshal(b, &m)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to unserialize data %#+v: %w", cfg, err) return nil, fmt.Errorf("unable to unserialize data %s: %w", b, err)
} }
b, err = goyaml.Marshal(m) b, err = goyaml.Marshal(m)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to re-serialize data %#+v: %w", cfg, err) return nil, fmt.Errorf("unable to re-serialize data %#+v: %w", m, err)
} }
return b, nil return b, nil

File diff suppressed because it is too large Load Diff

View File

@@ -53,9 +53,12 @@ type StreamDClient interface {
ListStreamDestinations(ctx context.Context, in *ListStreamDestinationsRequest, opts ...grpc.CallOption) (*ListStreamDestinationsReply, error) ListStreamDestinations(ctx context.Context, in *ListStreamDestinationsRequest, opts ...grpc.CallOption) (*ListStreamDestinationsReply, error)
AddStreamDestination(ctx context.Context, in *AddStreamDestinationRequest, opts ...grpc.CallOption) (*AddStreamDestinationReply, error) AddStreamDestination(ctx context.Context, in *AddStreamDestinationRequest, opts ...grpc.CallOption) (*AddStreamDestinationReply, error)
RemoveStreamDestination(ctx context.Context, in *RemoveStreamDestinationRequest, opts ...grpc.CallOption) (*RemoveStreamDestinationReply, error) RemoveStreamDestination(ctx context.Context, in *RemoveStreamDestinationRequest, opts ...grpc.CallOption) (*RemoveStreamDestinationReply, error)
AddIncomingStream(ctx context.Context, in *AddIncomingStreamRequest, opts ...grpc.CallOption) (*AddIncomingStreamReply, error)
RemoveIncomingStream(ctx context.Context, in *RemoveIncomingStreamRequest, opts ...grpc.CallOption) (*RemoveIncomingStreamReply, error)
ListIncomingStreams(ctx context.Context, in *ListIncomingStreamsRequest, opts ...grpc.CallOption) (*ListIncomingStreamsReply, error) ListIncomingStreams(ctx context.Context, in *ListIncomingStreamsRequest, opts ...grpc.CallOption) (*ListIncomingStreamsReply, error)
ListStreamForwards(ctx context.Context, in *ListStreamForwardsRequest, opts ...grpc.CallOption) (*ListStreamForwardsReply, error) ListStreamForwards(ctx context.Context, in *ListStreamForwardsRequest, opts ...grpc.CallOption) (*ListStreamForwardsReply, error)
AddStreamForward(ctx context.Context, in *AddStreamForwardRequest, opts ...grpc.CallOption) (*AddStreamForwardReply, error) AddStreamForward(ctx context.Context, in *AddStreamForwardRequest, opts ...grpc.CallOption) (*AddStreamForwardReply, error)
UpdateStreamForward(ctx context.Context, in *UpdateStreamForwardRequest, opts ...grpc.CallOption) (*UpdateStreamForwardReply, error)
RemoveStreamForward(ctx context.Context, in *RemoveStreamForwardRequest, opts ...grpc.CallOption) (*RemoveStreamForwardReply, error) RemoveStreamForward(ctx context.Context, in *RemoveStreamForwardRequest, opts ...grpc.CallOption) (*RemoveStreamForwardReply, error)
} }
@@ -369,6 +372,24 @@ func (c *streamDClient) RemoveStreamDestination(ctx context.Context, in *RemoveS
return out, nil return out, nil
} }
func (c *streamDClient) AddIncomingStream(ctx context.Context, in *AddIncomingStreamRequest, opts ...grpc.CallOption) (*AddIncomingStreamReply, error) {
out := new(AddIncomingStreamReply)
err := c.cc.Invoke(ctx, "/StreamD/AddIncomingStream", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *streamDClient) RemoveIncomingStream(ctx context.Context, in *RemoveIncomingStreamRequest, opts ...grpc.CallOption) (*RemoveIncomingStreamReply, error) {
out := new(RemoveIncomingStreamReply)
err := c.cc.Invoke(ctx, "/StreamD/RemoveIncomingStream", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *streamDClient) ListIncomingStreams(ctx context.Context, in *ListIncomingStreamsRequest, opts ...grpc.CallOption) (*ListIncomingStreamsReply, error) { func (c *streamDClient) ListIncomingStreams(ctx context.Context, in *ListIncomingStreamsRequest, opts ...grpc.CallOption) (*ListIncomingStreamsReply, error) {
out := new(ListIncomingStreamsReply) out := new(ListIncomingStreamsReply)
err := c.cc.Invoke(ctx, "/StreamD/ListIncomingStreams", in, out, opts...) err := c.cc.Invoke(ctx, "/StreamD/ListIncomingStreams", in, out, opts...)
@@ -396,6 +417,15 @@ func (c *streamDClient) AddStreamForward(ctx context.Context, in *AddStreamForwa
return out, nil return out, nil
} }
func (c *streamDClient) UpdateStreamForward(ctx context.Context, in *UpdateStreamForwardRequest, opts ...grpc.CallOption) (*UpdateStreamForwardReply, error) {
out := new(UpdateStreamForwardReply)
err := c.cc.Invoke(ctx, "/StreamD/UpdateStreamForward", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *streamDClient) RemoveStreamForward(ctx context.Context, in *RemoveStreamForwardRequest, opts ...grpc.CallOption) (*RemoveStreamForwardReply, error) { func (c *streamDClient) RemoveStreamForward(ctx context.Context, in *RemoveStreamForwardRequest, opts ...grpc.CallOption) (*RemoveStreamForwardReply, error) {
out := new(RemoveStreamForwardReply) out := new(RemoveStreamForwardReply)
err := c.cc.Invoke(ctx, "/StreamD/RemoveStreamForward", in, out, opts...) err := c.cc.Invoke(ctx, "/StreamD/RemoveStreamForward", in, out, opts...)
@@ -440,9 +470,12 @@ type StreamDServer interface {
ListStreamDestinations(context.Context, *ListStreamDestinationsRequest) (*ListStreamDestinationsReply, error) ListStreamDestinations(context.Context, *ListStreamDestinationsRequest) (*ListStreamDestinationsReply, error)
AddStreamDestination(context.Context, *AddStreamDestinationRequest) (*AddStreamDestinationReply, error) AddStreamDestination(context.Context, *AddStreamDestinationRequest) (*AddStreamDestinationReply, error)
RemoveStreamDestination(context.Context, *RemoveStreamDestinationRequest) (*RemoveStreamDestinationReply, error) RemoveStreamDestination(context.Context, *RemoveStreamDestinationRequest) (*RemoveStreamDestinationReply, error)
AddIncomingStream(context.Context, *AddIncomingStreamRequest) (*AddIncomingStreamReply, error)
RemoveIncomingStream(context.Context, *RemoveIncomingStreamRequest) (*RemoveIncomingStreamReply, error)
ListIncomingStreams(context.Context, *ListIncomingStreamsRequest) (*ListIncomingStreamsReply, error) ListIncomingStreams(context.Context, *ListIncomingStreamsRequest) (*ListIncomingStreamsReply, error)
ListStreamForwards(context.Context, *ListStreamForwardsRequest) (*ListStreamForwardsReply, error) ListStreamForwards(context.Context, *ListStreamForwardsRequest) (*ListStreamForwardsReply, error)
AddStreamForward(context.Context, *AddStreamForwardRequest) (*AddStreamForwardReply, error) AddStreamForward(context.Context, *AddStreamForwardRequest) (*AddStreamForwardReply, error)
UpdateStreamForward(context.Context, *UpdateStreamForwardRequest) (*UpdateStreamForwardReply, error)
RemoveStreamForward(context.Context, *RemoveStreamForwardRequest) (*RemoveStreamForwardReply, error) RemoveStreamForward(context.Context, *RemoveStreamForwardRequest) (*RemoveStreamForwardReply, error)
mustEmbedUnimplementedStreamDServer() mustEmbedUnimplementedStreamDServer()
} }
@@ -544,6 +577,12 @@ func (UnimplementedStreamDServer) AddStreamDestination(context.Context, *AddStre
func (UnimplementedStreamDServer) RemoveStreamDestination(context.Context, *RemoveStreamDestinationRequest) (*RemoveStreamDestinationReply, error) { func (UnimplementedStreamDServer) RemoveStreamDestination(context.Context, *RemoveStreamDestinationRequest) (*RemoveStreamDestinationReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method RemoveStreamDestination not implemented") return nil, status.Errorf(codes.Unimplemented, "method RemoveStreamDestination not implemented")
} }
func (UnimplementedStreamDServer) AddIncomingStream(context.Context, *AddIncomingStreamRequest) (*AddIncomingStreamReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method AddIncomingStream not implemented")
}
func (UnimplementedStreamDServer) RemoveIncomingStream(context.Context, *RemoveIncomingStreamRequest) (*RemoveIncomingStreamReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method RemoveIncomingStream not implemented")
}
func (UnimplementedStreamDServer) ListIncomingStreams(context.Context, *ListIncomingStreamsRequest) (*ListIncomingStreamsReply, error) { func (UnimplementedStreamDServer) ListIncomingStreams(context.Context, *ListIncomingStreamsRequest) (*ListIncomingStreamsReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListIncomingStreams not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListIncomingStreams not implemented")
} }
@@ -553,6 +592,9 @@ func (UnimplementedStreamDServer) ListStreamForwards(context.Context, *ListStrea
func (UnimplementedStreamDServer) AddStreamForward(context.Context, *AddStreamForwardRequest) (*AddStreamForwardReply, error) { func (UnimplementedStreamDServer) AddStreamForward(context.Context, *AddStreamForwardRequest) (*AddStreamForwardReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method AddStreamForward not implemented") return nil, status.Errorf(codes.Unimplemented, "method AddStreamForward not implemented")
} }
func (UnimplementedStreamDServer) UpdateStreamForward(context.Context, *UpdateStreamForwardRequest) (*UpdateStreamForwardReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method UpdateStreamForward not implemented")
}
func (UnimplementedStreamDServer) RemoveStreamForward(context.Context, *RemoveStreamForwardRequest) (*RemoveStreamForwardReply, error) { func (UnimplementedStreamDServer) RemoveStreamForward(context.Context, *RemoveStreamForwardRequest) (*RemoveStreamForwardReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method RemoveStreamForward not implemented") return nil, status.Errorf(codes.Unimplemented, "method RemoveStreamForward not implemented")
} }
@@ -1130,6 +1172,42 @@ func _StreamD_RemoveStreamDestination_Handler(srv interface{}, ctx context.Conte
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _StreamD_AddIncomingStream_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AddIncomingStreamRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(StreamDServer).AddIncomingStream(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/StreamD/AddIncomingStream",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(StreamDServer).AddIncomingStream(ctx, req.(*AddIncomingStreamRequest))
}
return interceptor(ctx, in, info, handler)
}
func _StreamD_RemoveIncomingStream_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoveIncomingStreamRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(StreamDServer).RemoveIncomingStream(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/StreamD/RemoveIncomingStream",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(StreamDServer).RemoveIncomingStream(ctx, req.(*RemoveIncomingStreamRequest))
}
return interceptor(ctx, in, info, handler)
}
func _StreamD_ListIncomingStreams_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { func _StreamD_ListIncomingStreams_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListIncomingStreamsRequest) in := new(ListIncomingStreamsRequest)
if err := dec(in); err != nil { if err := dec(in); err != nil {
@@ -1184,6 +1262,24 @@ func _StreamD_AddStreamForward_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _StreamD_UpdateStreamForward_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(UpdateStreamForwardRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(StreamDServer).UpdateStreamForward(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/StreamD/UpdateStreamForward",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(StreamDServer).UpdateStreamForward(ctx, req.(*UpdateStreamForwardRequest))
}
return interceptor(ctx, in, info, handler)
}
func _StreamD_RemoveStreamForward_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { func _StreamD_RemoveStreamForward_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoveStreamForwardRequest) in := new(RemoveStreamForwardRequest)
if err := dec(in); err != nil { if err := dec(in); err != nil {
@@ -1329,6 +1425,14 @@ var StreamD_ServiceDesc = grpc.ServiceDesc{
MethodName: "RemoveStreamDestination", MethodName: "RemoveStreamDestination",
Handler: _StreamD_RemoveStreamDestination_Handler, Handler: _StreamD_RemoveStreamDestination_Handler,
}, },
{
MethodName: "AddIncomingStream",
Handler: _StreamD_AddIncomingStream_Handler,
},
{
MethodName: "RemoveIncomingStream",
Handler: _StreamD_RemoveIncomingStream_Handler,
},
{ {
MethodName: "ListIncomingStreams", MethodName: "ListIncomingStreams",
Handler: _StreamD_ListIncomingStreams_Handler, Handler: _StreamD_ListIncomingStreams_Handler,
@@ -1341,6 +1445,10 @@ var StreamD_ServiceDesc = grpc.ServiceDesc{
MethodName: "AddStreamForward", MethodName: "AddStreamForward",
Handler: _StreamD_AddStreamForward_Handler, Handler: _StreamD_AddStreamForward_Handler,
}, },
{
MethodName: "UpdateStreamForward",
Handler: _StreamD_UpdateStreamForward_Handler,
},
{ {
MethodName: "RemoveStreamForward", MethodName: "RemoveStreamForward",
Handler: _StreamD_RemoveStreamForward_Handler, Handler: _StreamD_RemoveStreamForward_Handler,

View File

@@ -38,9 +38,12 @@ service StreamD {
rpc ListStreamDestinations(ListStreamDestinationsRequest) returns (ListStreamDestinationsReply) {} rpc ListStreamDestinations(ListStreamDestinationsRequest) returns (ListStreamDestinationsReply) {}
rpc AddStreamDestination(AddStreamDestinationRequest) returns (AddStreamDestinationReply) {} rpc AddStreamDestination(AddStreamDestinationRequest) returns (AddStreamDestinationReply) {}
rpc RemoveStreamDestination(RemoveStreamDestinationRequest) returns (RemoveStreamDestinationReply) {} rpc RemoveStreamDestination(RemoveStreamDestinationRequest) returns (RemoveStreamDestinationReply) {}
rpc AddIncomingStream(AddIncomingStreamRequest) returns (AddIncomingStreamReply) {}
rpc RemoveIncomingStream(RemoveIncomingStreamRequest) returns (RemoveIncomingStreamReply) {}
rpc ListIncomingStreams(ListIncomingStreamsRequest) returns (ListIncomingStreamsReply) {} rpc ListIncomingStreams(ListIncomingStreamsRequest) returns (ListIncomingStreamsReply) {}
rpc ListStreamForwards(ListStreamForwardsRequest) returns (ListStreamForwardsReply) {} rpc ListStreamForwards(ListStreamForwardsRequest) returns (ListStreamForwardsReply) {}
rpc AddStreamForward(AddStreamForwardRequest) returns (AddStreamForwardReply) {} rpc AddStreamForward(AddStreamForwardRequest) returns (AddStreamForwardReply) {}
rpc UpdateStreamForward(UpdateStreamForwardRequest) returns (UpdateStreamForwardReply) {}
rpc RemoveStreamForward(RemoveStreamForwardRequest) returns (RemoveStreamForwardReply) {} rpc RemoveStreamForward(RemoveStreamForwardRequest) returns (RemoveStreamForwardReply) {}
} }
@@ -197,9 +200,19 @@ message StreamServer {
string listenAddr = 2; string listenAddr = 2;
} }
message StreamServerStatistics {
int64 NumBytesConsumerWrote = 1;
int64 NumBytesProducerRead = 2;
}
message StreamServerWithStatistics {
StreamServer config = 1;
StreamServerStatistics statistics = 2;
}
message ListStreamServersRequest {} message ListStreamServersRequest {}
message ListStreamServersReply { message ListStreamServersReply {
repeated StreamServer streamServers = 1; repeated StreamServerWithStatistics streamServers = 1;
} }
message StartStreamServerRequest { message StartStreamServerRequest {
@@ -237,6 +250,16 @@ message IncomingStream {
string streamID = 1; string streamID = 1;
} }
message AddIncomingStreamRequest {
string streamID = 1;
}
message AddIncomingStreamReply {}
message RemoveIncomingStreamRequest {
string streamID = 1;
}
message RemoveIncomingStreamReply {}
message ListIncomingStreamsRequest {} message ListIncomingStreamsRequest {}
message ListIncomingStreamsReply { message ListIncomingStreamsReply {
repeated IncomingStream incomingStreams = 1; repeated IncomingStream incomingStreams = 1;
@@ -245,11 +268,22 @@ message ListIncomingStreamsReply {
message StreamForward { message StreamForward {
string streamID = 1; string streamID = 1;
string destinationID = 2; string destinationID = 2;
bool enabled = 3;
}
message StreamForwardStatistics {
int64 numBytesWrote = 1;
int64 numBytesRead = 2;
}
message StreamForwardWithStatistics {
StreamForward config = 1;
StreamForwardStatistics statistics = 2;
} }
message ListStreamForwardsRequest {} message ListStreamForwardsRequest {}
message ListStreamForwardsReply { message ListStreamForwardsReply {
repeated StreamForward streamForwards = 1; repeated StreamForwardWithStatistics streamForwards = 1;
} }
message AddStreamForwardRequest { message AddStreamForwardRequest {
@@ -257,6 +291,11 @@ message AddStreamForwardRequest {
} }
message AddStreamForwardReply {} message AddStreamForwardReply {}
message UpdateStreamForwardRequest {
StreamForward config = 1;
}
message UpdateStreamForwardReply {}
message RemoveStreamForwardRequest { message RemoveStreamForwardRequest {
StreamForward config = 1; StreamForward config = 1;
} }

View File

@@ -706,16 +706,22 @@ func (grpc *GRPCServer) ListStreamServers(
return nil, err return nil, err
} }
var result []*streamd_grpc.StreamServer var result []*streamd_grpc.StreamServerWithStatistics
for _, srv := range servers { for _, srv := range servers {
t, err := grpcconv.StreamServerTypeGo2GRPC(srv.Type) t, err := grpcconv.StreamServerTypeGo2GRPC(srv.Type)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to convert the server type value: %w", err) return nil, fmt.Errorf("unable to convert the server type value: %w", err)
} }
result = append(result, &streamd_grpc.StreamServer{ result = append(result, &streamd_grpc.StreamServerWithStatistics{
ServerType: t, Config: &streamd_grpc.StreamServer{
ListenAddr: srv.ListenAddr, ServerType: t,
ListenAddr: srv.ListenAddr,
},
Statistics: &streamd_grpc.StreamServerStatistics{
NumBytesConsumerWrote: int64(srv.NumBytesConsumerWrote),
NumBytesProducerRead: int64(srv.NumBytesProducerRead),
},
}) })
} }
return &streamd_grpc.ListStreamServersReply{ return &streamd_grpc.ListStreamServersReply{
@@ -809,6 +815,28 @@ func (grpc *GRPCServer) RemoveStreamDestination(
return &streamd_grpc.RemoveStreamDestinationReply{}, nil return &streamd_grpc.RemoveStreamDestinationReply{}, nil
} }
func (grpc *GRPCServer) AddIncomingStream(
ctx context.Context,
req *streamd_grpc.AddIncomingStreamRequest,
) (*streamd_grpc.AddIncomingStreamReply, error) {
err := grpc.StreamD.AddIncomingStream(ctx, api.StreamID(req.GetStreamID()))
if err != nil {
return nil, err
}
return &streamd_grpc.AddIncomingStreamReply{}, nil
}
func (grpc *GRPCServer) RemoveIncomingStream(
ctx context.Context,
req *streamd_grpc.RemoveIncomingStreamRequest,
) (*streamd_grpc.RemoveIncomingStreamReply, error) {
err := grpc.StreamD.RemoveIncomingStream(ctx, api.StreamID(req.GetStreamID()))
if err != nil {
return nil, err
}
return &streamd_grpc.RemoveIncomingStreamReply{}, nil
}
func (grpc *GRPCServer) ListIncomingStreams( func (grpc *GRPCServer) ListIncomingStreams(
ctx context.Context, ctx context.Context,
req *streamd_grpc.ListIncomingStreamsRequest, req *streamd_grpc.ListIncomingStreamsRequest,
@@ -840,11 +868,18 @@ func (grpc *GRPCServer) ListStreamForwards(
return nil, err return nil, err
} }
var result []*streamd_grpc.StreamForward var result []*streamd_grpc.StreamForwardWithStatistics
for _, s := range streamFwds { for _, s := range streamFwds {
result = append(result, &streamd_grpc.StreamForward{ result = append(result, &streamd_grpc.StreamForwardWithStatistics{
StreamID: string(s.StreamID), Config: &streamd_grpc.StreamForward{
DestinationID: string(s.DestinationID), StreamID: string(s.StreamID),
DestinationID: string(s.DestinationID),
Enabled: s.Enabled,
},
Statistics: &streamd_grpc.StreamForwardStatistics{
NumBytesWrote: int64(s.NumBytesWrote),
NumBytesRead: int64(s.NumBytesRead),
},
}) })
} }
return &streamd_grpc.ListStreamForwardsReply{ return &streamd_grpc.ListStreamForwardsReply{
@@ -860,6 +895,7 @@ func (grpc *GRPCServer) AddStreamForward(
ctx, ctx,
api.StreamID(req.GetConfig().GetStreamID()), api.StreamID(req.GetConfig().GetStreamID()),
api.DestinationID(req.GetConfig().GetDestinationID()), api.DestinationID(req.GetConfig().GetDestinationID()),
req.Config.Enabled,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -867,6 +903,22 @@ func (grpc *GRPCServer) AddStreamForward(
return &streamd_grpc.AddStreamForwardReply{}, nil return &streamd_grpc.AddStreamForwardReply{}, nil
} }
func (grpc *GRPCServer) UpdateStreamForward(
ctx context.Context,
req *streamd_grpc.UpdateStreamForwardRequest,
) (*streamd_grpc.UpdateStreamForwardReply, error) {
err := grpc.StreamD.UpdateStreamForward(
ctx,
api.StreamID(req.GetConfig().GetStreamID()),
api.DestinationID(req.GetConfig().GetDestinationID()),
req.Config.Enabled,
)
if err != nil {
return nil, err
}
return &streamd_grpc.UpdateStreamForwardReply{}, nil
}
func (grpc *GRPCServer) RemoveStreamForward( func (grpc *GRPCServer) RemoveStreamForward(
ctx context.Context, ctx context.Context,
req *streamd_grpc.RemoveStreamForwardRequest, req *streamd_grpc.RemoveStreamForwardRequest,

View File

@@ -2,6 +2,7 @@ package streamd
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
@@ -34,7 +35,7 @@ func (d *StreamD) EXPERIMENTAL_ReinitStreamControllers(ctx context.Context) erro
case strings.ToLower(string(youtube.ID)): case strings.ToLower(string(youtube.ID)):
err = d.initYouTubeBackend(ctx) err = d.initYouTubeBackend(ctx)
} }
if err == ErrSkipBackend { if errors.Is(err, ErrSkipBackend) {
continue continue
} }
if err != nil { if err != nil {

View File

@@ -88,7 +88,10 @@ func New(
return d, nil return d, nil
} }
func (d *StreamD) Run(ctx context.Context) error { func (d *StreamD) Run(ctx context.Context) (_ret error) {
logger.Debugf(ctx, "StreamD.Run()")
defer func() { logger.Debugf(ctx, "/StreamD.Run(): %v", _ret) }()
d.UI.SetStatus("Initializing remote GIT storage...") d.UI.SetStatus("Initializing remote GIT storage...")
err := d.FetchConfig(ctx) err := d.FetchConfig(ctx)
if err != nil { if err != nil {
@@ -117,6 +120,7 @@ func (d *StreamD) Run(ctx context.Context) error {
func (d *StreamD) InitStreamServer(ctx context.Context) error { func (d *StreamD) InitStreamServer(ctx context.Context) error {
d.StreamServer = streamserver.New(&d.Config.StreamServer) d.StreamServer = streamserver.New(&d.Config.StreamServer)
assert(d.StreamServer != nil)
return d.StreamServer.Init(ctx) return d.StreamServer.Init(ctx)
} }
@@ -715,6 +719,8 @@ func (d *StreamD) ListStreamServers(
d.StreamServerLocker.Lock() d.StreamServerLocker.Lock()
defer d.StreamServerLocker.Unlock() defer d.StreamServerLocker.Unlock()
assert(d.StreamServer != nil)
servers := d.StreamServer.ListServers(ctx) servers := d.StreamServer.ListServers(ctx)
var result []api.StreamServer var result []api.StreamServer
@@ -722,6 +728,9 @@ func (d *StreamD) ListStreamServers(
result = append(result, api.StreamServer{ result = append(result, api.StreamServer{
Type: api.ServerTypeServer2API(src.Type()), Type: api.ServerTypeServer2API(src.Type()),
ListenAddr: src.ListenAddr(), ListenAddr: src.ListenAddr(),
NumBytesConsumerWrote: src.NumBytesConsumerWrote(),
NumBytesProducerRead: src.NumBytesProducerRead(),
}) })
} }
@@ -797,6 +806,52 @@ func (d *StreamD) StopStreamServer(
return nil return nil
} }
func (d *StreamD) AddIncomingStream(
ctx context.Context,
streamID api.StreamID,
) error {
logger.Debugf(ctx, "AddIncomingStream")
defer logger.Debugf(ctx, "/AddIncomingStream")
d.StreamServerLocker.Lock()
defer d.StreamServerLocker.Unlock()
err := d.StreamServer.AddIncomingStream(ctx, types.StreamID(streamID))
if err != nil {
return fmt.Errorf("unable to add an incoming stream: %w", err)
}
err = d.SaveConfig(ctx)
if err != nil {
return fmt.Errorf("unable to save the config: %w", err)
}
return nil
}
func (d *StreamD) RemoveIncomingStream(
ctx context.Context,
streamID api.StreamID,
) error {
logger.Debugf(ctx, "RemoveIncomingStream")
defer logger.Debugf(ctx, "/RemoveIncomingStream")
d.StreamServerLocker.Lock()
defer d.StreamServerLocker.Unlock()
err := d.StreamServer.RemoveIncomingStream(ctx, types.StreamID(streamID))
if err != nil {
return fmt.Errorf("unable to remove an incoming stream: %w", err)
}
err = d.SaveConfig(ctx)
if err != nil {
return fmt.Errorf("unable to save the config: %w", err)
}
return nil
}
func (d *StreamD) ListIncomingStreams( func (d *StreamD) ListIncomingStreams(
ctx context.Context, ctx context.Context,
) ([]api.IncomingStream, error) { ) ([]api.IncomingStream, error) {
@@ -899,8 +954,11 @@ func (d *StreamD) listStreamForwards(
} }
for _, streamFwd := range streamForwards { for _, streamFwd := range streamForwards {
result = append(result, api.StreamForward{ result = append(result, api.StreamForward{
Enabled: streamFwd.Enabled,
StreamID: api.StreamID(streamFwd.StreamID), StreamID: api.StreamID(streamFwd.StreamID),
DestinationID: api.DestinationID(streamFwd.DestinationID), DestinationID: api.DestinationID(streamFwd.DestinationID),
NumBytesWrote: streamFwd.NumBytesWrote,
NumBytesRead: streamFwd.NumBytesRead,
}) })
} }
return result, nil return result, nil
@@ -922,6 +980,7 @@ func (d *StreamD) AddStreamForward(
ctx context.Context, ctx context.Context,
streamID api.StreamID, streamID api.StreamID,
destinationID api.DestinationID, destinationID api.DestinationID,
enabled bool,
) error { ) error {
logger.Debugf(ctx, "AddStreamForward") logger.Debugf(ctx, "AddStreamForward")
defer logger.Debugf(ctx, "/AddStreamForward") defer logger.Debugf(ctx, "/AddStreamForward")
@@ -933,6 +992,37 @@ func (d *StreamD) AddStreamForward(
resetContextCancellers(ctx), resetContextCancellers(ctx),
types.StreamID(streamID), types.StreamID(streamID),
types.DestinationID(destinationID), types.DestinationID(destinationID),
enabled,
)
if err != nil {
return fmt.Errorf("unable to add the stream forwarding: %w", err)
}
err = d.SaveConfig(ctx)
if err != nil {
return fmt.Errorf("unable to save the config: %w", err)
}
return nil
}
func (d *StreamD) UpdateStreamForward(
ctx context.Context,
streamID api.StreamID,
destinationID api.DestinationID,
enabled bool,
) error {
logger.Debugf(ctx, "AddStreamForward")
defer logger.Debugf(ctx, "/AddStreamForward")
d.StreamServerLocker.Lock()
defer d.StreamServerLocker.Unlock()
err := d.StreamServer.UpdateStreamForward(
resetContextCancellers(ctx),
types.StreamID(streamID),
types.DestinationID(destinationID),
enabled,
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to add the stream forwarding: %w", err) return fmt.Errorf("unable to add the stream forwarding: %w", err)

View File

@@ -3,3 +3,9 @@ package streamd
func ptr[T any](in T) *T { func ptr[T any](in T) *T {
return &in return &in
} }
func assert(b bool) {
if !b {
panic("assertion failed")
}
}

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"os" "os"
@@ -37,7 +38,16 @@ func ReadConfigFromPath[CFG Config](
return fmt.Errorf("unable to read file '%s': %w", cfgPath, err) return fmt.Errorf("unable to read file '%s': %w", cfgPath, err)
} }
logger.Default().Debugf("unparsed config == %s", b)
_, err = cfg.Read(b) _, err = cfg.Read(b)
var cfgSerialized bytes.Buffer
if _, _err := cfg.WriteTo(&cfgSerialized); _err != nil {
logger.Default().Error(_err)
} else {
logger.Default().Debugf("parsed config == %s", cfgSerialized.String())
}
return err return err
} }

View File

@@ -13,7 +13,11 @@ var _ io.ReaderFrom = (*Config)(nil)
func (cfg *Config) Read( func (cfg *Config) Read(
b []byte, b []byte,
) (int, error) { ) (int, error) {
return len(b), yaml.Unmarshal(b, cfg) n := len(b)
if err := yaml.Unmarshal(b, cfg); err != nil {
return n, fmt.Errorf("unable to unmarshal the config: %w", err)
}
return n, nil
} }
func (cfg *Config) ReadFrom( func (cfg *Config) ReadFrom(

View File

@@ -2,8 +2,10 @@ package config
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
goyaml "github.com/go-yaml/yaml"
"github.com/goccy/go-yaml" "github.com/goccy/go-yaml"
"github.com/xaionaro-go/datacounter" "github.com/xaionaro-go/datacounter"
) )
@@ -19,9 +21,36 @@ func (cfg Config) Write(b []byte) (int, error) {
func (cfg Config) WriteTo( func (cfg Config) WriteTo(
w io.Writer, w io.Writer,
) (int64, error) { ) (int64, error) {
// There is bug in github.com/goccy/go-yaml that makes wrong intention
// in cfg.BuiltinStreamD.GitRepo.PrivateKey makes the whole value unparsable
//
// Working this around...
key := cfg.BuiltinStreamD.GitRepo.PrivateKey
cfg.BuiltinStreamD.GitRepo.PrivateKey = ""
b, err := yaml.Marshal(cfg) b, err := yaml.Marshal(cfg)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("unable to serialize data %#+v: %w", cfg, err)
}
m := map[any]any{}
err = goyaml.Unmarshal(b, &m)
if err != nil {
return 0, fmt.Errorf("unable to unserialize data %s: %w", b, err)
}
if v, ok := m["streamd_builtin"]; ok {
if m2, ok := v.(map[any]any); ok {
if v, ok := m2["gitrepo"]; ok {
if m3, ok := v.(map[any]any); ok {
m3["private_key"] = key
}
}
}
}
b, err = goyaml.Marshal(m)
if err != nil {
return 0, fmt.Errorf("unable to re-serialize data %#+v: %w", m, err)
} }
counter := datacounter.NewWriterCounter(w) counter := datacounter.NewWriterCounter(w)

View File

@@ -41,7 +41,7 @@ func (p *Panel) InputGitUserData(
gitRepo.SetPlaceHolder("git@github.com:myname/myrepo.git") gitRepo.SetPlaceHolder("git@github.com:myname/myrepo.git")
gitPrivateKey := widget.NewMultiLineEntry() gitPrivateKey := widget.NewMultiLineEntry()
gitPrivateKey.SetText(cfg.GitRepo.PrivateKey) gitPrivateKey.SetText(string(cfg.GitRepo.PrivateKey))
gitPrivateKey.SetMinRowsVisible(10) gitPrivateKey.SetMinRowsVisible(10)
gitPrivateKey.TextStyle.Monospace = true gitPrivateKey.TextStyle.Monospace = true
gitPrivateKey.SetPlaceHolder(`-----BEGIN OPENSSH PRIVATE KEY----- gitPrivateKey.SetPlaceHolder(`-----BEGIN OPENSSH PRIVATE KEY-----

View File

@@ -119,6 +119,10 @@ type Panel struct {
streamsWidget *fyne.Container streamsWidget *fyne.Container
destinationsWidget *fyne.Container destinationsWidget *fyne.Container
restreamsWidget *fyne.Container restreamsWidget *fyne.Container
previousNumBytesLocker sync.Mutex
previousNumBytes map[any][4]uint64
previousNumBytesTS map[any]time.Time
} }
func New( func New(
@@ -136,13 +140,16 @@ func New(
return nil, fmt.Errorf("unable to read the config from path '%s': %w", configPath, err) return nil, fmt.Errorf("unable to read the config from path '%s': %w", configPath, err)
} }
return &Panel{ p := &Panel{
configPath: configPath, configPath: configPath,
Config: Options(opts).ApplyOverrides(cfg), Config: Options(opts).ApplyOverrides(cfg),
Screenshoter: screenshoter.New(screenshot.Implementation{}), Screenshoter: screenshoter.New(screenshot.Implementation{}),
imageLastDownloaded: map[consts.ImageID][]byte{}, imageLastDownloaded: map[consts.ImageID][]byte{},
streamStatus: map[streamcontrol.PlatformName]*widget.Label{}, streamStatus: map[streamcontrol.PlatformName]*widget.Label{},
}, nil previousNumBytes: map[any][4]uint64{},
previousNumBytesTS: map[any]time.Time{},
}
return p, nil
} }
func (p *Panel) SetStatus(msg string) { func (p *Panel) SetStatus(msg string) {
@@ -178,15 +185,30 @@ func (opt LoopOptionStartingPage) apply(cfg *loopConfig) {
cfg.StartingPage = consts.Page(opt) cfg.StartingPage = consts.Page(opt)
} }
func (p *Panel) dumpConfig(ctx context.Context) {
if logger.FromCtx(ctx).Level() < logger.LevelTrace {
return
}
var buf bytes.Buffer
_, err := p.Config.WriteTo(&buf)
if err != nil {
logger.Error(ctx, err)
return
}
logger.Tracef(ctx, "the current config is: %s", buf.String())
}
func (p *Panel) Loop(ctx context.Context, opts ...LoopOption) error { func (p *Panel) Loop(ctx context.Context, opts ...LoopOption) error {
if p.defaultContext != nil { if p.defaultContext != nil {
return fmt.Errorf("Loop was already used, and cannot be used the second time") return fmt.Errorf("Loop was already used, and cannot be used the second time")
} }
p.dumpConfig(ctx)
initCfg := loopOptions(opts).Config() initCfg := loopOptions(opts).Config()
p.defaultContext = ctx p.defaultContext = ctx
logger.Debug(ctx, "config", p.Config)
if p.Config.RemoteStreamDAddr != "" { if p.Config.RemoteStreamDAddr != "" {
if err := p.initRemoteStreamD(ctx); err != nil { if err := p.initRemoteStreamD(ctx); err != nil {
@@ -230,6 +252,9 @@ func (p *Panel) Loop(ctx context.Context, opts ...LoopOption) error {
p.DisplayError(fmt.Errorf("unable to initialize the streaming controllers: %w", err)) p.DisplayError(fmt.Errorf("unable to initialize the streaming controllers: %w", err))
} }
p.setStatusFunc = nil p.setStatusFunc = nil
if streamD, ok := p.StreamD.(*streamd.StreamD); ok {
assert(streamD.StreamServer != nil)
}
p.reinitScreenshoter(ctx) p.reinitScreenshoter(ctx)
@@ -241,7 +266,7 @@ func (p *Panel) Loop(ctx context.Context, opts ...LoopOption) error {
if p.Config.RemoteStreamDAddr == "" { if p.Config.RemoteStreamDAddr == "" {
logger.Tracef(ctx, "hiding the loading window") logger.Tracef(ctx, "hiding the loading window")
loadingWindow.Hide() hideWindow(loadingWindow)
} }
logger.Tracef(ctx, "ended stream controllers initialization") logger.Tracef(ctx, "ended stream controllers initialization")
@@ -1531,11 +1556,17 @@ func (p *Panel) initMainWindow(
p.openAddStreamServerWindow(ctx) p.openAddStreamServerWindow(ctx)
}) })
p.streamsWidget = container.NewVBox() p.streamsWidget = container.NewVBox()
addStreamButton := widget.NewButtonWithIcon("Add stream", theme.ContentAddIcon(), p.openAddStreamWindow) addStreamButton := widget.NewButtonWithIcon("Add stream", theme.ContentAddIcon(), func() {
p.openAddStreamWindow(ctx)
})
p.destinationsWidget = container.NewVBox() p.destinationsWidget = container.NewVBox()
addDestination := widget.NewButtonWithIcon("Add destination", theme.ContentAddIcon(), p.openAddDestinationWindow) addDestination := widget.NewButtonWithIcon("Add destination", theme.ContentAddIcon(), func() {
p.openAddDestinationWindow(ctx)
})
p.restreamsWidget = container.NewVBox() p.restreamsWidget = container.NewVBox()
addRestream := widget.NewButtonWithIcon("Add restream", theme.ContentAddIcon(), p.openAddRestreamWindow) addRestream := widget.NewButtonWithIcon("Add restream", theme.ContentAddIcon(), func() {
p.openAddRestreamWindow(ctx)
})
restreamPage := container.NewBorder( restreamPage := container.NewBorder(
nil, nil,
nil, nil,
@@ -2517,7 +2548,9 @@ func (p *Panel) DisplayError(err error) {
func (p *Panel) waitForResponse(callback func()) { func (p *Panel) waitForResponse(callback func()) {
p.showWaitWindow() p.showWaitWindow()
defer p.hideWaitWindow() defer func() {
p.hideWaitWindow()
}()
callback() callback()
} }
@@ -2540,6 +2573,8 @@ func (p *Panel) showWaitWindow() {
func (p *Panel) hideWaitWindow() { func (p *Panel) hideWaitWindow() {
p.waitWindowLocker.Lock() p.waitWindowLocker.Lock()
defer p.waitWindowLocker.Unlock() defer p.waitWindowLocker.Unlock()
p.waitWindow.Hide()
time.Sleep(100 * time.Millisecond)
p.waitWindow.Close() p.waitWindow.Close()
p.waitWindow = nil p.waitWindow = nil
} }

View File

@@ -3,14 +3,18 @@ package streampanel
import ( import (
"context" "context"
"fmt" "fmt"
"math"
"sort"
"strconv" "strconv"
"sync" "sync"
"time" "time"
"fyne.io/fyne/v2" "fyne.io/fyne/v2"
"fyne.io/fyne/v2/container" "fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/theme" "fyne.io/fyne/v2/theme"
"fyne.io/fyne/v2/widget" "fyne.io/fyne/v2/widget"
"github.com/dustin/go-humanize"
"github.com/facebookincubator/go-belt/tool/logger" "github.com/facebookincubator/go-belt/tool/logger"
"github.com/xaionaro-go/streamctl/pkg/streamd/api" "github.com/xaionaro-go/streamctl/pkg/streamd/api"
"github.com/xaionaro-go/streamctl/pkg/xfyne" "github.com/xaionaro-go/streamctl/pkg/xfyne"
@@ -28,7 +32,7 @@ func (p *Panel) startRestreamPage(
ctx, cancelFn := context.WithCancel(ctx) ctx, cancelFn := context.WithCancel(ctx)
p.restreamPageUpdaterCancel = cancelFn p.restreamPageUpdaterCancel = cancelFn
p.initRestartPage(ctx) p.initRestreamPage(ctx)
go func(ctx context.Context) { go func(ctx context.Context) {
p.updateRestreamPage(ctx) p.updateRestreamPage(ctx)
@@ -46,11 +50,11 @@ func (p *Panel) startRestreamPage(
}(ctx) }(ctx)
} }
func (p *Panel) initRestartPage( func (p *Panel) initRestreamPage(
ctx context.Context, ctx context.Context,
) { ) {
logger.Debugf(ctx, "initRestartPage") logger.Debugf(ctx, "initRestreamPage")
defer logger.Debugf(ctx, "/initRestartPage") defer logger.Debugf(ctx, "/initRestreamPage")
streamServers, err := p.StreamD.ListStreamServers(ctx) streamServers, err := p.StreamD.ListStreamServers(ctx)
if err != nil { if err != nil {
@@ -116,12 +120,15 @@ func (p *Panel) openAddStreamServerWindow(ctx context.Context) {
saveButton := widget.NewButtonWithIcon("Save", theme.DocumentSaveIcon(), func() { saveButton := widget.NewButtonWithIcon("Save", theme.DocumentSaveIcon(), func() {
listenHost := listenHostEntry.Text listenHost := listenHostEntry.Text
err := p.addStreamServer(ctx, currentProtocol, listenHost, listenPort) p.waitForResponse(func() {
if err != nil { err := p.addStreamServer(ctx, currentProtocol, listenHost, listenPort)
p.DisplayError(err) if err != nil {
} p.DisplayError(err)
return
w.Close() }
w.Close()
p.initRestreamPage(ctx)
})
}) })
w.SetContent(container.NewBorder( w.SetContent(container.NewBorder(
@@ -160,25 +167,123 @@ func (p *Panel) displayStreamServers(
logger.Debugf(ctx, "displayStreamServers") logger.Debugf(ctx, "displayStreamServers")
defer logger.Debugf(ctx, "/displayStreamServers") defer logger.Debugf(ctx, "/displayStreamServers")
c := widget.NewList(
func() int {
return len(streamServers)
},
func() fyne.CanvasObject {
return widget.NewLabel("")
},
func(idx widget.ListItemID, co fyne.CanvasObject) {
o := co.(*widget.Label)
srv := streamServers[idx]
o.SetText(fmt.Sprintf("%s://%s", srv.Type, srv.ListenAddr))
},
)
p.streamServersWidget.RemoveAll() p.streamServersWidget.RemoveAll()
p.streamServersWidget.Add(c) for idx, srv := range streamServers {
logger.Tracef(ctx, "streamServer[%3d] == %#+v", idx, srv)
c := container.NewHBox()
button := widget.NewButtonWithIcon("", theme.DeleteIcon(), func() {
w := dialog.NewConfirm(
fmt.Sprintf("Delete Stream Server %s://%s ?", srv.Type, srv.ListenAddr),
"",
func(b bool) {
if !b {
return
}
logger.Debugf(ctx, "remove stream server")
defer logger.Debugf(ctx, "/remove stream server")
p.waitForResponse(func() {
err := p.StreamD.StopStreamServer(ctx, srv.ListenAddr)
if err != nil {
p.DisplayError(err)
return
}
})
p.initRestreamPage(ctx)
},
p.mainWindow,
)
w.Show()
})
label := widget.NewLabel(fmt.Sprintf("%s://%s", srv.Type, srv.ListenAddr))
c.RemoveAll()
c.Add(button)
c.Add(label)
c.Add(widget.NewSeparator())
type numBytesID struct {
ID string
}
key := numBytesID{ID: srv.ListenAddr}
p.previousNumBytesLocker.Lock()
prevNumBytes := p.previousNumBytes[key]
now := time.Now()
bwText := widget.NewRichTextWithText(bwString(srv.NumBytesProducerRead, prevNumBytes[0], srv.NumBytesConsumerWrote, prevNumBytes[1], now, p.previousNumBytesTS[key]))
p.previousNumBytes[key] = [4]uint64{srv.NumBytesProducerRead, srv.NumBytesConsumerWrote}
p.previousNumBytesTS[key] = now
p.previousNumBytesLocker.Unlock()
c.Add(bwText)
p.streamServersWidget.Add(c)
}
} }
func (p *Panel) openAddStreamWindow() {} func bwString(
nRead, nReadPrev uint64,
nWrote, nWrotePrev uint64,
ts, tsPrev time.Time,
) string {
var nReadStr, nWroteStr string
duration := ts.Sub(tsPrev)
if nRead != math.MaxUint64 {
n := 8 * (nRead - nReadPrev)
nReadStr = humanize.Bytes(uint64(float64(n) * float64(time.Second) / float64(duration)))
nReadStr = nReadStr[:len(nReadStr)-1] + "bps"
}
if nWrote != math.MaxUint64 {
n := 8 * (nWrote - nWrotePrev)
nWroteStr = humanize.Bytes(uint64(float64(n) * float64(time.Second) / float64(duration)))
nWroteStr = nWroteStr[:len(nWroteStr)-1] + "bps"
}
if nReadStr == "" && nWroteStr == "" {
return ""
}
return fmt.Sprintf("%s | %s", nReadStr, nWroteStr)
}
func (p *Panel) openAddStreamWindow(ctx context.Context) {
w := p.app.NewWindow(appName + ": Add incoming stream")
resizeWindow(w, fyne.NewSize(400, 300))
streamIDEntry := widget.NewEntry()
streamIDEntry.SetPlaceHolder("stream name")
saveButton := widget.NewButtonWithIcon("Save", theme.DocumentSaveIcon(), func() {
p.waitForResponse(func() {
err := p.addIncomingStream(ctx, api.StreamID(streamIDEntry.Text))
if err != nil {
p.DisplayError(err)
return
}
w.Close()
p.initRestreamPage(ctx)
})
})
w.SetContent(container.NewBorder(
nil,
container.NewHBox(saveButton),
nil,
nil,
container.NewVBox(
streamIDEntry,
),
))
w.Show()
}
func (p *Panel) addIncomingStream(
ctx context.Context,
streamID api.StreamID,
) error {
logger.Debugf(ctx, "addIncomingStream")
defer logger.Debugf(ctx, "/addIncomingStream")
return p.StreamD.AddIncomingStream(ctx, streamID)
}
func (p *Panel) displayIncomingServers( func (p *Panel) displayIncomingServers(
ctx context.Context, ctx context.Context,
@@ -186,10 +291,91 @@ func (p *Panel) displayIncomingServers(
) { ) {
logger.Debugf(ctx, "displayIncomingServers") logger.Debugf(ctx, "displayIncomingServers")
defer logger.Debugf(ctx, "/displayIncomingServers") defer logger.Debugf(ctx, "/displayIncomingServers")
sort.Slice(inStreams, func(i, j int) bool {
return inStreams[i].StreamID < inStreams[j].StreamID
})
p.streamsWidget.RemoveAll()
for idx, stream := range inStreams {
logger.Tracef(ctx, "inStream[%3d] == %#+v", idx, stream)
c := container.NewHBox()
button := widget.NewButtonWithIcon("", theme.DeleteIcon(), func() {
w := dialog.NewConfirm(
fmt.Sprintf("Delete incoming server %s ?", stream.StreamID),
"",
func(b bool) {
if !b {
return
}
logger.Debugf(ctx, "remove incoming stream")
defer logger.Debugf(ctx, "/remove incoming stream")
p.waitForResponse(func() {
err := p.StreamD.RemoveIncomingStream(ctx, stream.StreamID)
if err != nil {
p.DisplayError(err)
return
}
})
p.initRestreamPage(ctx)
},
p.mainWindow,
)
w.Show()
p.initRestreamPage(ctx)
})
label := widget.NewLabel(string(stream.StreamID))
c.RemoveAll()
c.Add(button)
c.Add(label)
p.streamsWidget.Add(c)
}
p.streamsWidget.Refresh()
} }
func (p *Panel) openAddDestinationWindow() {} func (p *Panel) openAddDestinationWindow(ctx context.Context) {
w := p.app.NewWindow(appName + ": Add stream destination")
resizeWindow(w, fyne.NewSize(400, 300))
destinationIDEntry := widget.NewEntry()
destinationIDEntry.SetPlaceHolder("destination ID")
urlEntry := widget.NewEntry()
urlEntry.SetPlaceHolder("URL")
saveButton := widget.NewButtonWithIcon("Save", theme.DocumentSaveIcon(), func() {
p.waitForResponse(func() {
err := p.addStreamDestination(ctx, api.DestinationID(destinationIDEntry.Text), urlEntry.Text)
if err != nil {
p.DisplayError(err)
return
}
w.Close()
p.initRestreamPage(ctx)
})
})
w.SetContent(container.NewBorder(
nil,
container.NewHBox(saveButton),
nil,
nil,
container.NewVBox(
destinationIDEntry,
urlEntry,
),
))
w.Show()
}
func (p *Panel) addStreamDestination(
ctx context.Context,
destinationID api.DestinationID,
url string,
) error {
logger.Debugf(ctx, "addStreamDestination")
defer logger.Debugf(ctx, "/addStreamDestination")
return p.StreamD.AddStreamDestination(ctx, destinationID, url)
}
func (p *Panel) displayStreamDestinations( func (p *Panel) displayStreamDestinations(
ctx context.Context, ctx context.Context,
@@ -198,17 +384,220 @@ func (p *Panel) displayStreamDestinations(
logger.Debugf(ctx, "displayStreamDestinations") logger.Debugf(ctx, "displayStreamDestinations")
defer logger.Debugf(ctx, "/displayStreamDestinations") defer logger.Debugf(ctx, "/displayStreamDestinations")
p.destinationsWidget.RemoveAll()
for idx, dst := range dsts {
logger.Tracef(ctx, "dsts[%3d] == %#+v", idx, dst)
c := container.NewHBox()
deleteButton := widget.NewButtonWithIcon("", theme.DeleteIcon(), func() {
w := dialog.NewConfirm(
fmt.Sprintf("Delete destination %s ?", dst.ID),
"",
func(b bool) {
if !b {
return
}
logger.Debugf(ctx, "remove destination")
defer logger.Debugf(ctx, "/remove destination")
p.waitForResponse(func() {
err := p.StreamD.RemoveStreamDestination(ctx, dst.ID)
if err != nil {
p.DisplayError(err)
return
}
})
p.initRestreamPage(ctx)
},
p.mainWindow,
)
w.Show()
p.initRestreamPage(ctx)
})
label := widget.NewLabel(string(dst.ID) + ": " + string(dst.URL))
c.RemoveAll()
c.Add(deleteButton)
c.Add(label)
p.destinationsWidget.Add(c)
}
} }
func (p *Panel) openAddRestreamWindow() {} func (p *Panel) openAddRestreamWindow(ctx context.Context) {
w := p.app.NewWindow(appName + ": Add restreaming (stream forwarding)")
resizeWindow(w, fyne.NewSize(400, 300))
enabledCheck := widget.NewCheck("Enable", func(b bool) {})
inStreams, err := p.StreamD.ListIncomingStreams(ctx)
if err != nil {
p.DisplayError(err)
return
}
dsts, err := p.StreamD.ListStreamDestinations(ctx)
if err != nil {
p.DisplayError(err)
return
}
var inStreamStrs []string
for _, inStream := range inStreams {
inStreamStrs = append(inStreamStrs, string(inStream.StreamID))
}
inStreamsSelect := widget.NewSelect(inStreamStrs, func(s string) {})
var dstStrs []string
dstMap := map[string]api.DestinationID{}
for _, dst := range dsts {
k := string(dst.ID) + ": " + dst.URL
dstStrs = append(dstStrs, k)
dstMap[k] = dst.ID
}
dstSelect := widget.NewSelect(dstStrs, func(s string) {})
saveButton := widget.NewButtonWithIcon("Save", theme.DocumentSaveIcon(), func() {
p.waitForResponse(func() {
err := p.addStreamForward(
ctx,
api.StreamID(inStreamsSelect.Selected),
dstMap[dstSelect.Selected],
enabledCheck.Checked,
)
if err != nil {
p.DisplayError(err)
return
}
w.Close()
p.initRestreamPage(ctx)
})
})
w.SetContent(container.NewBorder(
nil,
container.NewHBox(saveButton),
nil,
nil,
container.NewVBox(
widget.NewLabel("From:"),
inStreamsSelect,
widget.NewLabel("To:"),
dstSelect,
),
))
w.Show()
}
func (p *Panel) addStreamForward(
ctx context.Context,
streamID api.StreamID,
dstID api.DestinationID,
enabled bool,
) error {
logger.Debugf(ctx, "addStreamForward")
defer logger.Debugf(ctx, "/addStreamForward")
return p.StreamD.AddStreamForward(
ctx,
streamID,
dstID,
enabled,
)
}
func (p *Panel) displayStreamForwards( func (p *Panel) displayStreamForwards(
ctx context.Context, ctx context.Context,
dsts []api.StreamForward, fwds []api.StreamForward,
) { ) {
logger.Debugf(ctx, "displayStreamForwards") logger.Debugf(ctx, "displayStreamForwards")
defer logger.Debugf(ctx, "/displayStreamForwards") defer logger.Debugf(ctx, "/displayStreamForwards")
p.restreamsWidget.RemoveAll()
for idx, fwd := range fwds {
logger.Tracef(ctx, "fwds[%3d] == %#+v", idx, fwd)
c := container.NewHBox()
deleteButton := widget.NewButtonWithIcon("", theme.DeleteIcon(), func() {
w := dialog.NewConfirm(
fmt.Sprintf("Delete restreaming (stream forwarding) %s -> %s ?", fwd.StreamID, fwd.DestinationID),
"",
func(b bool) {
if !b {
return
}
logger.Debugf(ctx, "remove restreaming (stream forwarding)")
defer logger.Debugf(ctx, "/remove restreaming (stream forwarding)")
p.waitForResponse(func() {
err := p.StreamD.RemoveStreamForward(ctx, fwd.StreamID, fwd.DestinationID)
if err != nil {
p.DisplayError(err)
return
}
})
p.initRestreamPage(ctx)
},
p.mainWindow,
)
w.Show()
p.initRestreamPage(ctx)
})
icon := theme.MediaPauseIcon()
label := "Pause"
title := fmt.Sprintf("Pause forwarding %s -> %s ?", fwd.StreamID, fwd.DestinationID)
if !fwd.Enabled {
icon = theme.MediaPlayIcon()
label = "Unpause"
title = fmt.Sprintf("Unpause forwarding %s -> %s ?", fwd.StreamID, fwd.DestinationID)
}
playPauseButton := widget.NewButtonWithIcon(label, icon, func() {
w := dialog.NewConfirm(
title,
"",
func(b bool) {
if !b {
return
}
logger.Debugf(ctx, "pause/unpause restreaming (stream forwarding): disabled:%v->%v", fwd.Enabled, !fwd.Enabled)
defer logger.Debugf(ctx, "/pause/unpause restreaming (stream forwarding): disabled:%v->%v", !fwd.Enabled, fwd.Enabled)
p.waitForResponse(func() {
err := p.StreamD.UpdateStreamForward(
ctx,
fwd.StreamID,
fwd.DestinationID,
!fwd.Enabled,
)
if err != nil {
p.DisplayError(err)
return
}
})
p.initRestreamPage(ctx)
},
p.mainWindow,
)
w.Show()
p.initRestreamPage(ctx)
})
caption := widget.NewLabel(string(fwd.StreamID) + " -> " + string(fwd.DestinationID))
c.RemoveAll()
c.Add(deleteButton)
c.Add(playPauseButton)
c.Add(caption)
if fwd.Enabled {
c.Add(widget.NewSeparator())
type numBytesID struct {
StrID api.StreamID
DstID api.DestinationID
}
key := numBytesID{StrID: fwd.StreamID, DstID: fwd.DestinationID}
now := time.Now()
p.previousNumBytesLocker.Lock()
prevNumBytes := p.previousNumBytes[key]
bwText := widget.NewRichTextWithText(bwString(fwd.NumBytesRead, prevNumBytes[0], fwd.NumBytesWrote, prevNumBytes[1], now, p.previousNumBytesTS[key]))
p.previousNumBytes[key] = [4]uint64{fwd.NumBytesRead, fwd.NumBytesWrote}
p.previousNumBytesTS[key] = now
p.previousNumBytesLocker.Unlock()
c.Add(bwText)
}
p.restreamsWidget.Add(c)
}
} }
func (p *Panel) stopRestreamPage( func (p *Panel) stopRestreamPage(
@@ -239,6 +628,7 @@ func (p *Panel) updateRestreamPage(
go func() { go func() {
defer wg.Done() defer wg.Done()
p.initRestreamPage(ctx)
// whatever // whatever
}() }()
wg.Wait() wg.Wait()

View File

@@ -14,16 +14,19 @@ import (
"github.com/facebookincubator/go-belt/tool/experimental/errmon" "github.com/facebookincubator/go-belt/tool/experimental/errmon"
"github.com/facebookincubator/go-belt/tool/logger" "github.com/facebookincubator/go-belt/tool/logger"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/xaionaro-go/datacounter"
"github.com/xaionaro-go/streamctl/pkg/streamserver/consts" "github.com/xaionaro-go/streamctl/pkg/streamserver/consts"
"github.com/xaionaro-go/streamctl/pkg/streamserver/server"
"github.com/xaionaro-go/streamctl/pkg/streamserver/streams" "github.com/xaionaro-go/streamctl/pkg/streamserver/streams"
"github.com/xaionaro-go/streamctl/pkg/streamserver/types" "github.com/xaionaro-go/streamctl/pkg/streamserver/types"
) )
type RTMPServer struct { type RTMPServer struct {
Config Config Config Config
StreamHandler *streams.StreamHandler StreamHandler *streams.StreamHandler
Listener net.Listener Listener net.Listener
CancelFn context.CancelFunc CancelFn context.CancelFunc
TrafficCounter server.TrafficCounter
} }
type Config struct { type Config struct {
@@ -123,7 +126,12 @@ func (s *RTMPServer) tcpHandle(netConn net.Conn) error {
return err return err
} }
_, _ = cons.WriteTo(rtmpConn) wc := datacounter.NewWriterCounter(rtmpConn)
s.TrafficCounter.Lock()
s.TrafficCounter.WriterCounter = wc
s.TrafficCounter.Unlock()
_, _ = cons.WriteTo(wc)
return nil return nil
@@ -146,7 +154,15 @@ func (s *RTMPServer) tcpHandle(netConn net.Conn) error {
defer stream.RemoveProducer(prod) defer stream.RemoveProducer(prod)
_ = prod.Start() rc := server.NewIntPtrCounter(&prod.Recv)
s.TrafficCounter.Lock()
s.TrafficCounter.ReaderCounter = rc
s.TrafficCounter.Unlock()
err = prod.Start()
if err != nil {
logger.Default().Error(err)
}
return nil return nil
} }
@@ -154,18 +170,31 @@ func (s *RTMPServer) tcpHandle(netConn net.Conn) error {
return errors.New("rtmp: unknown command: " + rtmpConn.Intent) return errors.New("rtmp: unknown command: " + rtmpConn.Intent)
} }
func (s *RTMPServer) NumBytesConsumerWrote() uint64 {
return s.TrafficCounter.NumBytesWrote()
}
func (s *RTMPServer) NumBytesProducerRead() uint64 {
return s.TrafficCounter.NumBytesRead()
}
func StreamsHandle(url string) (core.Producer, error) { func StreamsHandle(url string) (core.Producer, error) {
return rtmp.DialPlay(url) return rtmp.DialPlay(url)
} }
func StreamsConsumerHandle(url string) (core.Consumer, func(context.Context) error, error) { func StreamsConsumerHandle(url string) (core.Consumer, server.NumBytesReaderWroter, func(context.Context) error, error) {
cons := flv.NewConsumer() cons := flv.NewConsumer()
trafficCounter := &server.TrafficCounter{}
run := func(ctx context.Context) error { run := func(ctx context.Context) error {
wr, err := rtmp.DialPublish(url) wr, err := rtmp.DialPublish(url)
if err != nil { if err != nil {
return fmt.Errorf("unable to connect to '%s': %w", url, err) return fmt.Errorf("unable to connect to '%s': %w", url, err)
} }
wrc := datacounter.NewWriterCounter(wr)
trafficCounter.Lock()
trafficCounter.WriterCounter = wrc
trafficCounter.Unlock()
ctx, cancelFn := context.WithCancel(ctx) ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn() defer cancelFn()
go func() { go func() {
@@ -175,14 +204,14 @@ func StreamsConsumerHandle(url string) (core.Consumer, func(context.Context) err
errmon.ObserveErrorCtx(ctx, err) errmon.ObserveErrorCtx(ctx, err)
}() }()
_, err = cons.WriteTo(wr) _, err = cons.WriteTo(wrc)
if err != nil { if err != nil {
return fmt.Errorf("unable to write: %w", err) return fmt.Errorf("unable to write: %w", err)
} }
return nil return nil
} }
return cons, run, nil return cons, trafficCounter, run, nil
} }
func (s *RTMPServer) apiHandle(w http.ResponseWriter, r *http.Request) { func (s *RTMPServer) apiHandle(w http.ResponseWriter, r *http.Request) {

View File

@@ -14,17 +14,19 @@ import (
"github.com/facebookincubator/go-belt/tool/experimental/errmon" "github.com/facebookincubator/go-belt/tool/experimental/errmon"
"github.com/facebookincubator/go-belt/tool/logger" "github.com/facebookincubator/go-belt/tool/logger"
"github.com/xaionaro-go/streamctl/pkg/streamserver/consts" "github.com/xaionaro-go/streamctl/pkg/streamserver/consts"
"github.com/xaionaro-go/streamctl/pkg/streamserver/server"
"github.com/xaionaro-go/streamctl/pkg/streamserver/streams" "github.com/xaionaro-go/streamctl/pkg/streamserver/streams"
"github.com/xaionaro-go/streamctl/pkg/streamserver/types" "github.com/xaionaro-go/streamctl/pkg/streamserver/types"
) )
type RTSPServer struct { type RTSPServer struct {
Config Config Config Config
Listener net.Listener Listener net.Listener
DefaultMedias []*core.Media DefaultMedias []*core.Media
StreamHandler *streams.StreamHandler StreamHandler *streams.StreamHandler
Handlers []HandlerFunc Handlers []HandlerFunc
CancelFn context.CancelFunc CancelFn context.CancelFunc
TrafficCounter server.TrafficCounter
} }
type Config struct { type Config struct {
@@ -280,6 +282,13 @@ func (s *RTSPServer) Close() error {
return nil return nil
} }
func (s *RTSPServer) NumBytesConsumerWrote() uint64 {
return s.TrafficCounter.NumBytesWrote()
}
func (s *RTSPServer) NumBytesProducerRead() uint64 {
return s.TrafficCounter.NumBytesRead()
}
func ParseQuery(query map[string][]string) []*core.Media { func ParseQuery(query map[string][]string) []*core.Media {
if v := query["mp4"]; v != nil { if v := query["mp4"]; v != nil {
return []*core.Media{ return []*core.Media{

View File

@@ -3,8 +3,10 @@ package streamserver
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync" "sync"
"github.com/facebookincubator/go-belt/tool/logger"
rtmpserver "github.com/xaionaro-go/streamctl/pkg/streamserver/server/rtmp" rtmpserver "github.com/xaionaro-go/streamctl/pkg/streamserver/server/rtmp"
rtspserver "github.com/xaionaro-go/streamctl/pkg/streamserver/server/rtsp" rtspserver "github.com/xaionaro-go/streamctl/pkg/streamserver/server/rtsp"
"github.com/xaionaro-go/streamctl/pkg/streamserver/streams" "github.com/xaionaro-go/streamctl/pkg/streamserver/streams"
@@ -20,9 +22,9 @@ type StreamServer struct {
} }
func New(cfg *types.Config) *StreamServer { func New(cfg *types.Config) *StreamServer {
if cfg == nil { assert(cfg != nil)
cfg = &types.Config{} logger.Default().Debugf("config == %#+v", *cfg)
}
if cfg.Streams == nil { if cfg.Streams == nil {
cfg.Streams = map[types.StreamID]*types.StreamConfig{} cfg.Streams = map[types.StreamID]*types.StreamConfig{}
} }
@@ -52,6 +54,7 @@ func (s *StreamServer) Init(ctx context.Context) error {
defer s.Unlock() defer s.Unlock()
cfg := s.Config cfg := s.Config
logger.Debugf(ctx, "config == %#+v", *cfg)
for _, srv := range cfg.Servers { for _, srv := range cfg.Servers {
err := s.startServer(ctx, srv.Type, srv.Listen) err := s.startServer(ctx, srv.Type, srv.Listen)
@@ -73,10 +76,12 @@ func (s *StreamServer) Init(ctx context.Context) error {
return fmt.Errorf("unable to initialize stream '%s': %w", streamID, err) return fmt.Errorf("unable to initialize stream '%s': %w", streamID, err)
} }
for _, fwd := range streamCfg.Forwardings { for dstID, fwd := range streamCfg.Forwardings {
err := s.addStreamForward(ctx, streamID, fwd) if !fwd.Disabled {
if err != nil { err := s.addStreamForward(ctx, streamID, dstID)
return fmt.Errorf("unable to launch stream forward from '%s' to '%s': %w", streamID, fwd, err) if err != nil {
return fmt.Errorf("unable to launch stream forward from '%s' to '%s': %w", streamID, dstID, err)
}
} }
} }
} }
@@ -200,7 +205,7 @@ func (s *StreamServer) addIncomingStream(
if s.StreamHandler.Get(string(streamID)) != nil { if s.StreamHandler.Get(string(streamID)) != nil {
return fmt.Errorf("stream '%s' already exists", streamID) return fmt.Errorf("stream '%s' already exists", streamID)
} }
_, err := s.StreamHandler.New(string(streamID), "") _, err := s.StreamHandler.New(string(streamID), nil)
if err != nil { if err != nil {
return fmt.Errorf("unable to create the stream '%s': %w", streamID, err) return fmt.Errorf("unable to create the stream '%s': %w", streamID, err)
} }
@@ -209,6 +214,9 @@ func (s *StreamServer) addIncomingStream(
type IncomingStream struct { type IncomingStream struct {
StreamID types.StreamID StreamID types.StreamID
NumBytesWrote uint64
NumBytesRead uint64
} }
func (s *StreamServer) ListIncomingStreams( func (s *StreamServer) ListIncomingStreams(
@@ -258,21 +266,33 @@ func (s *StreamServer) removeIncomingStream(
type StreamForward struct { type StreamForward struct {
StreamID types.StreamID StreamID types.StreamID
DestinationID types.DestinationID DestinationID types.DestinationID
Enabled bool
NumBytesWrote uint64
NumBytesRead uint64
} }
func (s *StreamServer) AddStreamForward( func (s *StreamServer) AddStreamForward(
ctx context.Context, ctx context.Context,
streamID types.StreamID, streamID types.StreamID,
destinationID types.DestinationID, destinationID types.DestinationID,
enabled bool,
) error { ) error {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
err := s.addStreamForward(ctx, streamID, destinationID)
if err != nil {
return err
}
streamConfig := s.Config.Streams[streamID] streamConfig := s.Config.Streams[streamID]
streamConfig.Forwardings = append(streamConfig.Forwardings, destinationID) if _, ok := streamConfig.Forwardings[destinationID]; ok {
return fmt.Errorf("the forwarding %s->%s already exists", streamID, destinationID)
}
if enabled {
err := s.addStreamForward(ctx, streamID, destinationID)
if err != nil {
return err
}
}
streamConfig.Forwardings[destinationID] = types.ForwardingConfig{
Disabled: !enabled,
}
return nil return nil
} }
@@ -282,8 +302,8 @@ func (s *StreamServer) addStreamForward(
destinationID types.DestinationID, destinationID types.DestinationID,
) error { ) error {
streamSrc := s.StreamHandler.Get(string(streamID)) streamSrc := s.StreamHandler.Get(string(streamID))
if streamSrc != nil { if streamSrc == nil {
return fmt.Errorf("unable to find stream ID '%s'", streamID) return fmt.Errorf("unable to find stream ID '%s', available stream IDs: %s", streamID, strings.Join(s.StreamHandler.GetAll(), ", "))
} }
dst, err := s.findStreamDestinationByID(ctx, destinationID) dst, err := s.findStreamDestinationByID(ctx, destinationID)
if err != nil { if err != nil {
@@ -296,12 +316,82 @@ func (s *StreamServer) addStreamForward(
return nil return nil
} }
func (s *StreamServer) UpdateStreamForward(
ctx context.Context,
streamID types.StreamID,
destinationID types.DestinationID,
enabled bool,
) error {
s.Lock()
defer s.Unlock()
streamConfig := s.Config.Streams[streamID]
fwdCfg, ok := streamConfig.Forwardings[destinationID]
if !ok {
return fmt.Errorf("the forwarding %s->%s does not exist", streamID, destinationID)
}
if fwdCfg.Disabled && enabled {
err := s.addStreamForward(ctx, streamID, destinationID)
if err != nil {
return err
}
}
if !fwdCfg.Disabled && !enabled {
err := s.removeStreamForward(ctx, streamID, destinationID)
if err != nil {
return err
}
}
streamConfig.Forwardings[destinationID] = types.ForwardingConfig{
Disabled: !enabled,
}
return nil
}
func (s *StreamServer) ListStreamForwards( func (s *StreamServer) ListStreamForwards(
ctx context.Context, ctx context.Context,
) ([]StreamForward, error) { ) ([]StreamForward, error) {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
return s.listStreamForwards(ctx)
activeStreamForwards, err := s.listStreamForwards(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get the list of active stream forwardings: %w", err)
}
type fwdID struct {
StreamID types.StreamID
DestID types.DestinationID
}
m := map[fwdID]*StreamForward{}
for idx := range activeStreamForwards {
fwd := &activeStreamForwards[idx]
m[fwdID{
StreamID: fwd.StreamID,
DestID: fwd.DestinationID,
}] = fwd
}
var result []StreamForward
for streamID, stream := range s.Config.Streams {
for dstID, cfg := range stream.Forwardings {
item := StreamForward{
StreamID: streamID,
DestinationID: dstID,
Enabled: !cfg.Disabled,
}
if activeFwd, ok := m[fwdID{
StreamID: streamID,
DestID: dstID,
}]; ok {
item.NumBytesWrote = activeFwd.NumBytesWrote
item.NumBytesRead = activeFwd.NumBytesRead
}
logger.Tracef(ctx, "stream forwarding '%s->%s': %#+v", streamID, dstID, cfg)
result = append(result, item)
}
}
return result, nil
} }
func (s *StreamServer) listStreamForwards( func (s *StreamServer) listStreamForwards(
@@ -322,6 +412,9 @@ func (s *StreamServer) listStreamForwards(
result = append(result, StreamForward{ result = append(result, StreamForward{
StreamID: streamIDSrc, StreamID: streamIDSrc,
DestinationID: streamDst.ID, DestinationID: streamDst.ID,
Enabled: true,
NumBytesWrote: fwd.TrafficCounter.NumBytesWrote(),
NumBytesRead: fwd.TrafficCounter.NumBytesRead(),
}) })
} }
} }
@@ -336,13 +429,10 @@ func (s *StreamServer) RemoveStreamForward(
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
streamCfg := s.Config.Streams[streamID] streamCfg := s.Config.Streams[streamID]
for idx, _dstID := range streamCfg.Forwardings { if _, ok := streamCfg.Forwardings[dstID]; !ok {
if _dstID != dstID { return fmt.Errorf("the forwarding %s->%s does not exist", streamID, dstID)
continue
}
streamCfg.Forwardings = append(streamCfg.Forwardings[:idx], streamCfg.Forwardings[idx+1:]...)
break
} }
delete(streamCfg.Forwardings, dstID)
return s.removeStreamForward(ctx, streamID, dstID) return s.removeStreamForward(ctx, streamID, dstID)
} }
@@ -424,13 +514,7 @@ func (s *StreamServer) RemoveStreamDestination(
s.Mutex.Lock() s.Mutex.Lock()
defer s.Mutex.Unlock() defer s.Mutex.Unlock()
for _, streamCfg := range s.Config.Streams { for _, streamCfg := range s.Config.Streams {
for fIdx, destID := range streamCfg.Forwardings { delete(streamCfg.Forwardings, destinationID)
if destID != destinationID {
continue
}
streamCfg.Forwardings = append(streamCfg.Forwardings[:fIdx], streamCfg.Forwardings[fIdx+1:]...)
break
}
} }
delete(s.Config.Destinations, destinationID) delete(s.Config.Destinations, destinationID)
return s.removeStreamDestination(ctx, destinationID) return s.removeStreamDestination(ctx, destinationID)

View File

@@ -2,6 +2,7 @@ package streams
import ( import (
"errors" "errors"
"fmt"
"strings" "strings"
"github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/core"
@@ -12,6 +13,10 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
// support for multiple simultaneous pending from different consumers // support for multiple simultaneous pending from different consumers
consN := s.pending.Add(1) - 1 consN := s.pending.Add(1) - 1
if len(s.producers) == 0 {
return ErrNoProducer{}
}
var prodErrors = make([]error, len(s.producers)) var prodErrors = make([]error, len(s.producers))
var prodMedias []*core.Media var prodMedias []*core.Media
var prodStarts []*Producer var prodStarts []*Producer
@@ -29,8 +34,8 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
} }
if err = prod.Dial(); err != nil { if err = prod.Dial(); err != nil {
logger.Default().WithField("error", err).Tracef("[streams] dial cons=%d prod=%d", consN, prodN) logger.Default().Tracef("[streams] dial cons=%d prod=%d err=%v", consN, prodN, err)
prodErrors[prodN] = err prodErrors[prodN] = fmt.Errorf("unable to Dial(): %w", err)
continue continue
} }
@@ -53,13 +58,13 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
// Step 4. Get recvonly track from producer // Step 4. Get recvonly track from producer
if track, err = prod.GetTrack(prodMedia, prodCodec); err != nil { if track, err = prod.GetTrack(prodMedia, prodCodec); err != nil {
logger.Default().WithField("error", err).Info("[streams] can't get track") logger.Default().Info("[streams] can't get track; err=%v", err)
prodErrors[prodN] = err prodErrors[prodN] = fmt.Errorf("unable to GetTrack(): %w", err)
continue continue
} }
// Step 5. Add track to consumer // Step 5. Add track to consumer
if err = cons.AddTrack(consMedia, consCodec, track); err != nil { if err = cons.AddTrack(consMedia, consCodec, track); err != nil {
logger.Default().WithField("error", err).Info("[streams] can't add track") logger.Default().Info("[streams] can't add track; err=%v", err)
continue continue
} }
@@ -68,13 +73,13 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
// Step 4. Get recvonly track from consumer (backchannel) // Step 4. Get recvonly track from consumer (backchannel)
if track, err = cons.(core.Producer).GetTrack(consMedia, consCodec); err != nil { if track, err = cons.(core.Producer).GetTrack(consMedia, consCodec); err != nil {
logger.Default().WithField("error", err).Info("[streams] can't get track") logger.Default().Info("[streams] can't get track; err=%v", err)
continue continue
} }
// Step 5. Add track to producer // Step 5. Add track to producer
if err = prod.AddTrack(prodMedia, prodCodec, track); err != nil { if err = prod.AddTrack(prodMedia, prodCodec, track); err != nil {
logger.Default().WithField("error", err).Info("[streams] can't add track") logger.Default().Info("[streams] can't add track; err=%v", err)
prodErrors[prodN] = err prodErrors[prodN] = fmt.Errorf("unable to AddTrack(): %w", err)
continue continue
} }
} }

View File

@@ -3,9 +3,11 @@ package streams
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"strings" "strings"
"github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/core"
"github.com/xaionaro-go/streamctl/pkg/streamserver/server"
) )
type Handler func(source string) (core.Producer, error) type Handler func(source string) (core.Producer, error)
@@ -30,26 +32,34 @@ func (s *StreamHandler) HasProducer(url string) bool {
return false return false
} }
type ErrNoProducer struct{}
func (err ErrNoProducer) Error() string {
return "no producers"
}
func (s *StreamHandler) GetProducer(url string) (core.Producer, error) { func (s *StreamHandler) GetProducer(url string) (core.Producer, error) {
if i := strings.IndexByte(url, ':'); i > 0 { i := strings.IndexByte(url, ':')
scheme := url[:i] if i <= 0 {
return nil, fmt.Errorf("streams: empty scheme in URL: '%s'", url)
}
scheme := url[:i]
if redirect, ok := s.redirects[scheme]; ok { if redirect, ok := s.redirects[scheme]; ok {
location, err := redirect(url) location, err := redirect(url)
if err != nil { if err != nil {
return nil, err return nil, err
}
if location != "" {
return s.GetProducer(location)
}
} }
if location != "" {
if handler, ok := s.handlers[scheme]; ok { return s.GetProducer(location)
return handler(url)
} }
} }
return nil, errors.New("streams: unsupported scheme: " + url) if handler, ok := s.handlers[scheme]; ok {
return handler(url)
}
return nil, errors.New("streams: unsupported scheme in URL: " + url)
} }
// Redirect can return: location URL or error or empty URL and error // Redirect can return: location URL or error or empty URL and error
@@ -73,13 +83,13 @@ func (s *StreamHandler) Location(url string) (string, error) {
// TODO: rework // TODO: rework
type ConsumerHandler func(url string) (core.Consumer, func(context.Context) error, error) type ConsumerHandler func(url string) (core.Consumer, server.NumBytesReaderWroter, func(context.Context) error, error)
func (s *StreamHandler) HandleConsumerFunc(scheme string, handler ConsumerHandler) { func (s *StreamHandler) HandleConsumerFunc(scheme string, handler ConsumerHandler) {
s.consumerHandlers[scheme] = handler s.consumerHandlers[scheme] = handler
} }
func (s *StreamHandler) GetConsumer(url string) (core.Consumer, func(context.Context) error, error) { func (s *StreamHandler) GetConsumer(url string) (core.Consumer, server.NumBytesReaderWroter, func(context.Context) error, error) {
if i := strings.IndexByte(url, ':'); i > 0 { if i := strings.IndexByte(url, ':'); i > 0 {
scheme := url[:i] scheme := url[:i]
@@ -88,5 +98,5 @@ func (s *StreamHandler) GetConsumer(url string) (core.Consumer, func(context.Con
} }
} }
return nil, nil, errors.New("streams: unsupported scheme: " + url) return nil, nil, nil, errors.New("streams: unsupported scheme: " + url)
} }

View File

@@ -55,7 +55,7 @@ func (s *Stream) Play(source string) error {
for _, producer := range s.producers { for _, producer := range s.producers {
// start new client // start new client
dst, err := s.streamHandler.GetProducer(producer.url) dst, err := s.streamHandler.GetProducer(producer.urlFunc())
if err != nil { if err != nil {
continue continue
} }

View File

@@ -25,7 +25,7 @@ const (
type Producer struct { type Producer struct {
core.Listener core.Listener
url string urlFunc func() string
template string template string
conn core.Producer conn core.Producer
@@ -41,20 +41,19 @@ type Producer struct {
const SourceTemplate = "{input}" const SourceTemplate = "{input}"
func (s *StreamHandler) NewProducer(source string) *Producer { func (s *StreamHandler) NewProducer(source func() string) *Producer {
if strings.Contains(source, SourceTemplate) { if strings.Contains(source(), SourceTemplate) {
return &Producer{streamHandler: s, template: source} return &Producer{streamHandler: s, template: source()}
} }
return &Producer{streamHandler: s, url: source} return &Producer{streamHandler: s, urlFunc: source}
} }
func (p *Producer) SetSource(s string) { func (p *Producer) SetSource(s string) {
if p.template == "" { if p.template != "" {
p.url = s s = strings.Replace(p.template, SourceTemplate, s, 1)
} else {
p.url = strings.Replace(p.template, SourceTemplate, s, 1)
} }
p.urlFunc = func() string { return s }
} }
func (p *Producer) Dial() error { func (p *Producer) Dial() error {
@@ -62,7 +61,7 @@ func (p *Producer) Dial() error {
defer p.mu.Unlock() defer p.mu.Unlock()
if p.state == stateNone { if p.state == stateNone {
conn, err := p.streamHandler.GetProducer(p.url) conn, err := p.streamHandler.GetProducer(p.urlFunc())
if err != nil { if err != nil {
return err return err
} }
@@ -138,7 +137,7 @@ func (p *Producer) MarshalJSON() ([]byte, error) {
if conn := p.conn; conn != nil { if conn := p.conn; conn != nil {
return json.Marshal(conn) return json.Marshal(conn)
} }
info := map[string]string{"url": p.url} info := map[string]string{"url": p.urlFunc()}
return json.Marshal(info) return json.Marshal(info)
} }
@@ -150,7 +149,7 @@ func (p *Producer) start() {
return return
} }
logger.Default().Debugf("[streams] start producer url=%s", p.url) logger.Default().Debugf("[streams] start producer url=%s", p.urlFunc)
p.state = stateStart p.state = stateStart
p.workerID++ p.workerID++
@@ -168,7 +167,7 @@ func (p *Producer) worker(conn core.Producer, workerID int) {
return return
} }
logger.Default().Warn(struct{ URL string }{URL: p.url}, err) logger.Default().Warn(struct{ URL string }{URL: p.urlFunc()}, err)
} }
p.reconnect(workerID, 0) p.reconnect(workerID, 0)
@@ -179,13 +178,13 @@ func (p *Producer) reconnect(workerID, retry int) {
defer p.mu.Unlock() defer p.mu.Unlock()
if p.workerID != workerID { if p.workerID != workerID {
logger.Default().Tracef("[streams] stop reconnect url=%s", p.url) logger.Default().Tracef("[streams] stop reconnect url=%s", p.urlFunc)
return return
} }
logger.Default().Debugf("[streams] retry=%d to url=%s", retry, p.url) logger.Default().Debugf("[streams] retry=%d to url=%s", retry, p.urlFunc)
conn, err := p.streamHandler.GetProducer(p.url) conn, err := p.streamHandler.GetProducer(p.urlFunc())
if err != nil { if err != nil {
logger.Default().Debugf("[streams] producer=%s", err) logger.Default().Debugf("[streams] producer=%s", err)
@@ -258,7 +257,7 @@ func (p *Producer) stop() {
p.workerID++ p.workerID++
} }
logger.Default().Tracef("[streams] stop producer url=%s", p.url) logger.Default().Tracef("[streams] stop producer url=%s", p.urlFunc)
if p.conn != nil { if p.conn != nil {
_ = p.conn.Stop() _ = p.conn.Stop()

View File

@@ -22,6 +22,11 @@ type Stream struct {
func (s *StreamHandler) NewStream(source any) *Stream { func (s *StreamHandler) NewStream(source any) *Stream {
switch source := source.(type) { switch source := source.(type) {
case string: case string:
return &Stream{
producers: []*Producer{s.NewProducer(func() string { return source })},
streamHandler: s,
}
case func() string:
return &Stream{ return &Stream{
producers: []*Producer{s.NewProducer(source)}, producers: []*Producer{s.NewProducer(source)},
streamHandler: s, streamHandler: s,
@@ -35,15 +40,15 @@ func (s *StreamHandler) NewStream(source any) *Stream {
logger.Default().Errorf("[stream] NewStream: Expected string, got %v", src) logger.Default().Errorf("[stream] NewStream: Expected string, got %v", src)
continue continue
} }
stream.producers = append(stream.producers, s.NewProducer(str)) stream.producers = append(stream.producers, s.NewProducer(func() string { return str }))
} }
return stream return stream
case map[string]any: case map[string]any:
return s.NewStream(source["url"]) return s.NewStream(source["url"])
case nil: case nil:
stream := new(Stream) return &Stream{
stream.streamHandler = s streamHandler: s,
return stream }
default: default:
panic(core.Caller()) panic(core.Caller())
} }
@@ -51,7 +56,7 @@ func (s *StreamHandler) NewStream(source any) *Stream {
func (s *Stream) Sources() (sources []string) { func (s *Stream) Sources() (sources []string) {
for _, prod := range s.producers { for _, prod := range s.producers {
sources = append(sources, prod.url) sources = append(sources, prod.urlFunc())
} }
return return
} }

View File

@@ -2,22 +2,26 @@ package streams
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync" "sync"
"time" "time"
"github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/core"
"github.com/facebookincubator/go-belt/tool/experimental/errmon" "github.com/facebookincubator/go-belt/tool/experimental/errmon"
"github.com/facebookincubator/go-belt/tool/logger"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/xaionaro-go/streamctl/pkg/streamserver/server"
) )
type StreamForwarding struct { type StreamForwarding struct {
sync.Mutex sync.Mutex
Stream *Stream Stream *Stream
Consumer core.Consumer Consumer core.Consumer
StreamHandler *StreamHandler StreamHandler *StreamHandler
CancelFunc context.CancelFunc CancelFunc context.CancelFunc
URL string URL string
TrafficCounter server.NumBytesReaderWroter
} }
func NewStreamForwarding(streamHandler *StreamHandler) *StreamForwarding { func NewStreamForwarding(streamHandler *StreamHandler) *StreamForwarding {
@@ -32,21 +36,39 @@ func (sf *StreamForwarding) Start(
sf.Lock() sf.Lock()
defer sf.Unlock() defer sf.Unlock()
cons, run, err := sf.StreamHandler.GetConsumer(url) cons, trafficCounter, run, err := sf.StreamHandler.GetConsumer(url)
if err != nil { if err != nil {
return fmt.Errorf("unable to initialize consumer of '%s': %w", url, err) return fmt.Errorf("unable to initialize consumer of '%s': %w", url, err)
} }
sf.Stream = s sf.Stream = s
sf.URL = url sf.URL = url
sf.Consumer = cons sf.Consumer = cons
sf.TrafficCounter = trafficCounter
if err = s.AddConsumer(cons); err != nil {
return fmt.Errorf("unable to add consumer: %w", err)
}
ctx, cancelFn := context.WithCancel(ctx) ctx, cancelFn := context.WithCancel(ctx)
sf.CancelFunc = cancelFn sf.CancelFunc = cancelFn
go func() { go func() {
for {
select {
case <-ctx.Done():
return
default:
}
err = s.AddConsumer(cons)
if errors.Is(err, ErrNoProducer{}) {
logger.Debugf(ctx, "waiting for a producer")
time.Sleep(time.Second)
continue
}
if err != nil {
logger.Errorf(ctx, "unable to add consumer of '%s': %v", sf.URL, err)
time.Sleep(time.Second * 5)
continue
}
break
}
err := run(ctx) err := run(ctx)
errmon.ObserveErrorCtx(ctx, err) errmon.ObserveErrorCtx(ctx, err)
s.RemoveConsumer(cons) s.RemoveConsumer(cons)

View File

@@ -22,11 +22,7 @@ func (s *StreamHandler) Validate(source string) error {
return nil return nil
} }
func (s *StreamHandler) New(name string, source string) (*Stream, error) { func (s *StreamHandler) New(name string, source any) (*Stream, error) {
if err := s.Validate(source); err != nil {
return nil, err
}
stream := s.NewStream(source) stream := s.NewStream(source)
s.streams[name] = stream s.streams[name] = stream
return stream, nil return stream, nil

View File

@@ -5,8 +5,12 @@ type Server struct {
Listen string `yaml:"listen"` Listen string `yaml:"listen"`
} }
type ForwardingConfig struct {
Disabled bool `yaml:"disabled,omitempty"`
}
type StreamConfig struct { type StreamConfig struct {
Forwardings []DestinationID `yaml:"forwardings"` Forwardings map[DestinationID]ForwardingConfig `yaml:"forwardings"`
} }
type DestinationConfig struct { type DestinationConfig struct {

View File

@@ -83,6 +83,9 @@ type ServerHandler interface {
Type() ServerType Type() ServerType
ListenAddr() string ListenAddr() string
NumBytesConsumerWrote() uint64
NumBytesProducerRead() uint64
} }
type StreamDestination struct { type StreamDestination struct {