diff --git a/pkg/rtmp/producer.go b/pkg/rtmp/producer.go index 380fff17..c7e42c01 100644 --- a/pkg/rtmp/producer.go +++ b/pkg/rtmp/producer.go @@ -1,14 +1,11 @@ package rtmp import ( - "crypto/tls" - "errors" - "net" "net/url" - "strings" "github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/flv" + "github.com/AlexxIT/go2rtc/pkg/tcp" ) func Dial(rawURL string) (core.Producer, error) { @@ -17,38 +14,11 @@ func Dial(rawURL string) (core.Producer, error) { return nil, err } - var hostname string // without port - if i := strings.IndexByte(u.Host, ':'); i > 0 { - hostname = u.Host[:i] - } else { - hostname = u.Host - u.Host += ":1935" - } - - conn, err := net.DialTimeout("tcp", u.Host, core.ConnDialTimeout) + conn, err := tcp.Dial(u, "1935") if err != nil { return nil, err } - if u.Scheme != "rtmp" { - var conf *tls.Config - - switch { - case u.Scheme == "rtmpx" || net.ParseIP(hostname) != nil: - conf = &tls.Config{InsecureSkipVerify: true} - case u.Scheme == "rtmps": - conf = &tls.Config{ServerName: hostname} - default: - return nil, errors.New("unsupported scheme: " + u.Scheme) - } - - tlsConn := tls.Client(conn, conf) - if err = tlsConn.Handshake(); err != nil { - return nil, err - } - conn = tlsConn - } - rd, err := NewReader(u, conn) if err != nil { return nil, err diff --git a/pkg/rtsp/client.go b/pkg/rtsp/client.go index 3a076444..82518ede 100644 --- a/pkg/rtsp/client.go +++ b/pkg/rtsp/client.go @@ -24,19 +24,18 @@ func NewClient(uri string) *Conn { } func (c *Conn) Dial() (err error) { - var conn net.Conn - - if c.Transport == "" { - conn, err = Dial(c.uri) - } else { - conn, err = websocket.Dial(c.Transport) - } - - if err != nil { + if c.URL, err = url.Parse(c.uri); err != nil { return } - if c.URL, err = url.Parse(c.uri); err != nil { + var conn net.Conn + + if c.Transport == "" { + conn, err = tcp.Dial(c.URL, "554") + } else { + conn, err = websocket.Dial(c.Transport) + } + if err != nil { return } diff --git a/pkg/rtsp/dial.go b/pkg/rtsp/dial.go deleted file mode 100644 index 58d5dd65..00000000 --- a/pkg/rtsp/dial.go +++ /dev/null @@ -1,44 +0,0 @@ -package rtsp - -import ( - "crypto/tls" - "errors" - "net" - "net/url" - "strings" - "time" -) - -func Dial(uri string) (net.Conn, error) { - u, err := url.Parse(uri) - if err != nil { - return nil, err - } - - switch u.Scheme { - case "rtsp": - return dialTCP(u.Host, nil) - case "rtsps": - tlsConf := &tls.Config{ServerName: u.Hostname()} - return dialTCP(u.Host, tlsConf) - case "rtspx": - tlsConf := &tls.Config{InsecureSkipVerify: true} - return dialTCP(u.Host, tlsConf) - } - - return nil, errors.New("unsupported scheme: " + u.Scheme) -} - -func dialTCP(address string, tlsConf *tls.Config) (net.Conn, error) { - if strings.IndexByte(address, ':') < 0 { - address += ":554" - } - - conn, err := net.DialTimeout("tcp", address, time.Second*5) - if tlsConf == nil || err != nil { - return conn, err - } - - tlsConn := tls.Client(conn, tlsConf) - return tlsConn, tlsConn.Handshake() -} diff --git a/pkg/tcp/dial.go b/pkg/tcp/dial.go new file mode 100644 index 00000000..bb604234 --- /dev/null +++ b/pkg/tcp/dial.go @@ -0,0 +1,56 @@ +package tcp + +import ( + "crypto/tls" + "errors" + "net" + "net/url" + "strings" + + "github.com/AlexxIT/go2rtc/pkg/core" +) + +// Dial - for RTSP(S|X) and RTMP(S|X) +func Dial(u *url.URL, port string) (net.Conn, error) { + var hostname string // without port + if i := strings.IndexByte(u.Host, ':'); i > 0 { + hostname = u.Host[:i] + } else { + hostname = u.Host + u.Host += ":" + port + } + + var secure *tls.Config + + switch u.Scheme { + case "rtsp", "rtmp": + case "rtsps", "rtspx", "rtmps", "rtmpx": + if u.Scheme[4] == 'x' || net.ParseIP(hostname) != nil { + secure = &tls.Config{InsecureSkipVerify: true} + } else { + secure = &tls.Config{ServerName: hostname} + } + default: + return nil, errors.New("unsupported scheme: " + u.Scheme) + } + + conn, err := net.DialTimeout("tcp", u.Host, core.ConnDialTimeout) + if err != nil { + return nil, err + } + + if secure == nil { + return conn, nil + } + + tlsConn := tls.Client(conn, secure) + if err = tlsConn.Handshake(); err != nil { + return nil, err + } + + if u.Scheme[4] == 'x' { + u.Scheme = u.Scheme[:4] + "s" + } + + return tlsConn, nil +}