From 6a2d5bede13a8b89e17f16e6cc1a73e9d8f54b6d Mon Sep 17 00:00:00 2001 From: Mochi Date: Sat, 4 Jan 2020 20:06:37 +0000 Subject: [PATCH] Websocket Listener --- server/listeners/http_sysinfo.go | 1 - server/listeners/listeners.go | 19 +--- server/listeners/listeners_test.go | 21 +---- server/listeners/mock.go | 4 +- server/listeners/mock_test.go | 9 +- server/listeners/websocket.go | 141 +++++++++++++++++------------ server/listeners/websocket_test.go | 47 ++++++++-- server/server.go | 6 +- server/server_test.go | 16 +++- 9 files changed, 151 insertions(+), 113 deletions(-) diff --git a/server/listeners/http_sysinfo.go b/server/listeners/http_sysinfo.go index 4e648c3..0d08392 100644 --- a/server/listeners/http_sysinfo.go +++ b/server/listeners/http_sysinfo.go @@ -63,7 +63,6 @@ func (l *HTTPStats) ID() string { // Listen starts listening on the listener's network address. func (l *HTTPStats) Listen(s *system.Info) error { l.system = s - mux := http.NewServeMux() mux.HandleFunc("/", l.jsonHandler) l.listen = &http.Server{ diff --git a/server/listeners/listeners.go b/server/listeners/listeners.go index a8a29a0..1dbbcd1 100644 --- a/server/listeners/listeners.go +++ b/server/listeners/listeners.go @@ -82,28 +82,20 @@ func (l *Listeners) Delete(id string) { } // Serve starts a listener serving from the internal map. -func (l *Listeners) Serve(id string, establisher EstablishFunc) error { +func (l *Listeners) Serve(id string, establisher EstablishFunc) { l.RLock() listener := l.internal[id] l.RUnlock() - // Start listening on the network address. - err := listener.Listen(l.system) - if err != nil { - return err - } - go func(e EstablishFunc) { defer l.wg.Done() l.wg.Add(1) listener.Serve(e) }(establisher) - - return nil } // ServeAll starts all listeners serving from the internal map. -func (l *Listeners) ServeAll(establisher EstablishFunc) error { +func (l *Listeners) ServeAll(establisher EstablishFunc) { l.RLock() i := 0 ids := make([]string, len(l.internal)) @@ -114,13 +106,8 @@ func (l *Listeners) ServeAll(establisher EstablishFunc) error { l.RUnlock() for _, id := range ids { - err := l.Serve(id, establisher) - if err != nil { - return err - } + l.Serve(id, establisher) } - - return nil } // Close stops a listener from the internal map. diff --git a/server/listeners/listeners_test.go b/server/listeners/listeners_test.go index 0d94a21..b6cf0bf 100644 --- a/server/listeners/listeners_test.go +++ b/server/listeners/listeners_test.go @@ -99,15 +99,6 @@ func TestServeListener(t *testing.T) { require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing) } -func TestServeListenerFailure(t *testing.T) { - l := New(nil) - m := NewMockListener("t1", ":1882") - m.errListen = true - l.Add(m) - err := l.Serve("t1", MockEstablisher) - require.Error(t, err) -} - func BenchmarkServeListener(b *testing.B) { l := New(nil) l.Add(NewMockListener("t1", ":1882")) @@ -121,8 +112,7 @@ func TestServeAllListeners(t *testing.T) { l.Add(NewMockListener("t1", ":1882")) l.Add(NewMockListener("t2", ":1882")) l.Add(NewMockListener("t3", ":1882")) - err := l.ServeAll(MockEstablisher) - require.NoError(t, err) + l.ServeAll(MockEstablisher) time.Sleep(time.Millisecond) require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing) @@ -138,15 +128,6 @@ func TestServeAllListeners(t *testing.T) { require.Equal(t, false, l.internal["t3"].(*MockListener).IsServing) } -func TestServeAllListenersFailure(t *testing.T) { - l := New(nil) - m := NewMockListener("t1", ":1882") - m.errListen = true - l.Add(m) - err := l.ServeAll(MockEstablisher) - require.Error(t, err) -} - func BenchmarkServeAllListeners(b *testing.B) { l := New(nil) l.Add(NewMockListener("t1", ":1882")) diff --git a/server/listeners/mock.go b/server/listeners/mock.go index 6a5e6cd..981596f 100644 --- a/server/listeners/mock.go +++ b/server/listeners/mock.go @@ -27,7 +27,7 @@ type MockListener struct { IsListening bool IsServing bool done chan bool - errListen bool + ErrListen bool } // NewMockListener returns a new instance of MockListener @@ -54,7 +54,7 @@ func (l *MockListener) Serve(establisher EstablishFunc) { // SetConfig sets the configuration values of the mock listener. func (l *MockListener) Listen(s *system.Info) error { - if l.errListen { + if l.ErrListen { return fmt.Errorf("listen failure") } diff --git a/server/listeners/mock_test.go b/server/listeners/mock_test.go index 5eead17..4d144dd 100644 --- a/server/listeners/mock_test.go +++ b/server/listeners/mock_test.go @@ -29,9 +29,16 @@ func TestNewMockListenerListen(t *testing.T) { require.Equal(t, ":1882", mocked.address) require.Equal(t, false, mocked.IsListening) - mocked.Listen(nil) + err := mocked.Listen(nil) + require.NoError(t, err) require.Equal(t, true, mocked.IsListening) } +func TestNewMockListenerListenFailure(t *testing.T) { + mocked := NewMockListener("t1", ":1882") + mocked.ErrListen = true + err := mocked.Listen(nil) + require.Error(t, err) +} func TestMockListenerServe(t *testing.T) { mocked := NewMockListener("t1", ":1882") diff --git a/server/listeners/websocket.go b/server/listeners/websocket.go index 7d8be12..06dc163 100644 --- a/server/listeners/websocket.go +++ b/server/listeners/websocket.go @@ -1,38 +1,85 @@ package listeners import ( - "fmt" + "context" + "errors" "net" "net/http" - "net/url" - "strings" "sync" "sync/atomic" + "time" - "golang.org/x/net/websocket" + "github.com/gorilla/websocket" "github.com/mochi-co/mqtt/server/listeners/auth" "github.com/mochi-co/mqtt/server/system" ) +var ( + ErrInvalidMessage = errors.New("Message type not binary") + + // wsUpgrader is used to upgrade the incoming http/tcp connection to a + // websocket compliant connection. + wsUpgrader = &websocket.Upgrader{ + Subprotocols: []string{"mqtt"}, + } +) + // Websocket is a listener for establishing websocket connections. type Websocket struct { sync.RWMutex - id string // the internal id of the listener. - protocol string // the protocol of the listener. - config *Config // configuration values for the listener. - address string // the network address to bind to. - listen net.Listener // a net.Listener which will listen for new clients. - end int64 // ensure the close methods are only called once.} + id string // the internal id of the listener. + config *Config // configuration values for the listener. + address string // the network address to bind to. + listen *http.Server // an http server for serving websocket connections. + end int64 // ensure the close methods are only called once. + establish EstablishFunc // the server's establish conection handler. +} + +// wsConn is a websocket connection which satisfies the net.Conn interface. +// Inspired by +type wsConn struct { + net.Conn + c *websocket.Conn +} + +// Read reads the next span of bytes from the websocket connection and returns +// the number of bytes read. +func (ws *wsConn) Read(p []byte) (n int, err error) { + op, r, err := ws.c.NextReader() + if err != nil { + return + } + + if op != websocket.BinaryMessage { + err = ErrInvalidMessage + return + } + + return r.Read(p) +} + +// Write writes bytes to the websocket connection. +func (ws *wsConn) Write(p []byte) (n int, err error) { + err = ws.c.WriteMessage(websocket.BinaryMessage, p) + if err != nil { + return + } + + return len(p), nil +} + +// Close signals the underlying websocket conn to close. +func (ws *wsConn) Close() error { + return ws.Conn.Close() } // NewWebsocket initialises and returns a new Websocket listener, listening on an address. func NewWebsocket(id, address string) *Websocket { return &Websocket{ - id: id, - protocol: "tcp", - address: address, - config: &Config{ // default configuration. + id: id, + address: address, + config: &Config{ Auth: new(auth.Allow), TLS: new(TLS), }, @@ -65,52 +112,31 @@ func (l *Websocket) ID() string { // Listen starts listening on the listener's network address. func (l *Websocket) Listen(s *system.Info) error { - var err error - l.listen, err = net.Listen(l.protocol, l.address) - if err != nil { - return err + mux := http.NewServeMux() + mux.HandleFunc("/", l.handler) + l.listen = &http.Server{ + Addr: l.address, + Handler: mux, } return nil } -// Serve starts waiting for new Websocket connections, and calls the connection -// establishment callback for any received. -func (l *Websocket) Serve(establish EstablishFunc) { - server := &websocket.Server{ - Handshake: func(c *websocket.Config, req *http.Request) error { - - c.Protocol = []string{"mqtt"} - - // If the remote address is an IP, prepend a protocol string so it can - // be parsed without errors. - if !strings.Contains(req.RemoteAddr, "://") { - req.RemoteAddr = "ws://" + req.RemoteAddr - } - - // Websocket struggles to get a request origin address, so the remote - // address from the request is parsed into the origin struct instead. - var err error - c.Origin, err = url.Parse(req.RemoteAddr) - if err != nil { - fmt.Println(err) - } - - return nil - }, - Handler: func(c *websocket.Conn) { - c.PayloadType = websocket.BinaryFrame - err := establish(l.id, c, l.config.Auth) - if err != nil { - fmt.Println(err) - } - }, - } - - err := http.Serve(l.listen, server) +func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) { + c, err := wsUpgrader.Upgrade(w, r, nil) if err != nil { return } + defer c.Close() + + l.establish(l.id, &wsConn{c.UnderlyingConn(), c}, l.config.Auth) +} + +// Serve starts waiting for new Websocket connections, and calls the connection +// establishment callback for any received. +func (l *Websocket) Serve(establish EstablishFunc) { + l.establish = establish + l.listen.ListenAndServe() } // Close closes the listener and any client connections. @@ -120,13 +146,10 @@ func (l *Websocket) Close(closeClients CloseFunc) { if atomic.LoadInt64(&l.end) == 0 { atomic.StoreInt64(&l.end, 1) - closeClients(l.id) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + l.listen.Shutdown(ctx) } - if l.listen != nil { - err := l.listen.Close() - if err != nil { - return - } - } + closeClients(l.id) } diff --git a/server/listeners/websocket_test.go b/server/listeners/websocket_test.go index 55da0b9..1ebaa4c 100644 --- a/server/listeners/websocket_test.go +++ b/server/listeners/websocket_test.go @@ -1,15 +1,26 @@ package listeners import ( - //"errors" - //"net" + "net" + "net/http" + "net/http/httptest" + "strings" "testing" "time" + "github.com/gorilla/websocket" + "github.com/mochi-co/mqtt/server/listeners/auth" "github.com/stretchr/testify/require" ) +func TestWsConnClose(t *testing.T) { + r, _ := net.Pipe() + ws := &wsConn{r, new(websocket.Conn)} + err := ws.Close() + require.NoError(t, err) +} + func TestNewWebsocket(t *testing.T) { l := NewWebsocket("t1", testPort) require.Equal(t, "t1", l.id) @@ -60,30 +71,46 @@ func BenchmarkWebsocketID(b *testing.B) { func TestWebsocketListen(t *testing.T) { l := NewWebsocket("t1", testPort) + require.Nil(t, l.listen) err := l.Listen(nil) require.NoError(t, err) - - l2 := NewWebsocket("t2", testPort) - err = l2.Listen(nil) - require.Error(t, err) - l.listen.Close() + require.NotNil(t, l.listen) } func TestWebsocketServeAndClose(t *testing.T) { l := NewWebsocket("t1", testPort) - err := l.Listen(nil) - require.NoError(t, err) - + l.Listen(nil) o := make(chan bool) go func(o chan bool) { l.Serve(MockEstablisher) o <- true }(o) time.Sleep(time.Millisecond) + var closed bool l.Close(func(id string) { closed = true }) require.Equal(t, true, closed) + <-o } + +func TestWebsocketUpgrade(t *testing.T) { + l := NewWebsocket("t1", testPort) + l.Listen(nil) + e := make(chan bool) + l.establish = func(id string, c net.Conn, ac auth.Controller) error { + e <- true + return nil + } + s := httptest.NewServer(http.HandlerFunc(l.handler)) + u := "ws" + strings.TrimPrefix(s.URL, "http") + ws, _, err := websocket.DefaultDialer.Dial(u, nil) + require.NoError(t, err) + require.Equal(t, true, <-e) + + s.Close() + ws.Close() + +} diff --git a/server/server.go b/server/server.go index a82eefa..73bc604 100644 --- a/server/server.go +++ b/server/server.go @@ -75,6 +75,10 @@ func (s *Server) AddListener(listener listeners.Listener, config *listeners.Conf } s.Listeners.Add(listener) + err := listener.Listen(s.System) + if err != nil { + return err + } return nil } @@ -528,7 +532,6 @@ func (s *Server) closeListenerClients(listener string) { // closeClient closes a client connection and publishes any LWT messages. func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error { if sendLWT && cl.LWT.Topic != "" { - // omit errors, since we're not logging and need to close the client in either case. s.processPublish(cl, packets.Packet{ FixedHeader: packets.FixedHeader{ Type: packets.Publish, @@ -538,6 +541,7 @@ func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error { TopicName: cl.LWT.Topic, Payload: cl.LWT.Message, }) + // omit errors, since we're not logging and need to close the client in either case. } cl.Stop() diff --git a/server/server_test.go b/server/server_test.go index 8b5e335..f2a456f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -18,6 +18,8 @@ import ( "github.com/mochi-co/mqtt/server/listeners/auth" ) +const defaultPort = ":18882" + func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) { s = New() r, w = net.Pipe() @@ -47,12 +49,11 @@ func BenchmarkNew(b *testing.B) { func TestServerAddListener(t *testing.T) { s := New() require.NotNil(t, s) - - err := s.AddListener(listeners.NewMockListener("t1", ":1882"), nil) + err := s.AddListener(listeners.NewMockListener("t1", defaultPort), nil) require.NoError(t, err) // Add listener with config. - err = s.AddListener(listeners.NewMockListener("t2", ":1882"), &listeners.Config{ + err = s.AddListener(listeners.NewMockListener("t2", defaultPort), &listeners.Config{ Auth: new(auth.Disallow), }) require.NoError(t, err) @@ -66,6 +67,15 @@ func TestServerAddListener(t *testing.T) { require.Equal(t, ErrListenerIDExists, err) } +func TestServerAddListenerFailure(t *testing.T) { + s := New() + require.NotNil(t, s) + m := listeners.NewMockListener("t1", ":1882") + m.ErrListen = true + err := s.AddListener(m, nil) + require.Error(t, err) +} + func BenchmarkServerAddListener(b *testing.B) { s := New() l := listeners.NewMockListener("t1", ":1882")