From e75b14c608344f18da318a5cc6f8470ec64094f3 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Mon, 10 May 2021 18:11:01 +0200 Subject: [PATCH] client: add DialReadContext, DialPublishContext --- client.go | 40 ++++++++++++++++++++++++++++++++++++++++ clientconn.go | 2 ++ 2 files changed, 42 insertions(+) diff --git a/client.go b/client.go index 03a397bc..f76df95d 100644 --- a/client.go +++ b/client.go @@ -115,6 +115,11 @@ func (c *Client) Dial(scheme string, host string) (*ClientConn, error) { // DialRead connects to the address and starts reading all tracks. func (c *Client) DialRead(address string) (*ClientConn, error) { + return c.DialReadContext(context.Background(), address) +} + +// DialReadContext connects to the address with the given context and starts reading all tracks. +func (c *Client) DialReadContext(ctx context.Context, address string) (*ClientConn, error) { u, err := base.ParseURL(address) if err != nil { return nil, err @@ -125,6 +130,21 @@ func (c *Client) DialRead(address string) (*ClientConn, error) { return nil, err } + ctxHandlerDone := make(chan struct{}) + defer func() { <-ctxHandlerDone }() + + ctxHandlerTerminate := make(chan struct{}) + defer close(ctxHandlerTerminate) + + go func() { + defer close(ctxHandlerDone) + select { + case <-ctx.Done(): + conn.Close() + case <-ctxHandlerTerminate: + } + }() + _, err = conn.Options(u) if err != nil { conn.Close() @@ -156,6 +176,11 @@ func (c *Client) DialRead(address string) (*ClientConn, error) { // DialPublish connects to the address and starts publishing the tracks. func (c *Client) DialPublish(address string, tracks Tracks) (*ClientConn, error) { + return c.DialPublishContext(context.Background(), address, tracks) +} + +// DialPublishContext connects to the address with the given context and starts publishing the tracks. +func (c *Client) DialPublishContext(ctx context.Context, address string, tracks Tracks) (*ClientConn, error) { u, err := base.ParseURL(address) if err != nil { return nil, err @@ -166,6 +191,21 @@ func (c *Client) DialPublish(address string, tracks Tracks) (*ClientConn, error) return nil, err } + ctxHandlerDone := make(chan struct{}) + defer func() { <-ctxHandlerDone }() + + ctxHandlerTerminate := make(chan struct{}) + defer close(ctxHandlerTerminate) + + go func() { + defer close(ctxHandlerDone) + select { + case <-ctx.Done(): + conn.Close() + case <-ctxHandlerTerminate: + } + }() + _, err = conn.Options(u) if err != nil { conn.Close() diff --git a/clientconn.go b/clientconn.go index b60944d2..2653c743 100644 --- a/clientconn.go +++ b/clientconn.go @@ -832,6 +832,8 @@ func (cc *ClientConn) do(req *base.Request, skipResponse bool) (*base.Response, // it's better not to stop the request, but wait until teardown if !skipResponse { ctxHandlerDone := make(chan struct{}) + defer func() { <-ctxHandlerDone }() + ctxHandlerTerminate := make(chan struct{}) defer close(ctxHandlerTerminate)