mirror of
				https://github.com/mochi-mqtt/server.git
				synced 2025-10-31 19:42:38 +08:00 
			
		
		
		
	Compare commits
	
		
			49 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | ca849131eb | ||
|   | ba7e534122 | ||
|   | db760c34a5 | ||
|   | ae3ee81bb4 | ||
|   | c2ca02d149 | ||
|   | 77a64d9c87 | ||
|   | 8dec9cc962 | ||
|   | f90e52328d | ||
|   | 50aae47618 | ||
|   | 0d79f2d63b | ||
|   | 300152413c | ||
|   | 0de1d731db | ||
|   | 80746abc52 | ||
|   | a73cf4ca0e | ||
|   | bc549ee7ed | ||
|   | c464b46713 | ||
|   | 05ce56008c | ||
|   | 8254cb0cbc | ||
|   | 4ae58b79e3 | ||
|   | b895d688e0 | ||
|   | a600cd4ead | ||
|   | cdb44990cf | ||
|   | 2d9c128111 | ||
|   | a0d5bdb39f | ||
|   | 4ebcef3cb6 | ||
|   | fb8d4720d7 | ||
|   | 4080c89127 | ||
|   | 1b67e6f3f6 | ||
|   | 1adb02e087 | ||
|   | 4d4140aa99 | ||
|   | e31840a37d | ||
|   | 7d2e16f2d3 | ||
|   | 92cd935a16 | ||
|   | 25ce27ce2d | ||
|   | 527d084a4b | ||
|   | bb9f937bb0 | ||
|   | 511fe88684 | ||
|   | 75504ff201 | ||
|   | a556feb325 | ||
|   | d06f47f4b9 | ||
|   | 8d4cc091b4 | ||
|   | d8f28cb843 | ||
|   | 88861c219d | ||
|   | 7ba6cf28d9 | ||
|   | c174cfdc6b | ||
|   | 4f198a99dd | ||
|   | 2a9c9fcc40 | ||
|   | 835a85c8bf | ||
|   | fe5d9ffa61 | 
							
								
								
									
										53
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								README.md
									
									
									
									
									
								
							| @@ -2,7 +2,7 @@ | |||||||
