mirror of
https://github.com/xaionaro-go/streamctl.git
synced 2025-10-05 15:37:00 +08:00
648 lines
15 KiB
Go
648 lines
15 KiB
Go
package mainprocess
|
|
|
|
import (
|
|
"context"
|
|
"encoding/gob"
|
|
"fmt"
|
|
"net"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/facebookincubator/go-belt"
|
|
"github.com/facebookincubator/go-belt/tool/logger"
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/sethvargo/go-password/password"
|
|
"github.com/xaionaro-go/lockmap"
|
|
"github.com/xaionaro-go/observability"
|
|
"github.com/xaionaro-go/xsync"
|
|
)
|
|
|
|
func init() {
|
|
gob.Register(RegistrationMessage{})
|
|
gob.Register(RegistrationResult{})
|
|
gob.Register(MessageReady{})
|
|
gob.Register(MessageReadyConfirmed{})
|
|
gob.Register(MessageFromMain{})
|
|
gob.Register(MessageToMain{})
|
|
}
|
|
|
|
type OnReceivedMessageFunc func(
|
|
ctx context.Context,
|
|
source ProcessName,
|
|
content any,
|
|
) error
|
|
|
|
type LaunchClientFunc func(
|
|
ctx context.Context,
|
|
procName ProcessName,
|
|
addr string,
|
|
password string,
|
|
isRestart bool,
|
|
) error
|
|
|
|
type Manager struct {
|
|
listener net.Listener
|
|
password string
|
|
|
|
connsLocker xsync.Mutex
|
|
conns map[ProcessName]net.Conn
|
|
connLocker *lockmap.LockMap
|
|
childReadyFor map[ProcessName]map[reflect.Type]struct{}
|
|
connsChanged chan struct{}
|
|
|
|
allClientProcesses []ProcessName
|
|
|
|
LaunchClient LaunchClientFunc
|
|
}
|
|
|
|
func NewManager(
|
|
launchClient LaunchClientFunc,
|
|
expectedClients ...ProcessName,
|
|
) (*Manager, error) {
|
|
listener, err := net.Listen("tcp", "localhost:0")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to listen: %w", err)
|
|
}
|
|
|
|
password, err := password.Generate(16, 4, 4, false, true)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to generate a password: %w", err)
|
|
}
|
|
|
|
return &Manager{
|
|
LaunchClient: launchClient,
|
|
|
|
listener: listener,
|
|
password: password,
|
|
|
|
conns: map[ProcessName]net.Conn{},
|
|
connLocker: lockmap.NewLockMap(),
|
|
childReadyFor: map[ProcessName]map[reflect.Type]struct{}{},
|
|
connsChanged: make(chan struct{}),
|
|
|
|
allClientProcesses: expectedClients,
|
|
}, nil
|
|
}
|
|
|
|
func (m *Manager) Password() string {
|
|
return m.password
|
|
}
|
|
|
|
func (m *Manager) Addr() net.Addr {
|
|
return m.listener.Addr()
|
|
}
|
|
|
|
func (m *Manager) Close() error {
|
|
return m.listener.Close()
|
|
}
|
|
|
|
func (m *Manager) VerifyEverybodyConnected(
|
|
ctx context.Context,
|
|
) error {
|
|
return xsync.DoR1(ctx, &m.connsLocker, func() error {
|
|
for _, name := range m.allClientProcesses {
|
|
if _, ok := m.conns[name]; !ok {
|
|
return fmt.Errorf("client '%s' is not connected", name)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (m *Manager) Serve(
|
|
ctx context.Context,
|
|
onReceivedMessage OnReceivedMessageFunc,
|
|
) error {
|
|
logger.Tracef(ctx, "serving listener at %s", m.listener.Addr())
|
|
defer logger.Tracef(ctx, "/serving listener at %s", m.listener.Addr())
|
|
|
|
ctx, cancelFn := context.WithCancel(ctx)
|
|
defer cancelFn()
|
|
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
<-ctx.Done()
|
|
err := m.Close()
|
|
if err != nil {
|
|
logger.Debug(ctx, err)
|
|
}
|
|
})
|
|
|
|
if m.LaunchClient != nil {
|
|
for _, name := range m.allClientProcesses {
|
|
err := m.LaunchClient(ctx, name, m.listener.Addr().String(), m.password, false)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to launch '%s': %w", name, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
conn, err := m.listener.Accept()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to accept connection: %w", err)
|
|
}
|
|
logger.Tracef(ctx, "accepted a connection from '%s'", conn.RemoteAddr())
|
|
conn.(*net.TCPConn).SetNoDelay(true)
|
|
|
|
m.addNewConnection(ctx, conn, onReceivedMessage)
|
|
}
|
|
}
|
|
|
|
func (m *Manager) addNewConnection(
|
|
ctx context.Context,
|
|
conn net.Conn,
|
|
onReceivedMessage OnReceivedMessageFunc,
|
|
) {
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
m.handleConnection(ctx, conn, onReceivedMessage)
|
|
})
|
|
}
|
|
|
|
func (m *Manager) handleConnection(
|
|
ctx context.Context,
|
|
conn net.Conn,
|
|
onReceivedMessage OnReceivedMessageFunc,
|
|
) {
|
|
var regMessage RegistrationMessage
|
|
logger.Tracef(ctx, "handleConnection from %s", conn.RemoteAddr())
|
|
defer func() { logger.Tracef(ctx, "/handleConnection from %s (%s)", conn.RemoteAddr(), regMessage.Source) }()
|
|
|
|
ctx, cancelFn := context.WithCancel(ctx)
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
<-ctx.Done()
|
|
conn.Close()
|
|
})
|
|
defer cancelFn()
|
|
|
|
encoder := gob.NewEncoder(conn)
|
|
|
|
decoder := gob.NewDecoder(conn)
|
|
err := decoder.Decode(®Message)
|
|
if err != nil {
|
|
err = fmt.Errorf("unable to decode registration message: %w", err)
|
|
encoder.Encode(RegistrationResult{Error: err.Error()})
|
|
logger.Debug(ctx, err)
|
|
return
|
|
}
|
|
|
|
logger.Debugf(ctx, "received registration message: %#+v", regMessage)
|
|
if err := m.checkPassword(regMessage.Password); err != nil {
|
|
regMessage = RegistrationMessage{}
|
|
err = fmt.Errorf("invalid password: %w", err)
|
|
encoder.Encode(RegistrationResult{Error: err.Error()})
|
|
logger.Warn(ctx, err)
|
|
return
|
|
}
|
|
if err := m.registerConnection(ctx, regMessage.Source, conn); err != nil {
|
|
err = fmt.Errorf("unable to register process '%s': %w", regMessage.Source, err)
|
|
encoder.Encode(RegistrationResult{Error: err.Error()})
|
|
logger.Error(ctx, err)
|
|
return
|
|
}
|
|
defer func(sourceName ProcessName) {
|
|
m.unregisterConnection(ctx, sourceName)
|
|
}(regMessage.Source)
|
|
if err := encoder.Encode(RegistrationResult{}); err != nil {
|
|
err = fmt.Errorf(
|
|
"unable to encode&send the registration result to '%s': %w",
|
|
regMessage.Source,
|
|
err,
|
|
)
|
|
logger.Error(ctx, err)
|
|
return
|
|
}
|
|
ctx = belt.WithField(ctx, "client", regMessage.Source)
|
|
|
|
defer func() {
|
|
if m.LaunchClient == nil {
|
|
return
|
|
}
|
|
err := m.LaunchClient(ctx, regMessage.Source, m.listener.Addr().String(), m.password, true)
|
|
if err != nil {
|
|
logger.Error(ctx, err)
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
logger.Tracef(ctx, "context was closed")
|
|
return
|
|
default:
|
|
}
|
|
var message MessageToMain
|
|
logger.Tracef(ctx, "waiting for a message from '%s'", regMessage.Source)
|
|
decoder := gob.NewDecoder(conn)
|
|
err := decoder.Decode(&message)
|
|
logger.Tracef(
|
|
ctx,
|
|
"getting a message from '%s': %#+v %#+v",
|
|
regMessage.Source,
|
|
message,
|
|
err,
|
|
)
|
|
select {
|
|
case <-ctx.Done():
|
|
logger.Tracef(ctx, "context was closed")
|
|
return
|
|
default:
|
|
}
|
|
if err != nil {
|
|
err = fmt.Errorf(
|
|
"unable to parse the message from %s (%s): %w",
|
|
regMessage.Source,
|
|
conn.RemoteAddr().String(),
|
|
err,
|
|
)
|
|
logger.Error(ctx, err)
|
|
return
|
|
}
|
|
|
|
if err := m.processMessage(ctx, regMessage.Source, message, onReceivedMessage); err != nil {
|
|
logger.Errorf(ctx, "unable to process the message %#+v from %s (%s): %v",
|
|
message, regMessage.Source, conn.RemoteAddr().String(), err,
|
|
)
|
|
}
|
|
logger.Tracef(ctx, "next iteration")
|
|
}
|
|
}
|
|
|
|
func (m *Manager) processMessage(
|
|
ctx context.Context,
|
|
source ProcessName,
|
|
message MessageToMain,
|
|
onReceivedMessage OnReceivedMessageFunc,
|
|
) (_ret error) {
|
|
logger.Tracef(ctx, "processing message from '%s': %#+v", source, message)
|
|
defer func() { logger.Tracef(ctx, "/processing message from '%s': %#+v: %v", source, message, _ret) }()
|
|
|
|
switch message.Destination {
|
|
case "":
|
|
logger.Tracef(ctx, "a broadcast message from '%s': %#+v", source, message.Content)
|
|
var wg sync.WaitGroup
|
|
var err *multierror.Error
|
|
err = multierror.Append(err, onReceivedMessage(ctx, source, message.Content))
|
|
|
|
errCh := make(chan error)
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
for e := range errCh {
|
|
err = multierror.Append(err, e)
|
|
}
|
|
})
|
|
for _, dst := range m.allClientProcesses {
|
|
if dst == source {
|
|
continue
|
|
}
|
|
wg.Add(1)
|
|
{
|
|
dst := dst
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
defer wg.Done()
|
|
errCh <- m.sendMessage(ctx, source, dst, message.Content)
|
|
})
|
|
}
|
|
}
|
|
wg.Wait()
|
|
close(errCh)
|
|
return err.ErrorOrNil()
|
|
case ProcessNameMain:
|
|
logger.Tracef(
|
|
ctx,
|
|
"got a message to the main process from '%s': %#+v",
|
|
source,
|
|
message.Content,
|
|
)
|
|
switch content := message.Content.(type) {
|
|
case MessageReady:
|
|
var result *multierror.Error
|
|
result = multierror.Append(result, m.SendMessagePreReady(ctx, source, MessageReadyConfirmed{}))
|
|
for _, sample := range content.ReadyForMessages {
|
|
result = multierror.Append(result, m.setReady(ctx, source, reflect.TypeOf(sample)))
|
|
}
|
|
return result.ErrorOrNil()
|
|
default:
|
|
return onReceivedMessage(ctx, source, message.Content)
|
|
}
|
|
default:
|
|
logger.Tracef(
|
|
ctx,
|
|
"got a message to '%s' from '%s': %#+v",
|
|
message.Destination,
|
|
source,
|
|
message.Content,
|
|
)
|
|
return m.sendMessage(ctx, source, message.Destination, message.Content)
|
|
}
|
|
}
|
|
|
|
type MessageFromMain struct {
|
|
Source ProcessName
|
|
Password string
|
|
Destination ProcessName
|
|
Content any
|
|
}
|
|
|
|
func (m *Manager) isExpectedProcess(
|
|
name ProcessName,
|
|
) bool {
|
|
for _, p := range m.allClientProcesses {
|
|
if name == p {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *Manager) sendMessage(
|
|
ctx context.Context,
|
|
source ProcessName,
|
|
destination ProcessName,
|
|
content any,
|
|
) (_ret error) {
|
|
logger.Tracef(ctx, "sending message %#+v from '%s' to '%s'", content, source, destination)
|
|
defer func() {
|
|
logger.Tracef(
|
|
ctx,
|
|
"/sending message %#+v from '%s' to '%s': %v",
|
|
content,
|
|
source,
|
|
destination,
|
|
_ret,
|
|
)
|
|
}()
|
|
|
|
if !m.isExpectedProcess(destination) {
|
|
return fmt.Errorf("process '%s' is not ever expected", destination)
|
|
}
|
|
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
conn, err := m.waitForReadyProcess(ctx, destination, reflect.TypeOf(content))
|
|
if err != nil {
|
|
logger.Errorf(
|
|
ctx,
|
|
"%v",
|
|
fmt.Errorf("unable to wait for process '%s': %w", destination, err),
|
|
)
|
|
return
|
|
}
|
|
|
|
message := MessageFromMain{
|
|
Source: source,
|
|
Password: m.password,
|
|
Destination: destination,
|
|
Content: content,
|
|
}
|
|
|
|
h := m.connLocker.Lock(context.Background(), destination) // TODO: should we use the provided ctx?
|
|
defer h.Unlock()
|
|
defer time.Sleep(
|
|
100 * time.Millisecond,
|
|
) // TODO: Delete this horrible hack (that is introduced to avoid erasing messages in the buffer)
|
|
err = gob.NewEncoder(conn).Encode(message)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "%v", fmt.Errorf("unable to encode&send message: %v", err))
|
|
return
|
|
}
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) waitForProcess(
|
|
ctx context.Context,
|
|
name ProcessName,
|
|
) (net.Conn, error) {
|
|
if !m.isExpectedProcess(name) {
|
|
return nil, fmt.Errorf("process '%s' is not ever expected", name)
|
|
}
|
|
|
|
for {
|
|
conn, ch := xsync.DoR2(ctx, &m.connsLocker, func() (net.Conn, chan struct{}) {
|
|
return m.conns[name], m.connsChanged
|
|
})
|
|
|
|
if conn != nil {
|
|
return conn, nil
|
|
}
|
|
|
|
<-ch
|
|
}
|
|
}
|
|
|
|
func (m *Manager) waitForReadyProcess(
|
|
ctx context.Context,
|
|
name ProcessName,
|
|
msgType reflect.Type,
|
|
) (net.Conn, error) {
|
|
logger.Debugf(ctx, "waitForReadyProcess(ctx, '%s', '%s')", name, msgType)
|
|
defer logger.Debugf(ctx, "/waitForReadyProcess(ctx, '%s', '%s')", name, msgType)
|
|
if !m.isExpectedProcess(name) {
|
|
return nil, fmt.Errorf("process '%s' is not ever expected", name)
|
|
}
|
|
if msgType.Kind() == reflect.Pointer {
|
|
msgType = msgType.Elem()
|
|
}
|
|
|
|
for {
|
|
conn, ch, isReady := xsync.DoR3(
|
|
ctx,
|
|
&m.connsLocker,
|
|
func() (net.Conn, chan struct{}, bool) {
|
|
readyMap := m.childReadyFor[name]
|
|
isReady := false
|
|
if readyMap != nil {
|
|
if _, ok := readyMap[msgType]; ok {
|
|
isReady = true
|
|
}
|
|
}
|
|
return m.conns[name], m.connsChanged, isReady
|
|
},
|
|
)
|
|
|
|
if conn != nil && isReady {
|
|
logger.Debugf(
|
|
ctx,
|
|
"waitForReadyProcess(ctx, '%s', '%s'): waiting is complete",
|
|
name,
|
|
msgType,
|
|
)
|
|
return conn, nil
|
|
}
|
|
|
|
logger.Debugf(
|
|
ctx,
|
|
"waitForReadyProcess(ctx, '%s', '%s'): waiting for a change in connections",
|
|
name,
|
|
msgType,
|
|
)
|
|
<-ch
|
|
}
|
|
}
|
|
|
|
func (m *Manager) checkPassword(
|
|
password string,
|
|
) error {
|
|
return checkPassword(m.password, password)
|
|
}
|
|
|
|
func (m *Manager) registerConnection(
|
|
ctx context.Context,
|
|
sourceName ProcessName,
|
|
conn net.Conn,
|
|
) error {
|
|
logger.Debugf(ctx, "registerConnection(ctx, '%s', %s)", sourceName, conn.RemoteAddr().String())
|
|
defer logger.Debugf(
|
|
ctx,
|
|
"/registerConnection(ctx, '%s', %s)",
|
|
sourceName,
|
|
conn.RemoteAddr().String(),
|
|
)
|
|
if !m.isExpectedProcess(sourceName) {
|
|
return fmt.Errorf("process '%s' is not ever expected", sourceName)
|
|
}
|
|
|
|
m.connsLocker.Do(ctx, func() {
|
|
if conn, ok := m.conns[sourceName]; ok {
|
|
conn.Close()
|
|
logger.Warnf(ctx, "closing the old registered connection")
|
|
}
|
|
m.conns[sourceName] = conn
|
|
var oldCh chan struct{}
|
|
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
|
|
close(oldCh)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) setReady(
|
|
ctx context.Context,
|
|
sourceName ProcessName,
|
|
msgType reflect.Type,
|
|
) error {
|
|
if !m.isExpectedProcess(sourceName) {
|
|
return fmt.Errorf("process '%s' is not ever expected", sourceName)
|
|
}
|
|
if msgType.Kind() == reflect.Pointer {
|
|
msgType = msgType.Elem()
|
|
}
|
|
|
|
return xsync.DoR1(ctx, &m.connsLocker, func() error {
|
|
readyMap := m.childReadyFor[sourceName]
|
|
if readyMap == nil {
|
|
m.childReadyFor[sourceName] = map[reflect.Type]struct{}{}
|
|
readyMap = m.childReadyFor[sourceName]
|
|
}
|
|
if _, ok := readyMap[msgType]; ok {
|
|
return fmt.Errorf("process '%s' is already ready for '%s'", sourceName, msgType)
|
|
}
|
|
readyMap[msgType] = struct{}{}
|
|
var oldCh chan struct{}
|
|
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
|
|
close(oldCh)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (m *Manager) unregisterConnection(
|
|
ctx context.Context,
|
|
sourceName ProcessName,
|
|
) {
|
|
logger.Infof(ctx, "unregistering process '%s'", string(sourceName))
|
|
defer logger.Infof(ctx, "/unregistering process '%s'", string(sourceName))
|
|
m.connsLocker.Do(ctx, func() {
|
|
delete(m.conns, sourceName)
|
|
delete(m.childReadyFor, sourceName)
|
|
var oldCh chan struct{}
|
|
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
|
|
close(oldCh)
|
|
})
|
|
}
|
|
|
|
type RegistrationMessage struct {
|
|
Password string
|
|
Source ProcessName
|
|
}
|
|
|
|
type RegistrationResult struct {
|
|
Error string
|
|
}
|
|
|
|
type MessageToMain struct {
|
|
Password string
|
|
Destination ProcessName
|
|
Content any
|
|
}
|
|
|
|
func (m *Manager) SendMessagePreReady(
|
|
ctx context.Context,
|
|
dst ProcessName,
|
|
content any,
|
|
) error {
|
|
logger.Debugf(ctx, "SendMessagePreReady(ctx, '%s', %T)", dst, content)
|
|
defer logger.Debugf(ctx, "/SendMessagePreReady(ctx, '%s', %T)", dst, content)
|
|
|
|
conn, err := m.waitForProcess(ctx, dst)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to wait for process '%s': %w", dst, err)
|
|
}
|
|
encoder := gob.NewEncoder(conn)
|
|
msg := MessageFromMain{
|
|
Source: ProcessNameMain,
|
|
Password: m.password,
|
|
Destination: dst,
|
|
Content: content,
|
|
}
|
|
h := m.connLocker.Lock(context.Background(), dst)
|
|
defer h.Unlock()
|
|
defer time.Sleep(
|
|
100 * time.Millisecond,
|
|
) // TODO: Delete this horrible hack (that is introduced to avoid erasing messages in the buffer)
|
|
err = encoder.Encode(msg)
|
|
logger.Tracef(ctx, "sending message %#+v: %v", msg, err)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to encode&send message %#+v: %w", msg, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) SendMessage(
|
|
ctx context.Context,
|
|
dst ProcessName,
|
|
content any,
|
|
) error {
|
|
logger.Debugf(ctx, "SendMessage(ctx, '%s', %T)", dst, content)
|
|
defer logger.Debugf(ctx, "/SendMessage(ctx, '%s', %T)", dst, content)
|
|
|
|
conn, err := m.waitForReadyProcess(ctx, dst, reflect.TypeOf(content))
|
|
if err != nil {
|
|
return fmt.Errorf("unable to wait for process '%s': %w", dst, err)
|
|
}
|
|
encoder := gob.NewEncoder(conn)
|
|
msg := MessageFromMain{
|
|
Source: ProcessNameMain,
|
|
Password: m.password,
|
|
Destination: dst,
|
|
Content: content,
|
|
}
|
|
h := m.connLocker.Lock(context.Background(), dst)
|
|
defer h.Unlock()
|
|
err = encoder.Encode(msg)
|
|
logger.Tracef(ctx, "sending message %#+v: %v", msg, err)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to encode&send message %#+v: %w", msg, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type MessageReady struct {
|
|
ReadyForMessages []any
|
|
}
|
|
type MessageReadyConfirmed struct{}
|