diff --git a/client.go b/client.go index 9b63758..73124b5 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,9 @@ import ( "errors" "fmt" "io" + "log" "net" + "runtime" "sync" "time" ) @@ -56,9 +58,25 @@ func NewClient(options ClientOptions) (*Client, error) { c.wg.Add(2) go c.readUntilClosed() go c.collectUntilClosed() + runtime.SetFinalizer(c, clientFinalizer) return c, nil } +func clientFinalizer(c *Client) { + if c == nil { + return + } + err := c.Close() + if err == ErrClientClosed { + return + } + if err == nil { + log.Println("client: called finalizer on non-closed client") + return + } + log.Println("client: called finalizer on non-closed client:", err) +} + // Connection wraps Reader, Writer and Closer interfaces. type Connection interface { io.Reader @@ -171,9 +189,18 @@ func (c *Client) Close() error { } c.closed = true c.closedMux.Unlock() - agentErr := c.a.Close() - connErr := c.c.Close() - close(c.close) + var ( + agentErr, connErr error + ) + if c.a != nil { + agentErr = c.a.Close() + } + if c.c != nil { + connErr = c.c.Close() + } + if c.close != nil { + close(c.close) + } c.wg.Wait() if agentErr == nil && connErr == nil { return nil diff --git a/client_test.go b/client_test.go index 000efcb..df3c4d0 100644 --- a/client_test.go +++ b/client_test.go @@ -1,9 +1,12 @@ package stun import ( + "bufio" + "bytes" "errors" "io" "log" + "os" "testing" "time" ) @@ -443,3 +446,58 @@ func TestClientGC(t *testing.T) { t.Error("timed out") } } + +func TestClientFinalizer(t *testing.T) { + var buf bytes.Buffer + log.SetOutput(&buf) + defer func() { + log.SetOutput(os.Stderr) + }() + clientFinalizer(nil) // should not panic + clientFinalizer(&Client{}) + conn := &testConnection{ + write: func(bytes []byte) (int, error) { + return 0, io.ErrClosedPipe + }, + } + c, err := NewClient(ClientOptions{ + Connection: conn, + }) + if err != nil { + log.Fatal(err) + } + if err := c.Close(); err != nil { + t.Error(err) + } + clientFinalizer(c) + response := MustBuild(TransactionID, BindingSuccess) + response.Encode() + conn = &testConnection{ + b: response.Raw, + write: func(bytes []byte) (int, error) { + return len(bytes), nil + }, + } + c, err = NewClient(ClientOptions{ + Agent: errorAgent{ + closeErr: io.ErrUnexpectedEOF, + }, + Connection: conn, + }) + if err != nil { + log.Fatal(err) + } + clientFinalizer(c) + reader := bufio.NewScanner(&buf) + var lines int + for reader.Scan() { + lines += 1 + t.Log(reader.Text()) + } + if reader.Err() != nil { + t.Error(err) + } + if lines != 2 { + t.Error("incorrect count of log lines:", lines) + } +}