mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-26 20:11:16 +08:00
158 lines
3.7 KiB
Go
158 lines
3.7 KiB
Go
package mq
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
|
|
"github.com/oarkflow/errors"
|
|
"github.com/oarkflow/xid/wuid"
|
|
|
|
"github.com/oarkflow/mq/consts"
|
|
"github.com/oarkflow/mq/storage"
|
|
"github.com/oarkflow/mq/storage/memory"
|
|
)
|
|
|
|
type Handler func(context.Context, *Task) Result
|
|
|
|
func IsClosed(conn net.Conn) bool {
|
|
_, err := conn.Read(make([]byte, 1))
|
|
if err != nil {
|
|
if err == net.ErrClosed {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func SetHeaders(ctx context.Context, headers map[string]string) context.Context {
|
|
hd, _ := GetHeaders(ctx)
|
|
if hd == nil {
|
|
hd = memory.New[string, string]()
|
|
}
|
|
for key, val := range headers {
|
|
hd.Set(key, val)
|
|
}
|
|
return context.WithValue(ctx, consts.HeaderKey, hd)
|
|
}
|
|
|
|
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
|
|
hd, _ := GetHeaders(ctx)
|
|
if hd == nil {
|
|
hd = memory.New[string, string]()
|
|
}
|
|
for key, val := range headers {
|
|
hd.Set(key, val)
|
|
}
|
|
return hd.AsMap()
|
|
}
|
|
|
|
func GetHeaders(ctx context.Context) (storage.IMap[string, string], bool) {
|
|
headers, ok := ctx.Value(consts.HeaderKey).(storage.IMap[string, string])
|
|
return headers, ok
|
|
}
|
|
|
|
func GetHeader(ctx context.Context, key string) (string, bool) {
|
|
headers, ok := GetHeaders(ctx)
|
|
if !ok {
|
|
return "", false
|
|
}
|
|
val, ok := headers.Get(key)
|
|
return val, ok
|
|
}
|
|
|
|
func GetContentType(ctx context.Context) (string, bool) {
|
|
return GetHeader(ctx, consts.ContentType)
|
|
}
|
|
|
|
func GetQueue(ctx context.Context) (string, bool) {
|
|
return GetHeader(ctx, consts.QueueKey)
|
|
}
|
|
|
|
func GetConsumerID(ctx context.Context) (string, bool) {
|
|
return GetHeader(ctx, consts.ConsumerKey)
|
|
}
|
|
|
|
func GetTriggerNode(ctx context.Context) (string, bool) {
|
|
return GetHeader(ctx, consts.TriggerNode)
|
|
}
|
|
|
|
func GetAwaitResponse(ctx context.Context) (string, bool) {
|
|
return GetHeader(ctx, consts.AwaitResponseKey)
|
|
}
|
|
|
|
func GetPublisherID(ctx context.Context) (string, bool) {
|
|
return GetHeader(ctx, consts.PublisherKey)
|
|
}
|
|
|
|
func NewID() string {
|
|
return wuid.New().String()
|
|
}
|
|
|
|
func createTLSConnection(addr, certPath, keyPath string, caPath ...string) (net.Conn, error) {
|
|
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load client cert/key: %w", err)
|
|
}
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
InsecureSkipVerify: true,
|
|
}
|
|
if len(caPath) > 0 && caPath[0] != "" {
|
|
caCert, err := os.ReadFile(caPath[0])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load CA cert: %w", err)
|
|
}
|
|
caCertPool := x509.NewCertPool()
|
|
caCertPool.AppendCertsFromPEM(caCert)
|
|
tlsConfig.RootCAs = caCertPool
|
|
tlsConfig.ClientCAs = caCertPool
|
|
}
|
|
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to dial TLS connection: %w", err)
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// Global connection pool
|
|
var connPool sync.Map
|
|
|
|
// Modified GetConnection: reuse existing connection if valid.
|
|
func GetConnection(addr string, config TLSConfig) (net.Conn, error) {
|
|
key := fmt.Sprintf("%s_%t", addr, config.UseTLS)
|
|
// Check if a connection exists and reuse it if not closed.
|
|
if c, ok := connPool.Load(key); ok {
|
|
conn := c.(net.Conn)
|
|
if !IsClosed(conn) {
|
|
return conn, nil
|
|
}
|
|
// If closed, delete the stale connection.
|
|
connPool.Delete(key)
|
|
}
|
|
|
|
var conn net.Conn
|
|
var err error
|
|
if config.UseTLS {
|
|
conn, err = createTLSConnection(addr, config.CertPath, config.KeyPath, config.CAPath)
|
|
} else {
|
|
conn, err = net.Dial("tcp", addr)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Store the new connection in the pool.
|
|
connPool.Store(key, conn)
|
|
return conn, nil
|
|
}
|
|
|
|
func WrapError(err error, msg, op string) error {
|
|
return errors.Wrap(err, msg, op)
|
|
}
|