package mqtt import ( "errors" "log" "net" "sync" "github.com/rs/xid" "github.com/mochi-co/mqtt/auth" "github.com/mochi-co/mqtt/listeners" "github.com/mochi-co/mqtt/packets" "github.com/mochi-co/mqtt/pools" ) var ( ErrListenerIDExists = errors.New("Listener id already exists") ErrReadConnectFixedHeader = errors.New("Error reading fixed header on CONNECT packet") ErrReadConnectPacket = errors.New("Error reading CONNECT packet") ErrFirstPacketInvalid = errors.New("First packet was not CONNECT packet") ErrReadConnectInvalid = errors.New("CONNECT packet was not valid") ErrConnectNotAuthorized = errors.New("CONNECT packet was not authorized") ErrReadFixedHeader = errors.New("Error reading fixed header") ErrReadPacketPayload = errors.New("Error reading packet payload") ErrReadPacketValidation = errors.New("Error validating packet") ErrConnectionClosed = errors.New("Connection not open") ErrNoData = errors.New("No data") // clientKeepalive is the default keepalive time in seconds. clientKeepalive uint16 = 60 ) /* ErrListenerInvalid = errors.New("listener validation failed") ErrListenerIDExists = errors.New("listener id already exists") ErrPortInUse = errors.New("port already in use") ErrFailedListening = errors.New("couldnt start net listener") ErrIDNotSet = errors.New("id not set") ErrListenerNotFound = errors.New("listener id not found") ErrFailedInitializing = errors.New("failed initializing") ErrFailedServingTCP = errors.New("error serving tcp listener") ErrFailedServingWS = errors.New("error serving websocket listener") ErrAcceptConnection = errors.New("error accepting connection") ErrEstablishingConnection = errors.New("error establishing connection") ErrCloseConnection = errors.New("error closing connection") ErrReadConnectPacket = errors.New("error reading CONNECT packet") ErrFirstPacketInvalid = errors.New("first packet was not CONNECT packet") ErrReadConnectInvalid = errors.New("CONNECT packet was not valid") ErrParsingRemoteOrigin = errors.New("error parsing remote origin from websocket") ErrFailedConnack = errors.New("failed sending CONNACK packet") */ // Server is an MQTT broker server. type Server struct { // listeners is a map of listeners, which listen for new connections. listeners listeners.Listeners // clients is a map of clients known to the broker. clients clients // buffers is a pool of bytes.buffer. buffers pools.BytesBuffersPool // readers is a pool of bufio.reader. readers pools.BufioReadersPool // writers is a pool of bufio.writer. writers pools.BufioWritersPool } // New returns a pointer to a new instance of the MQTT broker. func New() *Server { return &Server{ listeners: listeners.NewListeners(), clients: newClients(), buffers: pools.NewBytesBuffersPool(), readers: pools.NewBufioReadersPool(512), writers: pools.NewBufioWritersPool(512), } } // AddListener adds a new network listener to the server. func (s *Server) AddListener(listener listeners.Listener, config *listeners.Config) error { if _, ok := s.listeners.Get(listener.ID()); ok { return ErrListenerIDExists } if config != nil { listener.SetConfig(config) } s.listeners.Add(listener) return nil } // Serve begins the event loops for establishing client connections on all // attached listeners. func (s *Server) Serve() error { s.listeners.ServeAll(s.EstablishConnection) return nil } // Close attempts to gracefully shutdown the server, all listeners, and clients. func (s *Server) Close() error { s.listeners.CloseAll(listeners.MockCloser) return nil } // EstablishConnection establishes a new client connection with the broker. func (s *Server) EstablishConnection(c net.Conn, ac auth.Controller) error { log.Println("connecting") // Create a new packets parser which will parse all packets for this client, // using buffered writers and readers from the pool. r, w := s.readers.Get(c), s.writers.Get(c) defer s.readers.Put(r) defer s.writers.Put(w) p := packets.NewParser(c, r, w) // Pull the header from the first packet and check for a CONNECT message. fh := new(packets.FixedHeader) err := p.ReadFixedHeader(fh) if err != nil { log.Println("_A", err) return ErrReadConnectFixedHeader } // Read the first packet expecting a CONNECT message. pk, err := p.Read() if err != nil { log.Println("_B", err) return ErrReadConnectPacket } // Ensure first packet is a connect packet. msg, ok := pk.(*packets.ConnectPacket) if !ok { log.Println("_C") return ErrFirstPacketInvalid } // Ensure the packet conforms to MQTT CONNECT specifications. retcode, _ := msg.Validate() if retcode != packets.Accepted { log.Println("_D", retcode) return ErrReadConnectInvalid } // If a username and password has been provided, perform authentication. if msg.Username != "" && !ac.Authenticate(msg.Username, msg.Password) { retcode = packets.CodeConnectNotAuthorised log.Println("_E", retcode) return ErrConnectNotAuthorized } // Add the new client to the clients manager. client := newClient(p, msg) s.clients.Add(client) // Send a CONNACK back to the client. err = s.writeClient(client, &packets.ConnackPacket{ FixedHeader: packets.NewFixedHeader(packets.Connack), SessionPresent: msg.CleanSession, ReturnCode: retcode, }) // Publish out any unacknowledged QOS messages still pending for the client. // @TODO ... // Block and listen for more packets, and end if an error or nil packet occurs. // @TODO ... s.readClient log.Println(msg, pk) return nil } // readClient reads new packets from a client connection. func (s *Server) readClient(cl *client) error { var err error var pk packets.Packet fh := new(packets.FixedHeader) DONE: for { select { case <-cl.end: break DONE default: if cl.p.Conn == nil { return ErrConnectionClosed } // Reset the keepalive read deadline. cl.p.RefreshDeadline(cl.keepalive) // Read in the fixed header of the packet. err = cl.p.ReadFixedHeader(fh) if err != nil { return ErrReadFixedHeader } // If it's a disconnect packet, begin the close process. if fh.Type == packets.Disconnect { return nil } // Otherwise read in the packet payload. pk, err = cl.p.Read() if err != nil { return ErrReadPacketPayload } // Validate the packet if necessary. _, err := pk.Validate() if err != nil { return ErrReadPacketValidation } // Process inbound packet. go s.processPacket(cl, pk) } } return nil } // writeClient writes packets to a client connection. func (s *Server) writeClient(cl *client, pk packets.Packet) error { // Ensure Writer is open. if cl.p.W == nil { return ErrConnectionClosed } // Encode packet to a pooled byte buffer. buf := s.buffers.Get() defer s.buffers.Put(buf) err := pk.Encode(buf) if err != nil { return err } // Write packet to client. _, err = buf.WriteTo(cl.p.W) if err != nil { return err } err = cl.p.W.Flush() if err != nil { return err } // Refresh deadline. cl.p.RefreshDeadline(cl.keepalive) // Log $SYS stats. // @TODO ... return nil } // closeClient closes a client connection and publishes any LWT messages. func (s *Server) closeClient(cl *client) error { // close client connection // send LWT return nil } // processPacket processes an inbound packet for a client. func (s *Server) processPacket(cl *client, pk packets.Packet) error { log.Println("PROCESSING PACKET", cl, pk) // Log read stats for $SYS. // @TODO ... // // switch on packet type //// connect // stop //// disconnect // stop //// ping // pingresp // else stop //// publish // retain if 1 // find valid subscribers // upgrade copied packet // if (qos > 1) add packetID > cl.nextPacketID() // write packet to client > go s.writeClient // handle qos > s.processQOS(cl, pk) //// pub* // handle qos > s.processQOS(cl, pk) //// subscribe // subscribe topics // send subacks // receive any retained messages //// unsubscribe // unsubscribe topics // send unsuback return nil } // processQOS handles the back and forth of QOS>0 packets. func (s *Server) processQOS(cl *client, pk packets.Packet) error { // handle publish in/out // handle puback // handle pubrec // handle pubrel // handle pubcomp return nil } // clients contains a map of the clients known by the broker. type clients struct { sync.RWMutex // internal is a map of the clients known by the broker, keyed on client id. internal map[string]*client } // newClients returns an instance of clients. func newClients() clients { return clients{ internal: make(map[string]*client), } } // Add adds a new client to the clients map, keyed on client id. func (cl *clients) Add(val *client) { cl.Lock() cl.internal[val.id] = val cl.Unlock() } // Get returns the value of a client if it exists. func (cl *clients) Get(id string) (*client, bool) { cl.RLock() val, ok := cl.internal[id] cl.RUnlock() return val, ok } // Len returns the length of the clients map. func (cl *clients) Len() int { cl.RLock() val := len(cl.internal) cl.RUnlock() return val } // Delete removes a client from the internal map. func (cl *clients) Delete(id string) { cl.Lock() delete(cl.internal, id) cl.Unlock() } // Client contains information about a client known by the broker. type client struct { sync.RWMutex // p is a packets parser which reads incoming packets. p *packets.Parser // end is a channel that indicates the client should halt. end chan struct{} // done can be called to ensure the close methods are only called once. done *sync.Once // id is the client id. id string // user is the username the client authenticated with. user string // keepalive is the number of seconds the connection can stay open without // receiving a message from the client. keepalive uint16 // cleanSession indicates if the client expects a cleansession. cleanSession bool } // newClient creates a new instance of client. func newClient(p *packets.Parser, pk *packets.ConnectPacket) *client { cl := &client{ p: p, end: make(chan struct{}), done: new(sync.Once), id: pk.ClientIdentifier, user: pk.Username, keepalive: pk.Keepalive, cleanSession: pk.CleanSession, } // If no client id was provided, generate a new one. if cl.id == "" { cl.id = xid.New().String() } // if no deadline value was provided, set it to the default seconds. if cl.keepalive == 0 { cl.keepalive = clientKeepalive } // If a last will and testament has been provided, record it. /*if pk.WillFlag { // @TODO ... client.will = lwt{ topic: pk.WillTopic, message: pk.WillMessage, qos: pk.WillQos, retain: pk.WillRetain, } } */ return cl }