implement client TLS support

This commit is contained in:
aler9
2020-12-14 22:49:47 +01:00
parent 9cd36cdd68
commit 61318d7f96
10 changed files with 141 additions and 34 deletions

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"bufio"
"crypto/tls"
"fmt"
"net"
"strings"
@@ -18,8 +19,8 @@ import (
var DefaultClientConf = ClientConf{}
// Dial connects to a server.
func Dial(host string) (*ClientConn, error) {
return DefaultClientConf.Dial(host)
func Dial(scheme string, host string) (*ClientConn, error) {
return DefaultClientConf.Dial(scheme, host)
}
// DialRead connects to a server and starts reading all tracks.
@@ -40,6 +41,10 @@ type ClientConf struct {
// It defaults to nil.
StreamProtocol *StreamProtocol
// A TLS configuration to connect to TLS (RTSPS) servers.
// It defaults to &tls.Config{InsecureSkipVerify:true}
TLSConfig *tls.Config
// timeout of read operations.
// It defaults to 10 seconds.
ReadTimeout time.Duration
@@ -74,7 +79,10 @@ type ClientConf struct {
}
// Dial connects to a server.
func (c ClientConf) Dial(host string) (*ClientConn, error) {
func (c ClientConf) Dial(scheme string, host string) (*ClientConn, error) {
if c.TLSConfig == nil {
c.TLSConfig = &tls.Config{InsecureSkipVerify: true}
}
if c.ReadTimeout == 0 {
c.ReadTimeout = 10 * time.Second
}
@@ -91,6 +99,10 @@ func (c ClientConf) Dial(host string) (*ClientConn, error) {
c.ListenPacket = net.ListenPacket
}
if scheme != "rtsp" && scheme != "rtsps" {
return nil, fmt.Errorf("unsupported scheme '%s'", scheme)
}
if !strings.Contains(host, ":") {
host += ":554"
}
@@ -100,11 +112,18 @@ func (c ClientConf) Dial(host string) (*ClientConn, error) {
return nil, err
}
conn := func() net.Conn {
if scheme == "rtsps" {
return tls.Client(nconn, c.TLSConfig)
}
return nconn
}()
return &ClientConn{
conf: c,
nconn: nconn,
br: bufio.NewReaderSize(nconn, clientReadBufferSize),
bw: bufio.NewWriterSize(nconn, clientWriteBufferSize),
br: bufio.NewReaderSize(conn, clientReadBufferSize),
bw: bufio.NewWriterSize(conn, clientWriteBufferSize),
udpRtpListeners: make(map[int]*clientConnUDPListener),
udpRtcpListeners: make(map[int]*clientConnUDPListener),
rtcpReceivers: make(map[int]*rtcpreceiver.RtcpReceiver),
@@ -122,7 +141,7 @@ func (c ClientConf) DialRead(address string) (*ClientConn, error) {
return nil, err
}
conn, err := c.Dial(u.Host)
conn, err := c.Dial(u.Scheme, u.Host)
if err != nil {
return nil, err
}
@@ -163,7 +182,7 @@ func (c ClientConf) DialPublish(address string, tracks Tracks) (*ClientConn, err
return nil, err
}
conn, err := c.Dial(u.Host)
conn, err := c.Dial(u.Scheme, u.Host)
if err != nil {
return nil, err
}