Files
frontier/pkg/servicebound/service_manager.go
2024-02-10 23:06:20 +08:00

257 lines
6.6 KiB
Go

package servicebound
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"net"
"os"
"sync"
"github.com/jumboframes/armorigo/synchub"
"github.com/singchia/frontier/pkg/api"
"github.com/singchia/frontier/pkg/config"
"github.com/singchia/frontier/pkg/mapmap"
"github.com/singchia/frontier/pkg/repo/dao"
"github.com/singchia/frontier/pkg/repo/model"
"github.com/singchia/frontier/pkg/security"
"github.com/singchia/frontier/pkg/utils"
"github.com/singchia/geminio"
"github.com/singchia/geminio/delegate"
"github.com/singchia/geminio/pkg/id"
"github.com/singchia/geminio/server"
"github.com/singchia/go-timer/v2"
"k8s.io/klog/v2"
)
type serviceManager struct {
*delegate.UnimplementedDelegate
informer api.ServiceInformer
exchange api.Exchange
conf *config.Configuration
// serviceID allocator
idFactory id.IDFactory
shub *synchub.SyncHub
// cache
// key: serviceID; value: geminio.End
services map[uint64]geminio.End
mtx sync.RWMutex
// key: serviceID; subkey: streamID; value: geminio.Stream
// we don't store stream info to dao, because they may will be too much.
streams *mapmap.MapMap
// dao and repo for services
dao *dao.Dao
ln net.Listener
// timer for all service ends
tmr timer.Timer
}
func newServiceManager(conf *config.Configuration, dao *dao.Dao, informer api.ServiceInformer,
exchange api.Exchange, tmr timer.Timer) (*serviceManager, error) {
listen := &conf.Servicebound.Listen
var (
ln net.Listener
network string = listen.Network
addr string = listen.Addr
err error
)
sm := &serviceManager{
conf: conf,
tmr: tmr,
streams: mapmap.NewMapMap(),
dao: dao,
shub: synchub.NewSyncHub(synchub.OptionTimer(tmr)),
services: make(map[uint64]geminio.End),
UnimplementedDelegate: &delegate.UnimplementedDelegate{},
// a simple unix timestamp incremental id factory
idFactory: id.DefaultIncIDCounter,
informer: informer,
}
if !listen.TLS.Enable {
if ln, err = net.Listen(network, addr); err != nil {
klog.Errorf("service manager net listen err: %s, network: %s, addr: %s", err, network, addr)
return nil, err
}
} else {
// load all certs to listen
certs := []tls.Certificate{}
for _, certFile := range listen.TLS.Certs {
cert, err := tls.LoadX509KeyPair(certFile.Cert, certFile.Key)
if err != nil {
klog.Errorf("service manager tls load x509 cert err: %s, cert: %s, key: %s", err, certFile.Cert, certFile.Key)
continue
}
certs = append(certs, cert)
}
if !listen.TLS.MTLS {
// tls
if ln, err = tls.Listen(network, addr, &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: security.CiperSuites,
Certificates: certs,
}); err != nil {
klog.Errorf("service manager tls listen err: %s, network: %s, addr: %s", err, network, addr)
return nil, err
}
} else {
// mtls, require for edge cert
// load all ca certs to pool
caPool := x509.NewCertPool()
for _, caFile := range listen.TLS.CACerts {
ca, err := os.ReadFile(caFile)
if err != nil {
klog.Errorf("service manager read ca cert err: %s, file: %s", err, caFile)
return nil, err
}
if !caPool.AppendCertsFromPEM(ca) {
klog.Warningf("service manager append ca cert to ca pool err: %s, file: %s", err, caFile)
continue
}
}
if ln, err = tls.Listen(network, addr, &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: security.CiperSuites,
ClientCAs: caPool,
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: certs,
}); err != nil {
klog.Errorf("service manager tls listen err: %s, network: %s, addr: %s", err, network, addr)
return nil, err
}
}
}
sm.ln = ln
return sm, nil
}
func (sm *serviceManager) Serve() {
for {
conn, err := sm.ln.Accept()
if err != nil {
klog.V(4).Infof("service manager listener accept err: %s", err)
return
}
go sm.handleConn(conn)
}
}
func (sm *serviceManager) handleConn(conn net.Conn) error {
// options for geminio End
opt := server.NewEndOptions()
opt.SetTimer(sm.tmr)
opt.SetDelegate(sm)
// stream handler
opt.SetAcceptStreamFunc(sm.acceptStream)
opt.SetClosedStreamFunc(sm.closedStream)
end, err := server.NewEndWithConn(conn, opt)
if err != nil {
klog.Errorf("service manager geminio server new end err: %s", err)
return err
}
meta := &api.Meta{}
err = json.Unmarshal(end.Meta(), meta)
if err != nil {
klog.Errorf("handle conn, json unmarshal err: %s", err)
return err
}
// register topics claim of end
sm.remoteReceiveClaim(end.ClientID(), meta.Topics)
// handle online event for end
if err = sm.online(end, meta); err != nil {
return err
}
// forward and stream up to edge
sm.forward(meta, end)
return nil
}
func (sm *serviceManager) remoteReceiveClaim(serviceID uint64, topics []string) error {
klog.V(5).Infof("service remote receive claim, topics: %v, serviceID: %d", topics, serviceID)
var err error
// memdb
for _, topic := range topics {
st := &model.ServiceTopic{
Topic: topic,
ServiceID: serviceID,
}
err = sm.dao.CreateServiceTopic(st)
if err != nil {
klog.Errorf("service remote receive claim, create service topic: %s, err: %s", topic, err)
return err
}
}
return nil
}
func (sm *serviceManager) GetServiceByID(serviceID uint64) geminio.End {
sm.mtx.RLock()
defer sm.mtx.RUnlock()
return sm.services[serviceID]
}
func (sm *serviceManager) GetServiceByRPC(rpc string) (geminio.End, error) {
mrpc, err := sm.dao.GetServiceRPC(rpc)
if err != nil {
klog.Errorf("get service by rpc: %s, err: %s", rpc, err)
return nil, err
}
sm.mtx.RLock()
defer sm.mtx.RUnlock()
return sm.services[mrpc.ServiceID], nil
}
func (sm *serviceManager) GetServiceByTopic(topic string) (geminio.End, error) {
mtopic, err := sm.dao.GetServiceTopic(topic)
if err != nil {
klog.Errorf("get service by topic: %s, err: %s", topic, err)
return nil, err
}
sm.mtx.RLock()
defer sm.mtx.RUnlock()
return sm.services[mtopic.ServiceID], nil
}
func (sm *serviceManager) ListService() []geminio.End {
ends := []geminio.End{}
sm.mtx.RLock()
defer sm.mtx.RUnlock()
for _, value := range sm.services {
ends = append(ends, value)
}
return ends
}
func (sm *serviceManager) CountServices() int {
sm.mtx.RLock()
defer sm.mtx.RUnlock()
return len(sm.services)
}
func (sm *serviceManager) ListStreams(serviceID uint64) []geminio.Stream {
all := sm.streams.MGetAll(serviceID)
return utils.Slice2streams(all)
}
// close all services
func (sm *serviceManager) Close() error {
if err := sm.ln.Close(); err != nil {
return err
}
return nil
}