diff --git a/README.md b/README.md index 3e4313e..16ee096 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,11 @@ The tool supports multiple concurrent clients, configurable message size, etc: > mqtt-benchmark --help Usage of mqtt-benchmark: -broker="tcp://localhost:1883": MQTT broker endpoint as scheme://host:port + -cert="cert.pem": File path to your client certificate in PEM format -clients=10: Number of clients to start -count=100: Number of messages to send per client -format="text": Output format: text|json + -key="key.pem": File path to your private key in PEM format -password="": MQTT password (empty if auth disabled) -qos=1: QoS for published messages -quiet=false : Suppress logs while running (except errors and the result) diff --git a/client.go b/client.go index a61bbaf..662d009 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "fmt" "log" "time" @@ -22,6 +23,7 @@ type Client struct { MsgQoS byte Quiet bool WaitTimeout time.Duration + TlsConfig *tls.Config } func (c *Client) Run(res chan *RunResults) { @@ -136,6 +138,10 @@ func (c *Client) pubMessages(in, out chan *Message, doneGen, donePub chan bool) opts.SetUsername(c.BrokerUser) opts.SetPassword(c.BrokerPass) } + if c.TlsConfig != nil { + opts.SetTLSConfig(c.TlsConfig) + } + client := mqtt.NewClient(opts) token := client.Connect() token.Wait() diff --git a/main.go b/main.go index 0c58e11..730f40e 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "crypto/tls" "encoding/json" "flag" "fmt" @@ -69,6 +70,8 @@ func main() { clients = flag.Int("clients", 10, "Number of clients to start") format = flag.String("format", "text", "Output format: text|json") quiet = flag.Bool("quiet", false, "Suppress logs while running") + cert = flag.String("cert", "", "File path to your client certificate in PEM format") + key = flag.String("key", "", "File path to your private key in PEM format") ) flag.Parse() @@ -80,6 +83,11 @@ func main() { log.Fatalf("Invalid arguments: messages count should be > 1, given: %v", *count) } + var tlsConfig *tls.Config + if *cert != "" && *key != "" { + tlsConfig = generateTlsConfig(*cert, *key) + } + resCh := make(chan *RunResults) start := time.Now() for i := 0; i < *clients; i++ { @@ -97,6 +105,7 @@ func main() { MsgQoS: byte(*qos), Quiet: *quiet, WaitTimeout: time.Duration(*wait) * time.Millisecond, + TlsConfig: tlsConfig, } go c.Run(resCh) } @@ -192,3 +201,19 @@ func printResults(results []*RunResults, totals *TotalResults, format string) { } return } + +func generateTlsConfig(certFile string, keyFile string) *tls.Config { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalf("Error reading certificate files: %v", err) + } + + cfg := tls.Config{ + ClientAuth: tls.NoClientCert, + ClientCAs: nil, + InsecureSkipVerify: true, + Certificates: []tls.Certificate{cert}, + } + + return &cfg +}