| <p align="center"> | <p align="center"> | ||||||
|  |  | ||||||
|   |   | ||||||
| [](https://coveralls.io/github/mochi-co/mqtt?branch=master) | [](https://coveralls.io/github/mochi-co/mqtt?branch=master) | ||||||
| [](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2) | [](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2) | ||||||
| [](https://pkg.go.dev/github.com/mochi-co/mqtt/v2) | [](https://pkg.go.dev/github.com/mochi-co/mqtt/v2) | ||||||
| [](https://github.com/mochi-co/mqtt/issues) | [](https://github.com/mochi-co/mqtt/issues) | ||||||
| @@ -83,22 +83,26 @@ docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest | |||||||
| Importing Mochi MQTT as a package requires just a few lines of code to get started. | Importing Mochi MQTT as a package requires just a few lines of code to get started. | ||||||
| ``` go | ``` go | ||||||
| import ( | import ( | ||||||
|  |   "log" | ||||||
|  |  | ||||||
|   "github.com/mochi-co/mqtt/v2" |   "github.com/mochi-co/mqtt/v2" | ||||||
|  |   "github.com/mochi-co/mqtt/v2/hooks/auth" | ||||||
|  |   "github.com/mochi-co/mqtt/v2/listeners" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func main() { | func main() { | ||||||
|   // Create the new MQTT Server. |   // Create the new MQTT Server. | ||||||
|   server := mqtt.New(nil) |   server := mqtt.New(nil) | ||||||
|  |    | ||||||
|   // Allow all connections. |   // Allow all connections. | ||||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) |   _ = server.AddHook(new(auth.AllowHook), nil) | ||||||
|  |    | ||||||
|   // Create a TCP listener on a standard port. |   // Create a TCP listener on a standard port. | ||||||
| 	tcp := listeners.NewTCP("t1", *tcpAddr, nil) |   tcp := listeners.NewTCP("t1", ":1883", nil) | ||||||
| 	err := server.AddListener(tcp) |   err := server.AddListener(tcp) | ||||||
| 	if err != nil { |   if err != nil { | ||||||
| 		log.Fatal(err) |     log.Fatal(err) | ||||||
| 	} |   } | ||||||
|    |    | ||||||
|   err = server.Serve() |   err = server.Serve() | ||||||
|   if err != nil { |   if err != nil { | ||||||
| @@ -112,10 +116,15 @@ Examples of running the broker with various configurations can be found in the [ | |||||||
| #### Network Listeners | #### Network Listeners | ||||||
| The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are: | The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are: | ||||||
|  |  | ||||||
| - `listeners.NewTCP(...)` - A TCP listener. | | Listener | Usage | | ||||||
| - `listeners.NewWebsocket(...)` A Websocket listener. | | --- | --- | | ||||||
| - `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard. | | listeners.NewTCP | A TCP listener | | ||||||
| - Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know! | | listeners.NewUnixSock | A Unix Socket listener | | ||||||
|  | | listeners.NewNet | A net.Listener listener | | ||||||
|  | | listeners.NewWebsocket | A Websocket listener | | ||||||
|  | | listeners.NewHTTPStats | An HTTP $SYS info dashboard | | ||||||
|  |  | ||||||
|  | > Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know! | ||||||
|  |  | ||||||
| A `*listeners.Config` may be passed to configure TLS.  | A `*listeners.Config` may be passed to configure TLS.  | ||||||
|  |  | ||||||
| @@ -296,7 +305,6 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found | |||||||
| | OnWillSent | Called when an LWT message has been issued from a disconnecting client. |  | | OnWillSent | Called when an LWT message has been issued from a disconnecting client. |  | ||||||
| | OnClientExpired | Called when a client session has expired and should be deleted. |  | | OnClientExpired | Called when a client session has expired and should be deleted. |  | ||||||
| | OnRetainedExpired | Called when a retained message has expired and should be deleted. |  | | OnRetainedExpired | Called when a retained message has expired and should be deleted. |  | ||||||
| | OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|  |  | ||||||
| | StoredClients |  Returns clients, eg. from a persistent store. |  | | StoredClients |  Returns clients, eg. from a persistent store. |  | ||||||
| | StoredSubscriptions |  Returns client subscriptions, eg. from a persistent store. |  | | StoredSubscriptions |  Returns client subscriptions, eg. from a persistent store. |  | ||||||
| | StoredInflightMessages | Returns inflight messages, eg. from a persistent store.  |  | | StoredInflightMessages | Returns inflight messages, eg. from a persistent store.  |  | ||||||
| @@ -305,13 +313,22 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found | |||||||
|  |  | ||||||
| If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`. | If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`. | ||||||
|  |  | ||||||
| ### Packet Injection |  | ||||||
| It's also possible to inject custom MQTT packets directly into the runtime as though they had been received by a specific client. This special client is called an InlineClient, and it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.  |  | ||||||
|  |  | ||||||
| Packet injection can be used with MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirement). | ### Direct Publish | ||||||
|  | To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method. | ||||||
|  |  | ||||||
| ```go | ```go | ||||||
| cl := mqtt.NewInlineClient("inline", "local") | err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0) | ||||||
|  | ``` | ||||||
|  | > The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec. | ||||||
|  |  | ||||||
|  | ### Packet Injection | ||||||
|  | If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.  | ||||||
|  |  | ||||||
|  | Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements). | ||||||
|  |  | ||||||
|  | ```go | ||||||
|  | cl := server.NewClient(nil, "local", "inline", true) | ||||||
| server.InjectPacket(cl, packets.Packet{ | server.InjectPacket(cl, packets.Packet{ | ||||||
|   FixedHeader: packets.FixedHeader{ |   FixedHeader: packets.FixedHeader{ | ||||||
|     Type: packets.Publish, |     Type: packets.Publish, | ||||||
|   | |||||||
							
								
								
									
										91
									
								
								clients.go
									
									
									
									
									
								
							
							
						
						
									
										91
									
								
								clients.go
									
									
									
									
									
								
							| @@ -106,7 +106,7 @@ type Client struct { | |||||||
|  |  | ||||||
| // ClientConnection contains the connection transport and metadata for the client. | // ClientConnection contains the connection transport and metadata for the client. | ||||||
| type ClientConnection struct { | type ClientConnection struct { | ||||||
| 	conn     net.Conn          // the net.Conn used to establish the connection | 	Conn     net.Conn          // the net.Conn used to establish the connection | ||||||
| 	bconn    *bufio.ReadWriter // a buffered net.Conn for reading packets | 	bconn    *bufio.ReadWriter // a buffered net.Conn for reading packets | ||||||
| 	Remote   string            // the remote address of the client | 	Remote   string            // the remote address of the client | ||||||
| 	Listener string            // listener id of the client | 	Listener string            // listener id of the client | ||||||
| @@ -146,14 +146,10 @@ type ClientState struct { | |||||||
| 	keepalive     uint16         // the number of seconds the connection can wait | 	keepalive     uint16         // the number of seconds the connection can wait | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewClient returns a new instance of Client. | // newClient returns a new instance of Client. This is almost exclusively used by Server | ||||||
| func NewClient(c net.Conn, o *ops) *Client { | // for creating new clients, but it lives here because it's not dependent. | ||||||
|  | func newClient(c net.Conn, o *ops) *Client { | ||||||
| 	cl := &Client{ | 	cl := &Client{ | ||||||
| 		Net: ClientConnection{ |  | ||||||
| 			conn:   c, |  | ||||||
| 			bconn:  bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)), |  | ||||||
| 			Remote: c.RemoteAddr().String(), |  | ||||||
| 		}, |  | ||||||
| 		State: ClientState{ | 		State: ClientState{ | ||||||
| 			Inflight:      NewInflights(), | 			Inflight:      NewInflights(), | ||||||
| 			Subscriptions: NewSubscriptions(), | 			Subscriptions: NewSubscriptions(), | ||||||
| @@ -166,46 +162,19 @@ func NewClient(c net.Conn, o *ops) *Client { | |||||||
| 		ops: o, | 		ops: o, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if c != nil { | ||||||
|  | 		cl.Net = ClientConnection{ | ||||||
|  | 			Conn:   c, | ||||||
|  | 			bconn:  bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)), | ||||||
|  | 			Remote: c.RemoteAddr().String(), | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	cl.refreshDeadline(cl.State.keepalive) | 	cl.refreshDeadline(cl.State.keepalive) | ||||||
|  |  | ||||||
| 	return cl | 	return cl | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewInlineClient returns a client used when publishing from the embedding system. |  | ||||||
| func NewInlineClient(id, remote string) *Client { |  | ||||||
| 	return &Client{ |  | ||||||
| 		ID: id, |  | ||||||
| 		Net: ClientConnection{ |  | ||||||
| 			Remote: remote, |  | ||||||
| 			Inline: true, |  | ||||||
| 		}, |  | ||||||
| 		State: ClientState{ |  | ||||||
| 			Inflight:      NewInflights(), |  | ||||||
| 			Subscriptions: NewSubscriptions(), |  | ||||||
| 			TopicAliases:  NewTopicAliases(0), |  | ||||||
| 		}, |  | ||||||
| 		Properties: ClientProperties{ |  | ||||||
| 			ProtocolVersion: defaultClientProtocolVersion, // default protocol version |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // newClientStub returns an instance of Client with minimal initializations, such as |  | ||||||
| // restoring client data from a db. In particular, the client is marked as offline (done). |  | ||||||
| func newClientStub() *Client { |  | ||||||
| 	return &Client{ |  | ||||||
| 		State: ClientState{ |  | ||||||
| 			Inflight:      NewInflights(), |  | ||||||
| 			Subscriptions: NewSubscriptions(), |  | ||||||
| 			TopicAliases:  NewTopicAliases(0), |  | ||||||
| 			done:          1, |  | ||||||
| 		}, |  | ||||||
| 		Properties: ClientProperties{ |  | ||||||
| 			ProtocolVersion: defaultClientProtocolVersion, // default protocol version |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ParseConnect parses the connect parameters and properties for a client. | // ParseConnect parses the connect parameters and properties for a client. | ||||||
| func (cl *Client) ParseConnect(lid string, pk packets.Packet) { | func (cl *Client) ParseConnect(lid string, pk packets.Packet) { | ||||||
| 	cl.Net.Listener = lid | 	cl.Net.Listener = lid | ||||||
| @@ -254,13 +223,13 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) { | |||||||
|  |  | ||||||
| // refreshDeadline refreshes the read/write deadline for the net.Conn connection. | // refreshDeadline refreshes the read/write deadline for the net.Conn connection. | ||||||
| func (cl *Client) refreshDeadline(keepalive uint16) { | func (cl *Client) refreshDeadline(keepalive uint16) { | ||||||
| 	if cl.Net.conn != nil { | 	if cl.Net.Conn != nil { | ||||||
| 		var expiry time.Time // nil time can be used to disable deadline if keepalive = 0 | 		var expiry time.Time // nil time can be used to disable deadline if keepalive = 0 | ||||||
| 		if keepalive > 0 { | 		if keepalive > 0 { | ||||||
| 			expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22] | 			expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22] | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22] | 		_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22] | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -303,7 +272,7 @@ func (cl *Client) ResendInflightMessages(force bool) error { | |||||||
| 			tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3] | 			tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3] | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		//	cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends) | 		cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0) | ||||||
| 		err := cl.WritePacket(tk) | 		err := cl.WritePacket(tk) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| @@ -321,17 +290,18 @@ func (cl *Client) ResendInflightMessages(force bool) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| // ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session. | // ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session. | ||||||
| func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 { | func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 { | ||||||
| 	var deleted int64 | 	deleted := []uint16{} | ||||||
| 	for _, tk := range cl.State.Inflight.GetAll(false) { | 	for _, tk := range cl.State.Inflight.GetAll(false) { | ||||||
| 		if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now { | 		if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now { | ||||||
| 			if ok := cl.State.Inflight.Delete(tk.PacketID); ok { | 			if ok := cl.State.Inflight.Delete(tk.PacketID); ok { | ||||||
| 				cl.ops.hooks.OnQosDropped(cl, tk) | 				cl.ops.hooks.OnQosDropped(cl, tk) | ||||||
| 				atomic.AddInt64(&cl.ops.info.Inflight, -1) | 				atomic.AddInt64(&cl.ops.info.Inflight, -1) | ||||||
| 				deleted++ | 				deleted = append(deleted, tk.PacketID) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return deleted | 	return deleted | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -371,8 +341,8 @@ func (cl *Client) Stop(err error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cl.State.endOnce.Do(func() { | 	cl.State.endOnce.Do(func() { | ||||||
| 		if cl.Net.conn != nil { | 		if cl.Net.Conn != nil { | ||||||
| 			_ = cl.Net.conn.Close() // omit close error | 			_ = cl.Net.Conn.Close() // omit close error | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -392,6 +362,11 @@ func (cl *Client) StopCause() error { | |||||||
| 	return cl.State.stopCause.Load().(error) | 	return cl.State.stopCause.Load().(error) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Closed returns true if client connection is closed. | ||||||
|  | func (cl *Client) Closed() bool { | ||||||
|  | 	return atomic.LoadUint32(&cl.State.done) == 1 | ||||||
|  | } | ||||||
|  |  | ||||||
| // ReadFixedHeader reads in the values of the next packet's fixed header. | // ReadFixedHeader reads in the values of the next packet's fixed header. | ||||||
| func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error { | func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error { | ||||||
| 	if cl.Net.bconn == nil { | 	if cl.Net.bconn == nil { | ||||||
| @@ -414,6 +389,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize { | ||||||
|  | 		return packets.ErrPacketTooLarge // [MQTT-3.2.2-15] | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1)) | 	atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1)) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -485,11 +464,10 @@ func (cl *Client) WritePacket(pk packets.Packet) error { | |||||||
| 		return ErrConnectionClosed | 		return ErrConnectionClosed | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if cl.Net.conn == nil { | 	if cl.Net.Conn == nil { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	defer cl.refreshDeadline(cl.State.keepalive) |  | ||||||
| 	if pk.Expiry > 0 { | 	if pk.Expiry > 0 { | ||||||
| 		pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6] | 		pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6] | ||||||
| 	} | 	} | ||||||
| @@ -503,8 +481,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error { | |||||||
| 		pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set | 		pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo { | 	if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo { | ||||||
| 		pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode | 		pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	pk = cl.ops.hooks.OnPacketEncode(cl, pk) | 	pk = cl.ops.hooks.OnPacketEncode(cl, pk) | ||||||
| @@ -554,7 +532,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	nb := net.Buffers{buf.Bytes()} | 	nb := net.Buffers{buf.Bytes()} | ||||||
| 	n, err := nb.WriteTo(cl.Net.conn) | 	n, err := nb.WriteTo(cl.Net.Conn) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -565,6 +543,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error { | |||||||
| 		atomic.AddInt64(&cl.ops.info.MessagesSent, 1) | 		atomic.AddInt64(&cl.ops.info.MessagesSent, 1) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	cl.refreshDeadline(cl.State.keepalive) | ||||||
| 	cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes()) | 	cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes()) | ||||||
|  |  | ||||||
| 	return err | 	return err | ||||||
|   | |||||||
							
								
								
									
										133
									
								
								clients_test.go
									
									
									
									
									
								
							
							
						
						
									
										133
									
								
								clients_test.go
									
									
									
									
									
								
							| @@ -22,10 +22,10 @@ const pkInfo = "packet type %v, %s" | |||||||
|  |  | ||||||
| var errClientStop = errors.New("test stop") | var errClientStop = errors.New("test stop") | ||||||
|  |  | ||||||
| func newClient() (cl *Client, r net.Conn, w net.Conn) { | func newTestClient() (cl *Client, r net.Conn, w net.Conn) { | ||||||
| 	r, w = net.Pipe() | 	r, w = net.Pipe() | ||||||
|  |  | ||||||
| 	cl = NewClient(w, &ops{ | 	cl = newClient(w, &ops{ | ||||||
| 		info:  new(system.Info), | 		info:  new(system.Info), | ||||||
| 		hooks: new(Hooks), | 		hooks: new(Hooks), | ||||||
| 		log:   &logger, | 		log:   &logger, | ||||||
| @@ -119,34 +119,21 @@ func TestClientsGetByListener(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestNewClient(t *testing.T) { | func TestNewClient(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
|  |  | ||||||
| 	require.NotNil(t, cl) | 	require.NotNil(t, cl) | ||||||
| 	require.NotNil(t, cl.State.Inflight.internal) | 	require.NotNil(t, cl.State.Inflight.internal) | ||||||
| 	require.NotNil(t, cl.State.Subscriptions) | 	require.NotNil(t, cl.State.Subscriptions) | ||||||
| 	require.Nil(t, cl.StopCause()) | 	require.NotNil(t, cl.State.TopicAliases) | ||||||
| } | 	require.Equal(t, defaultKeepalive, cl.State.keepalive) | ||||||
|  | 	require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion) | ||||||
| func TestNewClientStub(t *testing.T) { | 	require.NotNil(t, cl.Net.Conn) | ||||||
| 	cl := newClientStub() | 	require.NotNil(t, cl.Net.bconn) | ||||||
| 	require.NotNil(t, cl) | 	require.False(t, cl.Net.Inline) | ||||||
| 	require.NotNil(t, cl.State.Inflight.internal) |  | ||||||
| 	require.NotNil(t, cl.State.Subscriptions) |  | ||||||
| 	require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done)) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestNewInlineClient(t *testing.T) { |  | ||||||
| 	cl := NewInlineClient("inline", "local") |  | ||||||
| 	require.NotNil(t, cl) |  | ||||||
| 	require.NotNil(t, cl.State.Inflight.internal) |  | ||||||
| 	require.NotNil(t, cl.State.Subscriptions) |  | ||||||
| 	require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done)) |  | ||||||
| 	require.Equal(t, "inline", cl.ID) |  | ||||||
| 	require.Equal(t, "local", cl.Net.Remote) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientParseConnect(t *testing.T) { | func TestClientParseConnect(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
|  |  | ||||||
| 	pk := packets.Packet{ | 	pk := packets.Packet{ | ||||||
| 		ProtocolVersion: 4, | 		ProtocolVersion: 4, | ||||||
| @@ -183,7 +170,7 @@ func TestClientParseConnect(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientParseConnectOverrideWillDelay(t *testing.T) { | func TestClientParseConnectOverrideWillDelay(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
|  |  | ||||||
| 	pk := packets.Packet{ | 	pk := packets.Packet{ | ||||||
| 		ProtocolVersion: 4, | 		ProtocolVersion: 4, | ||||||
| @@ -208,13 +195,13 @@ func TestClientParseConnectOverrideWillDelay(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientParseConnectNoID(t *testing.T) { | func TestClientParseConnectNoID(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.ParseConnect("tcp1", packets.Packet{}) | 	cl.ParseConnect("tcp1", packets.Packet{}) | ||||||
| 	require.NotEmpty(t, cl.ID) | 	require.NotEmpty(t, cl.ID) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientNextPacketID(t *testing.T) { | func TestClientNextPacketID(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
|  |  | ||||||
| 	i, err := cl.NextPacketID() | 	i, err := cl.NextPacketID() | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| @@ -226,7 +213,7 @@ func TestClientNextPacketID(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientNextPacketIDInUse(t *testing.T) { | func TestClientNextPacketIDInUse(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
|  |  | ||||||
| 	// skip over 2 | 	// skip over 2 | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | ||||||
| @@ -249,7 +236,7 @@ func TestClientNextPacketIDInUse(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientNextPacketIDExhausted(t *testing.T) { | func TestClientNextPacketIDExhausted(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	for i := 0; i <= 65535; i++ { | 	for i := 0; i <= 65535; i++ { | ||||||
| 		cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) | 		cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) | ||||||
| 	} | 	} | ||||||
| @@ -261,7 +248,7 @@ func TestClientNextPacketIDExhausted(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientNextPacketIDOverflow(t *testing.T) { | func TestClientNextPacketIDOverflow(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
|  |  | ||||||
| 	cl.State.packetID = uint32(65534) | 	cl.State.packetID = uint32(65534) | ||||||
|  |  | ||||||
| @@ -275,7 +262,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientClearInflights(t *testing.T) { | func TestClientClearInflights(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
|  |  | ||||||
| 	n := time.Now().Unix() | 	n := time.Now().Unix() | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1}) | ||||||
| @@ -285,13 +272,15 @@ func TestClientClearInflights(t *testing.T) { | |||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n}) | ||||||
| 	require.Equal(t, 5, cl.State.Inflight.Len()) | 	require.Equal(t, 5, cl.State.Inflight.Len()) | ||||||
|  |  | ||||||
| 	cl.ClearInflights(n, 4) | 	deleted := cl.ClearInflights(n, 4) | ||||||
|  | 	require.Len(t, deleted, 3) | ||||||
|  | 	require.ElementsMatch(t, []uint16{1, 2, 5}, deleted) | ||||||
| 	require.Equal(t, 2, cl.State.Inflight.Len()) | 	require.Equal(t, 2, cl.State.Inflight.Len()) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientResendInflightMessages(t *testing.T) { | func TestClientResendInflightMessages(t *testing.T) { | ||||||
| 	pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback) | 	pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback) | ||||||
| 	cl, r, w := newClient() | 	cl, r, w := newTestClient() | ||||||
|  |  | ||||||
| 	cl.State.Inflight.Set(*pk1.Packet) | 	cl.State.Inflight.Set(*pk1.Packet) | ||||||
| 	require.Equal(t, 1, cl.State.Inflight.Len()) | 	require.Equal(t, 1, cl.State.Inflight.Len()) | ||||||
| @@ -311,7 +300,7 @@ func TestClientResendInflightMessages(t *testing.T) { | |||||||
|  |  | ||||||
| func TestClientResendInflightMessagesWriteFailure(t *testing.T) { | func TestClientResendInflightMessagesWriteFailure(t *testing.T) { | ||||||
| 	pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup) | 	pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup) | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	r.Close() | 	r.Close() | ||||||
|  |  | ||||||
| 	cl.State.Inflight.Set(*pk1.Packet) | 	cl.State.Inflight.Set(*pk1.Packet) | ||||||
| @@ -323,19 +312,19 @@ func TestClientResendInflightMessagesWriteFailure(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientResendInflightMessagesNoMessages(t *testing.T) { | func TestClientResendInflightMessagesNoMessages(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	err := cl.ResendInflightMessages(true) | 	err := cl.ResendInflightMessages(true) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientRefreshDeadline(t *testing.T) { | func TestClientRefreshDeadline(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.refreshDeadline(10) | 	cl.refreshDeadline(10) | ||||||
| 	require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline? | 	require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline? | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadFixedHeader(t *testing.T) { | func TestClientReadFixedHeader(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
|  |  | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
| 	go func() { | 	go func() { | ||||||
| @@ -350,7 +339,7 @@ func TestClientReadFixedHeader(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadFixedHeaderDecodeError(t *testing.T) { | func TestClientReadFixedHeaderDecodeError(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| @@ -363,8 +352,24 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) { | |||||||
| 	require.Error(t, err) | 	require.Error(t, err) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestClientReadFixedHeaderPacketOversized(t *testing.T) { | ||||||
|  | 	cl, r, _ := newTestClient() | ||||||
|  | 	cl.ops.capabilities.MaximumPacketSize = 2 | ||||||
|  | 	defer cl.Stop(errClientStop) | ||||||
|  |  | ||||||
|  | 	go func() { | ||||||
|  | 		r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes) | ||||||
|  | 		r.Close() | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	fh := new(packets.FixedHeader) | ||||||
|  | 	err := cl.ReadFixedHeader(fh) | ||||||
|  | 	require.Error(t, err) | ||||||
|  | 	require.ErrorIs(t, err, packets.ErrPacketTooLarge) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestClientReadFixedHeaderReadEOF(t *testing.T) { | func TestClientReadFixedHeaderReadEOF(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| @@ -378,7 +383,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { | func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| @@ -392,7 +397,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadOK(t *testing.T) { | func TestClientReadOK(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		r.Write([]byte{ | 		r.Write([]byte{ | ||||||
| @@ -446,7 +451,7 @@ func TestClientReadOK(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadDone(t *testing.T) { | func TestClientReadDone(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
| 	cl.State.done = 1 | 	cl.State.done = 1 | ||||||
|  |  | ||||||
| @@ -461,15 +466,23 @@ func TestClientReadDone(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientStop(t *testing.T) { | func TestClientStop(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.Stop(nil) | 	cl.Stop(nil) | ||||||
| 	require.Equal(t, nil, cl.State.stopCause.Load()) | 	require.Equal(t, nil, cl.State.stopCause.Load()) | ||||||
| 	require.Equal(t, time.Now().Unix(), cl.State.disconnected) | 	require.Equal(t, time.Now().Unix(), cl.State.disconnected) | ||||||
| 	require.Equal(t, uint32(1), cl.State.done) | 	require.Equal(t, uint32(1), cl.State.done) | ||||||
|  | 	require.Equal(t, nil, cl.StopCause()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestClientClosed(t *testing.T) { | ||||||
|  | 	cl, _, _ := newTestClient() | ||||||
|  | 	require.False(t, cl.Closed()) | ||||||
|  | 	cl.Stop(nil) | ||||||
|  | 	require.True(t, cl.Closed()) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadFixedHeaderError(t *testing.T) { | func TestClientReadFixedHeaderError(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		r.Write([]byte{ | 		r.Write([]byte{ | ||||||
| @@ -486,7 +499,7 @@ func TestClientReadFixedHeaderError(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadReadHandlerErr(t *testing.T) { | func TestClientReadReadHandlerErr(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		r.Write([]byte{ | 		r.Write([]byte{ | ||||||
| @@ -506,7 +519,7 @@ func TestClientReadReadHandlerErr(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadReadPacketOK(t *testing.T) { | func TestClientReadReadPacketOK(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		r.Write([]byte{ | 		r.Write([]byte{ | ||||||
| @@ -538,7 +551,7 @@ func TestClientReadReadPacketOK(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadPacket(t *testing.T) { | func TestClientReadPacket(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
|  |  | ||||||
| 	for _, tx := range pkTable { | 	for _, tx := range pkTable { | ||||||
| @@ -571,9 +584,17 @@ func TestClientReadPacket(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestClientReadPacketInvalidTypeError(t *testing.T) { | ||||||
|  | 	cl, _, _ := newTestClient() | ||||||
|  | 	cl.Net.Conn.Close() | ||||||
|  | 	_, err := cl.ReadPacket(&packets.FixedHeader{}) | ||||||
|  | 	require.Error(t, err) | ||||||
|  | 	require.Contains(t, err.Error(), "invalid packet type") | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestClientWritePacket(t *testing.T) { | func TestClientWritePacket(t *testing.T) { | ||||||
| 	for _, tt := range pkTable { | 	for _, tt := range pkTable { | ||||||
| 		cl, r, _ := newClient() | 		cl, r, _ := newTestClient() | ||||||
| 		defer cl.Stop(errClientStop) | 		defer cl.Stop(errClientStop) | ||||||
|  |  | ||||||
| 		cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion | 		cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion | ||||||
| @@ -589,7 +610,7 @@ func TestClientWritePacket(t *testing.T) { | |||||||
| 		require.NoError(t, err, pkInfo, tt.Case, tt.Desc) | 		require.NoError(t, err, pkInfo, tt.Case, tt.Desc) | ||||||
|  |  | ||||||
| 		time.Sleep(2 * time.Millisecond) | 		time.Sleep(2 * time.Millisecond) | ||||||
| 		cl.Net.conn.Close() | 		cl.Net.Conn.Close() | ||||||
|  |  | ||||||
| 		require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc) | 		require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc) | ||||||
|  |  | ||||||
| @@ -613,7 +634,7 @@ func TestClientWritePacket(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestWriteClientOversizePacket(t *testing.T) { | func TestWriteClientOversizePacket(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.Properties.Props.MaximumPacketSize = 2 | 	cl.Properties.Props.MaximumPacketSize = 2 | ||||||
| 	pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet | 	pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet | ||||||
| 	err := cl.WritePacket(pk) | 	err := cl.WritePacket(pk) | ||||||
| @@ -622,7 +643,7 @@ func TestWriteClientOversizePacket(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadPacketReadingError(t *testing.T) { | func TestClientReadPacketReadingError(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		r.Write([]byte{ | 		r.Write([]byte{ | ||||||
| @@ -642,7 +663,7 @@ func TestClientReadPacketReadingError(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientReadPacketReadUnknown(t *testing.T) { | func TestClientReadPacketReadUnknown(t *testing.T) { | ||||||
| 	cl, r, _ := newClient() | 	cl, r, _ := newTestClient() | ||||||
| 	defer cl.Stop(errClientStop) | 	defer cl.Stop(errClientStop) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		r.Write([]byte{ | 		r.Write([]byte{ | ||||||
| @@ -661,7 +682,7 @@ func TestClientReadPacketReadUnknown(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientWritePacketWriteNoConn(t *testing.T) { | func TestClientWritePacketWriteNoConn(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.Stop(errClientStop) | 	cl.Stop(errClientStop) | ||||||
|  |  | ||||||
| 	err := cl.WritePacket(*pkTable[1].Packet) | 	err := cl.WritePacket(*pkTable[1].Packet) | ||||||
| @@ -670,15 +691,15 @@ func TestClientWritePacketWriteNoConn(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientWritePacketWriteError(t *testing.T) { | func TestClientWritePacketWriteError(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.Net.conn.Close() | 	cl.Net.Conn.Close() | ||||||
|  |  | ||||||
| 	err := cl.WritePacket(*pkTable[1].Packet) | 	err := cl.WritePacket(*pkTable[1].Packet) | ||||||
| 	require.Error(t, err) | 	require.Error(t, err) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestClientWritePacketInvalidPacket(t *testing.T) { | func TestClientWritePacketInvalidPacket(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	err := cl.WritePacket(packets.Packet{}) | 	err := cl.WritePacket(packets.Packet{}) | ||||||
| 	require.Error(t, err) | 	require.Error(t, err) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -36,7 +36,7 @@ func main() { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = server.AddHook(new(debug.Hook), &debug.Options{ | 	err = server.AddHook(new(debug.Hook), &debug.Options{ | ||||||
| 		ShowPacketData: true, | 		// ShowPacketData: true, | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatal(err) | 		log.Fatal(err) | ||||||
|   | |||||||
| @@ -52,15 +52,30 @@ func main() { | |||||||
| 	// `server.Publish` method. Subscribe to `direct/publish` using your | 	// `server.Publish` method. Subscribe to `direct/publish` using your | ||||||
| 	// MQTT client to see the messages. | 	// MQTT client to see the messages. | ||||||
| 	go func() { | 	go func() { | ||||||
| 		cl := mqtt.NewInlineClient("inline", "local") | 		cl := server.NewClient(nil, "local", "inline", true) | ||||||
| 		for range time.Tick(time.Second * 10) { | 		for range time.Tick(time.Second * 1) { | ||||||
| 			server.InjectPacket(cl, packets.Packet{ | 			err := server.InjectPacket(cl, packets.Packet{ | ||||||
| 				FixedHeader: packets.FixedHeader{ | 				FixedHeader: packets.FixedHeader{ | ||||||
| 					Type: packets.Publish, | 					Type: packets.Publish, | ||||||
| 				}, | 				}, | ||||||
| 				TopicName: "direct/publish", | 				TopicName: "direct/publish", | ||||||
| 				Payload:   []byte("scheduled message"), | 				Payload:   []byte("injected scheduled message"), | ||||||
| 			}) | 			}) | ||||||
|  | 			if err != nil { | ||||||
|  | 				server.Log.Error().Err(err).Msg("server.InjectPacket") | ||||||
|  | 			} | ||||||
|  | 			server.Log.Info().Msgf("main.go injected packet to direct/publish") | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	// There is also a shorthand convenience function, Publish, for easily sending | ||||||
|  | 	// publish packets if you are not concerned with creating your own packets. | ||||||
|  | 	go func() { | ||||||
|  | 		for range time.Tick(time.Second * 5) { | ||||||
|  | 			err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0) | ||||||
|  | 			if err != nil { | ||||||
|  | 				server.Log.Error().Err(err).Msg("server.Publish") | ||||||
|  | 			} | ||||||
| 			server.Log.Info().Msgf("main.go issued direct message to direct/publish") | 			server.Log.Info().Msgf("main.go issued direct message to direct/publish") | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
|   | |||||||
| @@ -29,7 +29,6 @@ func main() { | |||||||
| 	server.Options.Capabilities.ServerKeepAlive = 60 | 	server.Options.Capabilities.ServerKeepAlive = 60 | ||||||
| 	server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true | 	server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true | ||||||
| 	server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true | 	server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true | ||||||
| 	server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true |  | ||||||
|  |  | ||||||
| 	_ = server.AddHook(new(pahoAuthHook), nil) | 	_ = server.AddHook(new(pahoAuthHook), nil) | ||||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||||
|   | |||||||
| @@ -30,12 +30,15 @@ func main() { | |||||||
| 	server := mqtt.New(nil) | 	server := mqtt.New(nil) | ||||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||||
|  |  | ||||||
| 	err := server.AddHook(new(bolt.Hook), bolt.Options{ | 	err := server.AddHook(new(bolt.Hook), &bolt.Options{ | ||||||
| 		Path: "bolt.db", | 		Path: "bolt.db", | ||||||
| 		Options: &bbolt.Options{ | 		Options: &bbolt.Options{ | ||||||
| 			Timeout: 500 * time.Millisecond, | 			Timeout: 500 * time.Millisecond, | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||||
| 	err = server.AddListener(tcp) | 	err = server.AddListener(tcp) | ||||||
|   | |||||||
							
								
								
									
										121
									
								
								hooks.go
									
									
									
									
									
								
							
							
						
						
									
										121
									
								
								hooks.go
									
									
									
									
									
								
							| @@ -1,6 +1,6 @@ | |||||||
| // SPDX-License-Identifier: MIT | // SPDX-License-Identifier: MIT | ||||||
| // SPDX-FileCopyrightText: 2022 mochi-co | // SPDX-FileCopyrightText: 2022 mochi-co | ||||||
| // SPDX-FileContributor: mochi-co | // SPDX-FileContributor: mochi-co, thedevop | ||||||
|  |  | ||||||
| package mqtt | package mqtt | ||||||
|  |  | ||||||
| @@ -47,7 +47,6 @@ const ( | |||||||
| 	OnWillSent | 	OnWillSent | ||||||
| 	OnClientExpired | 	OnClientExpired | ||||||
| 	OnRetainedExpired | 	OnRetainedExpired | ||||||
| 	OnExpireInflights |  | ||||||
| 	StoredClients | 	StoredClients | ||||||
| 	StoredSubscriptions | 	StoredSubscriptions | ||||||
| 	StoredInflightMessages | 	StoredInflightMessages | ||||||
| @@ -96,7 +95,6 @@ type Hook interface { | |||||||
| 	OnWillSent(cl *Client, pk packets.Packet) | 	OnWillSent(cl *Client, pk packets.Packet) | ||||||
| 	OnClientExpired(cl *Client) | 	OnClientExpired(cl *Client) | ||||||
| 	OnRetainedExpired(filter string) | 	OnRetainedExpired(filter string) | ||||||
| 	OnExpireInflights(cl *Client, expiry int64) |  | ||||||
| 	StoredClients() ([]storage.Client, error) | 	StoredClients() ([]storage.Client, error) | ||||||
| 	StoredSubscriptions() ([]storage.Subscription, error) | 	StoredSubscriptions() ([]storage.Subscription, error) | ||||||
| 	StoredInflightMessages() ([]storage.Message, error) | 	StoredInflightMessages() ([]storage.Message, error) | ||||||
| @@ -112,10 +110,10 @@ type HookOptions struct { | |||||||
| // Hooks is a slice of Hook interfaces to be called in sequence. | // Hooks is a slice of Hook interfaces to be called in sequence. | ||||||
| type Hooks struct { | type Hooks struct { | ||||||
| 	Log        *zerolog.Logger // a logger for the hook (from the server) | 	Log        *zerolog.Logger // a logger for the hook (from the server) | ||||||
| 	internal   []Hook          // a slice of hooks | 	internal   atomic.Value    // a slice of []Hook | ||||||
| 	wg         sync.WaitGroup  // a waitgroup for syncing hook shutdown | 	wg         sync.WaitGroup  // a waitgroup for syncing hook shutdown | ||||||
| 	qty        int64           // the number of hooks in use | 	qty        int64           // the number of hooks in use | ||||||
| 	sync.Mutex                 // a mutex | 	sync.Mutex                 // a mutex for locking when adding hooks | ||||||
| } | } | ||||||
|  |  | ||||||
| // Len returns the number of hooks added. | // Len returns the number of hooks added. | ||||||
| @@ -125,7 +123,7 @@ func (h *Hooks) Len() int64 { | |||||||
|  |  | ||||||
| // Provides returns true if any one hook provides any of the requested hook methods. | // Provides returns true if any one hook provides any of the requested hook methods. | ||||||
| func (h *Hooks) Provides(b ...byte) bool { | func (h *Hooks) Provides(b ...byte) bool { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		for _, hb := range b { | 		for _, hb := range b { | ||||||
| 			if hook.Provides(hb) { | 			if hook.Provides(hb) { | ||||||
| 				return true | 				return true | ||||||
| @@ -140,26 +138,39 @@ func (h *Hooks) Provides(b ...byte) bool { | |||||||
| func (h *Hooks) Add(hook Hook, config any) error { | func (h *Hooks) Add(hook Hook, config any) error { | ||||||
| 	h.Lock() | 	h.Lock() | ||||||
| 	defer h.Unlock() | 	defer h.Unlock() | ||||||
| 	if h.internal == nil { |  | ||||||
| 		h.internal = []Hook{} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	err := hook.Init(config) | 	err := hook.Init(config) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err) | 		return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	h.internal = append(h.internal, hook) | 	i, ok := h.internal.Load().([]Hook) | ||||||
|  | 	if !ok { | ||||||
|  | 		i = []Hook{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	i = append(i, hook) | ||||||
|  | 	h.internal.Store(i) | ||||||
| 	atomic.AddInt64(&h.qty, 1) | 	atomic.AddInt64(&h.qty, 1) | ||||||
| 	h.wg.Add(1) | 	h.wg.Add(1) | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // GetAll returns a slice of all the hooks. | ||||||
|  | func (h *Hooks) GetAll() []Hook { | ||||||
|  | 	i, ok := h.internal.Load().([]Hook) | ||||||
|  | 	if !ok { | ||||||
|  | 		return []Hook{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return i | ||||||
|  | } | ||||||
|  |  | ||||||
| // Stop indicates all attached hooks to gracefully end. | // Stop indicates all attached hooks to gracefully end. | ||||||
| func (h *Hooks) Stop() { | func (h *Hooks) Stop() { | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for _, hook := range h.internal { | 		for _, hook := range h.GetAll() { | ||||||
| 			h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook") | 			h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook") | ||||||
| 			if err := hook.Stop(); err != nil { | 			if err := hook.Stop(); err != nil { | ||||||
| 				h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook") | 				h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook") | ||||||
| @@ -174,7 +185,7 @@ func (h *Hooks) Stop() { | |||||||
|  |  | ||||||
| // OnSysInfoTick is called when the $SYS topic values are published out. | // OnSysInfoTick is called when the $SYS topic values are published out. | ||||||
| func (h *Hooks) OnSysInfoTick(sys *system.Info) { | func (h *Hooks) OnSysInfoTick(sys *system.Info) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnSysInfoTick) { | 		if hook.Provides(OnSysInfoTick) { | ||||||
| 			hook.OnSysInfoTick(sys) | 			hook.OnSysInfoTick(sys) | ||||||
| 		} | 		} | ||||||
| @@ -183,7 +194,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) { | |||||||
|  |  | ||||||
| // OnStarted is called when the server has successfully started. | // OnStarted is called when the server has successfully started. | ||||||
| func (h *Hooks) OnStarted() { | func (h *Hooks) OnStarted() { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnStarted) { | 		if hook.Provides(OnStarted) { | ||||||
| 			hook.OnStarted() | 			hook.OnStarted() | ||||||
| 		} | 		} | ||||||
| @@ -192,7 +203,7 @@ func (h *Hooks) OnStarted() { | |||||||
|  |  | ||||||
| // OnStopped is called when the server has successfully stopped. | // OnStopped is called when the server has successfully stopped. | ||||||
| func (h *Hooks) OnStopped() { | func (h *Hooks) OnStopped() { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnStopped) { | 		if hook.Provides(OnStopped) { | ||||||
| 			hook.OnStopped() | 			hook.OnStopped() | ||||||
| 		} | 		} | ||||||
| @@ -201,7 +212,7 @@ func (h *Hooks) OnStopped() { | |||||||
|  |  | ||||||
| // OnConnect is called when a new client connects. | // OnConnect is called when a new client connects. | ||||||
| func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) { | func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnConnect) { | 		if hook.Provides(OnConnect) { | ||||||
| 			hook.OnConnect(cl, pk) | 			hook.OnConnect(cl, pk) | ||||||
| 		} | 		} | ||||||
| @@ -210,7 +221,7 @@ func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) { | |||||||
|  |  | ||||||
| // OnSessionEstablished is called when a new client establishes a session (after OnConnect). | // OnSessionEstablished is called when a new client establishes a session (after OnConnect). | ||||||
| func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) { | func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnSessionEstablished) { | 		if hook.Provides(OnSessionEstablished) { | ||||||
| 			hook.OnSessionEstablished(cl, pk) | 			hook.OnSessionEstablished(cl, pk) | ||||||
| 		} | 		} | ||||||
| @@ -219,7 +230,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) { | |||||||
|  |  | ||||||
| // OnDisconnect is called when a client is disconnected for any reason. | // OnDisconnect is called when a client is disconnected for any reason. | ||||||
| func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) { | func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnDisconnect) { | 		if hook.Provides(OnDisconnect) { | ||||||
| 			hook.OnDisconnect(cl, err, expire) | 			hook.OnDisconnect(cl, err, expire) | ||||||
| 		} | 		} | ||||||
| @@ -229,7 +240,7 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) { | |||||||
| // OnPacketRead is called when a packet is received from a client. | // OnPacketRead is called when a packet is received from a client. | ||||||
| func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | ||||||
| 	pkx = pk | 	pkx = pk | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnPacketRead) { | 		if hook.Provides(OnPacketRead) { | ||||||
| 			npk, err := hook.OnPacketRead(cl, pkx) | 			npk, err := hook.OnPacketRead(cl, pkx) | ||||||
| 			if err != nil && errors.Is(err, packets.ErrRejectPacket) { | 			if err != nil && errors.Is(err, packets.ErrRejectPacket) { | ||||||
| @@ -250,7 +261,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, | |||||||
| // to create their own auth packet handling mechanisms. | // to create their own auth packet handling mechanisms. | ||||||
| func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | ||||||
| 	pkx = pk | 	pkx = pk | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnAuthPacket) { | 		if hook.Provides(OnAuthPacket) { | ||||||
| 			npk, err := hook.OnAuthPacket(cl, pkx) | 			npk, err := hook.OnAuthPacket(cl, pkx) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -266,7 +277,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, | |||||||
|  |  | ||||||
| // OnPacketEncode is called immediately before a packet is encoded to be sent to a client. | // OnPacketEncode is called immediately before a packet is encoded to be sent to a client. | ||||||
| func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet { | func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnPacketEncode) { | 		if hook.Provides(OnPacketEncode) { | ||||||
| 			pk = hook.OnPacketEncode(cl, pk) | 			pk = hook.OnPacketEncode(cl, pk) | ||||||
| 		} | 		} | ||||||
| @@ -277,7 +288,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet { | |||||||
|  |  | ||||||
| // OnPacketProcessed is called when a packet has been received and successfully handled by the broker. | // OnPacketProcessed is called when a packet has been received and successfully handled by the broker. | ||||||
| func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) { | func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnPacketProcessed) { | 		if hook.Provides(OnPacketProcessed) { | ||||||
| 			hook.OnPacketProcessed(cl, pk, err) | 			hook.OnPacketProcessed(cl, pk, err) | ||||||
| 		} | 		} | ||||||
| @@ -287,7 +298,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) { | |||||||
| // OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter | // OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter | ||||||
| // containing the bytes sent. | // containing the bytes sent. | ||||||
| func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) { | func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnPacketSent) { | 		if hook.Provides(OnPacketSent) { | ||||||
| 			hook.OnPacketSent(cl, pk, b) | 			hook.OnPacketSent(cl, pk, b) | ||||||
| 		} | 		} | ||||||
| @@ -299,7 +310,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) { | |||||||
| // before the packet is processed. The return values of the hook methods are passed-through | // before the packet is processed. The return values of the hook methods are passed-through | ||||||
| // in the order the hooks were attached. | // in the order the hooks were attached. | ||||||
| func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet { | func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnSubscribe) { | 		if hook.Provides(OnSubscribe) { | ||||||
| 			pk = hook.OnSubscribe(cl, pk) | 			pk = hook.OnSubscribe(cl, pk) | ||||||
| 		} | 		} | ||||||
| @@ -309,7 +320,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet { | |||||||
|  |  | ||||||
| // OnSubscribed is called when a client subscribes to one or more filters. | // OnSubscribed is called when a client subscribes to one or more filters. | ||||||
| func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) { | func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnSubscribed) { | 		if hook.Provides(OnSubscribed) { | ||||||
| 			hook.OnSubscribed(cl, pk, reasonCodes) | 			hook.OnSubscribed(cl, pk, reasonCodes) | ||||||
| 		} | 		} | ||||||
| @@ -321,7 +332,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) | |||||||
| // remove or add clients to a publish to subscribers process, or to select the subscriber for a shared | // remove or add clients to a publish to subscribers process, or to select the subscriber for a shared | ||||||
| // group in a custom manner (such as based on client id, ip, etc). | // group in a custom manner (such as based on client id, ip, etc). | ||||||
| func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers { | func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnSelectSubscribers) { | 		if hook.Provides(OnSelectSubscribers) { | ||||||
| 			subs = hook.OnSelectSubscribers(subs, pk) | 			subs = hook.OnSelectSubscribers(subs, pk) | ||||||
| 		} | 		} | ||||||
| @@ -334,7 +345,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc | |||||||
| // before the packet is processed. The return values of the hook methods are passed-through | // before the packet is processed. The return values of the hook methods are passed-through | ||||||
| // in the order the hooks were attached. | // in the order the hooks were attached. | ||||||
| func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet { | func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnUnsubscribe) { | 		if hook.Provides(OnUnsubscribe) { | ||||||
| 			pk = hook.OnUnsubscribe(cl, pk) | 			pk = hook.OnUnsubscribe(cl, pk) | ||||||
| 		} | 		} | ||||||
| @@ -344,19 +355,19 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet { | |||||||
|  |  | ||||||
| // OnUnsubscribed is called when a client unsubscribes from one or more filters. | // OnUnsubscribed is called when a client unsubscribes from one or more filters. | ||||||
| func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) { | func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnUnsubscribed) { | 		if hook.Provides(OnUnsubscribed) { | ||||||
| 			hook.OnUnsubscribed(cl, pk) | 			hook.OnUnsubscribed(cl, pk) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // OnPublish is called when a client publishes a message. This method differs from OnMessage | // OnPublish is called when a client publishes a message. This method differs from OnPublished | ||||||
| // in that it allows you to modify you to modify the incoming packet before it is processed. | // in that it allows you to modify you to modify the incoming packet before it is processed. | ||||||
| // The return values of the hook methods are passed-through in the order the hooks were attached. | // The return values of the hook methods are passed-through in the order the hooks were attached. | ||||||
| func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | ||||||
| 	pkx = pk | 	pkx = pk | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnPublish) { | 		if hook.Provides(OnPublish) { | ||||||
| 			npk, err := hook.OnPublish(cl, pkx) | 			npk, err := hook.OnPublish(cl, pkx) | ||||||
| 			if err != nil && errors.Is(err, packets.ErrRejectPacket) { | 			if err != nil && errors.Is(err, packets.ErrRejectPacket) { | ||||||
| @@ -375,7 +386,7 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er | |||||||
|  |  | ||||||
| // OnPublished is called when a client has published a message to subscribers. | // OnPublished is called when a client has published a message to subscribers. | ||||||
| func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) { | func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnPublished) { | 		if hook.Provides(OnPublished) { | ||||||
| 			hook.OnPublished(cl, pk) | 			hook.OnPublished(cl, pk) | ||||||
| 		} | 		} | ||||||
| @@ -384,7 +395,7 @@ func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) { | |||||||
|  |  | ||||||
| // OnRetainMessage is called then a published message is retained. | // OnRetainMessage is called then a published message is retained. | ||||||
| func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) { | func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnRetainMessage) { | 		if hook.Provides(OnRetainMessage) { | ||||||
| 			hook.OnRetainMessage(cl, pk, r) | 			hook.OnRetainMessage(cl, pk, r) | ||||||
| 		} | 		} | ||||||
| @@ -395,7 +406,7 @@ func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) { | |||||||
| // In other words, this method is called when a new inflight message is created or resent. | // In other words, this method is called when a new inflight message is created or resent. | ||||||
| // It is typically used to store a new inflight message. | // It is typically used to store a new inflight message. | ||||||
| func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) { | func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnQosPublish) { | 		if hook.Provides(OnQosPublish) { | ||||||
| 			hook.OnQosPublish(cl, pk, sent, resends) | 			hook.OnQosPublish(cl, pk, sent, resends) | ||||||
| 		} | 		} | ||||||
| @@ -406,7 +417,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends | |||||||
| // In other words, when an inflight message is resolved. | // In other words, when an inflight message is resolved. | ||||||
| // It is typically used to delete an inflight message from a store. | // It is typically used to delete an inflight message from a store. | ||||||
| func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) { | func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnQosComplete) { | 		if hook.Provides(OnQosComplete) { | ||||||
| 			hook.OnQosComplete(cl, pk) | 			hook.OnQosComplete(cl, pk) | ||||||
| 		} | 		} | ||||||
| @@ -414,10 +425,10 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // OnQosDropped is called the Qos flow for a message expires. In other words, when | // OnQosDropped is called the Qos flow for a message expires. In other words, when | ||||||
| // an inflight message expires or is abandoned. | // an inflight message expires or is abandoned. It is typically used to delete an | ||||||
| // It is typically used to delete an inflight message from a store. | // inflight message from a store. | ||||||
| func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) { | func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnQosDropped) { | 		if hook.Provides(OnQosDropped) { | ||||||
| 			hook.OnQosDropped(cl, pk) | 			hook.OnQosDropped(cl, pk) | ||||||
| 		} | 		} | ||||||
| @@ -429,7 +440,7 @@ func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) { | |||||||
| // published. The return values of the hook methods are passed-through in the order | // published. The return values of the hook methods are passed-through in the order | ||||||
| // the hooks were attached. | // the hooks were attached. | ||||||
| func (h *Hooks) OnWill(cl *Client, will Will) Will { | func (h *Hooks) OnWill(cl *Client, will Will) Will { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnWill) { | 		if hook.Provides(OnWill) { | ||||||
| 			mlwt, err := hook.OnWill(cl, will) | 			mlwt, err := hook.OnWill(cl, will) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -445,7 +456,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will { | |||||||
|  |  | ||||||
| // OnWillSent is called when an LWT message has been issued from a disconnecting client. | // OnWillSent is called when an LWT message has been issued from a disconnecting client. | ||||||
| func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) { | func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnWillSent) { | 		if hook.Provides(OnWillSent) { | ||||||
| 			hook.OnWillSent(cl, pk) | 			hook.OnWillSent(cl, pk) | ||||||
| 		} | 		} | ||||||
| @@ -454,7 +465,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) { | |||||||
|  |  | ||||||
| // OnClientExpired is called when a client session has expired and should be deleted. | // OnClientExpired is called when a client session has expired and should be deleted. | ||||||
| func (h *Hooks) OnClientExpired(cl *Client) { | func (h *Hooks) OnClientExpired(cl *Client) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnClientExpired) { | 		if hook.Provides(OnClientExpired) { | ||||||
| 			hook.OnClientExpired(cl) | 			hook.OnClientExpired(cl) | ||||||
| 		} | 		} | ||||||
| @@ -463,7 +474,7 @@ func (h *Hooks) OnClientExpired(cl *Client) { | |||||||
|  |  | ||||||
| // OnRetainedExpired is called when a retained message has expired and should be deleted. | // OnRetainedExpired is called when a retained message has expired and should be deleted. | ||||||
| func (h *Hooks) OnRetainedExpired(filter string) { | func (h *Hooks) OnRetainedExpired(filter string) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnRetainedExpired) { | 		if hook.Provides(OnRetainedExpired) { | ||||||
| 			hook.OnRetainedExpired(filter) | 			hook.OnRetainedExpired(filter) | ||||||
| 		} | 		} | ||||||
| @@ -473,7 +484,7 @@ func (h *Hooks) OnRetainedExpired(filter string) { | |||||||
| // StoredClients returns all clients, e.g. from a persistent store, is used to | // StoredClients returns all clients, e.g. from a persistent store, is used to | ||||||
| // populate the server clients list before start. | // populate the server clients list before start. | ||||||
| func (h *Hooks) StoredClients() (v []storage.Client, err error) { | func (h *Hooks) StoredClients() (v []storage.Client, err error) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(StoredClients) { | 		if hook.Provides(StoredClients) { | ||||||
| 			v, err := hook.StoredClients() | 			v, err := hook.StoredClients() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -493,7 +504,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) { | |||||||
| // StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is | // StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is | ||||||
| // used to populate the server subscriptions list before start. | // used to populate the server subscriptions list before start. | ||||||
| func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) { | func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(StoredSubscriptions) { | 		if hook.Provides(StoredSubscriptions) { | ||||||
| 			v, err := hook.StoredSubscriptions() | 			v, err := hook.StoredSubscriptions() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -513,7 +524,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) { | |||||||
| // StoredInflightMessages returns all inflight messages, e.g. from a persistent store, | // StoredInflightMessages returns all inflight messages, e.g. from a persistent store, | ||||||
| // and is used to populate the restored clients with inflight messages before start. | // and is used to populate the restored clients with inflight messages before start. | ||||||
| func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) { | func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(StoredInflightMessages) { | 		if hook.Provides(StoredInflightMessages) { | ||||||
| 			v, err := hook.StoredInflightMessages() | 			v, err := hook.StoredInflightMessages() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -533,7 +544,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) { | |||||||
| // StoredRetainedMessages returns all retained messages, e.g. from a persistent store, | // StoredRetainedMessages returns all retained messages, e.g. from a persistent store, | ||||||
| // and is used to populate the server topics with retained messages before start. | // and is used to populate the server topics with retained messages before start. | ||||||
| func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) { | func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(StoredRetainedMessages) { | 		if hook.Provides(StoredRetainedMessages) { | ||||||
| 			v, err := hook.StoredRetainedMessages() | 			v, err := hook.StoredRetainedMessages() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -552,7 +563,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) { | |||||||
|  |  | ||||||
| // StoredSysInfo returns a set of system info values. | // StoredSysInfo returns a set of system info values. | ||||||
| func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) { | func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(StoredSysInfo) { | 		if hook.Provides(StoredSysInfo) { | ||||||
| 			v, err := hook.StoredSysInfo() | 			v, err := hook.StoredSysInfo() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -574,7 +585,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) { | |||||||
| // server (see hooks/auth/allow_all or basic). It can be used in custom hooks to | // server (see hooks/auth/allow_all or basic). It can be used in custom hooks to | ||||||
| // check connecting users against an existing user database. | // check connecting users against an existing user database. | ||||||
| func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { | func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnConnectAuthenticate) { | 		if hook.Provides(OnConnectAuthenticate) { | ||||||
| 			if ok := hook.OnConnectAuthenticate(cl, pk); ok { | 			if ok := hook.OnConnectAuthenticate(cl, pk); ok { | ||||||
| 				return true | 				return true | ||||||
| @@ -590,7 +601,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { | |||||||
| // (see hooks/auth/allow_all or basic). It can be used in custom hooks to | // (see hooks/auth/allow_all or basic). It can be used in custom hooks to | ||||||
| // check publishing and subscribing users against an existing permissions or roles database. | // check publishing and subscribing users against an existing permissions or roles database. | ||||||
| func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { | func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { | ||||||
| 	for _, hook := range h.internal { | 	for _, hook := range h.GetAll() { | ||||||
| 		if hook.Provides(OnACLCheck) { | 		if hook.Provides(OnACLCheck) { | ||||||
| 			if ok := hook.OnACLCheck(cl, topic, write); ok { | 			if ok := hook.OnACLCheck(cl, topic, write); ok { | ||||||
| 				return true | 				return true | ||||||
| @@ -601,19 +612,6 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { | |||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
| // OnExpireInflights is called when the server issues a clear request for expired |  | ||||||
| // inflight messages. Expiry should be the time after which the message is no longer |  | ||||||
| // valid (usually some time in the past). A message has expired if it's created time |  | ||||||
| // is older than time.Now() minus Inflight TTL. This method can be used to expire |  | ||||||
| // old inflight messages in a persistent store which doesnt support per-item TTL. |  | ||||||
| func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) { |  | ||||||
| 	for _, hook := range h.internal { |  | ||||||
| 		if hook.Provides(OnExpireInflights) { |  | ||||||
| 			hook.OnExpireInflights(cl, expiry) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // HookBase provides a set of default methods for each hook. It should be embedded in | // HookBase provides a set of default methods for each hook. It should be embedded in | ||||||
| // all hooks. | // all hooks. | ||||||
| type HookBase struct { | type HookBase struct { | ||||||
| @@ -755,9 +753,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {} | |||||||
| // OnRetainedExpired is called when a retained message for a topic has expired. | // OnRetainedExpired is called when a retained message for a topic has expired. | ||||||
| func (h *HookBase) OnRetainedExpired(topic string) {} | func (h *HookBase) OnRetainedExpired(topic string) {} | ||||||
|  |  | ||||||
| // OnExpireInflights is called when the server issues a clear request for expired inflight messages. |  | ||||||
| func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {} |  | ||||||
|  |  | ||||||
| // StoredClients returns all clients from a store. | // StoredClients returns all clients from a store. | ||||||
| func (h *HookBase) StoredClients() (v []storage.Client, err error) { | func (h *HookBase) StoredClients() (v []storage.Client, err error) { | ||||||
| 	return | 	return | ||||||
|   | |||||||
| @@ -80,7 +80,6 @@ func (h *Hook) Provides(b byte) bool { | |||||||
| 		mqtt.OnSysInfoTick, | 		mqtt.OnSysInfoTick, | ||||||
| 		mqtt.OnClientExpired, | 		mqtt.OnClientExpired, | ||||||
| 		mqtt.OnRetainedExpired, | 		mqtt.OnRetainedExpired, | ||||||
| 		mqtt.OnExpireInflights, |  | ||||||
| 		mqtt.StoredClients, | 		mqtt.StoredClients, | ||||||
| 		mqtt.StoredInflightMessages, | 		mqtt.StoredInflightMessages, | ||||||
| 		mqtt.StoredRetainedMessages, | 		mqtt.StoredRetainedMessages, | ||||||
| @@ -199,11 +198,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by | |||||||
| 	var in *storage.Subscription | 	var in *storage.Subscription | ||||||
| 	for i := 0; i < len(pk.Filters); i++ { | 	for i := 0; i < len(pk.Filters); i++ { | ||||||
| 		in = &storage.Subscription{ | 		in = &storage.Subscription{ | ||||||
| 			ID:     subscriptionKey(cl, pk.Filters[i].Filter), | 			ID:                subscriptionKey(cl, pk.Filters[i].Filter), | ||||||
| 			T:      storage.SubscriptionKey, | 			T:                 storage.SubscriptionKey, | ||||||
| 			Client: cl.ID, | 			Client:            cl.ID, | ||||||
| 			Filter: pk.Filters[i].Filter, | 			Qos:               reasonCodes[i], | ||||||
| 			Qos:    reasonCodes[i], | 			Filter:            pk.Filters[i].Filter, | ||||||
|  | 			Identifier:        pk.Filters[i].Identifier, | ||||||
|  | 			NoLocal:           pk.Filters[i].NoLocal, | ||||||
|  | 			RetainHandling:    pk.Filters[i].RetainHandling, | ||||||
|  | 			RetainAsPublished: pk.Filters[i].RetainAsPublished, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		err := h.db.Upsert(in.ID, in) | 		err := h.db.Upsert(in.ID, in) | ||||||
| @@ -348,32 +351,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // OnExpireInflights removes all inflight messages which have passed the provided expiry time. | // OnRetainedExpired deletes expired retained messages from the store. | ||||||
| func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) { | func (h *Hook) OnRetainedExpired(filter string) { | ||||||
| 	if h.db == nil { | 	if h.db == nil { | ||||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var v []storage.Message |  | ||||||
| 	err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey)) |  | ||||||
| 	if err != nil && !errors.Is(err, badgerhold.ErrNotFound) { |  | ||||||
| 		h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data") |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, m := range v { |  | ||||||
| 		if m.Created < expiry || m.Created == 0 { |  | ||||||
| 			err := h.db.Delete(m.ID, new(storage.Message)) |  | ||||||
| 			if err != nil { |  | ||||||
| 				h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data") |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // OnRetainedExpired deletes expired retained messages from the store. |  | ||||||
| func (h *Hook) OnRetainedExpired(filter string) { |  | ||||||
| 	err := h.db.Delete(retainedKey(filter), new(storage.Message)) | 	err := h.db.Delete(retainedKey(filter), new(storage.Message)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data") | 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data") | ||||||
| @@ -382,6 +366,11 @@ func (h *Hook) OnRetainedExpired(filter string) { | |||||||
|  |  | ||||||
| // OnClientExpired deleted expired clients from the store. | // OnClientExpired deleted expired clients from the store. | ||||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||||
|  | 	if h.db == nil { | ||||||
|  | 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	err := h.db.Delete(clientKey(cl), new(storage.Client)) | 	err := h.db.Delete(clientKey(cl), new(storage.Client)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data") | 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data") | ||||||
|   | |||||||
| @@ -5,13 +5,11 @@ | |||||||
| package badger | package badger | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"errors" |  | ||||||
| 	"os" | 	"os" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/asdine/storm/v3" |  | ||||||
| 	"github.com/mochi-co/mqtt/v2" | 	"github.com/mochi-co/mqtt/v2" | ||||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||||
| 	"github.com/mochi-co/mqtt/v2/packets" | 	"github.com/mochi-co/mqtt/v2/packets" | ||||||
| @@ -170,6 +168,21 @@ func TestOnClientExpired(t *testing.T) { | |||||||
| 	require.ErrorIs(t, badgerhold.ErrNotFound, err) | 	require.ErrorIs(t, badgerhold.ErrNotFound, err) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestOnClientExpiredNoDB(t *testing.T) { | ||||||
|  | 	h := new(Hook) | ||||||
|  | 	h.SetOpts(&logger, nil) | ||||||
|  | 	h.OnClientExpired(client) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestOnClientExpiredClosedDB(t *testing.T) { | ||||||
|  | 	h := new(Hook) | ||||||
|  | 	h.SetOpts(&logger, nil) | ||||||
|  | 	err := h.Init(nil) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	teardown(t, h.config.Path, h) | ||||||
|  | 	h.OnClientExpired(client) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestOnSessionEstablishedNoDB(t *testing.T) { | func TestOnSessionEstablishedNoDB(t *testing.T) { | ||||||
| 	h := new(Hook) | 	h := new(Hook) | ||||||
| 	h.SetOpts(&logger, nil) | 	h.SetOpts(&logger, nil) | ||||||
| @@ -333,6 +346,21 @@ func TestOnRetainedExpired(t *testing.T) { | |||||||
| 	require.ErrorIs(t, err, badgerhold.ErrNotFound) | 	require.ErrorIs(t, err, badgerhold.ErrNotFound) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestOnRetainExpiredNoDB(t *testing.T) { | ||||||
|  | 	h := new(Hook) | ||||||
|  | 	h.SetOpts(&logger, nil) | ||||||
|  | 	h.OnRetainedExpired("a/b/c") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestOnRetainExpiredClosedDB(t *testing.T) { | ||||||
|  | 	h := new(Hook) | ||||||
|  | 	h.SetOpts(&logger, nil) | ||||||
|  | 	err := h.Init(nil) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	teardown(t, h.config.Path, h) | ||||||
|  | 	h.OnRetainedExpired("a/b/c") | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestOnRetainMessageNoDB(t *testing.T) { | func TestOnRetainMessageNoDB(t *testing.T) { | ||||||
| 	h := new(Hook) | 	h := new(Hook) | ||||||
| 	h.SetOpts(&logger, nil) | 	h.SetOpts(&logger, nil) | ||||||
| @@ -419,48 +447,6 @@ func TestOnQosDroppedNoDB(t *testing.T) { | |||||||
| 	h.OnQosDropped(client, packets.Packet{}) | 	h.OnQosDropped(client, packets.Packet{}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestOnExpireInflights(t *testing.T) { |  | ||||||
| 	h := new(Hook) |  | ||||||
| 	h.SetOpts(&logger, nil) |  | ||||||
|  |  | ||||||
| 	err := h.Init(nil) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	defer teardown(t, h.config.Path, h) |  | ||||||
|  |  | ||||||
| 	err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1}) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20}) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey}) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
|  |  | ||||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) |  | ||||||
|  |  | ||||||
| 	var v []storage.Message |  | ||||||
| 	err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey)) |  | ||||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	require.Len(t, v, 1) |  | ||||||
| 	require.Equal(t, "i1", v[0].ID) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestOnExpireInflightsNoDB(t *testing.T) { |  | ||||||
| 	h := new(Hook) |  | ||||||
| 	h.SetOpts(&logger, nil) |  | ||||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestOnExpireInflightsClosedDB(t *testing.T) { |  | ||||||
| 	h := new(Hook) |  | ||||||
| 	h.SetOpts(&logger, nil) |  | ||||||
| 	err := h.Init(nil) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	teardown(t, h.config.Path, h) |  | ||||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestOnSysInfoTick(t *testing.T) { | func TestOnSysInfoTick(t *testing.T) { | ||||||
| 	h := new(Hook) | 	h := new(Hook) | ||||||
| 	h.SetOpts(&logger, nil) | 	h.SetOpts(&logger, nil) | ||||||
|   | |||||||
| @@ -85,7 +85,6 @@ func (h *Hook) Provides(b byte) bool { | |||||||
| 		mqtt.OnSysInfoTick, | 		mqtt.OnSysInfoTick, | ||||||
| 		mqtt.OnClientExpired, | 		mqtt.OnClientExpired, | ||||||
| 		mqtt.OnRetainedExpired, | 		mqtt.OnRetainedExpired, | ||||||
| 		mqtt.OnExpireInflights, |  | ||||||
| 		mqtt.StoredClients, | 		mqtt.StoredClients, | ||||||
| 		mqtt.StoredInflightMessages, | 		mqtt.StoredInflightMessages, | ||||||
| 		mqtt.StoredRetainedMessages, | 		mqtt.StoredRetainedMessages, | ||||||
| @@ -201,12 +200,17 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by | |||||||
| 	var in *storage.Subscription | 	var in *storage.Subscription | ||||||
| 	for i := 0; i < len(pk.Filters); i++ { | 	for i := 0; i < len(pk.Filters); i++ { | ||||||
| 		in = &storage.Subscription{ | 		in = &storage.Subscription{ | ||||||
| 			ID:     subscriptionKey(cl, pk.Filters[i].Filter), | 			ID:                subscriptionKey(cl, pk.Filters[i].Filter), | ||||||
| 			T:      storage.SubscriptionKey, | 			T:                 storage.SubscriptionKey, | ||||||
| 			Client: cl.ID, | 			Client:            cl.ID, | ||||||
| 			Filter: pk.Filters[i].Filter, | 			Qos:               reasonCodes[i], | ||||||
| 			Qos:    reasonCodes[i], | 			Filter:            pk.Filters[i].Filter, | ||||||
|  | 			Identifier:        pk.Filters[i].Identifier, | ||||||
|  | 			NoLocal:           pk.Filters[i].NoLocal, | ||||||
|  | 			RetainHandling:    pk.Filters[i].RetainHandling, | ||||||
|  | 			RetainAsPublished: pk.Filters[i].RetainAsPublished, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		err := h.db.Save(in) | 		err := h.db.Save(in) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			h.Log.Error().Err(err). | 			h.Log.Error().Err(err). | ||||||
| @@ -369,34 +373,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // OnExpireInflights removes all inflight messages which have passed the | // OnRetainedExpired deletes expired retained messages from the store. | ||||||
| // provided expiry time. | func (h *Hook) OnRetainedExpired(filter string) { | ||||||
| func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) { |  | ||||||
| 	if h.db == nil { | 	if h.db == nil { | ||||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var v []storage.Message |  | ||||||
| 	err := h.db.Find("T", storage.InflightKey, &v) |  | ||||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { |  | ||||||
| 		h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data") |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, m := range v { |  | ||||||
| 		if m.Created < expiry || m.Created == 0 { |  | ||||||
| 			err := h.db.DeleteStruct(&storage.Message{ID: m.ID}) |  | ||||||
| 			if err != nil && !errors.Is(err, storm.ErrNotFound) { |  | ||||||
| 				h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data") |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // OnRetainedExpired deletes expired retained messages from the store. |  | ||||||
| func (h *Hook) OnRetainedExpired(filter string) { |  | ||||||
| 	if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil { | 	if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil { | ||||||
| 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish") | 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish") | ||||||
| 	} | 	} | ||||||
| @@ -404,6 +387,11 @@ func (h *Hook) OnRetainedExpired(filter string) { | |||||||
|  |  | ||||||
| // OnClientExpired deleted expired clients from the store. | // OnClientExpired deleted expired clients from the store. | ||||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||||
|  | 	if h.db == nil { | ||||||
|  | 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) | 	err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) | ||||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") | 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") | ||||||
|   | |||||||
| @@ -5,7 +5,6 @@ | |||||||
| package bolt | package bolt | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"errors" |  | ||||||
| 	"os" | 	"os" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -212,6 +211,21 @@ func TestOnClientExpired(t *testing.T) { | |||||||
| 	require.ErrorIs(t, storm.ErrNotFound, err) | 	require.ErrorIs(t, storm.ErrNotFound, err) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestOnClientExpiredClosedDB(t *testing.T) { | ||||||
|  | 	h := new(Hook) | ||||||
|  | 	h.SetOpts(&logger, nil) | ||||||
|  | 	err := h.Init(nil) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	teardown(t, h.config.Path, h) | ||||||
|  | 	h.OnClientExpired(client) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestOnClientExpiredNoDB(t *testing.T) { | ||||||
|  | 	h := new(Hook) | ||||||
|  | 	h.SetOpts(&logger, nil) | ||||||
|  | 	h.OnClientExpired(client) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestOnDisconnectNoDB(t *testing.T) { | func TestOnDisconnectNoDB(t *testing.T) { | ||||||
| 	h := new(Hook) | 	h := new(Hook) | ||||||
| 	h.SetOpts(&logger, nil) | 	h.SetOpts(&logger, nil) | ||||||
| @@ -341,6 +355,21 @@ func TestOnRetainedExpired(t *testing.T) { | |||||||
| 	require.Equal(t, storm.ErrNotFound, err) | 	require.Equal(t, storm.ErrNotFound, err) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestOnRetainedExpiredClosedDB(t *testing.T) { | ||||||
|  | 	h := new(Hook) | ||||||
|  | 	h.SetOpts(&logger, nil) | ||||||
|  | 	err := h.Init(nil) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	teardown(t, h.config.Path, h) | ||||||
|  | 	h.OnRetainedExpired("a/b/c") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestOnRetainedExpiredNoDB(t *testing.T) { | ||||||
|  | 	h := new(Hook) | ||||||
|  | 	h.SetOpts(&logger, nil) | ||||||
|  | 	h.OnRetainedExpired("a/b/c") | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestOnRetainMessageNoDB(t *testing.T) { | func TestOnRetainMessageNoDB(t *testing.T) { | ||||||
| 	h := new(Hook) | 	h := new(Hook) | ||||||
| 	h.SetOpts(&logger, nil) | 	h.SetOpts(&logger, nil) | ||||||
| @@ -427,48 +456,6 @@ func TestOnQosDroppedNoDB(t *testing.T) { | |||||||
| 	h.OnQosDropped(client, packets.Packet{}) | 	h.OnQosDropped(client, packets.Packet{}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestOnExpireInflights(t *testing.T) { |  | ||||||
| 	h := new(Hook) |  | ||||||
| 	h.SetOpts(&logger, nil) |  | ||||||
|  |  | ||||||
| 	err := h.Init(nil) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	defer teardown(t, h.config.Path, h) |  | ||||||
|  |  | ||||||
| 	err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1}) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20}) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey}) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
|  |  | ||||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) |  | ||||||
|  |  | ||||||
| 	var v []storage.Message |  | ||||||
| 	err = h.db.Find("T", storage.InflightKey, &v) |  | ||||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	require.Len(t, v, 1) |  | ||||||
| 	require.Equal(t, "i1", v[0].ID) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestOnExpireInflightsClosedDB(t *testing.T) { |  | ||||||
| 	h := new(Hook) |  | ||||||
| 	h.SetOpts(&logger, nil) |  | ||||||
| 	err := h.Init(nil) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	teardown(t, h.config.Path, h) |  | ||||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestOnExpireInflightsNoDB(t *testing.T) { |  | ||||||
| 	h := new(Hook) |  | ||||||
| 	h.SetOpts(&logger, nil) |  | ||||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestOnSysInfoTick(t *testing.T) { | func TestOnSysInfoTick(t *testing.T) { | ||||||
| 	h := new(Hook) | 	h := new(Hook) | ||||||
| 	h.SetOpts(&logger, nil) | 	h.SetOpts(&logger, nil) | ||||||
|   | |||||||
| @@ -83,7 +83,6 @@ func (h *Hook) Provides(b byte) bool { | |||||||
| 		mqtt.OnSysInfoTick, | 		mqtt.OnSysInfoTick, | ||||||
| 		mqtt.OnClientExpired, | 		mqtt.OnClientExpired, | ||||||
| 		mqtt.OnRetainedExpired, | 		mqtt.OnRetainedExpired, | ||||||
| 		mqtt.OnExpireInflights, |  | ||||||
| 		mqtt.StoredClients, | 		mqtt.StoredClients, | ||||||
| 		mqtt.StoredInflightMessages, | 		mqtt.StoredInflightMessages, | ||||||
| 		mqtt.StoredRetainedMessages, | 		mqtt.StoredRetainedMessages, | ||||||
| @@ -216,11 +215,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by | |||||||
| 	var in *storage.Subscription | 	var in *storage.Subscription | ||||||
| 	for i := 0; i < len(pk.Filters); i++ { | 	for i := 0; i < len(pk.Filters); i++ { | ||||||
| 		in = &storage.Subscription{ | 		in = &storage.Subscription{ | ||||||
| 			ID:     subscriptionKey(cl, pk.Filters[i].Filter), | 			ID:                subscriptionKey(cl, pk.Filters[i].Filter), | ||||||
| 			T:      storage.SubscriptionKey, | 			T:                 storage.SubscriptionKey, | ||||||
| 			Client: cl.ID, | 			Client:            cl.ID, | ||||||
| 			Filter: pk.Filters[i].Filter, | 			Qos:               reasonCodes[i], | ||||||
| 			Qos:    reasonCodes[i], | 			Filter:            pk.Filters[i].Filter, | ||||||
|  | 			Identifier:        pk.Filters[i].Identifier, | ||||||
|  | 			NoLocal:           pk.Filters[i].NoLocal, | ||||||
|  | 			RetainHandling:    pk.Filters[i].RetainHandling, | ||||||
|  | 			RetainAsPublished: pk.Filters[i].RetainAsPublished, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err() | 		err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err() | ||||||
| @@ -364,37 +367,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // OnExpireInflights removes all inflight messages which have passed the | // OnRetainedExpired deletes expired retained messages from the store. | ||||||
| // provided expiry time. | func (h *Hook) OnRetainedExpired(filter string) { | ||||||
| func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) { |  | ||||||
| 	if h.db == nil { | 	if h.db == nil { | ||||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() |  | ||||||
| 	if err != nil && !errors.Is(err, redis.Nil) { |  | ||||||
| 		h.Log.Error().Err(err).Msg("failed to HGetAll inflight data") |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, row := range rows { |  | ||||||
| 		var d storage.Message |  | ||||||
| 		if err = d.UnmarshalBinary([]byte(row)); err != nil { |  | ||||||
| 			h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if d.Created < expiry || d.Created == 0 { |  | ||||||
| 			err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err() |  | ||||||
| 			if err != nil { |  | ||||||
| 				h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data") |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // OnRetainedExpired deletes expired retained messages from the store. |  | ||||||
| func (h *Hook) OnRetainedExpired(filter string) { |  | ||||||
| 	err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err() | 	err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data") | 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data") | ||||||
| @@ -403,6 +382,11 @@ func (h *Hook) OnRetainedExpired(filter string) { | |||||||
|  |  | ||||||
| // OnClientExpired deleted expired clients from the store. | // OnClientExpired deleted expired clients from the store. | ||||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||||
|  | 	if h.db == nil { | ||||||
|  | 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() | 	err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") | 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") | ||||||
|   | |||||||
| @@ -253,6 +253,22 @@ func TestOnClientExpired(t *testing.T) { | |||||||
| 	require.ErrorIs(t, redis.Nil, err) | 	require.ErrorIs(t, redis.Nil, err) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestOnClientExpiredClosedDB(t *testing.T) { | ||||||
|  | 	s := miniredis.RunT(t) | ||||||
|  | 	defer s.Close() | ||||||
|  | 	h := newHook(t, s.Addr()) | ||||||
|  | 	teardown(t, h) | ||||||
|  | 	h.OnClientExpired(client) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestOnClientExpiredNoDB(t *testing.T) { | ||||||
|  | 	s := miniredis.RunT(t) | ||||||
|  | 	defer s.Close() | ||||||
|  | 	h := newHook(t, s.Addr()) | ||||||
|  | 	h.db = nil | ||||||
|  | 	h.OnClientExpired(client) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestOnDisconnectNoDB(t *testing.T) { | func TestOnDisconnectNoDB(t *testing.T) { | ||||||
| 	s := miniredis.RunT(t) | 	s := miniredis.RunT(t) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
| @@ -392,6 +408,22 @@ func TestOnRetainedExpired(t *testing.T) { | |||||||
| 	require.ErrorIs(t, err, redis.Nil) | 	require.ErrorIs(t, err, redis.Nil) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestOnRetainedExpiredClosedDB(t *testing.T) { | ||||||
|  | 	s := miniredis.RunT(t) | ||||||
|  | 	defer s.Close() | ||||||
|  | 	h := newHook(t, s.Addr()) | ||||||
|  | 	teardown(t, h) | ||||||
|  | 	h.OnRetainedExpired("a/b/c") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestOnRetainedExpiredNoDB(t *testing.T) { | ||||||
|  | 	s := miniredis.RunT(t) | ||||||
|  | 	defer s.Close() | ||||||
|  | 	h := newHook(t, s.Addr()) | ||||||
|  | 	h.db = nil | ||||||
|  | 	h.OnRetainedExpired("a/b/c") | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestOnRetainMessageNoDB(t *testing.T) { | func TestOnRetainMessageNoDB(t *testing.T) { | ||||||
| 	s := miniredis.RunT(t) | 	s := miniredis.RunT(t) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
| @@ -484,60 +516,6 @@ func TestOnQosDroppedNoDB(t *testing.T) { | |||||||
| 	h.OnQosDropped(client, packets.Packet{}) | 	h.OnQosDropped(client, packets.Packet{}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestOnExpireInflights(t *testing.T) { |  | ||||||
| 	s := miniredis.RunT(t) |  | ||||||
| 	defer s.Close() |  | ||||||
| 	h := newHook(t, s.Addr()) |  | ||||||
| 	defer teardown(t, h) |  | ||||||
|  |  | ||||||
| 	n := time.Now().Unix() |  | ||||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", |  | ||||||
| 		&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1}, |  | ||||||
| 	).Err() |  | ||||||
| 	require.NoError(t, err) |  | ||||||
|  |  | ||||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", |  | ||||||
| 		&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20}, |  | ||||||
| 	).Err() |  | ||||||
| 	require.NoError(t, err) |  | ||||||
|  |  | ||||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", |  | ||||||
| 		&storage.Message{ID: "i3", T: storage.InflightKey}, |  | ||||||
| 	).Err() |  | ||||||
| 	require.NoError(t, err) |  | ||||||
|  |  | ||||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) |  | ||||||
|  |  | ||||||
| 	var r []storage.Message |  | ||||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	require.Len(t, rows, 1) |  | ||||||
| 	for _, row := range rows { |  | ||||||
| 		var d storage.Message |  | ||||||
| 		err = d.UnmarshalBinary([]byte(row)) |  | ||||||
| 		require.NoError(t, err) |  | ||||||
| 		r = append(r, d) |  | ||||||
| 	} |  | ||||||
| 	require.Len(t, r, 1) |  | ||||||
| 	require.Equal(t, "i1", r[0].ID) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestOnExpireInflightsClosedDB(t *testing.T) { |  | ||||||
| 	s := miniredis.RunT(t) |  | ||||||
| 	defer s.Close() |  | ||||||
| 	h := newHook(t, s.Addr()) |  | ||||||
| 	teardown(t, h) |  | ||||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestOnExpireInflightsNoDB(t *testing.T) { |  | ||||||
| 	s := miniredis.RunT(t) |  | ||||||
| 	defer s.Close() |  | ||||||
| 	h := newHook(t, s.Addr()) |  | ||||||
| 	h.db = nil |  | ||||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestOnSysInfoTick(t *testing.T) { | func TestOnSysInfoTick(t *testing.T) { | ||||||
| 	s := miniredis.RunT(t) | 	s := miniredis.RunT(t) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|   | |||||||
| @@ -27,6 +27,10 @@ type modifiedHookBase struct { | |||||||
|  |  | ||||||
| var errTestHook = errors.New("error") | var errTestHook = errors.New("error") | ||||||
|  |  | ||||||
|  | func (h *modifiedHookBase) ID() string { | ||||||
|  | 	return "modified" | ||||||
|  | } | ||||||
|  |  | ||||||
| func (h *modifiedHookBase) Init(config any) error { | func (h *modifiedHookBase) Init(config any) error { | ||||||
| 	if config != nil { | 	if config != nil { | ||||||
| 		return errTestHook | 		return errTestHook | ||||||
| @@ -178,12 +182,20 @@ func TestHooksProvides(t *testing.T) { | |||||||
| 	require.False(t, h.Provides(OnDisconnect)) | 	require.False(t, h.Provides(OnDisconnect)) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestHooksAddAndLen(t *testing.T) { | func TestHooksAddLenGetAll(t *testing.T) { | ||||||
| 	h := new(Hooks) | 	h := new(Hooks) | ||||||
| 	err := h.Add(new(HookBase), nil) | 	err := h.Add(new(HookBase), nil) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 	require.Equal(t, int64(1), atomic.LoadInt64(&h.qty)) |  | ||||||
| 	require.Equal(t, int64(1), h.Len()) | 	err = h.Add(new(modifiedHookBase), nil) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	require.Equal(t, int64(2), atomic.LoadInt64(&h.qty)) | ||||||
|  | 	require.Equal(t, int64(2), h.Len()) | ||||||
|  |  | ||||||
|  | 	all := h.GetAll() | ||||||
|  | 	require.Equal(t, "base", all[0].ID()) | ||||||
|  | 	require.Equal(t, "modified", all[1].ID()) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestHooksAddInitFailure(t *testing.T) { | func TestHooksAddInitFailure(t *testing.T) { | ||||||
| @@ -231,7 +243,6 @@ func TestHooksNonReturns(t *testing.T) { | |||||||
| 			h.OnWillSent(cl, packets.Packet{}) | 			h.OnWillSent(cl, packets.Packet{}) | ||||||
| 			h.OnClientExpired(cl) | 			h.OnClientExpired(cl) | ||||||
| 			h.OnRetainedExpired("a/b/c") | 			h.OnRetainedExpired("a/b/c") | ||||||
| 			h.OnExpireInflights(cl, time.Now().Unix()-1) |  | ||||||
|  |  | ||||||
| 			// on second iteration, check added hook methods | 			// on second iteration, check added hook methods | ||||||
| 			err := h.Add(new(modifiedHookBase), nil) | 			err := h.Add(new(modifiedHookBase), nil) | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								inflight.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								inflight.go
									
									
									
									
									
								
							| @@ -104,14 +104,14 @@ func (i *Inflight) Delete(id uint16) bool { | |||||||
| } | } | ||||||
|  |  | ||||||
| // TakeRecieveQuota reduces the receive quota by 1. | // TakeRecieveQuota reduces the receive quota by 1. | ||||||
| func (i *Inflight) TakeReceiveQuota() { | func (i *Inflight) DecreaseReceiveQuota() { | ||||||
| 	if atomic.LoadInt32(&i.receiveQuota) > 0 { | 	if atomic.LoadInt32(&i.receiveQuota) > 0 { | ||||||
| 		atomic.AddInt32(&i.receiveQuota, -1) | 		atomic.AddInt32(&i.receiveQuota, -1) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // TakeRecieveQuota increases the receive quota by 1. | // TakeRecieveQuota increases the receive quota by 1. | ||||||
| func (i *Inflight) ReturnReceiveQuota() { | func (i *Inflight) IncreaseReceiveQuota() { | ||||||
| 	if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) { | 	if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) { | ||||||
| 		atomic.AddInt32(&i.receiveQuota, 1) | 		atomic.AddInt32(&i.receiveQuota, 1) | ||||||
| 	} | 	} | ||||||
| @@ -123,15 +123,15 @@ func (i *Inflight) ResetReceiveQuota(n int32) { | |||||||
| 	atomic.StoreInt32(&i.maximumReceiveQuota, n) | 	atomic.StoreInt32(&i.maximumReceiveQuota, n) | ||||||
| } | } | ||||||
|  |  | ||||||
| // TakeSendQuota reduces the send quota by 1. | // DecreaseSendQuota reduces the send quota by 1. | ||||||
| func (i *Inflight) TakeSendQuota() { | func (i *Inflight) DecreaseSendQuota() { | ||||||
| 	if atomic.LoadInt32(&i.sendQuota) > 0 { | 	if atomic.LoadInt32(&i.sendQuota) > 0 { | ||||||
| 		atomic.AddInt32(&i.sendQuota, -1) | 		atomic.AddInt32(&i.sendQuota, -1) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // ReturnSendQuota increases the send quota by 1. | // IncreaseSendQuota increases the send quota by 1. | ||||||
| func (i *Inflight) ReturnSendQuota() { | func (i *Inflight) IncreaseSendQuota() { | ||||||
| 	if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) { | 	if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) { | ||||||
| 		atomic.AddInt32(&i.sendQuota, 1) | 		atomic.AddInt32(&i.sendQuota, 1) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestInflightSet(t *testing.T) { | func TestInflightSet(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
|  |  | ||||||
| 	r := cl.State.Inflight.Set(packets.Packet{PacketID: 1}) | 	r := cl.State.Inflight.Set(packets.Packet{PacketID: 1}) | ||||||
| 	require.True(t, r) | 	require.True(t, r) | ||||||
| @@ -25,7 +25,7 @@ func TestInflightSet(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestInflightGet(t *testing.T) { | func TestInflightGet(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | ||||||
|  |  | ||||||
| 	msg, ok := cl.State.Inflight.Get(2) | 	msg, ok := cl.State.Inflight.Get(2) | ||||||
| @@ -34,7 +34,7 @@ func TestInflightGet(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestInflightGetAllAndImmediate(t *testing.T) { | func TestInflightGetAllAndImmediate(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) | ||||||
| @@ -56,13 +56,13 @@ func TestInflightGetAllAndImmediate(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestInflightLen(t *testing.T) { | func TestInflightLen(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | ||||||
| 	require.Equal(t, 1, cl.State.Inflight.Len()) | 	require.Equal(t, 1, cl.State.Inflight.Len()) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestInflightDelete(t *testing.T) { | func TestInflightDelete(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
|  |  | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 3}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 3}) | ||||||
| 	require.NotNil(t, cl.State.Inflight.internal[3]) | 	require.NotNil(t, cl.State.Inflight.internal[3]) | ||||||
| @@ -95,12 +95,12 @@ func TestReceiveQuota(t *testing.T) { | |||||||
| 	require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota)) | 	require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota)) | ||||||
|  |  | ||||||
| 	// Return 1 | 	// Return 1 | ||||||
| 	i.ReturnReceiveQuota() | 	i.IncreaseReceiveQuota() | ||||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) | 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) | 	require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) | ||||||
|  |  | ||||||
| 	// Try to go over max limit | 	// Try to go over max limit | ||||||
| 	i.ReturnReceiveQuota() | 	i.IncreaseReceiveQuota() | ||||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) | 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) | 	require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) | ||||||
|  |  | ||||||
| @@ -110,12 +110,12 @@ func TestReceiveQuota(t *testing.T) { | |||||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota)) | 	require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota)) | ||||||
|  |  | ||||||
| 	// Take 1 | 	// Take 1 | ||||||
| 	i.TakeReceiveQuota() | 	i.DecreaseReceiveQuota() | ||||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) | 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) | 	require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) | ||||||
|  |  | ||||||
| 	// Try to go below zero | 	// Try to go below zero | ||||||
| 	i.TakeReceiveQuota() | 	i.DecreaseReceiveQuota() | ||||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) | 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) | 	require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) | ||||||
| } | } | ||||||
| @@ -137,12 +137,12 @@ func TestSendQuota(t *testing.T) { | |||||||
| 	require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota)) | 	require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota)) | ||||||
|  |  | ||||||
| 	// Return 1 | 	// Return 1 | ||||||
| 	i.ReturnSendQuota() | 	i.IncreaseSendQuota() | ||||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) | 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) | ||||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) | 	require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) | ||||||
|  |  | ||||||
| 	// Try to go over max limit | 	// Try to go over max limit | ||||||
| 	i.ReturnSendQuota() | 	i.IncreaseSendQuota() | ||||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) | 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) | ||||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) | 	require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) | ||||||
|  |  | ||||||
| @@ -152,18 +152,18 @@ func TestSendQuota(t *testing.T) { | |||||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota)) | 	require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota)) | ||||||
|  |  | ||||||
| 	// Take 1 | 	// Take 1 | ||||||
| 	i.TakeSendQuota() | 	i.DecreaseSendQuota() | ||||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) | 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) | ||||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) | 	require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) | ||||||
|  |  | ||||||
| 	// Try to go below zero | 	// Try to go below zero | ||||||
| 	i.TakeSendQuota() | 	i.DecreaseSendQuota() | ||||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) | 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) | ||||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) | 	require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestNextImmediate(t *testing.T) { | func TestNextImmediate(t *testing.T) { | ||||||
| 	cl, _, _ := newClient() | 	cl, _, _ := newTestClient() | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) | ||||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) | 	cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) | ||||||
|   | |||||||
							
								
								
									
										88
									
								
								listeners/net.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								listeners/net.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | |||||||
|  | package listeners | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net" | ||||||
|  | 	"sync" | ||||||
|  | 	"sync/atomic" | ||||||
|  |  | ||||||
|  | 	"github.com/rs/zerolog" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Net is a listener for establishing client connections on basic TCP protocol. | ||||||
|  | type Net struct { // [MQTT-4.2.0-1] | ||||||
|  | 	mu       sync.Mutex | ||||||
|  | 	listener net.Listener    // a net.Listener which will listen for new clients | ||||||
|  | 	id       string          // the internal id of the listener | ||||||
|  | 	log      *zerolog.Logger // server logger | ||||||
|  | 	end      uint32          // ensure the close methods are only called once | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewNet initialises and returns a listener serving incoming connections on the given net.Listener | ||||||
|  | func NewNet(id string, listener net.Listener) *Net { | ||||||
|  | 	return &Net{ | ||||||
|  | 		id:       id, | ||||||
|  | 		listener: listener, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ID returns the id of the listener. | ||||||
|  | func (l *Net) ID() string { | ||||||
|  | 	return l.id | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Address returns the address of the listener. | ||||||
|  | func (l *Net) Address() string { | ||||||
|  | 	return l.listener.Addr().String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Protocol returns the network of the listener. | ||||||
|  | func (l *Net) Protocol() string { | ||||||
|  | 	return l.listener.Addr().Network() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Init initializes the listener. | ||||||
|  | func (l *Net) Init(log *zerolog.Logger) error { | ||||||
|  | 	l.log = log | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Serve starts waiting for new TCP connections, and calls the establish | ||||||
|  | // connection callback for any received. | ||||||
|  | func (l *Net) Serve(establish EstablishFn) { | ||||||
|  | 	for { | ||||||
|  | 		if atomic.LoadUint32(&l.end) == 1 { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		conn, err := l.listener.Accept() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if atomic.LoadUint32(&l.end) == 0 { | ||||||
|  | 			go func() { | ||||||
|  | 				err = establish(l.id, conn) | ||||||
|  | 				if err != nil { | ||||||
|  | 					l.log.Warn().Err(err).Send() | ||||||
|  | 				} | ||||||
|  | 			}() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close closes the listener and any client connections. | ||||||
|  | func (l *Net) Close(closeClients CloseFn) { | ||||||
|  | 	l.mu.Lock() | ||||||
|  | 	defer l.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	if atomic.CompareAndSwapUint32(&l.end, 0, 1) { | ||||||
|  | 		closeClients(l.id) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if l.listener != nil { | ||||||
|  | 		err := l.listener.Close() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										105
									
								
								listeners/net_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								listeners/net_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | |||||||
|  | package listeners | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"net" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestNewNet(t *testing.T) { | ||||||
|  | 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	l := NewNet("t1", n) | ||||||
|  | 	require.Equal(t, "t1", l.id) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNetID(t *testing.T) { | ||||||
|  | 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	l := NewNet("t1", n) | ||||||
|  | 	require.Equal(t, "t1", l.ID()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNetAddress(t *testing.T) { | ||||||
|  | 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	l := NewNet("t1", n) | ||||||
|  | 	require.Equal(t, n.Addr().String(), l.Address()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNetProtocol(t *testing.T) { | ||||||
|  | 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	l := NewNet("t1", n) | ||||||
|  | 	require.Equal(t, "tcp", l.Protocol()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNetInit(t *testing.T) { | ||||||
|  | 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	l := NewNet("t1", n) | ||||||
|  | 	err = l.Init(&logger) | ||||||
|  | 	l.Close(MockCloser) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNetServeAndClose(t *testing.T) { | ||||||
|  | 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	l := NewNet("t1", n) | ||||||
|  | 	err = l.Init(&logger) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	o := make(chan bool) | ||||||
|  | 	go func(o chan bool) { | ||||||
|  | 		l.Serve(MockEstablisher) | ||||||
|  | 		o <- true | ||||||
|  | 	}(o) | ||||||
|  |  | ||||||
|  | 	time.Sleep(time.Millisecond) | ||||||
|  |  | ||||||
|  | 	var closed bool | ||||||
|  | 	l.Close(func(id string) { | ||||||
|  | 		closed = true | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	require.True(t, closed) | ||||||
|  | 	<-o | ||||||
|  |  | ||||||
|  | 	l.Close(MockCloser)      // coverage: close closed | ||||||
|  | 	l.Serve(MockEstablisher) // coverage: serve closed | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNetEstablishThenEnd(t *testing.T) { | ||||||
|  | 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	l := NewNet("t1", n) | ||||||
|  | 	err = l.Init(&logger) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	o := make(chan bool) | ||||||
|  | 	established := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		l.Serve(func(id string, c net.Conn) error { | ||||||
|  | 			established <- true | ||||||
|  | 			return errors.New("ending") // return an error to exit immediately | ||||||
|  | 		}) | ||||||
|  | 		o <- true | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	time.Sleep(time.Millisecond) | ||||||
|  | 	net.Dial("tcp", n.Addr().String()) | ||||||
|  | 	require.Equal(t, true, <-established) | ||||||
|  | 	l.Close(MockCloser) | ||||||
|  | 	<-o | ||||||
|  | } | ||||||
							
								
								
									
										98
									
								
								listeners/unixsock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								listeners/unixsock.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,98 @@ | |||||||
|  | // SPDX-License-Identifier: MIT | ||||||
|  | // SPDX-FileCopyrightText: 2022 mochi-co | ||||||
|  | // SPDX-FileContributor: jason@zgwit.com | ||||||
|  |  | ||||||
|  | package listeners | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net" | ||||||
|  | 	"os" | ||||||
|  | 	"sync" | ||||||
|  | 	"sync/atomic" | ||||||
|  |  | ||||||
|  | 	"github.com/rs/zerolog" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // UnixSock is a listener for establishing client connections on basic UnixSock protocol. | ||||||
|  | type UnixSock struct { | ||||||
|  | 	sync.RWMutex | ||||||
|  | 	id      string          // the internal id of the listener. | ||||||
|  | 	address string          // the network address to bind to. | ||||||
|  | 	listen  net.Listener    // a net.Listener which will listen for new clients. | ||||||
|  | 	log     *zerolog.Logger // server logger | ||||||
|  | 	end     uint32          // ensure the close methods are only called once. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewUnixSock initialises and returns a new UnixSock listener, listening on an address. | ||||||
|  | func NewUnixSock(id, address string) *UnixSock { | ||||||
|  | 	return &UnixSock{ | ||||||
|  | 		id:      id, | ||||||
|  | 		address: address, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ID returns the id of the listener. | ||||||
|  | func (l *UnixSock) ID() string { | ||||||
|  | 	return l.id | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Address returns the address of the listener. | ||||||
|  | func (l *UnixSock) Address() string { | ||||||
|  | 	return l.address | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Protocol returns the address of the listener. | ||||||
|  | func (l *UnixSock) Protocol() string { | ||||||
|  | 	return "unix" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Init initializes the listener. | ||||||
|  | func (l *UnixSock) Init(log *zerolog.Logger) error { | ||||||
|  | 	l.log = log | ||||||
|  |  | ||||||
|  | 	var err error | ||||||
|  | 	_ = os.Remove(l.address) | ||||||
|  | 	l.listen, err = net.Listen("unix", l.address) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Serve starts waiting for new UnixSock connections, and calls the establish | ||||||
|  | // connection callback for any received. | ||||||
|  | func (l *UnixSock) Serve(establish EstablishFn) { | ||||||
|  | 	for { | ||||||
|  | 		if atomic.LoadUint32(&l.end) == 1 { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		conn, err := l.listen.Accept() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if atomic.LoadUint32(&l.end) == 0 { | ||||||
|  | 			go func() { | ||||||
|  | 				err = establish(l.id, conn) | ||||||
|  | 				if err != nil { | ||||||
|  | 					l.log.Warn().Err(err).Send() | ||||||
|  | 				} | ||||||
|  | 			}() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close closes the listener and any client connections. | ||||||
|  | func (l *UnixSock) Close(closeClients CloseFn) { | ||||||
|  | 	l.Lock() | ||||||
|  | 	defer l.Unlock() | ||||||
|  |  | ||||||
|  | 	if atomic.CompareAndSwapUint32(&l.end, 0, 1) { | ||||||
|  | 		closeClients(l.id) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if l.listen != nil { | ||||||
|  | 		err := l.listen.Close() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										96
									
								
								listeners/unixsock_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								listeners/unixsock_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,96 @@ | |||||||
|  | // SPDX-License-Identifier: MIT | ||||||
|  | // SPDX-FileCopyrightText: 2022 mochi-co | ||||||
|  | // SPDX-FileContributor: jason@zgwit.com | ||||||
|  |  | ||||||
|  | package listeners | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"net" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const testUnixAddr = "mochi.sock" | ||||||
|  |  | ||||||
|  | func TestNewUnixSock(t *testing.T) { | ||||||
|  | 	l := NewUnixSock("t1", testUnixAddr) | ||||||
|  | 	require.Equal(t, "t1", l.id) | ||||||
|  | 	require.Equal(t, testUnixAddr, l.address) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestUnixSockID(t *testing.T) { | ||||||
|  | 	l := NewUnixSock("t1", testUnixAddr) | ||||||
|  | 	require.Equal(t, "t1", l.ID()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestUnixSockAddress(t *testing.T) { | ||||||
|  | 	l := NewUnixSock("t1", testUnixAddr) | ||||||
|  | 	require.Equal(t, testUnixAddr, l.Address()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestUnixSockProtocol(t *testing.T) { | ||||||
|  | 	l := NewUnixSock("t1", testUnixAddr) | ||||||
|  | 	require.Equal(t, "unix", l.Protocol()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestUnixSockInit(t *testing.T) { | ||||||
|  | 	l := NewUnixSock("t1", testUnixAddr) | ||||||
|  | 	err := l.Init(&logger) | ||||||
|  | 	l.Close(MockCloser) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	l2 := NewUnixSock("t2", testUnixAddr) | ||||||
|  | 	err = l2.Init(&logger) | ||||||
|  | 	l2.Close(MockCloser) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestUnixSockServeAndClose(t *testing.T) { | ||||||
|  | 	l := NewUnixSock("t1", testUnixAddr) | ||||||
|  | 	err := l.Init(&logger) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	o := make(chan bool) | ||||||
|  | 	go func(o chan bool) { | ||||||
|  | 		l.Serve(MockEstablisher) | ||||||
|  | 		o <- true | ||||||
|  | 	}(o) | ||||||
|  |  | ||||||
|  | 	time.Sleep(time.Millisecond) | ||||||
|  |  | ||||||
|  | 	var closed bool | ||||||
|  | 	l.Close(func(id string) { | ||||||
|  | 		closed = true | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	require.True(t, closed) | ||||||
|  | 	<-o | ||||||
|  |  | ||||||
|  | 	l.Close(MockCloser)      // coverage: close closed | ||||||
|  | 	l.Serve(MockEstablisher) // coverage: serve closed | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestUnixSockEstablishThenEnd(t *testing.T) { | ||||||
|  | 	l := NewUnixSock("t1", testUnixAddr) | ||||||
|  | 	err := l.Init(&logger) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	o := make(chan bool) | ||||||
|  | 	established := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		l.Serve(func(id string, c net.Conn) error { | ||||||
|  | 			established <- true | ||||||
|  | 			return errors.New("ending") // return an error to exit immediately | ||||||
|  | 		}) | ||||||
|  | 		o <- true | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	time.Sleep(time.Millisecond) | ||||||
|  | 	net.Dial("unix", l.listen.Addr().String()) | ||||||
|  | 	require.Equal(t, true, <-established) | ||||||
|  | 	l.Close(MockCloser) | ||||||
|  | 	<-o | ||||||
|  | } | ||||||
| @@ -7,6 +7,7 @@ package listeners | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"io" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"sync" | 	"sync" | ||||||
| @@ -137,25 +138,35 @@ type wsConn struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Read reads the next span of bytes from the websocket connection and returns the number of bytes read. | // Read reads the next span of bytes from the websocket connection and returns the number of bytes read. | ||||||
| func (ws *wsConn) Read(p []byte) (n int, err error) { | func (ws *wsConn) Read(p []byte) (int, error) { | ||||||
| 	op, r, err := ws.c.NextReader() | 	op, r, err := ws.c.NextReader() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return 0, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if op != websocket.BinaryMessage { | 	if op != websocket.BinaryMessage { | ||||||
| 		err = ErrInvalidMessage | 		err = ErrInvalidMessage | ||||||
| 		return | 		return 0, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return r.Read(p) | 	var n, br int | ||||||
|  | 	for { | ||||||
|  | 		br, err = r.Read(p[n:]) | ||||||
|  | 		n += br | ||||||
|  | 		if err != nil { | ||||||
|  | 			if errors.Is(err, io.EOF) { | ||||||
|  | 				err = nil | ||||||
|  | 			} | ||||||
|  | 			return n, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Write writes bytes to the websocket connection. | // Write writes bytes to the websocket connection. | ||||||
| func (ws *wsConn) Write(p []byte) (n int, err error) { | func (ws *wsConn) Write(p []byte) (int, error) { | ||||||
| 	err = ws.c.WriteMessage(websocket.BinaryMessage, p) | 	err := ws.c.WriteMessage(websocket.BinaryMessage, p) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return 0, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return len(p), nil | 	return len(p), nil | ||||||
|   | |||||||
| @@ -124,4 +124,23 @@ var ( | |||||||
| 	ErrMaxConnectTime                         = Code{Code: 0xA0, Reason: "maximum connect time"} | 	ErrMaxConnectTime                         = Code{Code: 0xA0, Reason: "maximum connect time"} | ||||||
| 	ErrSubscriptionIdentifiersNotSupported    = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} | 	ErrSubscriptionIdentifiersNotSupported    = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} | ||||||
| 	ErrWildcardSubscriptionsNotSupported      = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} | 	ErrWildcardSubscriptionsNotSupported      = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} | ||||||
|  |  | ||||||
|  | 	// MQTTv3 specific bytes. | ||||||
|  | 	Err3UnsupportedProtocolVersion = Code{Code: 0x01} | ||||||
|  | 	Err3ClientIdentifierNotValid   = Code{Code: 0x02} | ||||||
|  | 	Err3ServerUnavailable          = Code{Code: 0x03} | ||||||
|  | 	ErrMalformedUsernameOrPassword = Code{Code: 0x04} | ||||||
|  | 	Err3NotAuthorized              = Code{Code: 0x05} | ||||||
|  |  | ||||||
|  | 	// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes. | ||||||
|  | 	// This is required because MQTTv3 has different return byte specification. | ||||||
|  | 	// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257 | ||||||
|  | 	V5CodesToV3 = map[Code]Code{ | ||||||
|  | 		ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion, | ||||||
|  | 		ErrClientIdentifierNotValid:   Err3ClientIdentifierNotValid, | ||||||
|  | 		ErrServerUnavailable:          Err3ServerUnavailable, | ||||||
|  | 		ErrMalformedUsername:          ErrMalformedUsernameOrPassword, | ||||||
|  | 		ErrMalformedPassword:          ErrMalformedUsernameOrPassword, | ||||||
|  | 		ErrBadUsernameOrPassword:      Err3NotAuthorized, | ||||||
|  | 	} | ||||||
| ) | ) | ||||||
|   | |||||||
| @@ -89,6 +89,7 @@ const ( | |||||||
| 	TConnackServerUnavailable | 	TConnackServerUnavailable | ||||||
| 	TConnackBadUsernamePassword | 	TConnackBadUsernamePassword | ||||||
| 	TConnackBadUsernamePasswordNoSession | 	TConnackBadUsernamePasswordNoSession | ||||||
|  | 	TConnackMqtt5BadUsernamePasswordNoSession | ||||||
| 	TConnackNotAuthorised | 	TConnackNotAuthorised | ||||||
| 	TConnackMalSessionPresent | 	TConnackMalSessionPresent | ||||||
| 	TConnackMalReturnCode | 	TConnackMalReturnCode | ||||||
| @@ -1316,10 +1317,28 @@ var TPacketData = map[byte]TPacketCases{ | |||||||
| 			Desc: "bad username or password no session", | 			Desc: "bad username or password no session", | ||||||
| 			RawBytes: []byte{ | 			RawBytes: []byte{ | ||||||
| 				Connack << 4, 2, // fixed header | 				Connack << 4, 2, // fixed header | ||||||
| 				0, // No session present | 				0,                      // No session present | ||||||
| 				ErrBadUsernameOrPassword.Code, | 				Err3NotAuthorized.Code, // use v3 remapping | ||||||
| 			}, | 			}, | ||||||
| 			Packet: &Packet{ | 			Packet: &Packet{ | ||||||
|  | 				FixedHeader: FixedHeader{ | ||||||
|  | 					Type:      Connack, | ||||||
|  | 					Remaining: 2, | ||||||
|  | 				}, | ||||||
|  | 				ReasonCode: Err3NotAuthorized.Code, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Case: TConnackMqtt5BadUsernamePasswordNoSession, | ||||||
|  | 			Desc: "mqtt5 bad username or password no session", | ||||||
|  | 			RawBytes: []byte{ | ||||||
|  | 				Connack << 4, 3, // fixed header | ||||||
|  | 				0, // No session present | ||||||
|  | 				ErrBadUsernameOrPassword.Code, | ||||||
|  | 				0, | ||||||
|  | 			}, | ||||||
|  | 			Packet: &Packet{ | ||||||
|  | 				ProtocolVersion: 5, | ||||||
| 				FixedHeader: FixedHeader{ | 				FixedHeader: FixedHeader{ | ||||||
| 					Type:      Connack, | 					Type:      Connack, | ||||||
| 					Remaining: 2, | 					Remaining: 2, | ||||||
| @@ -1327,6 +1346,7 @@ var TPacketData = map[byte]TPacketCases{ | |||||||
| 				ReasonCode: ErrBadUsernameOrPassword.Code, | 				ReasonCode: ErrBadUsernameOrPassword.Code, | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		{ | 		{ | ||||||
| 			Case: TConnackNotAuthorised, | 			Case: TConnackNotAuthorised, | ||||||
| 			Desc: "not authorised", | 			Desc: "not authorised", | ||||||
| @@ -1804,13 +1824,10 @@ var TPacketData = map[byte]TPacketCases{ | |||||||
| 			Case: TPublishRetainMqtt5, | 			Case: TPublishRetainMqtt5, | ||||||
| 			Desc: "retain mqtt5", | 			Desc: "retain mqtt5", | ||||||
| 			RawBytes: []byte{ | 			RawBytes: []byte{ | ||||||
| 				Publish<<4 | 1<<0, 35, // Fixed header | 				Publish<<4 | 1<<0, 19, // Fixed header | ||||||
| 				0, 5, // Topic Name - LSB+MSB | 				0, 5, // Topic Name - LSB+MSB | ||||||
| 				'a', '/', 'b', '/', 'c', // Topic Name | 				'a', '/', 'b', '/', 'c', // Topic Name | ||||||
| 				16, // properties length | 				0,                                                     // properties length | ||||||
| 				38, // User Properties (38) |  | ||||||
| 				0, 5, 'h', 'e', 'l', 'l', 'o', |  | ||||||
| 				0, 6, 228, 184, 150, 231, 149, 140, |  | ||||||
| 				'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload | 				'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload | ||||||
| 			}, | 			}, | ||||||
| 			Packet: &Packet{ | 			Packet: &Packet{ | ||||||
| @@ -1818,18 +1835,11 @@ var TPacketData = map[byte]TPacketCases{ | |||||||
| 				FixedHeader: FixedHeader{ | 				FixedHeader: FixedHeader{ | ||||||
| 					Type:      Publish, | 					Type:      Publish, | ||||||
| 					Retain:    true, | 					Retain:    true, | ||||||
| 					Remaining: 35, | 					Remaining: 19, | ||||||
| 				}, | 				}, | ||||||
| 				TopicName: "a/b/c", | 				TopicName:  "a/b/c", | ||||||
| 				Properties: Properties{ | 				Properties: Properties{}, | ||||||
| 					User: []UserProperty{ | 				Payload:    []byte("hello mochi"), | ||||||
| 						{ |  | ||||||
| 							Key: "hello", |  | ||||||
| 							Val: "世界", |  | ||||||
| 						}, |  | ||||||
| 					}, |  | ||||||
| 				}, |  | ||||||
| 				Payload: []byte("hello mochi"), |  | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
|   | |||||||
							
								
								
									
										192
									
								
								server.go
									
									
									
									
									
								
							
							
						
						
									
										192
									
								
								server.go
									
									
									
									
									
								
							| @@ -26,10 +26,10 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	Version                        = "2.0.0"  // the current server version. | 	Version                        = "2.1.8" // the current server version. | ||||||
| 	defaultSysTopicInterval int64  = 1        // the interval between $SYS topic publishes | 	defaultSysTopicInterval int64  = 1       // the interval between $SYS topic publishes | ||||||
| 	defaultFanPoolSize      uint64 = 64       // the number of concurrent workers in the pool | 	defaultFanPoolSize      uint64 = 32      // the number of concurrent workers in the pool | ||||||
| 	defaultFanPoolQueueSize uint64 = 32 * 128 // the capacity of each worker queue | 	defaultFanPoolQueueSize uint64 = 1024    // the capacity of each worker queue | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| @@ -61,13 +61,13 @@ type Capabilities struct { | |||||||
| 	ReceiveMaximum               uint16 | 	ReceiveMaximum               uint16 | ||||||
| 	TopicAliasMaximum            uint16 | 	TopicAliasMaximum            uint16 | ||||||
| 	ServerKeepAlive              uint16 | 	ServerKeepAlive              uint16 | ||||||
|  | 	SharedSubAvailable           byte | ||||||
|  | 	MinimumProtocolVersion       byte | ||||||
| 	Compatibilities              Compatibilities | 	Compatibilities              Compatibilities | ||||||
| 	MaximumQos                   byte | 	MaximumQos                   byte | ||||||
| 	RetainAvailable              byte | 	RetainAvailable              byte | ||||||
| 	WildcardSubAvailable         byte | 	WildcardSubAvailable         byte | ||||||
| 	SubIDAvailable               byte | 	SubIDAvailable               byte | ||||||
| 	SharedSubAvailable           byte |  | ||||||
| 	MinimumProtocolVersion       byte |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Compatibilities provides flags for using compatibility modes. | // Compatibilities provides flags for using compatibility modes. | ||||||
| @@ -199,6 +199,31 @@ func (o *Options) ensureDefaults() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // NewClient returns a new Client instance, populated with all the required values and | ||||||
|  | // references to be used with the server. If you are using this client to directly publish | ||||||
|  | // messages from the embedding application, set the inline flag to true to bypass ACL and | ||||||
|  | // topic validation checks. | ||||||
|  | func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client { | ||||||
|  | 	cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit | ||||||
|  | 		capabilities: s.Options.Capabilities, | ||||||
|  | 		info:         s.Info, | ||||||
|  | 		hooks:        s.hooks, | ||||||
|  | 		log:          s.Log, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	cl.ID = id | ||||||
|  | 	cl.Net.Listener = listener | ||||||
|  |  | ||||||
|  | 	if inline { // inline clients bypass acl and some validity checks. | ||||||
|  | 		cl.Net.Inline = true | ||||||
|  | 		// By default we don't want to restrict developer publishes, | ||||||
|  | 		// but if you do, reset this after creating inline client. | ||||||
|  | 		cl.State.Inflight.ResetReceiveQuota(math.MaxInt32) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return cl | ||||||
|  | } | ||||||
|  |  | ||||||
| // AddHook attaches a new Hook to the server. Ideally, this should be called | // AddHook attaches a new Hook to the server. Ideally, this should be called | ||||||
| // before the server is started with s.Serve(). | // before the server is started with s.Serve(). | ||||||
| func (s *Server) AddHook(hook Hook, config any) error { | func (s *Server) AddHook(hook Hook, config any) error { | ||||||
| @@ -281,27 +306,21 @@ func (s *Server) eventLoop() { | |||||||
| } | } | ||||||
|  |  | ||||||
| // EstablishConnection establishes a new client when a listener accepts a new connection. | // EstablishConnection establishes a new client when a listener accepts a new connection. | ||||||
| func (s *Server) EstablishConnection(lid string, c net.Conn) error { | func (s *Server) EstablishConnection(listener string, c net.Conn) error { | ||||||
| 	cl := NewClient(c, &ops{ // [MQTT-3.1.2-6] implicit | 	cl := s.NewClient(c, listener, "", false) | ||||||
| 		capabilities: s.Options.Capabilities, | 	return s.attachClient(cl, listener) | ||||||
| 		info:         s.Info, |  | ||||||
| 		hooks:        s.hooks, |  | ||||||
| 		log:          s.Log, |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	return s.attachClient(cl, lid) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // attachClient validates an incoming client connection and if viable, attaches the client | // attachClient validates an incoming client connection and if viable, attaches the client | ||||||
| // to the server, performs session housekeeping, and reads incoming packets. | // to the server, performs session housekeeping, and reads incoming packets. | ||||||
| func (s *Server) attachClient(cl *Client, lid string) error { | func (s *Server) attachClient(cl *Client, listener string) error { | ||||||
| 	defer cl.Stop(nil) | 	defer cl.Stop(nil) | ||||||
| 	pk, err := s.readConnectionPacket(cl) | 	pk, err := s.readConnectionPacket(cl) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("read connection: %w", err) | 		return fmt.Errorf("read connection: %w", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cl.ParseConnect(lid, pk) | 	cl.ParseConnect(listener, pk) | ||||||
| 	code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2] | 	code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2] | ||||||
| 	if code != packets.CodeSuccess { | 	if code != packets.CodeSuccess { | ||||||
| 		if err := s.sendConnack(cl, code, false); err != nil { | 		if err := s.sendConnack(cl, code, false); err != nil { | ||||||
| @@ -332,6 +351,13 @@ func (s *Server) attachClient(cl *Client, lid string) error { | |||||||
| 		return fmt.Errorf("ack connection packet: %w", err) | 		return fmt.Errorf("ack connection packet: %w", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Publish any retained messages for subscriptions which already existed on client takeover. | ||||||
|  | 	if sessionPresent { | ||||||
|  | 		for _, sub := range cl.State.Subscriptions.GetAll() { | ||||||
|  | 			s.publishRetainedToClient(cl, sub, true) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	s.loop.willDelayed.Delete(cl.ID) // [MQTT-3.1.3-9] | 	s.loop.willDelayed.Delete(cl.ID) // [MQTT-3.1.3-9] | ||||||
|  |  | ||||||
| 	if sessionPresent { | 	if sessionPresent { | ||||||
| @@ -353,11 +379,11 @@ func (s *Server) attachClient(cl *Client, lid string) error { | |||||||
| 		cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10] | 		cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10] | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", lid).Msg("client disconnected") | 	s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected") | ||||||
| 	expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) | 	expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) | ||||||
| 	s.hooks.OnDisconnect(cl, err, expire) | 	s.hooks.OnDisconnect(cl, err, expire) | ||||||
| 	if expire { | 	if expire { | ||||||
| 		s.unsubscribeClient(cl) | 		s.UnsubscribeClient(cl) | ||||||
| 		cl.ClearInflights(math.MaxInt64, 0) | 		cl.ClearInflights(math.MaxInt64, 0) | ||||||
| 		s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23] | 		s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23] | ||||||
| 	} | 	} | ||||||
| @@ -436,19 +462,25 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { | |||||||
| 		defer existing.Unlock() | 		defer existing.Unlock() | ||||||
| 		s.DisconnectClient(existing, packets.ErrSessionTakenOver)                                 // [MQTT-3.1.4-3] | 		s.DisconnectClient(existing, packets.ErrSessionTakenOver)                                 // [MQTT-3.1.4-3] | ||||||
| 		if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] | 		if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] | ||||||
| 			s.unsubscribeClient(existing) | 			s.UnsubscribeClient(existing) | ||||||
| 			existing.ClearInflights(math.MaxInt64, 0) | 			existing.ClearInflights(math.MaxInt64, 0) | ||||||
| 			return false // [MQTT-3.2.2-3] | 			return false // [MQTT-3.2.2-3] | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5] | 		if existing.State.Inflight.Len() > 0 { | ||||||
|  | 			cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5] | ||||||
|  | 			if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.capabilities.ReceiveMaximum != 0 { | ||||||
|  | 				cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client | ||||||
|  | 				cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum))    // client receive max | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		 | ||||||
| 		for _, sub := range existing.State.Subscriptions.GetAll() { | 		for _, sub := range existing.State.Subscriptions.GetAll() { | ||||||
| 			existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3] | 			existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3] | ||||||
| 			if !existed { | 			if !existed { | ||||||
| 				atomic.AddInt64(&s.Info.Subscriptions, 1) | 				atomic.AddInt64(&s.Info.Subscriptions, 1) | ||||||
| 			} | 			} | ||||||
| 			cl.State.Subscriptions.Add(sub.Filter, sub) | 			cl.State.Subscriptions.Add(sub.Filter, sub) | ||||||
| 			s.publishRetainedToClient(cl, sub, existed) |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		return true // [MQTT-3.2.2-3] | 		return true // [MQTT-3.2.2-3] | ||||||
| @@ -470,6 +502,12 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if reason.Code >= packets.ErrUnspecifiedError.Code { | 	if reason.Code >= packets.ErrUnspecifiedError.Code { | ||||||
|  | 		if cl.Properties.ProtocolVersion < 5 { | ||||||
|  | 			if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes | ||||||
|  | 				reason = v3reason | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		properties.ReasonString = reason.Reason | 		properties.ReasonString = reason.Reason | ||||||
| 		ack := packets.Packet{ | 		ack := packets.Packet{ | ||||||
| 			FixedHeader: packets.FixedHeader{ | 			FixedHeader: packets.FixedHeader{ | ||||||
| @@ -569,7 +607,7 @@ func (s *Server) processPacket(cl *Client, pk packets.Packet) error { | |||||||
| 			if ok := cl.State.Inflight.Delete(next.PacketID); ok { | 			if ok := cl.State.Inflight.Delete(next.PacketID); ok { | ||||||
| 				atomic.AddInt64(&s.Info.Inflight, -1) | 				atomic.AddInt64(&s.Info.Inflight, -1) | ||||||
| 			} | 			} | ||||||
| 			cl.State.Inflight.TakeSendQuota() | 			cl.State.Inflight.DecreaseSendQuota() | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -592,6 +630,24 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Publish publishes a publish packet into the broker as if it were sent from the speicfied client. | ||||||
|  | // This is a convenience function which wraps InjectPacket. As such, this method can publish packets | ||||||
|  | // to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the | ||||||
|  | // outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete). | ||||||
|  | func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error { | ||||||
|  | 	cl := s.NewClient(nil, "local", "inline", true) | ||||||
|  | 	return s.InjectPacket(cl, packets.Packet{ | ||||||
|  | 		FixedHeader: packets.FixedHeader{ | ||||||
|  | 			Type:   packets.Publish, | ||||||
|  | 			Qos:    qos, | ||||||
|  | 			Retain: retain, | ||||||
|  | 		}, | ||||||
|  | 		TopicName: topic, | ||||||
|  | 		Payload:   payload, | ||||||
|  | 		PacketID:  uint16(qos), // we never process the inbound qos, but we need a packet id for validity checks. | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
| // InjectPacket injects a packet into the broker as if it were sent from the specified client. | // InjectPacket injects a packet into the broker as if it were sent from the specified client. | ||||||
| // InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks. | // InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks. | ||||||
| func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { | func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { | ||||||
| @@ -612,7 +668,7 @@ func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { | |||||||
|  |  | ||||||
| // processPublish processes a Publish packet. | // processPublish processes a Publish packet. | ||||||
| func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | ||||||
| 	if !IsValidFilter(pk.TopicName, true) && !cl.Net.Inline { | 	if !cl.Net.Inline && !IsValidFilter(pk.TopicName, true) { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -620,20 +676,22 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | |||||||
| 		return s.DisconnectClient(cl, packets.ErrReceiveMaximum) // ~[MQTT-3.3.4-7] ~[MQTT-3.3.4-8] | 		return s.DisconnectClient(cl, packets.ErrReceiveMaximum) // ~[MQTT-3.3.4-7] ~[MQTT-3.3.4-8] | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !s.hooks.OnACLCheck(cl, pk.TopicName, true) && !cl.Net.Inline { | 	if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	pk.Origin = cl.ID | 	pk.Origin = cl.ID | ||||||
| 	pk.Created = time.Now().Unix() | 	pk.Created = time.Now().Unix() | ||||||
|  |  | ||||||
| 	if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok { | 	if !cl.Net.Inline { | ||||||
| 		if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10] | 		if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok { | ||||||
| 			ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse) | 			if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10] | ||||||
| 			return cl.WritePacket(ack) | 				ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse) | ||||||
| 		} | 				return cl.WritePacket(ack) | ||||||
| 		if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5] | 			} | ||||||
| 			atomic.AddInt64(&s.Info.Inflight, -1) | 			if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5] | ||||||
|  | 				atomic.AddInt64(&s.Info.Inflight, -1) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -660,10 +718,11 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | |||||||
| 			s.publishToSubscribers(pk) | 			s.publishToSubscribers(pk) | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | 		s.hooks.OnPublished(cl, pk) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cl.State.Inflight.TakeReceiveQuota() | 	cl.State.Inflight.DecreaseReceiveQuota() | ||||||
| 	ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.QosCodes[pk.FixedHeader.Qos]) // [MQTT-4.3.2-4] | 	ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.QosCodes[pk.FixedHeader.Qos]) // [MQTT-4.3.2-4] | ||||||
| 	if pk.FixedHeader.Qos == 2 { | 	if pk.FixedHeader.Qos == 2 { | ||||||
| 		ack = s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.CodeSuccess) // [MQTT-3.3.4-1] [MQTT-4.3.3-8] | 		ack = s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.CodeSuccess) // [MQTT-3.3.4-1] [MQTT-4.3.3-8] | ||||||
| @@ -671,6 +730,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | |||||||
|  |  | ||||||
| 	if ok := cl.State.Inflight.Set(ack); ok { | 	if ok := cl.State.Inflight.Set(ack); ok { | ||||||
| 		atomic.AddInt64(&s.Info.Inflight, 1) | 		atomic.AddInt64(&s.Info.Inflight, 1) | ||||||
|  | 		s.hooks.OnQosPublish(cl, ack, ack.Created, 0) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err := cl.WritePacket(ack) | 	err := cl.WritePacket(ack) | ||||||
| @@ -682,16 +742,15 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | |||||||
| 		if ok := cl.State.Inflight.Delete(ack.PacketID); ok { | 		if ok := cl.State.Inflight.Delete(ack.PacketID); ok { | ||||||
| 			atomic.AddInt64(&s.Info.Inflight, -1) | 			atomic.AddInt64(&s.Info.Inflight, -1) | ||||||
| 		} | 		} | ||||||
| 		cl.State.Inflight.ReturnReceiveQuota() | 		cl.State.Inflight.IncreaseReceiveQuota() | ||||||
| 		s.hooks.OnQosComplete(cl, pk) | 		s.hooks.OnQosComplete(cl, ack) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.fanpool.Enqueue(cl.ID, func() { | 	s.fanpool.Enqueue(cl.ID, func() { | ||||||
| 		s.publishToSubscribers(pk) | 		s.publishToSubscribers(pk) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	s.hooks.OnPublish(cl, pk) | 	s.hooks.OnPublished(cl, pk) | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -734,13 +793,13 @@ func (s *Server) publishToSubscribers(pk packets.Packet) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (out packets.Packet, err error) { | func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) { | ||||||
| 	if sub.NoLocal && pk.Origin == cl.ID { | 	if sub.NoLocal && pk.Origin == cl.ID { | ||||||
| 		return pk, nil // [MQTT-3.8.3-3] | 		return pk, nil // [MQTT-3.8.3-3] | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	out = pk.Copy(false) | 	out := pk.Copy(false) | ||||||
| 	if !sub.RetainAsPublished { // ![MQTT-3.3.1-13] | 	if cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished { // ![MQTT-3.3.1-13] | ||||||
| 		out.FixedHeader.Retain = false // [MQTT-3.3.1-12] | 		out.FixedHeader.Retain = false // [MQTT-3.3.1-12] | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -780,6 +839,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet | |||||||
| 		if ok := cl.State.Inflight.Set(out); ok { // [MQTT-4.3.2-3] [MQTT-4.3.3-3] | 		if ok := cl.State.Inflight.Set(out); ok { // [MQTT-4.3.2-3] [MQTT-4.3.3-3] | ||||||
| 			atomic.AddInt64(&s.Info.Inflight, 1) | 			atomic.AddInt64(&s.Info.Inflight, 1) | ||||||
| 			s.hooks.OnQosPublish(cl, out, out.Created, 0) | 			s.hooks.OnQosPublish(cl, out, out.Created, 0) | ||||||
|  | 			cl.State.Inflight.DecreaseSendQuota() | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 { | 		if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 { | ||||||
| @@ -789,12 +849,10 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if cl.Net.conn == nil || atomic.LoadUint32(&cl.State.done) == 1 { | 	if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 { | ||||||
| 		return pk, packets.CodeDisconnect | 		return pk, packets.CodeDisconnect | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cl.State.Inflight.TakeSendQuota() |  | ||||||
|  |  | ||||||
| 	return out, cl.WritePacket(out) | 	return out, cl.WritePacket(out) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -844,7 +902,7 @@ func (s *Server) processPuback(cl *Client, pk packets.Packet) error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5] | 	if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5] | ||||||
| 		cl.State.Inflight.ReturnSendQuota() | 		cl.State.Inflight.IncreaseSendQuota() | ||||||
| 		atomic.AddInt64(&s.Info.Inflight, -1) | 		atomic.AddInt64(&s.Info.Inflight, -1) | ||||||
| 		s.hooks.OnQosComplete(cl, pk) | 		s.hooks.OnQosComplete(cl, pk) | ||||||
| 	} | 	} | ||||||
| @@ -867,7 +925,7 @@ func (s *Server) processPubrec(cl *Client, pk packets.Packet) error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ack := s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-4] ![MQTT-4.3.3-6] | 	ack := s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-4] ![MQTT-4.3.3-6] | ||||||
| 	cl.State.Inflight.TakeReceiveQuota()                                                  // -1 RECV QUOTA | 	cl.State.Inflight.DecreaseReceiveQuota()                                              // -1 RECV QUOTA | ||||||
| 	cl.State.Inflight.Set(ack)                                                            // [MQTT-4.3.3-5] | 	cl.State.Inflight.Set(ack)                                                            // [MQTT-4.3.3-5] | ||||||
| 	return cl.WritePacket(ack) | 	return cl.WritePacket(ack) | ||||||
| } | } | ||||||
| @@ -894,8 +952,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cl.State.Inflight.ReturnReceiveQuota()               // +1 RECV QUOTA | 	cl.State.Inflight.IncreaseReceiveQuota()             // +1 RECV QUOTA | ||||||
| 	cl.State.Inflight.ReturnSendQuota()                  // +1 SENT QUOTA | 	cl.State.Inflight.IncreaseSendQuota()                // +1 SENT QUOTA | ||||||
| 	if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.3-12] | 	if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.3-12] | ||||||
| 		atomic.AddInt64(&s.Info.Inflight, -1) | 		atomic.AddInt64(&s.Info.Inflight, -1) | ||||||
| 		s.hooks.OnQosComplete(cl, pk) | 		s.hooks.OnQosComplete(cl, pk) | ||||||
| @@ -907,8 +965,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error { | |||||||
| // processPubcomp processes a Pubcomp packet, denoting completion of a QOS 2 packet sent from the server. | // processPubcomp processes a Pubcomp packet, denoting completion of a QOS 2 packet sent from the server. | ||||||
| func (s *Server) processPubcomp(cl *Client, pk packets.Packet) error { | func (s *Server) processPubcomp(cl *Client, pk packets.Packet) error { | ||||||
| 	// regardless of whether the pubcomp is a success or failure, we end the qos flow, delete inflight, and restore the quotas. | 	// regardless of whether the pubcomp is a success or failure, we end the qos flow, delete inflight, and restore the quotas. | ||||||
| 	cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA | 	cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA | ||||||
| 	cl.State.Inflight.ReturnSendQuota()    // +1 SENT QUOTA | 	cl.State.Inflight.IncreaseSendQuota()    // +1 SENT QUOTA | ||||||
| 	if ok := cl.State.Inflight.Delete(pk.PacketID); ok { | 	if ok := cl.State.Inflight.Delete(pk.PacketID); ok { | ||||||
| 		atomic.AddInt64(&s.Info.Inflight, -1) | 		atomic.AddInt64(&s.Info.Inflight, -1) | ||||||
| 		s.hooks.OnQosComplete(cl, pk) | 		s.hooks.OnQosComplete(cl, pk) | ||||||
| @@ -925,7 +983,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error { | |||||||
| 		code = packets.ErrPacketIdentifierInUse | 		code = packets.ErrPacketIdentifierInUse | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	existed := false | 	filterExisted := make([]bool, len(pk.Filters)) | ||||||
| 	reasonCodes := make([]byte, len(pk.Filters)) | 	reasonCodes := make([]byte, len(pk.Filters)) | ||||||
| 	for i, sub := range pk.Filters { | 	for i, sub := range pk.Filters { | ||||||
| 		if code != packets.CodeSuccess { | 		if code != packets.CodeSuccess { | ||||||
| @@ -941,8 +999,8 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error { | |||||||
| 		} else if sub.NoLocal && IsSharedFilter(sub.Filter) { | 		} else if sub.NoLocal && IsSharedFilter(sub.Filter) { | ||||||
| 			reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4] | 			reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4] | ||||||
| 		} else { | 		} else { | ||||||
| 			existed = !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3] | 			isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3] | ||||||
| 			if !existed { | 			if isNew { | ||||||
| 				atomic.AddInt64(&s.Info.Subscriptions, 1) | 				atomic.AddInt64(&s.Info.Subscriptions, 1) | ||||||
| 			} | 			} | ||||||
| 			cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10] | 			cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10] | ||||||
| @@ -951,6 +1009,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error { | |||||||
| 				sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] | 				sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			filterExisted[i] = !isNew | ||||||
| 			reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7] | 			reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7] | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -985,7 +1044,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error { | |||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		s.publishRetainedToClient(cl, sub, existed) | 		s.publishRetainedToClient(cl, sub, filterExisted[i]) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| @@ -1035,14 +1094,20 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error { | |||||||
| 	return cl.WritePacket(ack) | 	return cl.WritePacket(ack) | ||||||
| } | } | ||||||
|  |  | ||||||
| // unsubscribeClient unsubscribes a client from all of their subscriptions. | // UnsubscribeClient unsubscribes a client from all of their subscriptions. | ||||||
| func (s *Server) unsubscribeClient(cl *Client) { | func (s *Server) UnsubscribeClient(cl *Client) { | ||||||
| 	for k := range cl.State.Subscriptions.GetAll() { | 	i := 0 | ||||||
|  | 	filterMap := cl.State.Subscriptions.GetAll() | ||||||
|  | 	filters := make([]packets.Subscription, len(filterMap)) | ||||||
|  | 	for k, v := range filterMap { | ||||||
| 		cl.State.Subscriptions.Delete(k) | 		cl.State.Subscriptions.Delete(k) | ||||||
| 		if s.Topics.Unsubscribe(k, cl.ID) { | 		if s.Topics.Unsubscribe(k, cl.ID) { | ||||||
| 			atomic.AddInt64(&s.Info.Subscriptions, -1) | 			atomic.AddInt64(&s.Info.Subscriptions, -1) | ||||||
| 		} | 		} | ||||||
|  | 		filters[i] = v | ||||||
|  | 		i++ | ||||||
| 	} | 	} | ||||||
|  | 	s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters}) | ||||||
| } | } | ||||||
|  |  | ||||||
| // processAuth processes an Auth packet. | // processAuth processes an Auth packet. | ||||||
| @@ -1087,9 +1152,14 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error { | |||||||
| 		out.Properties.ReasonString = code.Reason //  // [MQTT-3.14.2-1] | 		out.Properties.ReasonString = code.Reason //  // [MQTT-3.14.2-1] | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// We already have a code we are using to disconnect the client, so we are not | ||||||
|  | 	// interested if the write packet fails due to a closed connection (as we are closing it). | ||||||
| 	err := cl.WritePacket(out) | 	err := cl.WritePacket(out) | ||||||
| 	if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect { | 	if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect { | ||||||
| 		cl.Stop(code) | 		cl.Stop(code) | ||||||
|  | 		if code.Code >= packets.ErrUnspecifiedError.Code { | ||||||
|  | 			return code | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return err | 	return err | ||||||
| @@ -1304,9 +1374,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) { | |||||||
| // loadClients restores clients from the datastore. | // loadClients restores clients from the datastore. | ||||||
| func (s *Server) loadClients(v []storage.Client) { | func (s *Server) loadClients(v []storage.Client) { | ||||||
| 	for _, c := range v { | 	for _, c := range v { | ||||||
| 		cl := newClientStub() | 		cl := s.NewClient(nil, c.Listener, c.ID, false) | ||||||
| 		cl.ID = c.ID |  | ||||||
| 		cl.Net.Listener = c.Listener |  | ||||||
| 		cl.Properties.Username = c.Username | 		cl.Properties.Username = c.Username | ||||||
| 		cl.Properties.Clean = c.Clean | 		cl.Properties.Clean = c.Clean | ||||||
| 		cl.Properties.ProtocolVersion = c.ProtocolVersion | 		cl.Properties.ProtocolVersion = c.ProtocolVersion | ||||||
| @@ -1413,8 +1481,10 @@ func (s *Server) clearExpiredRetainedMessages(now int64) { | |||||||
| // clearExpiredInflights deletes any inflight messages which have expired. | // clearExpiredInflights deletes any inflight messages which have expired. | ||||||
| func (s *Server) clearExpiredInflights(now int64) { | func (s *Server) clearExpiredInflights(now int64) { | ||||||
| 	for _, client := range s.Clients.GetAll() { | 	for _, client := range s.Clients.GetAll() { | ||||||
| 		if d := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); d > 0 { | 		if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 { | ||||||
| 			s.hooks.OnExpireInflights(client, now) | 			for _, id := range deleted { | ||||||
|  | 				s.hooks.OnQosDropped(client, packets.Packet{PacketID: id}) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										322
									
								
								server_test.go
									
									
									
									
									
								
							
							
						
						
									
										322
									
								
								server_test.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user