rtmp: split net.Conn from rtmp.Conn

This commit is contained in:
aler9
2022-07-09 17:25:33 +02:00
parent bf1f45df32
commit 67e8a01d56
12 changed files with 141 additions and 123 deletions

2
go.mod
View File

@@ -5,7 +5,7 @@ go 1.17
require ( require (
code.cloudfoundry.org/bytefmt v0.0.0-20211005130812-5bb3c17173e5 code.cloudfoundry.org/bytefmt v0.0.0-20211005130812-5bb3c17173e5
github.com/abema/go-mp4 v0.7.2 github.com/abema/go-mp4 v0.7.2
github.com/aler9/gortsplib v0.0.0-20220705212903-df7336b5e81c github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f
github.com/asticode/go-astits v1.10.1-0.20220319093903-4abe66a9b757 github.com/asticode/go-astits v1.10.1-0.20220319093903-4abe66a9b757
github.com/fsnotify/fsnotify v1.4.9 github.com/fsnotify/fsnotify v1.4.9
github.com/gin-gonic/gin v1.8.1 github.com/gin-gonic/gin v1.8.1

4
go.sum
View File

@@ -6,8 +6,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/aler9/gortsplib v0.0.0-20220705212903-df7336b5e81c h1:aTx9xxf5j00n9iSaEOKrsGc5HKIXogaGEOlmXWkJFww= github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f h1:EC+MOSv3e8ZEvtdHoL1++HahNoiVIkvu2Ygjrx6LyOg=
github.com/aler9/gortsplib v0.0.0-20220705212903-df7336b5e81c/go.mod h1:WI3nMhY2mM6nfoeW9uyk7TyG5Qr6YnYxmFoCply0sbo= github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f/go.mod h1:WI3nMhY2mM6nfoeW9uyk7TyG5Qr6YnYxmFoCply0sbo=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927 h1:95mXJ5fUCYpBRdSOnLAQAdJHHKxxxJrVCiaqDi965YQ= github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927 h1:95mXJ5fUCYpBRdSOnLAQAdJHHKxxxJrVCiaqDi965YQ=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc= github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc=
github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82 h1:9WgSzBLo3a9ToSVV7sRTBYZ1GGOZUpq4+5H3SN0UZq4= github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82 h1:9WgSzBLo3a9ToSVV7sRTBYZ1GGOZUpq4+5H3SN0UZq4=

View File

@@ -11,6 +11,7 @@ import (
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/url"
"github.com/asticode/go-astits" "github.com/asticode/go-astits"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -139,9 +140,18 @@ func TestHLSSource(t *testing.T) {
}, },
} }
err = c.StartReading("rtsp://localhost:8554/proxied") u, err := url.Parse("rtsp://localhost:8554/proxied")
require.NoError(t, err)
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err) require.NoError(t, err)
defer c.Close() defer c.Close()
tracks, baseURL, _, err := c.Describe(u)
require.NoError(t, err)
err = c.SetupAndPlay(tracks, baseURL)
require.NoError(t, err)
<-frameRecv <-frameRecv
} }

View File

