mirror of
https://github.com/xaionaro-go/streamctl.git
synced 2025-10-10 09:50:15 +08:00
439 lines
10 KiB
Go
439 lines
10 KiB
Go
package mainprocess
|
|
|
|
import (
|
|
"context"
|
|
"encoding/gob"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
|
|
"github.com/facebookincubator/go-belt"
|
|
"github.com/facebookincubator/go-belt/tool/logger"
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/sethvargo/go-password/password"
|
|
)
|
|
|
|
type OnReceivedMessageFunc func(
|
|
ctx context.Context,
|
|
source ProcessName,
|
|
content any,
|
|
) error
|
|
|
|
type LaunchClientFunc func(
|
|
ctx context.Context,
|
|
procName ProcessName,
|
|
addr string,
|
|
password string,
|
|
isRestart bool,
|
|
) error
|
|
|
|
type Manager struct {
|
|
listener net.Listener
|
|
password string
|
|
|
|
connsLocker sync.Mutex
|
|
conns map[ProcessName]net.Conn
|
|
connsChanged chan struct{}
|
|
|
|
allClientProcesses []ProcessName
|
|
|
|
LaunchClient LaunchClientFunc
|
|
}
|
|
|
|
func NewManager(
|
|
launchClient LaunchClientFunc,
|
|
expectedClients ...ProcessName,
|
|
) (*Manager, error) {
|
|
listener, err := net.Listen("tcp", "localhost:0")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to listen: %w", err)
|
|
}
|
|
|
|
password, err := password.Generate(16, 4, 4, false, true)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to generate a password: %w", err)
|
|
}
|
|
|
|
return &Manager{
|
|
LaunchClient: launchClient,
|
|
|
|
listener: listener,
|
|
password: password,
|
|
|
|
conns: map[ProcessName]net.Conn{},
|
|
connsChanged: make(chan struct{}),
|
|
|
|
allClientProcesses: expectedClients,
|
|
}, nil
|
|
}
|
|
|
|
func (m *Manager) Password() string {
|
|
return m.password
|
|
}
|
|
|
|
func (m *Manager) Addr() net.Addr {
|
|
return m.listener.Addr()
|
|
}
|
|
|
|
func (m *Manager) Close() error {
|
|
return m.listener.Close()
|
|
}
|
|
|
|
func (m *Manager) VerifyEverybodyConnected(
|
|
ctx context.Context,
|
|
) error {
|
|
m.connsLocker.Lock()
|
|
defer m.connsLocker.Unlock()
|
|
|
|
for _, name := range m.allClientProcesses {
|
|
if _, ok := m.conns[name]; !ok {
|
|
return fmt.Errorf("client '%s' is not connected", name)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) Serve(
|
|
ctx context.Context,
|
|
onReceivedMessage OnReceivedMessageFunc,
|
|
) error {
|
|
logger.Tracef(ctx, "serving listener at %s", m.listener.Addr())
|
|
defer logger.Tracef(ctx, "/serving listener at %s", m.listener.Addr())
|
|
|
|
ctx, cancelFn := context.WithCancel(ctx)
|
|
defer cancelFn()
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
err := m.Close()
|
|
if err != nil {
|
|
logger.Error(ctx, err)
|
|
}
|
|
}()
|
|
|
|
if m.LaunchClient != nil {
|
|
for _, name := range m.allClientProcesses {
|
|
err := m.LaunchClient(ctx, name, m.listener.Addr().String(), m.password, false)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to launch '%s': %w", name, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
conn, err := m.listener.Accept()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to accept connection: %w", err)
|
|
}
|
|
logger.Tracef(ctx, "accepted a connection from '%s'", conn.RemoteAddr())
|
|
|
|
m.addNewConnection(ctx, conn, onReceivedMessage)
|
|
}
|
|
}
|
|
|
|
func (m *Manager) addNewConnection(
|
|
ctx context.Context,
|
|
conn net.Conn,
|
|
onReceivedMessage OnReceivedMessageFunc,
|
|
) {
|
|
go func() {
|
|
m.handleConnection(ctx, conn, onReceivedMessage)
|
|
}()
|
|
}
|
|
|
|
func (m *Manager) handleConnection(
|
|
ctx context.Context,
|
|
conn net.Conn,
|
|
onReceivedMessage OnReceivedMessageFunc,
|
|
) {
|
|
var regMessage RegistrationMessage
|
|
logger.Tracef(ctx, "handleConnection from %s", conn.RemoteAddr())
|
|
defer func() { logger.Tracef(ctx, "/handleConnection from %s (%s)", conn.RemoteAddr(), regMessage.Source) }()
|
|
|
|
ctx, cancelFn := context.WithCancel(ctx)
|
|
go func() {
|
|
<-ctx.Done()
|
|
conn.Close()
|
|
}()
|
|
defer cancelFn()
|
|
|
|
encoder := gob.NewEncoder(conn)
|
|
|
|
decoder := gob.NewDecoder(conn)
|
|
err := decoder.Decode(®Message)
|
|
if err != nil {
|
|
err = fmt.Errorf("unable to decode registration message: %w", err)
|
|
encoder.Encode(RegistrationResult{Error: err.Error()})
|
|
logger.Debug(ctx, err)
|
|
return
|
|
}
|
|
|
|
logger.Debugf(ctx, "received registration message: %#+v", regMessage)
|
|
if err := m.checkPassword(regMessage.Password); err != nil {
|
|
regMessage = RegistrationMessage{}
|
|
err = fmt.Errorf("invalid password: %w", err)
|
|
encoder.Encode(RegistrationResult{Error: err.Error()})
|
|
logger.Warn(ctx, err)
|
|
return
|
|
}
|
|
if err := m.registerConnection(regMessage.Source, conn); err != nil {
|
|
err = fmt.Errorf("unable to register process '%s': %w", regMessage.Source, err)
|
|
encoder.Encode(RegistrationResult{Error: err.Error()})
|
|
logger.Error(ctx, err)
|
|
return
|
|
}
|
|
defer func(sourceName ProcessName) {
|
|
m.unregisterConnection(sourceName)
|
|
}(regMessage.Source)
|
|
if err := encoder.Encode(RegistrationResult{}); err != nil {
|
|
err = fmt.Errorf("unable to encode&send the registration result to '%s': %w", regMessage.Source, err)
|
|
logger.Error(ctx, err)
|
|
return
|
|
}
|
|
ctx = belt.WithField(ctx, "client", regMessage.Source)
|
|
|
|
defer func() {
|
|
if m.LaunchClient == nil {
|
|
return
|
|
}
|
|
err := m.LaunchClient(ctx, regMessage.Source, m.listener.Addr().String(), m.password, true)
|
|
if err != nil {
|
|
logger.Error(ctx, err)
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
logger.Tracef(ctx, "context was closed")
|
|
return
|
|
default:
|
|
}
|
|
var message MessageToMain
|
|
logger.Tracef(ctx, "waiting for a message from '%s'", regMessage.Source)
|
|
decoder := gob.NewDecoder(conn)
|
|
err := decoder.Decode(&message)
|
|
logger.Tracef(ctx, "getting a message from '%s': %#+v %#+v", regMessage.Source, message, err)
|
|
if err != nil {
|
|
err = fmt.Errorf(
|
|
"unable to parse the message from %s (%s): %w",
|
|
regMessage.Source,
|
|
conn.RemoteAddr().String(),
|
|
err,
|
|
)
|
|
logger.Error(ctx, err)
|
|
return
|
|
}
|
|
|
|
if err := m.processMessage(ctx, regMessage.Source, message, onReceivedMessage); err != nil {
|
|
logger.Errorf(
|
|
ctx,
|
|
"unable to process the message %#+v from %s (%s): %w",
|
|
message, regMessage.Source, conn.RemoteAddr().String(), err,
|
|
)
|
|
}
|
|
logger.Tracef(ctx, "next iteration")
|
|
}
|
|
}
|
|
|
|
func (m *Manager) processMessage(
|
|
ctx context.Context,
|
|
source ProcessName,
|
|
message MessageToMain,
|
|
onReceivedMessage OnReceivedMessageFunc,
|
|
) (_ret error) {
|
|
logger.Tracef(ctx, "processing message from '%s': %#+v", source, message)
|
|
defer func() { logger.Tracef(ctx, "/processing message from '%s': %#+v: %v", source, message, _ret) }()
|
|
|
|
switch message.Destination {
|
|
case "":
|
|
logger.Tracef(ctx, "a broadcast message from '%s': %#+v", source, message.Content)
|
|
var wg sync.WaitGroup
|
|
var err *multierror.Error
|
|
err = multierror.Append(err, onReceivedMessage(ctx, source, message.Content))
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
for e := range errCh {
|
|
err = multierror.Append(err, e)
|
|
}
|
|
}()
|
|
for _, dst := range m.allClientProcesses {
|
|
if dst == source {
|
|
continue
|
|
}
|
|
wg.Add(1)
|
|
go func(dst ProcessName) {
|
|
defer wg.Done()
|
|
errCh <- m.sendMessage(ctx, source, dst, message.Content)
|
|
}(dst)
|
|
}
|
|
wg.Wait()
|
|
close(errCh)
|
|
return err.ErrorOrNil()
|
|
case "main":
|
|
logger.Tracef(ctx, "a message to the main process from '%s': %#+v", source, message.Content)
|
|
return onReceivedMessage(ctx, source, message.Content)
|
|
default:
|
|
logger.Tracef(ctx, "a message to '%s' from '%s': %#+v", message.Destination, source, message.Content)
|
|
return m.sendMessage(ctx, source, message.Destination, message.Content)
|
|
}
|
|
}
|
|
|
|
type MessageFromMain struct {
|
|
Source ProcessName
|
|
Password string
|
|
Destination ProcessName
|
|
Content any
|
|
}
|
|
|
|
func (m *Manager) isExpectedProcess(
|
|
name ProcessName,
|
|
) bool {
|
|
for _, p := range m.allClientProcesses {
|
|
if name == p {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *Manager) sendMessage(
|
|
ctx context.Context,
|
|
source ProcessName,
|
|
destination ProcessName,
|
|
content any,
|
|
) (_ret error) {
|
|
logger.Tracef(ctx, "sending message message %#+v from '%s' to '%s'", content, source, destination)
|
|
defer func() {
|
|
logger.Tracef(ctx, "/sending message message %#+v from '%s' to '%s': %v", content, source, destination, _ret)
|
|
}()
|
|
|
|
if !m.isExpectedProcess(destination) {
|
|
return fmt.Errorf("process '%s' is not ever expected", destination)
|
|
}
|
|
|
|
message := MessageFromMain{
|
|
Source: source,
|
|
Password: m.password,
|
|
Destination: destination,
|
|
Content: content,
|
|
}
|
|
|
|
conn, err := m.waitForProcess(destination)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to wait for process '%s': %w", destination, err)
|
|
}
|
|
|
|
encoder := gob.NewEncoder(conn)
|
|
err = encoder.Encode(message)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to encode&send message: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) waitForProcess(
|
|
name ProcessName,
|
|
) (net.Conn, error) {
|
|
if !m.isExpectedProcess(name) {
|
|
return nil, fmt.Errorf("process '%s' is not ever expected", name)
|
|
}
|
|
|
|
for {
|
|
m.connsLocker.Lock()
|
|
conn := m.conns[name]
|
|
ch := m.connsChanged
|
|
m.connsLocker.Unlock()
|
|
|
|
if conn != nil {
|
|
return conn, nil
|
|
}
|
|
|
|
<-ch
|
|
}
|
|
}
|
|
|
|
func (m *Manager) checkPassword(
|
|
password string,
|
|
) error {
|
|
return checkPassword(m.password, password)
|
|
}
|
|
|
|
func (m *Manager) registerConnection(
|
|
sourceName ProcessName,
|
|
conn net.Conn,
|
|
) error {
|
|
if !m.isExpectedProcess(sourceName) {
|
|
return fmt.Errorf("process '%s' is not ever expected", sourceName)
|
|
}
|
|
|
|
m.connsLocker.Lock()
|
|
defer m.connsLocker.Unlock()
|
|
if conn, ok := m.conns[sourceName]; ok {
|
|
return fmt.Errorf("process '%s' is already registered at %s", sourceName, conn.RemoteAddr().String())
|
|
}
|
|
m.conns[sourceName] = conn
|
|
var oldCh chan struct{}
|
|
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
|
|
close(oldCh)
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) unregisterConnection(
|
|
sourceName ProcessName,
|
|
) {
|
|
m.connsLocker.Lock()
|
|
defer m.connsLocker.Unlock()
|
|
delete(m.conns, sourceName)
|
|
var oldCh chan struct{}
|
|
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
|
|
close(oldCh)
|
|
}
|
|
|
|
type RegistrationMessage struct {
|
|
Password string
|
|
Source ProcessName
|
|
}
|
|
|
|
type RegistrationResult struct {
|
|
Error string
|
|
}
|
|
|
|
type MessageToMain struct {
|
|
Password string
|
|
Destination ProcessName
|
|
Content any
|
|
}
|
|
|
|
func (m *Manager) SendMessage(
|
|
ctx context.Context,
|
|
dst ProcessName,
|
|
content any,
|
|
) error {
|
|
conn, err := m.waitForProcess(dst)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to wait for process '%s': %w", dst, err)
|
|
}
|
|
encoder := gob.NewEncoder(conn)
|
|
msg := MessageFromMain{
|
|
Source: "main",
|
|
Password: m.password,
|
|
Destination: dst,
|
|
Content: content,
|
|
}
|
|
err = encoder.Encode(msg)
|
|
logger.Tracef(ctx, "sending message %#+v: %v", msg, err)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to encode&send message %#+v: %w", msg, err)
|
|
}
|
|
return nil
|
|
}
|