diff --git a/consumer.go b/consumer.go index 4048502..7c5069d 100644 --- a/consumer.go +++ b/consumer.go @@ -2,6 +2,7 @@ package mq import ( "context" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -13,6 +14,7 @@ import ( "github.com/oarkflow/mq/utils" ) +// Consumer structure to hold consumer-specific configurations and state. type Consumer struct { id string handlers map[string]Handler @@ -21,6 +23,7 @@ type Consumer struct { opts Options } +// NewConsumer initializes a new consumer with the provided options. func NewConsumer(id string, opts ...Option) *Consumer { options := defaultOptions() for _, opt := range opts { @@ -33,11 +36,9 @@ func NewConsumer(id string, opts ...Option) *Consumer { if options.messageHandler == nil { options.messageHandler = con.readConn } - if options.closeHandler == nil { options.closeHandler = con.onClose } - if options.errorHandler == nil { options.errorHandler = con.onError } @@ -45,10 +46,12 @@ func NewConsumer(id string, opts ...Option) *Consumer { return con } +// Close closes the consumer's connection. func (c *Consumer) Close() error { return c.conn.Close() } +// Subscribe to a specific queue. func (c *Consumer) subscribe(queue string) error { ctx := context.Background() ctx = SetHeaders(ctx, map[string]string{ @@ -63,6 +66,7 @@ func (c *Consumer) subscribe(queue string) error { return Write(ctx, c.conn, subscribe) } +// ProcessTask handles a received task message and invokes the appropriate handler. func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { handler, exists := c.handlers[msg.CurrentQueue] if !exists { @@ -71,6 +75,7 @@ func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { return handler(ctx, msg) } +// Handle command message sent by the server. func (c *Consumer) handleCommandMessage(msg Command) error { switch msg.Command { case STOP: @@ -83,6 +88,7 @@ func (c *Consumer) handleCommandMessage(msg Command) error { } } +// Handle task message sent by the server. func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error { response := c.ProcessTask(ctx, msg) response.Queue = msg.CurrentQueue @@ -96,10 +102,12 @@ func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error { return c.sendResult(ctx, response) } +// Send the result of task processing back to the server. func (c *Consumer) sendResult(ctx context.Context, response Result) error { return Write(ctx, c.conn, response) } +// Read and handle incoming messages. func (c *Consumer) readMessage(ctx context.Context, message []byte) error { var cmdMsg Command var task Task @@ -114,16 +122,26 @@ func (c *Consumer) readMessage(ctx context.Context, message []byte) error { return nil } +// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration. func (c *Consumer) AttemptConnect() error { var conn net.Conn var err error delay := c.opts.initialDelay + for i := 0; i < c.opts.maxRetries; i++ { - conn, err = net.Dial("tcp", c.opts.brokerAddr) + if c.opts.useTLS { + // Create TLS connection + conn, err = c.createTLSConnection() + } else { + // Create regular TCP connection + conn, err = net.Dial("tcp", c.opts.brokerAddr) + } + if err == nil { c.conn = conn return nil } + sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent) fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration) time.Sleep(sleepDuration) @@ -136,19 +154,55 @@ func (c *Consumer) AttemptConnect() error { return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err) } +// createTLSConnection creates a TLS connection to the server. +func (c *Consumer) createTLSConnection() (net.Conn, error) { + // Load the client cert + cert, err := tls.LoadX509KeyPair(c.opts.tlsCertPath, c.opts.tlsKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to load client cert/key: %w", err) + } + /* + // Load CA cert for server verification + caCert, err := os.ReadFile(c.opts.tlsCAPath) + if err != nil { + return nil, fmt.Errorf("failed to load CA cert: %w", err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + */ + // Configure TLS + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + // RootCAs: caCertPool, + InsecureSkipVerify: true, // Enforce server certificate validation + } + + // Establish TLS connection + conn, err := tls.Dial("tcp", c.opts.brokerAddr, tlsConfig) + if err != nil { + return nil, fmt.Errorf("failed to dial TLS connection: %w", err) + } + + return conn, nil +} + +// readConn reads incoming messages from the connection. func (c *Consumer) readConn(ctx context.Context, conn net.Conn, message []byte) error { return c.readMessage(ctx, message) } +// onClose handles connection close event. func (c *Consumer) onClose(ctx context.Context, conn net.Conn) error { fmt.Println("Consumer Connection closed", c.id, conn.RemoteAddr()) return nil } +// onError handles errors while reading from the connection. func (c *Consumer) onError(ctx context.Context, conn net.Conn, err error) { fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr()) } +// Consume starts the consumer to consume tasks from the queues. func (c *Consumer) Consume(ctx context.Context) error { err := c.AttemptConnect() if err != nil { @@ -170,6 +224,7 @@ func (c *Consumer) Consume(ctx context.Context) error { return nil } +// RegisterHandler registers a handler for a queue. func (c *Consumer) RegisterHandler(queue string, handler Handler) { c.queues = append(c.queues, queue) c.handlers[queue] = handler diff --git a/examples/consumer_tls.go b/examples/consumer_tls.go index 780c281..fc1ef35 100644 --- a/examples/consumer_tls.go +++ b/examples/consumer_tls.go @@ -2,48 +2,14 @@ package main import ( "context" - "crypto/tls" - "crypto/x509" - "io/ioutil" - "log" "github.com/oarkflow/mq" "github.com/oarkflow/mq/examples/tasks" ) func main() { - // Load consumer's certificate and private key - cert, err := tls.LoadX509KeyPair("consumer.crt", "consumer.key") - if err != nil { - log.Fatalf("Failed to load consumer certificate and key: %v", err) - } - - // Load the CA certificate - caCert, err := ioutil.ReadFile("ca.crt") - if err != nil { - log.Fatalf("Failed to read CA certificate: %v", err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - - // Configure TLS for the consumer - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: caCertPool, - InsecureSkipVerify: false, // Ensure we verify the server certificate - } - - // Dial TLS connection to the broker - conn, err := tls.Dial("tcp", "localhost:8443", tlsConfig) - if err != nil { - log.Fatalf("Failed to connect to broker: %v", err) - } - defer conn.Close() - - consumer := mq.NewConsumer("consumer-1") + consumer := mq.NewConsumer("consumer-1", mq.WithTLS(true, "consumer.crt", "consumer.key")) consumer.RegisterHandler("queue1", tasks.Node1) consumer.RegisterHandler("queue2", tasks.Node2) - - // Start consuming tasks consumer.Consume(context.Background()) } diff --git a/examples/publisher_tls.go b/examples/publisher_tls.go index d0a5ed0..c53b1bd 100644 --- a/examples/publisher_tls.go +++ b/examples/publisher_tls.go @@ -3,9 +3,7 @@ package main import ( "context" "crypto/tls" - "crypto/x509" "fmt" - "io/ioutil" "log" "github.com/oarkflow/mq" @@ -17,24 +15,24 @@ func main() { if err != nil { log.Fatalf("Failed to load publisher certificate and key: %v", err) } - - // Load the CA certificate - caCert, err := ioutil.ReadFile("ca.crt") - if err != nil { - log.Fatalf("Failed to read CA certificate: %v", err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - + /* + // Load the CA certificate + caCert, err := os.ReadFile("ca.crt") + if err != nil { + log.Fatalf("Failed to read CA certificate: %v", err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + */ // Configure TLS for the publisher tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: caCertPool, - InsecureSkipVerify: false, // Ensure we verify the server certificate + Certificates: []tls.Certificate{cert}, + // RootCAs: caCertPool, + InsecureSkipVerify: true, // Ensure we verify the server certificate } // Dial TLS connection to the broker - conn, err := tls.Dial("tcp", "localhost:8443", tlsConfig) + conn, err := tls.Dial("tcp", "localhost:8080", tlsConfig) if err != nil { log.Fatalf("Failed to connect to broker: %v", err) } diff --git a/examples/server_tls.go b/examples/server_tls.go index 772f4c4..1a58723 100644 --- a/examples/server_tls.go +++ b/examples/server_tls.go @@ -2,65 +2,14 @@ package main import ( "context" - "crypto/tls" - "crypto/x509" - "fmt" - "io/ioutil" - "log" - "net" "github.com/oarkflow/mq" "github.com/oarkflow/mq/examples/tasks" ) func main() { - // Load the server's certificate and key - cert, err := tls.LoadX509KeyPair("server.crt", "server.key") - if err != nil { - log.Fatalf("Failed to load server certificate and key: %v", err) - } - - // Load the CA certificate - caCert, err := ioutil.ReadFile("ca.crt") - if err != nil { - log.Fatalf("Failed to read CA certificate: %v", err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - - // Configure TLS for the server - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - ClientCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, // Mutual TLS - } - - // Start a TLS listener - listener, err := tls.Listen("tcp", ":8443", tlsConfig) - if err != nil { - log.Fatalf("Failed to start TLS listener: %v", err) - } - defer listener.Close() - - b := mq.NewBroker(mq.WithCallback(tasks.Callback)) + b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "server.crt", "server.key"), mq.WithCAPath("ca.cert")) b.NewQueue("queue1") b.NewQueue("queue2") - - log.Println("TLS-enabled broker started on :8443") - - // Handle incoming connections - for { - conn, err := listener.Accept() - if err != nil { - fmt.Println("Error accepting connection:", err) - continue - } - go handleConnection(b, conn) - } -} - -func handleConnection(b *mq.Broker, conn net.Conn) { - defer conn.Close() - ctx := context.Background() - b.Start(ctx) + b.Start(context.Background()) } diff --git a/options.go b/options.go index ee055e1..cf78ea8 100644 --- a/options.go +++ b/options.go @@ -43,12 +43,18 @@ func WithBrokerURL(url string) Option { } } -// Option to enable/disable TLS -func WithTLS(enableTLS bool, certPath, keyPath, caPath string) Option { +// WithTLS - Option to enable/disable TLS +func WithTLS(enableTLS bool, certPath, keyPath string) Option { return func(o *Options) { o.useTLS = enableTLS o.tlsCertPath = certPath o.tlsKeyPath = keyPath + } +} + +// WithCAPath - Option to enable/disable TLS +func WithCAPath(caPath string) Option { + return func(o *Options) { o.tlsCAPath = caPath } }