mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-09-27 04:26:23 +08:00
Compare commits
9 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
75504ff201 | ||
![]() |
a556feb325 | ||
![]() |
8d4cc091b4 | ||
![]() |
d8f28cb843 | ||
![]() |
88861c219d | ||
![]() |
7ba6cf28d9 | ||
![]() |
c174cfdc6b | ||
![]() |
4f198a99dd | ||
![]() |
2a9c9fcc40 |
@@ -383,6 +383,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
|
||||
return nil
|
||||
}
|
||||
|
@@ -350,6 +350,22 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
cl.ops.capabilities.MaximumPacketSize = 2
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
2
hooks.go
2
hooks.go
@@ -351,7 +351,7 @@ func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnMessage
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnPublished
|
||||
// in that it allows you to modify you to modify the incoming packet before it is processed.
|
||||
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
||||
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
|
@@ -7,6 +7,7 @@ package listeners
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -137,25 +138,35 @@ type wsConn struct {
|
||||
}
|
||||
|
||||
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
|
||||
func (ws *wsConn) Read(p []byte) (n int, err error) {
|
||||
func (ws *wsConn) Read(p []byte) (int, error) {
|
||||
op, r, err := ws.c.NextReader()
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if op != websocket.BinaryMessage {
|
||||
err = ErrInvalidMessage
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return r.Read(p)
|
||||
var n, br int
|
||||
for {
|
||||
br, err = r.Read(p[n:])
|
||||
n += br
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes bytes to the websocket connection.
|
||||
func (ws *wsConn) Write(p []byte) (n int, err error) {
|
||||
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
func (ws *wsConn) Write(p []byte) (int, error) {
|
||||
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
|
35
server.go
35
server.go
@@ -26,10 +26,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "2.0.0" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
defaultFanPoolSize uint64 = 64 // the number of concurrent workers in the pool
|
||||
defaultFanPoolQueueSize uint64 = 32 * 128 // the capacity of each worker queue
|
||||
Version = "2.0.7" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
|
||||
defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -376,7 +376,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
|
||||
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
|
||||
s.hooks.OnDisconnect(cl, err, expire)
|
||||
if expire {
|
||||
s.unsubscribeClient(cl)
|
||||
s.UnsubscribeClient(cl)
|
||||
cl.ClearInflights(math.MaxInt64, 0)
|
||||
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
|
||||
}
|
||||
@@ -455,7 +455,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
|
||||
defer existing.Unlock()
|
||||
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
||||
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
|
||||
s.unsubscribeClient(existing)
|
||||
s.UnsubscribeClient(existing)
|
||||
existing.ClearInflights(math.MaxInt64, 0)
|
||||
return false // [MQTT-3.2.2-3]
|
||||
}
|
||||
@@ -697,6 +697,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
s.publishToSubscribers(pk)
|
||||
})
|
||||
|
||||
s.hooks.OnPublished(cl, pk)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -727,8 +728,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
s.publishToSubscribers(pk)
|
||||
})
|
||||
|
||||
s.hooks.OnPublish(cl, pk)
|
||||
|
||||
s.hooks.OnPublished(cl, pk)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1072,14 +1072,20 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
|
||||
return cl.WritePacket(ack)
|
||||
}
|
||||
|
||||
// unsubscribeClient unsubscribes a client from all of their subscriptions.
|
||||
func (s *Server) unsubscribeClient(cl *Client) {
|
||||
for k := range cl.State.Subscriptions.GetAll() {
|
||||
// UnsubscribeClient unsubscribes a client from all of their subscriptions.
|
||||
func (s *Server) UnsubscribeClient(cl *Client) {
|
||||
i := 0
|
||||
filterMap := cl.State.Subscriptions.GetAll()
|
||||
filters := make([]packets.Subscription, len(filterMap))
|
||||
for k, v := range filterMap {
|
||||
cl.State.Subscriptions.Delete(k)
|
||||
if s.Topics.Unsubscribe(k, cl.ID) {
|
||||
atomic.AddInt64(&s.Info.Subscriptions, -1)
|
||||
}
|
||||
filters[i] = v
|
||||
i++
|
||||
}
|
||||
s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters})
|
||||
}
|
||||
|
||||
// processAuth processes an Auth packet.
|
||||
@@ -1126,12 +1132,15 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error {
|
||||
|
||||
// We already have a code we are using to disconnect the client, so we are not
|
||||
// interested if the write packet fails due to a closed connection (as we are closing it).
|
||||
_ = cl.WritePacket(out)
|
||||
err := cl.WritePacket(out)
|
||||
if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
|
||||
cl.Stop(code)
|
||||
if code.Code >= packets.ErrUnspecifiedError.Code {
|
||||
return code
|
||||
}
|
||||
}
|
||||
|
||||
return code
|
||||
return err
|
||||
}
|
||||
|
||||
// publishSysTopics publishes the current values to the server $SYS topics.
|
||||
|
@@ -844,7 +844,7 @@ func TestServerUnsubscribeClient(t *testing.T) {
|
||||
s.Topics.Subscribe(cl.ID, pk)
|
||||
subs := s.Topics.Subscribers("a/b/c")
|
||||
require.Equal(t, 1, len(subs.Subscriptions))
|
||||
s.unsubscribeClient(cl)
|
||||
s.UnsubscribeClient(cl)
|
||||
subs = s.Topics.Subscribers("a/b/c")
|
||||
require.Equal(t, 0, len(subs.Subscriptions))
|
||||
}
|
||||
@@ -2291,6 +2291,21 @@ func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) {
|
||||
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf)
|
||||
}
|
||||
|
||||
func TestServerRecievePacketDisconnectClient(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
go func() {
|
||||
err := s.DisconnectClient(cl, packets.CodeDisconnect)
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, buf)
|
||||
}
|
||||
|
||||
func TestServerProcessPacketDisconnect(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newTestClient()
|
||||
|
Reference in New Issue
Block a user