mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-27 18:10:38 +08:00
479 lines
12 KiB
Go
479 lines
12 KiB
Go
package clients
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/rs/xid"
|
|
|
|
"github.com/mochi-co/mqtt/internal/auth"
|
|
"github.com/mochi-co/mqtt/internal/circ"
|
|
"github.com/mochi-co/mqtt/internal/packets"
|
|
"github.com/mochi-co/mqtt/internal/topics"
|
|
)
|
|
|
|
var (
|
|
defaultKeepalive uint16 = 30 // in seconds.
|
|
|
|
ErrInsufficientBytes = errors.New("Insufficient bytes in buffer")
|
|
ErrReadPacketValidation = errors.New("Error validating packet")
|
|
ErrConnectionClosed = errors.New("Connection not open")
|
|
)
|
|
|
|
// Clients contains a map of the clients known by the broker.
|
|
type Clients struct {
|
|
sync.RWMutex
|
|
internal map[string]*Client // clients known by the broker, keyed on client id.
|
|
}
|
|
|
|
// New returns an instance of Clients.
|
|
func New() Clients {
|
|
return Clients{
|
|
internal: make(map[string]*Client),
|
|
}
|
|
}
|
|
|
|
// Add adds a new client to the clients map, keyed on client id.
|
|
func (cl *Clients) Add(val *Client) {
|
|
cl.Lock()
|
|
cl.internal[val.ID] = val
|
|
cl.Unlock()
|
|
}
|
|
|
|
// Get returns the value of a client if it exists.
|
|
func (cl *Clients) Get(id string) (*Client, bool) {
|
|
cl.RLock()
|
|
val, ok := cl.internal[id]
|
|
cl.RUnlock()
|
|
return val, ok
|
|
}
|
|
|
|
// Len returns the length of the clients map.
|
|
func (cl *Clients) Len() int {
|
|
cl.RLock()
|
|
val := len(cl.internal)
|
|
cl.RUnlock()
|
|
return val
|
|
}
|
|
|
|
// Delete removes a client from the internal map.
|
|
func (cl *Clients) Delete(id string) {
|
|
cl.Lock()
|
|
delete(cl.internal, id)
|
|
cl.Unlock()
|
|
}
|
|
|
|
// GetByListener returns clients matching a listener id.
|
|
func (cl *Clients) GetByListener(id string) []*Client {
|
|
clients := make([]*Client, 0, cl.Len())
|
|
cl.RLock()
|
|
for _, v := range cl.internal {
|
|
if v.Listener == id {
|
|
clients = append(clients, v)
|
|
}
|
|
}
|
|
cl.RUnlock()
|
|
return clients
|
|
}
|
|
|
|
// Client contains information about a client known by the broker.
|
|
type Client struct {
|
|
sync.RWMutex
|
|
conn net.Conn // the net.Conn used to establish the connection.
|
|
r *circ.Reader // a reader for reading incoming bytes.
|
|
w *circ.Writer // a writer for writing outgoing bytes.
|
|
ID string // the client id.
|
|
AC auth.Controller // an auth controller inherited from the listener.
|
|
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
|
|
Listener string // the id of the listener the client is connected to.
|
|
InFlight InFlight // a map of in-flight qos messages.
|
|
Username []byte // the username the client authenticated with.
|
|
keepalive uint16 // the number of seconds the connection can wait.
|
|
cleanSession bool // indicates if the client expects a clean-session.
|
|
packetID uint32 // the current highest packetID.
|
|
lwt LWT // the last will and testament for the client.
|
|
state clientState // the operational state of the client.
|
|
//fh packets.FixedHeader // the FixedHeader from the last read packet.
|
|
}
|
|
|
|
// clientState tracks the state of the client.
|
|
type clientState struct {
|
|
started *sync.WaitGroup // tracks the goroutines which have been started.
|
|
endedW *sync.WaitGroup // tracks when the writer has ended.
|
|
endedR *sync.WaitGroup // tracks when the reader has ended.
|
|
end chan struct{} // a channel which indicates the client event loop should be halted.
|
|
done int64 // atomic counter which indicates that the client has closed.
|
|
endOnce *sync.Once // called to ensure the close methods are only called once.
|
|
//wasClosed bool // indicates that the connection was closed deliberately.
|
|
}
|
|
|
|
// NewClient returns a new instance of Client.
|
|
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer) *Client {
|
|
cl := &Client{
|
|
conn: c,
|
|
r: r,
|
|
w: w,
|
|
|
|
keepalive: defaultKeepalive,
|
|
InFlight: InFlight{
|
|
internal: make(map[uint16]InFlightMessage),
|
|
},
|
|
Subscriptions: make(map[string]byte),
|
|
|
|
state: clientState{
|
|
started: new(sync.WaitGroup),
|
|
endedW: new(sync.WaitGroup),
|
|
endedR: new(sync.WaitGroup),
|
|
end: make(chan struct{}),
|
|
endOnce: new(sync.Once),
|
|
},
|
|
}
|
|
return cl
|
|
}
|
|
|
|
// Identify sets the identification values of a client instance.
|
|
func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) {
|
|
cl.Listener = lid
|
|
cl.AC = ac
|
|
cl.ID = pk.ClientIdentifier
|
|
|
|
// TBR
|
|
//cl.id = strings.Replace(pk.ClientIdentifier, "Jonathans-MacBook-Pro.local-worker0", "", -1)
|
|
//if strings.Contains(cl.id, "sub") {
|
|
// cl.id = "\t\033[0;34m" + cl.id + "\033[0m"
|
|
//}
|
|
cl.r.ID = cl.ID + " READER"
|
|
cl.w.ID = cl.ID + " WRITER"
|
|
|
|
cl.Username = pk.Username
|
|
cl.keepalive = pk.Keepalive
|
|
cl.cleanSession = pk.CleanSession
|
|
|
|
if cl.ID == "" {
|
|
cl.ID = xid.New().String()
|
|
}
|
|
|
|
if pk.WillFlag {
|
|
cl.lwt = LWT{
|
|
topic: pk.WillTopic,
|
|
message: pk.WillMessage,
|
|
qos: pk.WillQos,
|
|
retain: pk.WillRetain,
|
|
}
|
|
}
|
|
}
|
|
|
|
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
|
func (cl *Client) refreshDeadline(keepalive uint16) {
|
|
if cl.conn != nil {
|
|
expiry := time.Duration(keepalive+(keepalive/2)) * time.Second
|
|
cl.conn.SetDeadline(time.Now().Add(expiry))
|
|
}
|
|
}
|
|
|
|
// NextPacketID returns the next packet id for a client, looping back to 0
|
|
// if the maximum ID has been reached.
|
|
func (cl *Client) NextPacketID() uint32 {
|
|
i := atomic.LoadUint32(&cl.packetID)
|
|
if i == uint32(65535) || i == uint32(0) {
|
|
atomic.StoreUint32(&cl.packetID, 1)
|
|
return 1
|
|
}
|
|
|
|
return atomic.AddUint32(&cl.packetID, 1)
|
|
}
|
|
|
|
// NoteSubscription makes a note of a subscription for the client.
|
|
func (cl *Client) NoteSubscription(filter string, qos byte) {
|
|
cl.Lock()
|
|
cl.Subscriptions[filter] = qos
|
|
cl.Unlock()
|
|
}
|
|
|
|
// ForgetSubscription forgests a subscription note for the client.
|
|
func (cl *Client) ForgetSubscription(filter string) {
|
|
cl.Lock()
|
|
delete(cl.Subscriptions, filter)
|
|
cl.Unlock()
|
|
}
|
|
|
|
// Start begins the client goroutines reading and writing packets.
|
|
func (cl *Client) Start() {
|
|
cl.state.started.Add(2)
|
|
|
|
go func() {
|
|
cl.state.started.Done()
|
|
//_, err :=
|
|
cl.w.WriteTo(cl.conn)
|
|
cl.state.endedW.Done()
|
|
//cl.close()
|
|
}()
|
|
cl.state.endedW.Add(1)
|
|
go func() {
|
|
cl.state.started.Done()
|
|
//_, err :=
|
|
cl.r.ReadFrom(cl.conn)
|
|
cl.state.endedR.Done()
|
|
//cl.close()
|
|
}()
|
|
|
|
cl.state.endedR.Add(1)
|
|
cl.state.started.Wait()
|
|
}
|
|
|
|
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
|
func (cl *Client) Stop() {
|
|
cl.r.Stop()
|
|
cl.w.Stop()
|
|
cl.state.endedW.Wait()
|
|
|
|
if cl.conn != nil {
|
|
cl.conn.Close()
|
|
cl.conn = nil
|
|
}
|
|
|
|
cl.state.endedR.Wait()
|
|
}
|
|
|
|
// readFixedHeader reads in the values of the next packet's fixed header.
|
|
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
|
p, err := cl.r.Read(1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = fh.Decode(p[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// The remaining length value can be up to 5 bytes. Read through each byte
|
|
// looking for continue values, and if found increase the read. Otherwise
|
|
// decode the bytes that were legit.
|
|
buf := make([]byte, 0, 6)
|
|
i := 1
|
|
n := 2
|
|
for ; n < 6; n++ {
|
|
p, err = cl.r.Read(n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
//if i >= len(p) {
|
|
// return ErrInsufficientBytes
|
|
//}
|
|
|
|
buf = append(buf, p[i])
|
|
|
|
// If it's not a continuation flag, end here.
|
|
if p[i] < 128 {
|
|
break
|
|
}
|
|
|
|
// If i has reached 4 without a length terminator, return a protocol violation.
|
|
i++
|
|
if i == 4 {
|
|
return packets.ErrOversizedLengthIndicator
|
|
}
|
|
}
|
|
|
|
// Calculate and store the remaining length of the packet payload.
|
|
rem, _ := binary.Uvarint(buf)
|
|
fh.Remaining = int(rem)
|
|
|
|
// Having successfully read n bytes, commit the tail forward.
|
|
cl.r.CommitTail(n)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Read reads new packets from a client connection
|
|
func (cl *Client) Read(h func(*Client, packets.Packet) (bool, error)) error {
|
|
for {
|
|
if atomic.LoadInt64(&cl.state.done) == 1 && cl.r.CapDelta() == 0 {
|
|
return nil
|
|
}
|
|
|
|
if cl.conn == nil {
|
|
return ErrConnectionClosed
|
|
}
|
|
|
|
cl.refreshDeadline(cl.keepalive)
|
|
fh := new(packets.FixedHeader)
|
|
err := cl.ReadFixedHeader(fh)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pk, err := cl.ReadPacket(fh)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
close, err := h(cl, pk) // Process inbound packet.
|
|
if err != nil || close {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// ReadPacket reads the remaining buffer into an MQTT packet.
|
|
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
|
|
pk.FixedHeader = *fh
|
|
|
|
p, err := cl.r.Read(pk.FixedHeader.Remaining)
|
|
if err != nil {
|
|
return pk, err
|
|
}
|
|
|
|
// Decode the remaining packet values using a fresh copy of the bytes,
|
|
// otherwise the next packet will change the data of this one.
|
|
px := append([]byte{}, p[:]...)
|
|
|
|
switch pk.FixedHeader.Type {
|
|
case packets.Connect:
|
|
err = pk.ConnectDecode(px)
|
|
case packets.Connack:
|
|
err = pk.ConnackDecode(px)
|
|
case packets.Publish:
|
|
err = pk.PublishDecode(px)
|
|
case packets.Puback:
|
|
err = pk.PubackDecode(px)
|
|
case packets.Pubrec:
|
|
err = pk.PubrecDecode(px)
|
|
case packets.Pubrel:
|
|
err = pk.PubrelDecode(px)
|
|
case packets.Pubcomp:
|
|
err = pk.PubcompDecode(px)
|
|
case packets.Subscribe:
|
|
err = pk.SubscribeDecode(px)
|
|
case packets.Suback:
|
|
err = pk.SubackDecode(px)
|
|
case packets.Unsubscribe:
|
|
err = pk.UnsubscribeDecode(px)
|
|
case packets.Unsuback:
|
|
err = pk.UnsubackDecode(px)
|
|
case packets.Pingreq:
|
|
// err = pk.PingreqDecode(px)
|
|
case packets.Pingresp:
|
|
// err = pk.PingrespDecode(px)
|
|
case packets.Disconnect:
|
|
// err = pk.DisconnectDecode(px)
|
|
default:
|
|
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
|
|
}
|
|
|
|
cl.r.CommitTail(pk.FixedHeader.Remaining)
|
|
|
|
return
|
|
}
|
|
|
|
// WritePacket encodes and writes a packet to the client.
|
|
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
|
if cl.conn == nil {
|
|
return 0, ErrConnectionClosed
|
|
}
|
|
|
|
cl.w.Mu.Lock()
|
|
defer cl.w.Mu.Unlock()
|
|
|
|
buf := new(bytes.Buffer)
|
|
switch pk.FixedHeader.Type {
|
|
case packets.Connect:
|
|
err = pk.ConnectEncode(buf)
|
|
case packets.Connack:
|
|
err = pk.ConnackEncode(buf)
|
|
case packets.Publish:
|
|
err = pk.PublishEncode(buf)
|
|
case packets.Puback:
|
|
err = pk.PubackEncode(buf)
|
|
case packets.Pubrec:
|
|
err = pk.PubrecEncode(buf)
|
|
case packets.Pubrel:
|
|
err = pk.PubrelEncode(buf)
|
|
case packets.Pubcomp:
|
|
err = pk.PubcompEncode(buf)
|
|
case packets.Subscribe:
|
|
err = pk.SubscribeEncode(buf)
|
|
case packets.Suback:
|
|
err = pk.SubackEncode(buf)
|
|
case packets.Unsubscribe:
|
|
err = pk.UnsubscribeEncode(buf)
|
|
case packets.Unsuback:
|
|
err = pk.UnsubackEncode(buf)
|
|
case packets.Pingreq:
|
|
err = pk.PingreqEncode(buf)
|
|
case packets.Pingresp:
|
|
err = pk.PingrespEncode(buf)
|
|
case packets.Disconnect:
|
|
err = pk.DisconnectEncode(buf)
|
|
default:
|
|
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
n, err = cl.w.Write(buf.Bytes())
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
cl.refreshDeadline(cl.keepalive)
|
|
|
|
return
|
|
}
|
|
|
|
// LWT contains the last will and testament details for a client connection.
|
|
type LWT struct {
|
|
topic string // the topic the will message shall be sent to.
|
|
message []byte // the message that shall be sent when the client disconnects.
|
|
qos byte // the quality of service desired.
|
|
retain bool // indicates whether the will message should be retained
|
|
}
|
|
|
|
// InFlightMessage contains data about a packet which is currently in-flight.
|
|
type InFlightMessage struct {
|
|
Packet packets.Packet // the packet currently in-flight.
|
|
Sent int64 // the last time the message was sent (for retries) in unixtime.
|
|
}
|
|
|
|
// InFlight is a map of InFlightMessage keyed on packet id.
|
|
type InFlight struct {
|
|
sync.RWMutex
|
|
internal map[uint16]InFlightMessage // internal contains the inflight messages.
|
|
}
|
|
|
|
// set stores the packet of an in-flight message, keyed on message id.
|
|
func (i *InFlight) Set(key uint16, in InFlightMessage) {
|
|
i.Lock()
|
|
i.internal[key] = in
|
|
i.Unlock()
|
|
}
|
|
|
|
// get returns the value of an in-flight message if it exists.
|
|
func (i *InFlight) Get(key uint16) (InFlightMessage, bool) {
|
|
i.RLock()
|
|
val, ok := i.internal[key]
|
|
i.RUnlock()
|
|
return val, ok
|
|
}
|
|
|
|
func (i *InFlight) GetAll() map[uint16]InFlightMessage {
|
|
i.RLock()
|
|
defer i.RUnlock()
|
|
return i.internal
|
|
}
|
|
|
|
// delete removes an in-flight message from the map.
|
|
func (i *InFlight) Delete(key uint16) {
|
|
i.Lock()
|
|
delete(i.internal, key)
|
|
i.Unlock()
|
|
}
|