Files
streamctl/pkg/mainprocess/main_process.go
2025-06-22 14:59:00 +01:00

648 lines
15 KiB
Go

package mainprocess
import (
"context"
"encoding/gob"
"fmt"
"net"
"reflect"
"sync"
"time"
"github.com/facebookincubator/go-belt"
"github.com/facebookincubator/go-belt/tool/logger"
"github.com/hashicorp/go-multierror"
"github.com/sethvargo/go-password/password"
"github.com/xaionaro-go/lockmap"
"github.com/xaionaro-go/observability"
"github.com/xaionaro-go/xsync"
)
func init() {
gob.Register(RegistrationMessage{})
gob.Register(RegistrationResult{})
gob.Register(MessageReady{})
gob.Register(MessageReadyConfirmed{})
gob.Register(MessageFromMain{})
gob.Register(MessageToMain{})
}
type OnReceivedMessageFunc func(
ctx context.Context,
source ProcessName,
content any,
) error
type LaunchClientFunc func(
ctx context.Context,
procName ProcessName,
addr string,
password string,
isRestart bool,
) error
type Manager struct {
listener net.Listener
password string
connsLocker xsync.Mutex
conns map[ProcessName]net.Conn
connLocker *lockmap.LockMap
childReadyFor map[ProcessName]map[reflect.Type]struct{}
connsChanged chan struct{}
allClientProcesses []ProcessName
LaunchClient LaunchClientFunc
}
func NewManager(
launchClient LaunchClientFunc,
expectedClients ...ProcessName,
) (*Manager, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("failed to listen: %w", err)
}
password, err := password.Generate(16, 4, 4, false, true)
if err != nil {
return nil, fmt.Errorf("unable to generate a password: %w", err)
}
return &Manager{
LaunchClient: launchClient,
listener: listener,
password: password,
conns: map[ProcessName]net.Conn{},
connLocker: lockmap.NewLockMap(),
childReadyFor: map[ProcessName]map[reflect.Type]struct{}{},
connsChanged: make(chan struct{}),
allClientProcesses: expectedClients,
}, nil
}
func (m *Manager) Password() string {
return m.password
}
func (m *Manager) Addr() net.Addr {
return m.listener.Addr()
}
func (m *Manager) Close() error {
return m.listener.Close()
}
func (m *Manager) VerifyEverybodyConnected(
ctx context.Context,
) error {
return xsync.DoR1(ctx, &m.connsLocker, func() error {
for _, name := range m.allClientProcesses {
if _, ok := m.conns[name]; !ok {
return fmt.Errorf("client '%s' is not connected", name)
}
}
return nil
})
}
func (m *Manager) Serve(
ctx context.Context,
onReceivedMessage OnReceivedMessageFunc,
) error {
logger.Tracef(ctx, "serving listener at %s", m.listener.Addr())
defer logger.Tracef(ctx, "/serving listener at %s", m.listener.Addr())
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
observability.Go(ctx, func(ctx context.Context) {
<-ctx.Done()
err := m.Close()
if err != nil {
logger.Debug(ctx, err)
}
})
if m.LaunchClient != nil {
for _, name := range m.allClientProcesses {
err := m.LaunchClient(ctx, name, m.listener.Addr().String(), m.password, false)
if err != nil {
return fmt.Errorf("unable to launch '%s': %w", name, err)
}
}
}
for {
select {
case <-ctx.Done():
return nil
default:
}
conn, err := m.listener.Accept()
if err != nil {
return fmt.Errorf("unable to accept connection: %w", err)
}
logger.Tracef(ctx, "accepted a connection from '%s'", conn.RemoteAddr())
conn.(*net.TCPConn).SetNoDelay(true)
m.addNewConnection(ctx, conn, onReceivedMessage)
}
}
func (m *Manager) addNewConnection(
ctx context.Context,
conn net.Conn,
onReceivedMessage OnReceivedMessageFunc,
) {
observability.Go(ctx, func(ctx context.Context) {
m.handleConnection(ctx, conn, onReceivedMessage)
})
}
func (m *Manager) handleConnection(
ctx context.Context,
conn net.Conn,
onReceivedMessage OnReceivedMessageFunc,
) {
var regMessage RegistrationMessage
logger.Tracef(ctx, "handleConnection from %s", conn.RemoteAddr())
defer func() { logger.Tracef(ctx, "/handleConnection from %s (%s)", conn.RemoteAddr(), regMessage.Source) }()
ctx, cancelFn := context.WithCancel(ctx)
observability.Go(ctx, func(ctx context.Context) {
<-ctx.Done()
conn.Close()
})
defer cancelFn()
encoder := gob.NewEncoder(conn)
decoder := gob.NewDecoder(conn)
err := decoder.Decode(&regMessage)
if err != nil {
err = fmt.Errorf("unable to decode registration message: %w", err)
encoder.Encode(RegistrationResult{Error: err.Error()})
logger.Debug(ctx, err)
return
}
logger.Debugf(ctx, "received registration message: %#+v", regMessage)
if err := m.checkPassword(regMessage.Password); err != nil {
regMessage = RegistrationMessage{}
err = fmt.Errorf("invalid password: %w", err)
encoder.Encode(RegistrationResult{Error: err.Error()})
logger.Warn(ctx, err)
return
}
if err := m.registerConnection(ctx, regMessage.Source, conn); err != nil {
err = fmt.Errorf("unable to register process '%s': %w", regMessage.Source, err)
encoder.Encode(RegistrationResult{Error: err.Error()})
logger.Error(ctx, err)
return
}
defer func(sourceName ProcessName) {
m.unregisterConnection(ctx, sourceName)
}(regMessage.Source)
if err := encoder.Encode(RegistrationResult{}); err != nil {
err = fmt.Errorf(
"unable to encode&send the registration result to '%s': %w",
regMessage.Source,
err,
)
logger.Error(ctx, err)
return
}
ctx = belt.WithField(ctx, "client", regMessage.Source)
defer func() {
if m.LaunchClient == nil {
return
}
err := m.LaunchClient(ctx, regMessage.Source, m.listener.Addr().String(), m.password, true)
if err != nil {
logger.Error(ctx, err)
}
}()
for {
select {
case <-ctx.Done():
logger.Tracef(ctx, "context was closed")
return
default:
}
var message MessageToMain
logger.Tracef(ctx, "waiting for a message from '%s'", regMessage.Source)
decoder := gob.NewDecoder(conn)
err := decoder.Decode(&message)
logger.Tracef(
ctx,
"getting a message from '%s': %#+v %#+v",
regMessage.Source,
message,
err,
)
select {
case <-ctx.Done():
logger.Tracef(ctx, "context was closed")
return
default:
}
if err != nil {
err = fmt.Errorf(
"unable to parse the message from %s (%s): %w",
regMessage.Source,
conn.RemoteAddr().String(),
err,
)
logger.Error(ctx, err)
return
}
if err := m.processMessage(ctx, regMessage.Source, message, onReceivedMessage); err != nil {
logger.Errorf(ctx, "unable to process the message %#+v from %s (%s): %v",
message, regMessage.Source, conn.RemoteAddr().String(), err,
)
}
logger.Tracef(ctx, "next iteration")
}
}
func (m *Manager) processMessage(
ctx context.Context,
source ProcessName,
message MessageToMain,
onReceivedMessage OnReceivedMessageFunc,
) (_ret error) {
logger.Tracef(ctx, "processing message from '%s': %#+v", source, message)
defer func() { logger.Tracef(ctx, "/processing message from '%s': %#+v: %v", source, message, _ret) }()
switch message.Destination {
case "":
logger.Tracef(ctx, "a broadcast message from '%s': %#+v", source, message.Content)
var wg sync.WaitGroup
var err *multierror.Error
err = multierror.Append(err, onReceivedMessage(ctx, source, message.Content))
errCh := make(chan error)
observability.Go(ctx, func(ctx context.Context) {
for e := range errCh {
err = multierror.Append(err, e)
}
})
for _, dst := range m.allClientProcesses {
if dst == source {
continue
}
wg.Add(1)
{
dst := dst
observability.Go(ctx, func(ctx context.Context) {
defer wg.Done()
errCh <- m.sendMessage(ctx, source, dst, message.Content)
})
}
}
wg.Wait()
close(errCh)
return err.ErrorOrNil()
case ProcessNameMain:
logger.Tracef(
ctx,
"got a message to the main process from '%s': %#+v",
source,
message.Content,
)
switch content := message.Content.(type) {
case MessageReady:
var result *multierror.Error
result = multierror.Append(result, m.SendMessagePreReady(ctx, source, MessageReadyConfirmed{}))
for _, sample := range content.ReadyForMessages {
result = multierror.Append(result, m.setReady(ctx, source, reflect.TypeOf(sample)))
}
return result.ErrorOrNil()
default:
return onReceivedMessage(ctx, source, message.Content)
}
default:
logger.Tracef(
ctx,
"got a message to '%s' from '%s': %#+v",
message.Destination,
source,
message.Content,
)
return m.sendMessage(ctx, source, message.Destination, message.Content)
}
}
type MessageFromMain struct {
Source ProcessName
Password string
Destination ProcessName
Content any
}
func (m *Manager) isExpectedProcess(
name ProcessName,
) bool {
for _, p := range m.allClientProcesses {
if name == p {
return true
}
}
return false
}
func (m *Manager) sendMessage(
ctx context.Context,
source ProcessName,
destination ProcessName,
content any,
) (_ret error) {
logger.Tracef(ctx, "sending message %#+v from '%s' to '%s'", content, source, destination)
defer func() {
logger.Tracef(
ctx,
"/sending message %#+v from '%s' to '%s': %v",
content,
source,
destination,
_ret,
)
}()
if !m.isExpectedProcess(destination) {
return fmt.Errorf("process '%s' is not ever expected", destination)
}
observability.Go(ctx, func(ctx context.Context) {
conn, err := m.waitForReadyProcess(ctx, destination, reflect.TypeOf(content))
if err != nil {
logger.Errorf(
ctx,
"%v",
fmt.Errorf("unable to wait for process '%s': %w", destination, err),
)
return
}
message := MessageFromMain{
Source: source,
Password: m.password,
Destination: destination,
Content: content,
}
h := m.connLocker.Lock(context.Background(), destination) // TODO: should we use the provided ctx?
defer h.Unlock()
defer time.Sleep(
100 * time.Millisecond,
) // TODO: Delete this horrible hack (that is introduced to avoid erasing messages in the buffer)
err = gob.NewEncoder(conn).Encode(message)
if err != nil {
logger.Errorf(ctx, "%v", fmt.Errorf("unable to encode&send message: %v", err))
return
}
})
return nil
}
func (m *Manager) waitForProcess(
ctx context.Context,
name ProcessName,
) (net.Conn, error) {
if !m.isExpectedProcess(name) {
return nil, fmt.Errorf("process '%s' is not ever expected", name)
}
for {
conn, ch := xsync.DoR2(ctx, &m.connsLocker, func() (net.Conn, chan struct{}) {
return m.conns[name], m.connsChanged
})
if conn != nil {
return conn, nil
}
<-ch
}
}
func (m *Manager) waitForReadyProcess(
ctx context.Context,
name ProcessName,
msgType reflect.Type,
) (net.Conn, error) {
logger.Debugf(ctx, "waitForReadyProcess(ctx, '%s', '%s')", name, msgType)
defer logger.Debugf(ctx, "/waitForReadyProcess(ctx, '%s', '%s')", name, msgType)
if !m.isExpectedProcess(name) {
return nil, fmt.Errorf("process '%s' is not ever expected", name)
}
if msgType.Kind() == reflect.Pointer {
msgType = msgType.Elem()
}
for {
conn, ch, isReady := xsync.DoR3(
ctx,
&m.connsLocker,
func() (net.Conn, chan struct{}, bool) {
readyMap := m.childReadyFor[name]
isReady := false
if readyMap != nil {
if _, ok := readyMap[msgType]; ok {
isReady = true
}
}
return m.conns[name], m.connsChanged, isReady
},
)
if conn != nil && isReady {
logger.Debugf(
ctx,
"waitForReadyProcess(ctx, '%s', '%s'): waiting is complete",
name,
msgType,
)
return conn, nil
}
logger.Debugf(
ctx,
"waitForReadyProcess(ctx, '%s', '%s'): waiting for a change in connections",
name,
msgType,
)
<-ch
}
}
func (m *Manager) checkPassword(
password string,
) error {
return checkPassword(m.password, password)
}
func (m *Manager) registerConnection(
ctx context.Context,
sourceName ProcessName,
conn net.Conn,
) error {
logger.Debugf(ctx, "registerConnection(ctx, '%s', %s)", sourceName, conn.RemoteAddr().String())
defer logger.Debugf(
ctx,
"/registerConnection(ctx, '%s', %s)",
sourceName,
conn.RemoteAddr().String(),
)
if !m.isExpectedProcess(sourceName) {
return fmt.Errorf("process '%s' is not ever expected", sourceName)
}
m.connsLocker.Do(ctx, func() {
if conn, ok := m.conns[sourceName]; ok {
conn.Close()
logger.Warnf(ctx, "closing the old registered connection")
}
m.conns[sourceName] = conn
var oldCh chan struct{}
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
close(oldCh)
})
return nil
}
func (m *Manager) setReady(
ctx context.Context,
sourceName ProcessName,
msgType reflect.Type,
) error {
if !m.isExpectedProcess(sourceName) {
return fmt.Errorf("process '%s' is not ever expected", sourceName)
}
if msgType.Kind() == reflect.Pointer {
msgType = msgType.Elem()
}
return xsync.DoR1(ctx, &m.connsLocker, func() error {
readyMap := m.childReadyFor[sourceName]
if readyMap == nil {
m.childReadyFor[sourceName] = map[reflect.Type]struct{}{}
readyMap = m.childReadyFor[sourceName]
}
if _, ok := readyMap[msgType]; ok {
return fmt.Errorf("process '%s' is already ready for '%s'", sourceName, msgType)
}
readyMap[msgType] = struct{}{}
var oldCh chan struct{}
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
close(oldCh)
return nil
})
}
func (m *Manager) unregisterConnection(
ctx context.Context,
sourceName ProcessName,
) {
logger.Infof(ctx, "unregistering process '%s'", string(sourceName))
defer logger.Infof(ctx, "/unregistering process '%s'", string(sourceName))
m.connsLocker.Do(ctx, func() {
delete(m.conns, sourceName)
delete(m.childReadyFor, sourceName)
var oldCh chan struct{}
oldCh, m.connsChanged = m.connsChanged, make(chan struct{})
close(oldCh)
})
}
type RegistrationMessage struct {
Password string
Source ProcessName
}
type RegistrationResult struct {
Error string
}
type MessageToMain struct {
Password string
Destination ProcessName
Content any
}
func (m *Manager) SendMessagePreReady(
ctx context.Context,
dst ProcessName,
content any,
) error {
logger.Debugf(ctx, "SendMessagePreReady(ctx, '%s', %T)", dst, content)
defer logger.Debugf(ctx, "/SendMessagePreReady(ctx, '%s', %T)", dst, content)
conn, err := m.waitForProcess(ctx, dst)
if err != nil {
return fmt.Errorf("unable to wait for process '%s': %w", dst, err)
}
encoder := gob.NewEncoder(conn)
msg := MessageFromMain{
Source: ProcessNameMain,
Password: m.password,
Destination: dst,
Content: content,
}
h := m.connLocker.Lock(context.Background(), dst)
defer h.Unlock()
defer time.Sleep(
100 * time.Millisecond,
) // TODO: Delete this horrible hack (that is introduced to avoid erasing messages in the buffer)
err = encoder.Encode(msg)
logger.Tracef(ctx, "sending message %#+v: %v", msg, err)
if err != nil {
return fmt.Errorf("unable to encode&send message %#+v: %w", msg, err)
}
return nil
}
func (m *Manager) SendMessage(
ctx context.Context,
dst ProcessName,
content any,
) error {
logger.Debugf(ctx, "SendMessage(ctx, '%s', %T)", dst, content)
defer logger.Debugf(ctx, "/SendMessage(ctx, '%s', %T)", dst, content)
conn, err := m.waitForReadyProcess(ctx, dst, reflect.TypeOf(content))
if err != nil {
return fmt.Errorf("unable to wait for process '%s': %w", dst, err)
}
encoder := gob.NewEncoder(conn)
msg := MessageFromMain{
Source: ProcessNameMain,
Password: m.password,
Destination: dst,
Content: content,
}
h := m.connLocker.Lock(context.Background(), dst)
defer h.Unlock()
err = encoder.Encode(msg)
logger.Tracef(ctx, "sending message %#+v: %v", msg, err)
if err != nil {
return fmt.Errorf("unable to encode&send message %#+v: %w", msg, err)
}
return nil
}
type MessageReady struct {
ReadyForMessages []any
}
type MessageReadyConfirmed struct{}