@@ -66,6 +66,7 @@ type rtmpConn struct {
runOnConnectRestart bool runOnConnectRestart bool
wg *sync.WaitGroup wg *sync.WaitGroup
conn *rtmp.Conn conn *rtmp.Conn
nconn net.Conn
externalCmdPool *externalcmd.Pool externalCmdPool *externalcmd.Pool
pathManager rtmpConnPathManager pathManager rtmpConnPathManager
parent rtmpConnParent parent rtmpConnParent
@@ -107,6 +108,7 @@ func newRTMPConn(
runOnConnectRestart: runOnConnectRestart, runOnConnectRestart: runOnConnectRestart,
wg: wg, wg: wg,
conn: rtmp.NewServerConn(nconn), conn: rtmp.NewServerConn(nconn),
nconn: nconn,
externalCmdPool: externalCmdPool, externalCmdPool: externalCmdPool,
pathManager: pathManager, pathManager: pathManager,
parent: parent, parent: parent,
@@ -134,15 +136,15 @@ func (c *rtmpConn) ID() string {
// RemoteAddr returns the remote address of the Conn. // RemoteAddr returns the remote address of the Conn.
func (c *rtmpConn) RemoteAddr() net.Addr { func (c *rtmpConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr() return c.nconn.RemoteAddr()
} }
func (c *rtmpConn) log(level logger.Level, format string, args ...interface{}) { func (c *rtmpConn) log(level logger.Level, format string, args ...interface{}) {
c.parent.log(level, "[conn %v] "+format, append([]interface{}{c.conn.RemoteAddr()}, args...)...) c.parent.log(level, "[conn %v] "+format, append([]interface{}{c.nconn.RemoteAddr()}, args...)...)
} }
func (c *rtmpConn) ip() net.IP { func (c *rtmpConn) ip() net.IP {
return c.conn.RemoteAddr().(*net.TCPAddr).IP return c.nconn.RemoteAddr().(*net.TCPAddr).IP
} }
func (c *rtmpConn) safeState() rtmpConnState { func (c *rtmpConn) safeState() rtmpConnState {
@@ -204,11 +206,11 @@ func (c *rtmpConn) run() {
func (c *rtmpConn) runInner(ctx context.Context) error { func (c *rtmpConn) runInner(ctx context.Context) error {
go func() { go func() {
<-ctx.Done() <-ctx.Done()
c.conn.Close() c.nconn.Close()
}() }()
c.conn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout))) c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.ServerHandshake() err := c.conn.ServerHandshake()
if err != nil { if err != nil {
return err return err
@@ -291,7 +293,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
return fmt.Errorf("the stream doesn't contain an H264 track or an AAC track") return fmt.Errorf("the stream doesn't contain an H264 track or an AAC track")
} }
c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.WriteTracks(videoTrack, audioTrack) err := c.conn.WriteTracks(videoTrack, audioTrack)
if err != nil { if err != nil {
return err return err
@@ -325,7 +327,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
} }
// disable read deadline // disable read deadline
c.conn.SetReadDeadline(time.Time{}) c.nconn.SetReadDeadline(time.Time{})
var videoInitialPTS *time.Duration var videoInitialPTS *time.Duration
videoFirstIDRFound := false videoFirstIDRFound := false
@@ -435,7 +437,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
return err return err
} }
c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err = c.conn.WritePacket(av.Packet{ err = c.conn.WritePacket(av.Packet{
Type: av.H264, Type: av.H264,
Data: avcc, Data: avcc,
@@ -464,7 +466,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
} }
for i, au := range aus { for i, au := range aus {
c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.WritePacket(av.Packet{ err := c.conn.WritePacket(av.Packet{
Type: av.AAC, Type: av.AAC,
Data: au, Data: au,
@@ -479,7 +481,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
} }
func (c *rtmpConn) runPublish(ctx context.Context) error { func (c *rtmpConn) runPublish(ctx context.Context) error {
c.conn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout))) c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
videoTrack, audioTrack, err := c.conn.ReadTracks() videoTrack, audioTrack, err := c.conn.ReadTracks()
if err != nil { if err != nil {
return err return err
@@ -545,7 +547,7 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
c.stateMutex.Unlock() c.stateMutex.Unlock()
// disable write deadline // disable write deadline
c.conn.SetWriteDeadline(time.Time{}) c.nconn.SetWriteDeadline(time.Time{})
rres := c.path.onPublisherRecord(pathPublisherRecordReq{ rres := c.path.onPublisherRecord(pathPublisherRecordReq{
author: c, author: c,
@@ -556,7 +558,7 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
} }
for { for {
c.conn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout))) c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
pkt, err := c.conn.ReadPacket() pkt, err := c.conn.ReadPacket()
if err != nil { if err != nil {
return err return err

View File

@@ -1,8 +1,9 @@
package core package core
import ( import (
"context"
"io" "io"
"net"
"net/url"
"testing" "testing"
"time" "time"
@@ -134,10 +135,13 @@ func TestRTMPServerAuth(t *testing.T) {
defer a.close() defer a.close()
} }
conn, err := rtmp.DialContext(context.Background(), u, err := url.Parse("rtmp://127.0.0.1:1935/teststream?user=testreader&pass=testpass&param=value")
"rtmp://127.0.0.1/teststream?user=testreader&pass=testpass&param=value")
require.NoError(t, err) require.NoError(t, err)
defer conn.Close()
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewClientConn(nconn, u)
err = conn.ClientHandshake(true) err = conn.ClientHandshake(true)
require.NoError(t, err) require.NoError(t, err)
@@ -219,9 +223,13 @@ func TestRTMPServerAuthFail(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
conn, err := rtmp.DialContext(context.Background(), "rtmp://127.0.0.1/teststream?user=testuser&pass=testpass") u, err := url.Parse("rtmp://127.0.0.1:1935/teststream?user=testuser&pass=testpass")
require.NoError(t, err) require.NoError(t, err)
defer conn.Close()
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewClientConn(nconn, u)
err = conn.ClientHandshake(true) err = conn.ClientHandshake(true)
require.Equal(t, err, io.EOF) require.Equal(t, err, io.EOF)

View File

@@ -3,6 +3,8 @@ package core
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/url"
"sync" "sync"
"time" "time"
@@ -104,26 +106,40 @@ func (s *rtmpSource) runInner() bool {
runErr <- func() error { runErr <- func() error {
s.log(logger.Debug, "connecting") s.log(logger.Debug, "connecting")
ctx2, cancel2 := context.WithTimeout(innerCtx, time.Duration(s.readTimeout)) u, err := url.Parse(s.ur)
defer cancel2()
conn, err := rtmp.DialContext(ctx2, s.ur)
if err != nil { if err != nil {
return err return err
} }
// add default port
_, _, err = net.SplitHostPort(u.Host)
if err != nil {
u.Host = net.JoinHostPort(u.Host, "1935")
}
ctx2, cancel2 := context.WithTimeout(innerCtx, time.Duration(s.readTimeout))
defer cancel2()
var d net.Dialer
nconn, err := d.DialContext(ctx2, "tcp", u.Host)
if err != nil {
return err
}
conn := rtmp.NewClientConn(nconn, u)
readDone := make(chan error) readDone := make(chan error)
go func() { go func() {
readDone <- func() error { readDone <- func() error {
conn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
conn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout))) nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout)))
err = conn.ClientHandshake(true) err = conn.ClientHandshake(true)
if err != nil { if err != nil {
return err return err
} }
conn.SetWriteDeadline(time.Time{}) nconn.SetWriteDeadline(time.Time{})
conn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
videoTrack, audioTrack, err := conn.ReadTracks() videoTrack, audioTrack, err := conn.ReadTracks()
if err != nil { if err != nil {
return err return err
@@ -170,7 +186,7 @@ func (s *rtmpSource) runInner() bool {
}() }()
for { for {
conn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
pkt, err := conn.ReadPacket() pkt, err := conn.ReadPacket()
if err != nil { if err != nil {
return err return err
@@ -237,11 +253,11 @@ func (s *rtmpSource) runInner() bool {
select { select {
case err := <-readDone: case err := <-readDone:
conn.Close() nconn.Close()
return err return err
case <-innerCtx.Done(): case <-innerCtx.Done():
conn.Close() nconn.Close()
<-readDone <-readDone
return nil return nil
} }

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/url"
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -254,9 +255,18 @@ func TestRTSPServerAuth(t *testing.T) {
reader := gortsplib.Client{} reader := gortsplib.Client{}
err = reader.StartReading("rtsp://testreader:testpass@127.0.0.1:8554/teststream?param=value") u, err := url.Parse("rtsp://testreader:testpass@127.0.0.1:8554/teststream?param=value")
require.NoError(t, err)
err = reader.Start(u.Scheme, u.Host)
require.NoError(t, err) require.NoError(t, err)
defer reader.Close() defer reader.Close()
tracks, baseURL, _, err := reader.Describe(u)
require.NoError(t, err)
err = reader.SetupAndPlay(tracks, baseURL)
require.NoError(t, err)
}) })
} }
@@ -367,9 +377,14 @@ func TestRTSPServerAuthFail(t *testing.T) {
c := gortsplib.Client{} c := gortsplib.Client{}
err := c.StartReading( u, err := url.Parse("rtsp://" + ca.user + ":" + ca.pass + "@localhost:8554/test/stream")
"rtsp://" + ca.user + ":" + ca.pass + "@localhost:8554/test/stream", require.NoError(t, err)
)
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err)
defer c.Close()
_, _, _, err = c.Describe(u)
require.EqualError(t, err, "bad status code: 401 (Unauthorized)") require.EqualError(t, err, "bad status code: 401 (Unauthorized)")
}) })
} }
@@ -481,10 +496,19 @@ func TestRTSPServerPublisherOverride(t *testing.T) {
}, },
} }
err = c.StartReading("rtsp://localhost:8554/teststream") u, err := url.Parse("rtsp://localhost:8554/teststream")
require.NoError(t, err)
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err) require.NoError(t, err)
defer c.Close() defer c.Close()
tracks, baseURL, _, err := c.Describe(u)
require.NoError(t, err)
err = c.SetupAndPlay(tracks, baseURL)
require.NoError(t, err)
err = s1.WritePacketRTP(0, &rtp.Packet{ err = s1.WritePacketRTP(0, &rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 0x02, Version: 0x02,

View File

@@ -156,10 +156,19 @@ func TestRTSPSource(t *testing.T) {
}, },
} }
err = c.StartReading("rtsp://127.0.0.1:8554/proxied") u, err := url.Parse("rtsp://127.0.0.1:8554/proxied")
require.NoError(t, err)
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err) require.NoError(t, err)
defer c.Close() defer c.Close()
tracks, baseURL, _, err := c.Describe(u)
require.NoError(t, err)
err = c.SetupAndPlay(tracks, baseURL)
require.NoError(t, err)
<-received <-received
}) })
} }

