Files
mq/ctx.go
2025-07-31 14:26:08 +05:45

158 lines
3.7 KiB
Go

package mq
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"sync"
"github.com/oarkflow/errors"
"github.com/oarkflow/xid/wuid"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
)
type Handler func(context.Context, *Task) Result
func IsClosed(conn net.Conn) bool {
_, err := conn.Read(make([]byte, 1))
if err != nil {
if err == net.ErrClosed {
return true
}
}
return false
}
func SetHeaders(ctx context.Context, headers map[string]string) context.Context {
hd, _ := GetHeaders(ctx)
if hd == nil {
hd = memory.New[string, string]()
}
for key, val := range headers {
hd.Set(key, val)
}
return context.WithValue(ctx, consts.HeaderKey, hd)
}
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
hd, _ := GetHeaders(ctx)
if hd == nil {
hd = memory.New[string, string]()
}
for key, val := range headers {
hd.Set(key, val)
}
return hd.AsMap()
}
func GetHeaders(ctx context.Context) (storage.IMap[string, string], bool) {
headers, ok := ctx.Value(consts.HeaderKey).(storage.IMap[string, string])
return headers, ok
}
func GetHeader(ctx context.Context, key string) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
return "", false
}
val, ok := headers.Get(key)
return val, ok
}
func GetContentType(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.ContentType)
}
func GetQueue(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.QueueKey)
}
func GetConsumerID(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.ConsumerKey)
}
func GetTriggerNode(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.TriggerNode)
}
func GetAwaitResponse(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.AwaitResponseKey)
}
func GetPublisherID(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.PublisherKey)
}
func NewID() string {
return wuid.New().String()
}
func createTLSConnection(addr, certPath, keyPath string, caPath ...string) (net.Conn, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("failed to load client cert/key: %w", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
InsecureSkipVerify: true,
}
if len(caPath) > 0 && caPath[0] != "" {
caCert, err := os.ReadFile(caPath[0])
if err != nil {
return nil, fmt.Errorf("failed to load CA cert: %w", err)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial TLS connection: %w", err)
}
return conn, nil
}
// Global connection pool
var connPool sync.Map
// Modified GetConnection: reuse existing connection if valid.
func GetConnection(addr string, config TLSConfig) (net.Conn, error) {
key := fmt.Sprintf("%s_%t", addr, config.UseTLS)
// Check if a connection exists and reuse it if not closed.
if c, ok := connPool.Load(key); ok {
conn := c.(net.Conn)
if !IsClosed(conn) {
return conn, nil
}
// If closed, delete the stale connection.
connPool.Delete(key)
}
var conn net.Conn
var err error
if config.UseTLS {
conn, err = createTLSConnection(addr, config.CertPath, config.KeyPath, config.CAPath)
} else {
conn, err = net.Dial("tcp", addr)
}
if err != nil {
return nil, err
}
// Store the new connection in the pool.
connPool.Store(key, conn)
return conn, nil
}
func WrapError(err error, msg, op string) error {
return errors.Wrap(err, msg, op)
}