From f539368a80df93b1851003ba43893eb63d62a9a9 Mon Sep 17 00:00:00 2001 From: Aleksandr Razumov Date: Sun, 29 Apr 2018 14:20:09 +0300 Subject: [PATCH] client: add initialization checks --- client.go | 33 +++++++++++++++++++++------------ client_test.go | 15 +++++++++++---- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index 73124b5..a65bd90 100644 --- a/client.go +++ b/client.go @@ -182,6 +182,9 @@ var ErrClientClosed = errors.New("client is closed") // Close stops internal connection and agent, returning CloseErr on error. func (c *Client) Close() error { + if err := c.checkInit(); err != nil { + return err + } c.closedMux.Lock() if c.closed { c.closedMux.Unlock() @@ -189,18 +192,8 @@ func (c *Client) Close() error { } c.closed = true c.closedMux.Unlock() - var ( - agentErr, connErr error - ) - if c.a != nil { - agentErr = c.a.Close() - } - if c.c != nil { - connErr = c.c.Close() - } - if c.close != nil { - close(c.close) - } + agentErr, connErr := c.a.Close(), c.c.Close() + close(c.close) c.wg.Wait() if agentErr == nil && connErr == nil { return nil @@ -263,12 +256,25 @@ var callbackWaitHandlerPool = sync.Pool{ }, } +// ErrClientNotInitialized means that client connection or agent is nil. +var ErrClientNotInitialized = errors.New("client not initialized") + +func (c *Client) checkInit() error { + if c == nil || c.c == nil || c.a == nil || c.close == nil { + return ErrClientNotInitialized + } + return nil +} + // Do is Start wrapper that waits until callback is called. If no callback // provided, Indicate is called instead. // // Do has cpu overhead due to blocking, see BenchmarkClient_Do. // Use Start method for less overhead. func (c *Client) Do(m *Message, d time.Time, f func(Event)) error { + if err := c.checkInit(); err != nil { + return err + } if f == nil { return c.Indicate(m) } @@ -288,6 +294,9 @@ func (c *Client) Do(m *Message, d time.Time, f func(Event)) error { // Start starts transaction (if f set) and writes message to server, handler // is called asynchronously. func (c *Client) Start(m *Message, d time.Time, h Handler) error { + if err := c.checkInit(); err != nil { + return err + } c.closedMux.RLock() closed := c.closed c.closedMux.RUnlock() diff --git a/client_test.go b/client_test.go index df3c4d0..688fc2f 100644 --- a/client_test.go +++ b/client_test.go @@ -447,6 +447,15 @@ func TestClientGC(t *testing.T) { } } +func TestClientCheckInit(t *testing.T) { + if err := (&Client{}).Indicate(nil); err != ErrClientNotInitialized { + t.Error("unexpected error") + } + if err := (&Client{}).Do(nil, time.Time{}, nil); err != ErrClientNotInitialized { + t.Error("unexpected error") + } +} + func TestClientFinalizer(t *testing.T) { var buf bytes.Buffer log.SetOutput(&buf) @@ -466,9 +475,7 @@ func TestClientFinalizer(t *testing.T) { if err != nil { log.Fatal(err) } - if err := c.Close(); err != nil { - t.Error(err) - } + clientFinalizer(c) clientFinalizer(c) response := MustBuild(TransactionID, BindingSuccess) response.Encode() @@ -497,7 +504,7 @@ func TestClientFinalizer(t *testing.T) { if reader.Err() != nil { t.Error(err) } - if lines != 2 { + if lines != 3 { t.Error("incorrect count of log lines:", lines) } }