View File

@@ -1,38 +0,0 @@
package rtmp
import (
"bufio"
"context"
"net"
"net/url"
"github.com/notedit/rtmp/format/rtmp"
)
// DialContext connects to a server in reading mode.
func DialContext(ctx context.Context, address string) (*Conn, error) {
// https://github.com/aler9/rtmp/blob/3be4a55359274dcd88762e72aa0a702e2d8ba2fd/format/rtmp/client.go#L74
u, err := url.Parse(address)
if err != nil {
return nil, err
}
host := rtmp.UrlGetHost(u)
var d net.Dialer
nconn, err := d.DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
rconn := rtmp.NewConn(&bufio.ReadWriter{
Reader: bufio.NewReaderSize(nconn, readBufferSize),
Writer: bufio.NewWriterSize(nconn, writeBufferSize),
})
rconn.URL = u
return &Conn{
rconn: rconn,
nconn: nconn,
}, nil
}

View File

@@ -1,6 +1,7 @@
package rtmp package rtmp
import ( import (
"bufio"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@@ -26,12 +27,33 @@ const (
// Conn is a RTMP connection. // Conn is a RTMP connection.
type Conn struct { type Conn struct {
rconn *rtmp.Conn rconn *rtmp.Conn
nconn net.Conn
} }
// Close closes the connection. // NewClientConn initializes a client-side connection.
func (c *Conn) Close() error { func NewClientConn(nconn net.Conn, u *url.URL) *Conn {
return c.nconn.Close() c := rtmp.NewConn(&bufio.ReadWriter{
Reader: bufio.NewReaderSize(nconn, readBufferSize),
Writer: bufio.NewWriterSize(nconn, writeBufferSize),
})
c.URL = u
return &Conn{
rconn: c,
}
}
// NewServerConn initializes a server-side connection.
func NewServerConn(nconn net.Conn) *Conn {
// https://github.com/aler9/rtmp/blob/master/format/rtmp/server.go#L46
c := rtmp.NewConn(&bufio.ReadWriter{
Reader: bufio.NewReaderSize(nconn, readBufferSize),
Writer: bufio.NewWriterSize(nconn, writeBufferSize),
})
c.IsServer = true
return &Conn{
rconn: c,
}
} }
// ClientHandshake performs the handshake of a client-side connection. // ClientHandshake performs the handshake of a client-side connection.
@@ -50,21 +72,6 @@ func (c *Conn) ServerHandshake() error {
return c.rconn.Prepare(rtmp.StageGotPublishOrPlayCommand, 0) return c.rconn.Prepare(rtmp.StageGotPublishOrPlayCommand, 0)
} }
// SetReadDeadline sets the read deadline.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.nconn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.nconn.SetWriteDeadline(t)
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.nconn.RemoteAddr()
}
// IsPublishing returns whether the connection is publishing. // IsPublishing returns whether the connection is publishing.
func (c *Conn) IsPublishing() bool { func (c *Conn) IsPublishing() bool {
return c.rconn.Publishing return c.rconn.Publishing

View File

@@ -1,7 +1,6 @@
package rtmp package rtmp
import ( import (
"context"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@@ -218,9 +217,13 @@ func TestClientHandshake(t *testing.T) {
close(done) close(done)
}() }()
conn, err := DialContext(context.Background(), "rtmp://127.0.0.1:9121/stream") u, err := url.Parse("rtmp://127.0.0.1:9121/stream")
require.NoError(t, err) require.NoError(t, err)
defer conn.Close()
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := NewClientConn(nconn, u)
err = conn.ClientHandshake(true) err = conn.ClientHandshake(true)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -1,23 +0,0 @@
package rtmp
import (
"bufio"
"net"
"github.com/notedit/rtmp/format/rtmp"
)
// NewServerConn initializes a server-side connection.
func NewServerConn(nconn net.Conn) *Conn {
// https://github.com/aler9/rtmp/blob/master/format/rtmp/server.go#L46
c := rtmp.NewConn(&bufio.ReadWriter{
Reader: bufio.NewReaderSize(nconn, readBufferSize),
Writer: bufio.NewWriterSize(nconn, writeBufferSize),
})
c.IsServer = true
return &Conn{
rconn: c,
nconn: nconn,
}
}