mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-08 01:10:09 +08:00
feat: change package name
This commit is contained in:
59
consumer.go
59
consumer.go
@@ -2,6 +2,7 @@ package mq
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
"github.com/oarkflow/mq/utils"
|
"github.com/oarkflow/mq/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Consumer structure to hold consumer-specific configurations and state.
|
||||||
type Consumer struct {
|
type Consumer struct {
|
||||||
id string
|
id string
|
||||||
handlers map[string]Handler
|
handlers map[string]Handler
|
||||||
@@ -21,6 +23,7 @@ type Consumer struct {
|
|||||||
opts Options
|
opts Options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewConsumer initializes a new consumer with the provided options.
|
||||||
func NewConsumer(id string, opts ...Option) *Consumer {
|
func NewConsumer(id string, opts ...Option) *Consumer {
|
||||||
options := defaultOptions()
|
options := defaultOptions()
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
@@ -33,11 +36,9 @@ func NewConsumer(id string, opts ...Option) *Consumer {
|
|||||||
if options.messageHandler == nil {
|
if options.messageHandler == nil {
|
||||||
options.messageHandler = con.readConn
|
options.messageHandler = con.readConn
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.closeHandler == nil {
|
if options.closeHandler == nil {
|
||||||
options.closeHandler = con.onClose
|
options.closeHandler = con.onClose
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.errorHandler == nil {
|
if options.errorHandler == nil {
|
||||||
options.errorHandler = con.onError
|
options.errorHandler = con.onError
|
||||||
}
|
}
|
||||||
@@ -45,10 +46,12 @@ func NewConsumer(id string, opts ...Option) *Consumer {
|
|||||||
return con
|
return con
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the consumer's connection.
|
||||||
func (c *Consumer) Close() error {
|
func (c *Consumer) Close() error {
|
||||||
return c.conn.Close()
|
return c.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Subscribe to a specific queue.
|
||||||
func (c *Consumer) subscribe(queue string) error {
|
func (c *Consumer) subscribe(queue string) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = SetHeaders(ctx, map[string]string{
|
ctx = SetHeaders(ctx, map[string]string{
|
||||||
@@ -63,6 +66,7 @@ func (c *Consumer) subscribe(queue string) error {
|
|||||||
return Write(ctx, c.conn, subscribe)
|
return Write(ctx, c.conn, subscribe)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProcessTask handles a received task message and invokes the appropriate handler.
|
||||||
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
||||||
handler, exists := c.handlers[msg.CurrentQueue]
|
handler, exists := c.handlers[msg.CurrentQueue]
|
||||||
if !exists {
|
if !exists {
|
||||||
@@ -71,6 +75,7 @@ func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
|||||||
return handler(ctx, msg)
|
return handler(ctx, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle command message sent by the server.
|
||||||
func (c *Consumer) handleCommandMessage(msg Command) error {
|
func (c *Consumer) handleCommandMessage(msg Command) error {
|
||||||
switch msg.Command {
|
switch msg.Command {
|
||||||
case STOP:
|
case STOP:
|
||||||
@@ -83,6 +88,7 @@ func (c *Consumer) handleCommandMessage(msg Command) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle task message sent by the server.
|
||||||
func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
|
func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
|
||||||
response := c.ProcessTask(ctx, msg)
|
response := c.ProcessTask(ctx, msg)
|
||||||
response.Queue = msg.CurrentQueue
|
response.Queue = msg.CurrentQueue
|
||||||
@@ -96,10 +102,12 @@ func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
|
|||||||
return c.sendResult(ctx, response)
|
return c.sendResult(ctx, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send the result of task processing back to the server.
|
||||||
func (c *Consumer) sendResult(ctx context.Context, response Result) error {
|
func (c *Consumer) sendResult(ctx context.Context, response Result) error {
|
||||||
return Write(ctx, c.conn, response)
|
return Write(ctx, c.conn, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read and handle incoming messages.
|
||||||
func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
|
func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
|
||||||
var cmdMsg Command
|
var cmdMsg Command
|
||||||
var task Task
|
var task Task
|
||||||
@@ -114,16 +122,26 @@ func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration.
|
||||||
func (c *Consumer) AttemptConnect() error {
|
func (c *Consumer) AttemptConnect() error {
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
var err error
|
var err error
|
||||||
delay := c.opts.initialDelay
|
delay := c.opts.initialDelay
|
||||||
|
|
||||||
for i := 0; i < c.opts.maxRetries; i++ {
|
for i := 0; i < c.opts.maxRetries; i++ {
|
||||||
|
if c.opts.useTLS {
|
||||||
|
// Create TLS connection
|
||||||
|
conn, err = c.createTLSConnection()
|
||||||
|
} else {
|
||||||
|
// Create regular TCP connection
|
||||||
conn, err = net.Dial("tcp", c.opts.brokerAddr)
|
conn, err = net.Dial("tcp", c.opts.brokerAddr)
|
||||||
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.conn = conn
|
c.conn = conn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
||||||
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
||||||
time.Sleep(sleepDuration)
|
time.Sleep(sleepDuration)
|
||||||
@@ -136,19 +154,55 @@ func (c *Consumer) AttemptConnect() error {
|
|||||||
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createTLSConnection creates a TLS connection to the server.
|
||||||
|
func (c *Consumer) createTLSConnection() (net.Conn, error) {
|
||||||
|
// Load the client cert
|
||||||
|
cert, err := tls.LoadX509KeyPair(c.opts.tlsCertPath, c.opts.tlsKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load client cert/key: %w", err)
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
// Load CA cert for server verification
|
||||||
|
caCert, err := os.ReadFile(c.opts.tlsCAPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load CA cert: %w", err)
|
||||||
|
}
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
caCertPool.AppendCertsFromPEM(caCert)
|
||||||
|
*/
|
||||||
|
// Configure TLS
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
// RootCAs: caCertPool,
|
||||||
|
InsecureSkipVerify: true, // Enforce server certificate validation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Establish TLS connection
|
||||||
|
conn, err := tls.Dial("tcp", c.opts.brokerAddr, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to dial TLS connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readConn reads incoming messages from the connection.
|
||||||
func (c *Consumer) readConn(ctx context.Context, conn net.Conn, message []byte) error {
|
func (c *Consumer) readConn(ctx context.Context, conn net.Conn, message []byte) error {
|
||||||
return c.readMessage(ctx, message)
|
return c.readMessage(ctx, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// onClose handles connection close event.
|
||||||
func (c *Consumer) onClose(ctx context.Context, conn net.Conn) error {
|
func (c *Consumer) onClose(ctx context.Context, conn net.Conn) error {
|
||||||
fmt.Println("Consumer Connection closed", c.id, conn.RemoteAddr())
|
fmt.Println("Consumer Connection closed", c.id, conn.RemoteAddr())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// onError handles errors while reading from the connection.
|
||||||
func (c *Consumer) onError(ctx context.Context, conn net.Conn, err error) {
|
func (c *Consumer) onError(ctx context.Context, conn net.Conn, err error) {
|
||||||
fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr())
|
fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Consume starts the consumer to consume tasks from the queues.
|
||||||
func (c *Consumer) Consume(ctx context.Context) error {
|
func (c *Consumer) Consume(ctx context.Context) error {
|
||||||
err := c.AttemptConnect()
|
err := c.AttemptConnect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -170,6 +224,7 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterHandler registers a handler for a queue.
|
||||||
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
||||||
c.queues = append(c.queues, queue)
|
c.queues = append(c.queues, queue)
|
||||||
c.handlers[queue] = handler
|
c.handlers[queue] = handler
|
||||||
|
@@ -2,48 +2,14 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
"github.com/oarkflow/mq/examples/tasks"
|
"github.com/oarkflow/mq/examples/tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Load consumer's certificate and private key
|
consumer := mq.NewConsumer("consumer-1", mq.WithTLS(true, "consumer.crt", "consumer.key"))
|
||||||
cert, err := tls.LoadX509KeyPair("consumer.crt", "consumer.key")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to load consumer certificate and key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the CA certificate
|
|
||||||
caCert, err := ioutil.ReadFile("ca.crt")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to read CA certificate: %v", err)
|
|
||||||
}
|
|
||||||
caCertPool := x509.NewCertPool()
|
|
||||||
caCertPool.AppendCertsFromPEM(caCert)
|
|
||||||
|
|
||||||
// Configure TLS for the consumer
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
RootCAs: caCertPool,
|
|
||||||
InsecureSkipVerify: false, // Ensure we verify the server certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial TLS connection to the broker
|
|
||||||
conn, err := tls.Dial("tcp", "localhost:8443", tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to connect to broker: %v", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
consumer := mq.NewConsumer("consumer-1")
|
|
||||||
consumer.RegisterHandler("queue1", tasks.Node1)
|
consumer.RegisterHandler("queue1", tasks.Node1)
|
||||||
consumer.RegisterHandler("queue2", tasks.Node2)
|
consumer.RegisterHandler("queue2", tasks.Node2)
|
||||||
|
|
||||||
// Start consuming tasks
|
|
||||||
consumer.Consume(context.Background())
|
consumer.Consume(context.Background())
|
||||||
}
|
}
|
||||||
|
@@ -3,9 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
@@ -17,24 +15,24 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to load publisher certificate and key: %v", err)
|
log.Fatalf("Failed to load publisher certificate and key: %v", err)
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
// Load the CA certificate
|
// Load the CA certificate
|
||||||
caCert, err := ioutil.ReadFile("ca.crt")
|
caCert, err := os.ReadFile("ca.crt")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to read CA certificate: %v", err)
|
log.Fatalf("Failed to read CA certificate: %v", err)
|
||||||
}
|
}
|
||||||
caCertPool := x509.NewCertPool()
|
caCertPool := x509.NewCertPool()
|
||||||
caCertPool.AppendCertsFromPEM(caCert)
|
caCertPool.AppendCertsFromPEM(caCert)
|
||||||
|
*/
|
||||||
// Configure TLS for the publisher
|
// Configure TLS for the publisher
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
RootCAs: caCertPool,
|
// RootCAs: caCertPool,
|
||||||
InsecureSkipVerify: false, // Ensure we verify the server certificate
|
InsecureSkipVerify: true, // Ensure we verify the server certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial TLS connection to the broker
|
// Dial TLS connection to the broker
|
||||||
conn, err := tls.Dial("tcp", "localhost:8443", tlsConfig)
|
conn, err := tls.Dial("tcp", "localhost:8080", tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to connect to broker: %v", err)
|
log.Fatalf("Failed to connect to broker: %v", err)
|
||||||
}
|
}
|
||||||
|
@@ -2,65 +2,14 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
"github.com/oarkflow/mq/examples/tasks"
|
"github.com/oarkflow/mq/examples/tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Load the server's certificate and key
|
b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "server.crt", "server.key"), mq.WithCAPath("ca.cert"))
|
||||||
cert, err := tls.LoadX509KeyPair("server.crt", "server.key")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to load server certificate and key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the CA certificate
|
|
||||||
caCert, err := ioutil.ReadFile("ca.crt")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to read CA certificate: %v", err)
|
|
||||||
}
|
|
||||||
caCertPool := x509.NewCertPool()
|
|
||||||
caCertPool.AppendCertsFromPEM(caCert)
|
|
||||||
|
|
||||||
// Configure TLS for the server
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
ClientCAs: caCertPool,
|
|
||||||
ClientAuth: tls.RequireAndVerifyClientCert, // Mutual TLS
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start a TLS listener
|
|
||||||
listener, err := tls.Listen("tcp", ":8443", tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to start TLS listener: %v", err)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
|
|
||||||
b := mq.NewBroker(mq.WithCallback(tasks.Callback))
|
|
||||||
b.NewQueue("queue1")
|
b.NewQueue("queue1")
|
||||||
b.NewQueue("queue2")
|
b.NewQueue("queue2")
|
||||||
|
b.Start(context.Background())
|
||||||
log.Println("TLS-enabled broker started on :8443")
|
|
||||||
|
|
||||||
// Handle incoming connections
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Error accepting connection:", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go handleConnection(b, conn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleConnection(b *mq.Broker, conn net.Conn) {
|
|
||||||
defer conn.Close()
|
|
||||||
ctx := context.Background()
|
|
||||||
b.Start(ctx)
|
|
||||||
}
|
}
|
||||||
|
10
options.go
10
options.go
@@ -43,12 +43,18 @@ func WithBrokerURL(url string) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option to enable/disable TLS
|
// WithTLS - Option to enable/disable TLS
|
||||||
func WithTLS(enableTLS bool, certPath, keyPath, caPath string) Option {
|
func WithTLS(enableTLS bool, certPath, keyPath string) Option {
|
||||||
return func(o *Options) {
|
return func(o *Options) {
|
||||||
o.useTLS = enableTLS
|
o.useTLS = enableTLS
|
||||||
o.tlsCertPath = certPath
|
o.tlsCertPath = certPath
|
||||||
o.tlsKeyPath = keyPath
|
o.tlsKeyPath = keyPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCAPath - Option to enable/disable TLS
|
||||||
|
func WithCAPath(caPath string) Option {
|
||||||
|
return func(o *Options) {
|
||||||
o.tlsCAPath = caPath
|
o.tlsCAPath = caPath
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user