Websocket Listener

This commit is contained in:
Mochi
2020-01-04 20:06:37 +00:00
parent 0eaa111383
commit 6a2d5bede1
9 changed files with 151 additions and 113 deletions

View File

@@ -63,7 +63,6 @@ func (l *HTTPStats) ID() string {
// Listen starts listening on the listener's network address. // Listen starts listening on the listener's network address.
func (l *HTTPStats) Listen(s *system.Info) error { func (l *HTTPStats) Listen(s *system.Info) error {
l.system = s l.system = s
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", l.jsonHandler) mux.HandleFunc("/", l.jsonHandler)
l.listen = &http.Server{ l.listen = &http.Server{

View File

@@ -82,28 +82,20 @@ func (l *Listeners) Delete(id string) {
} }
// Serve starts a listener serving from the internal map. // Serve starts a listener serving from the internal map.
func (l *Listeners) Serve(id string, establisher EstablishFunc) error { func (l *Listeners) Serve(id string, establisher EstablishFunc) {
l.RLock() l.RLock()
listener := l.internal[id] listener := l.internal[id]
l.RUnlock() l.RUnlock()
// Start listening on the network address.
err := listener.Listen(l.system)
if err != nil {
return err
}
go func(e EstablishFunc) { go func(e EstablishFunc) {
defer l.wg.Done() defer l.wg.Done()
l.wg.Add(1) l.wg.Add(1)
listener.Serve(e) listener.Serve(e)
}(establisher) }(establisher)
return nil
} }
// ServeAll starts all listeners serving from the internal map. // ServeAll starts all listeners serving from the internal map.
func (l *Listeners) ServeAll(establisher EstablishFunc) error { func (l *Listeners) ServeAll(establisher EstablishFunc) {
l.RLock() l.RLock()
i := 0 i := 0
ids := make([]string, len(l.internal)) ids := make([]string, len(l.internal))
@@ -114,15 +106,10 @@ func (l *Listeners) ServeAll(establisher EstablishFunc) error {
l.RUnlock() l.RUnlock()
for _, id := range ids { for _, id := range ids {
err := l.Serve(id, establisher) l.Serve(id, establisher)
if err != nil {
return err
} }
} }
return nil
}
// Close stops a listener from the internal map. // Close stops a listener from the internal map.
func (l *Listeners) Close(id string, closer CloseFunc) { func (l *Listeners) Close(id string, closer CloseFunc) {
l.RLock() l.RLock()

View File

@@ -99,15 +99,6 @@ func TestServeListener(t *testing.T) {
require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing) require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing)
} }
func TestServeListenerFailure(t *testing.T) {
l := New(nil)
m := NewMockListener("t1", ":1882")
m.errListen = true
l.Add(m)
err := l.Serve("t1", MockEstablisher)
require.Error(t, err)
}
func BenchmarkServeListener(b *testing.B) { func BenchmarkServeListener(b *testing.B) {
l := New(nil) l := New(nil)
l.Add(NewMockListener("t1", ":1882")) l.Add(NewMockListener("t1", ":1882"))
@@ -121,8 +112,7 @@ func TestServeAllListeners(t *testing.T) {
l.Add(NewMockListener("t1", ":1882")) l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882")) l.Add(NewMockListener("t2", ":1882"))
l.Add(NewMockListener("t3", ":1882")) l.Add(NewMockListener("t3", ":1882"))
err := l.ServeAll(MockEstablisher) l.ServeAll(MockEstablisher)
require.NoError(t, err)
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing) require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing)
@@ -138,15 +128,6 @@ func TestServeAllListeners(t *testing.T) {
require.Equal(t, false, l.internal["t3"].(*MockListener).IsServing) require.Equal(t, false, l.internal["t3"].(*MockListener).IsServing)
} }
func TestServeAllListenersFailure(t *testing.T) {
l := New(nil)
m := NewMockListener("t1", ":1882")
m.errListen = true
l.Add(m)
err := l.ServeAll(MockEstablisher)
require.Error(t, err)
}
func BenchmarkServeAllListeners(b *testing.B) { func BenchmarkServeAllListeners(b *testing.B) {
l := New(nil) l := New(nil)
l.Add(NewMockListener("t1", ":1882")) l.Add(NewMockListener("t1", ":1882"))

View File

@@ -27,7 +27,7 @@ type MockListener struct {
IsListening bool IsListening bool
IsServing bool IsServing bool
done chan bool done chan bool
errListen bool ErrListen bool
} }
// NewMockListener returns a new instance of MockListener // NewMockListener returns a new instance of MockListener
@@ -54,7 +54,7 @@ func (l *MockListener) Serve(establisher EstablishFunc) {
// SetConfig sets the configuration values of the mock listener. // SetConfig sets the configuration values of the mock listener.
func (l *MockListener) Listen(s *system.Info) error { func (l *MockListener) Listen(s *system.Info) error {
if l.errListen { if l.ErrListen {
return fmt.Errorf("listen failure") return fmt.Errorf("listen failure")
} }

View File

@@ -29,9 +29,16 @@ func TestNewMockListenerListen(t *testing.T) {
require.Equal(t, ":1882", mocked.address) require.Equal(t, ":1882", mocked.address)
require.Equal(t, false, mocked.IsListening) require.Equal(t, false, mocked.IsListening)
mocked.Listen(nil) err := mocked.Listen(nil)
require.NoError(t, err)
require.Equal(t, true, mocked.IsListening) require.Equal(t, true, mocked.IsListening)
} }
func TestNewMockListenerListenFailure(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
mocked.ErrListen = true
err := mocked.Listen(nil)
require.Error(t, err)
}
func TestMockListenerServe(t *testing.T) { func TestMockListenerServe(t *testing.T) {
mocked := NewMockListener("t1", ":1882") mocked := NewMockListener("t1", ":1882")

View File

@@ -1,38 +1,85 @@
package listeners package listeners
import ( import (
"fmt" "context"
"errors"
"net" "net"
"net/http" "net/http"
"net/url"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"golang.org/x/net/websocket" "github.com/gorilla/websocket"
"github.com/mochi-co/mqtt/server/listeners/auth" "github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system" "github.com/mochi-co/mqtt/server/system"
) )
var (
ErrInvalidMessage = errors.New("Message type not binary")
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
// websocket compliant connection.
wsUpgrader = &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
}
)
// Websocket is a listener for establishing websocket connections. // Websocket is a listener for establishing websocket connections.
type Websocket struct { type Websocket struct {
sync.RWMutex sync.RWMutex
id string // the internal id of the listener. id string // the internal id of the listener.
protocol string // the protocol of the listener.
config *Config // configuration values for the listener. config *Config // configuration values for the listener.
address string // the network address to bind to. address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients. listen *http.Server // an http server for serving websocket connections.
end int64 // ensure the close methods are only called once.} end int64 // ensure the close methods are only called once.
establish EstablishFunc // the server's establish conection handler.
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
// Inspired by
type wsConn struct {
net.Conn
c *websocket.Conn
}
// Read reads the next span of bytes from the websocket connection and returns
// the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) {
op, r, err := ws.c.NextReader()
if err != nil {
return
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return
}
return r.Read(p)
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return
}
return len(p), nil
}
// Close signals the underlying websocket conn to close.
func (ws *wsConn) Close() error {
return ws.Conn.Close()
} }
// NewWebsocket initialises and returns a new Websocket listener, listening on an address. // NewWebsocket initialises and returns a new Websocket listener, listening on an address.
func NewWebsocket(id, address string) *Websocket { func NewWebsocket(id, address string) *Websocket {
return &Websocket{ return &Websocket{
id: id, id: id,
protocol: "tcp",
address: address, address: address,
config: &Config{ // default configuration. config: &Config{
Auth: new(auth.Allow), Auth: new(auth.Allow),
TLS: new(TLS), TLS: new(TLS),
}, },
@@ -65,52 +112,31 @@ func (l *Websocket) ID() string {
// Listen starts listening on the listener's network address. // Listen starts listening on the listener's network address.
func (l *Websocket) Listen(s *system.Info) error { func (l *Websocket) Listen(s *system.Info) error {
var err error mux := http.NewServeMux()
l.listen, err = net.Listen(l.protocol, l.address) mux.HandleFunc("/", l.handler)
if err != nil { l.listen = &http.Server{
return err Addr: l.address,
Handler: mux,
} }
return nil return nil
} }
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
c, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
l.establish(l.id, &wsConn{c.UnderlyingConn(), c}, l.config.Auth)
}
// Serve starts waiting for new Websocket connections, and calls the connection // Serve starts waiting for new Websocket connections, and calls the connection
// establishment callback for any received. // establishment callback for any received.
func (l *Websocket) Serve(establish EstablishFunc) { func (l *Websocket) Serve(establish EstablishFunc) {
server := &websocket.Server{ l.establish = establish
Handshake: func(c *websocket.Config, req *http.Request) error { l.listen.ListenAndServe()
c.Protocol = []string{"mqtt"}
// If the remote address is an IP, prepend a protocol string so it can
// be parsed without errors.
if !strings.Contains(req.RemoteAddr, "://") {
req.RemoteAddr = "ws://" + req.RemoteAddr
}
// Websocket struggles to get a request origin address, so the remote
// address from the request is parsed into the origin struct instead.
var err error
c.Origin, err = url.Parse(req.RemoteAddr)
if err != nil {
fmt.Println(err)
}
return nil
},
Handler: func(c *websocket.Conn) {
c.PayloadType = websocket.BinaryFrame
err := establish(l.id, c, l.config.Auth)
if err != nil {
fmt.Println(err)
}
},
}
err := http.Serve(l.listen, server)
if err != nil {
return
}
} }
// Close closes the listener and any client connections. // Close closes the listener and any client connections.
@@ -120,13 +146,10 @@ func (l *Websocket) Close(closeClients CloseFunc) {
if atomic.LoadInt64(&l.end) == 0 { if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1) atomic.StoreInt64(&l.end, 1)
closeClients(l.id) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
} }
if l.listen != nil { closeClients(l.id)
err := l.listen.Close()
if err != nil {
return
}
}
} }

View File

@@ -1,15 +1,26 @@
package listeners package listeners
import ( import (
//"errors" "net"
//"net" "net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
"github.com/gorilla/websocket"
"github.com/mochi-co/mqtt/server/listeners/auth" "github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestWsConnClose(t *testing.T) {
r, _ := net.Pipe()
ws := &wsConn{r, new(websocket.Conn)}
err := ws.Close()
require.NoError(t, err)
}
func TestNewWebsocket(t *testing.T) { func TestNewWebsocket(t *testing.T) {
l := NewWebsocket("t1", testPort) l := NewWebsocket("t1", testPort)
require.Equal(t, "t1", l.id) require.Equal(t, "t1", l.id)
@@ -60,30 +71,46 @@ func BenchmarkWebsocketID(b *testing.B) {
func TestWebsocketListen(t *testing.T) { func TestWebsocketListen(t *testing.T) {
l := NewWebsocket("t1", testPort) l := NewWebsocket("t1", testPort)
require.Nil(t, l.listen)
err := l.Listen(nil) err := l.Listen(nil)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, l.listen)
l2 := NewWebsocket("t2", testPort)
err = l2.Listen(nil)
require.Error(t, err)
l.listen.Close()
} }
func TestWebsocketServeAndClose(t *testing.T) { func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testPort) l := NewWebsocket("t1", testPort)
err := l.Listen(nil) l.Listen(nil)
require.NoError(t, err)
o := make(chan bool) o := make(chan bool)
go func(o chan bool) { go func(o chan bool) {
l.Serve(MockEstablisher) l.Serve(MockEstablisher)
o <- true o <- true
}(o) }(o)
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
var closed bool var closed bool
l.Close(func(id string) { l.Close(func(id string) {
closed = true closed = true
}) })
require.Equal(t, true, closed) require.Equal(t, true, closed)
<-o <-o
} }
func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.Listen(nil)
e := make(chan bool)
l.establish = func(id string, c net.Conn, ac auth.Controller) error {
e <- true
return nil
}
s := httptest.NewServer(http.HandlerFunc(l.handler))
u := "ws" + strings.TrimPrefix(s.URL, "http")
ws, _, err := websocket.DefaultDialer.Dial(u, nil)
require.NoError(t, err)
require.Equal(t, true, <-e)
s.Close()
ws.Close()
}

View File

@@ -75,6 +75,10 @@ func (s *Server) AddListener(listener listeners.Listener, config *listeners.Conf
} }
s.Listeners.Add(listener) s.Listeners.Add(listener)
err := listener.Listen(s.System)
if err != nil {
return err
}
return nil return nil
} }
@@ -528,7 +532,6 @@ func (s *Server) closeListenerClients(listener string) {
// closeClient closes a client connection and publishes any LWT messages. // closeClient closes a client connection and publishes any LWT messages.
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error { func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
if sendLWT && cl.LWT.Topic != "" { if sendLWT && cl.LWT.Topic != "" {
// omit errors, since we're not logging and need to close the client in either case.
s.processPublish(cl, packets.Packet{ s.processPublish(cl, packets.Packet{
FixedHeader: packets.FixedHeader{ FixedHeader: packets.FixedHeader{
Type: packets.Publish, Type: packets.Publish,
@@ -538,6 +541,7 @@ func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
TopicName: cl.LWT.Topic, TopicName: cl.LWT.Topic,
Payload: cl.LWT.Message, Payload: cl.LWT.Message,
}) })
// omit errors, since we're not logging and need to close the client in either case.
} }
cl.Stop() cl.Stop()

View File

@@ -18,6 +18,8 @@ import (
"github.com/mochi-co/mqtt/server/listeners/auth" "github.com/mochi-co/mqtt/server/listeners/auth"
) )
const defaultPort = ":18882"
func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) { func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) {
s = New() s = New()
r, w = net.Pipe() r, w = net.Pipe()
@@ -47,12 +49,11 @@ func BenchmarkNew(b *testing.B) {
func TestServerAddListener(t *testing.T) { func TestServerAddListener(t *testing.T) {
s := New() s := New()
require.NotNil(t, s) require.NotNil(t, s)
err := s.AddListener(listeners.NewMockListener("t1", defaultPort), nil)
err := s.AddListener(listeners.NewMockListener("t1", ":1882"), nil)
require.NoError(t, err) require.NoError(t, err)
// Add listener with config. // Add listener with config.
err = s.AddListener(listeners.NewMockListener("t2", ":1882"), &listeners.Config{ err = s.AddListener(listeners.NewMockListener("t2", defaultPort), &listeners.Config{
Auth: new(auth.Disallow), Auth: new(auth.Disallow),
}) })
require.NoError(t, err) require.NoError(t, err)
@@ -66,6 +67,15 @@ func TestServerAddListener(t *testing.T) {
require.Equal(t, ErrListenerIDExists, err) require.Equal(t, ErrListenerIDExists, err)
} }
func TestServerAddListenerFailure(t *testing.T) {
s := New()
require.NotNil(t, s)
m := listeners.NewMockListener("t1", ":1882")
m.ErrListen = true
err := s.AddListener(m, nil)
require.Error(t, err)
}
func BenchmarkServerAddListener(b *testing.B) { func BenchmarkServerAddListener(b *testing.B) {
s := New() s := New()
l := listeners.NewMockListener("t1", ":1882") l := listeners.NewMockListener("t1", ":1882")