fix: rtsp tcp read timeout

This commit is contained in:
langhuihui
2025-10-14 20:37:05 +08:00
parent a87eeb8a30
commit fe5d31ad08
6 changed files with 179 additions and 174 deletions

View File

@@ -6,6 +6,7 @@ import (
"net/textproto"
"strings"
"time"
. "github.com/langhuihui/gomem"
)
@@ -18,13 +19,24 @@ type BufReader struct {
BufLen int
Mouth chan []byte
feedData func() error
timeout time.Duration
}
func (r *BufReader) SetTimeout(timeout time.Duration) {
r.timeout = timeout
}
func NewBufReaderWithBufLen(reader io.Reader, bufLen int) (r *BufReader) {
conn, _ := reader.(net.Conn)
r = &BufReader{
Allocator: NewScalableMemoryAllocator(bufLen),
BufLen: bufLen,
feedData: func() error {
if conn != nil && r.timeout > 0 {
if err := conn.SetReadDeadline(time.Now().Add(r.timeout)); err != nil {
return err
}
}
buf, err := r.Allocator.Read(reader, r.BufLen)
if err != nil {
return err
@@ -42,34 +54,6 @@ func NewBufReaderWithBufLen(reader io.Reader, bufLen int) (r *BufReader) {
return
}
// NewBufReaderWithTimeout 创建一个具有指定读取超时时间的 BufReader
func NewBufReaderWithTimeout(conn net.Conn, timeout time.Duration) (r *BufReader) {
r = &BufReader{
Allocator: NewScalableMemoryAllocator(defaultBufSize),
BufLen: defaultBufSize,
feedData: func() error {
// 设置读取超时
if conn != nil && timeout > 0 {
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return err
}
}
buf, err := r.Allocator.Read(conn, r.BufLen)
if err != nil {
return err
}
n := len(buf)
r.totalRead += n
r.buf.Buffers = append(r.buf.Buffers, buf)
r.buf.Size += n
r.buf.Length += n
return nil
},
}
r.buf.Memory = &Memory{}
return
}
func NewBufReaderBuffersChan(feedChan chan net.Buffers) (r *BufReader) {
r = &BufReader{
feedData: func() error {

View File

@@ -23,42 +23,30 @@ var rtmpPullSteps = []pkg.StepDef{
{Name: pkg.StepStreaming, Description: "Receiving media stream"},
}
func (c *Client) Start() (err error) {
var addr string
if c.direction == DIRECTION_PULL {
// Initialize progress tracking for pull operations
c.pullCtx.SetProgressStepsDefs(rtmpPullSteps)
type Client struct {
NetStream
chunkSize int
u *url.URL
}
addr = c.pullCtx.Connection.RemoteURL
err = c.pullCtx.Publish()
if err != nil {
c.pullCtx.Fail(err.Error())
return
}
func (c *Client) GetPullJob() *m7s.PullJob {
return nil
}
c.pullCtx.GoToStepConst(pkg.StepURLParsing)
} else {
addr = c.pushCtx.Connection.RemoteURL
}
func (c *Client) GetPushJob() *m7s.PushJob {
return nil
}
func (c *Client) commonStart(addr string) (err error) {
c.u, err = url.Parse(addr)
if err != nil {
if c.direction == DIRECTION_PULL {
c.pullCtx.Fail(err.Error())
}
return
}
ps := strings.Split(c.u.Path, "/")
if len(ps) < 2 {
if c.direction == DIRECTION_PULL {
c.pullCtx.Fail("illegal rtmp url")
}
return errors.New("illegal rtmp url")
}
if c.direction == DIRECTION_PULL {
c.pullCtx.GoToStepConst(pkg.StepConnection)
}
isRtmps := c.u.Scheme == "rtmps"
if strings.Count(c.u.Host, ":") == 0 {
if isRtmps {
@@ -78,72 +66,19 @@ func (c *Client) Start() (err error) {
conn, err = net.Dial("tcp", c.u.Host)
}
if err != nil {
if c.direction == DIRECTION_PULL {
c.pullCtx.Fail(err.Error())
}
return err
}
if c.direction == DIRECTION_PULL {
c.pullCtx.GoToStepConst(pkg.StepHandshake)
}
c.Init(conn)
c.SetDescription("local", conn.LocalAddr().String())
c.Info("connect")
c.WriteChunkSize = c.chunkSize
c.AppName = strings.Join(ps[1:len(ps)-1], "/")
if c.direction == DIRECTION_PULL {
c.pullCtx.GoToStepConst(pkg.StepStreaming)
}
return err
}
const (
DIRECTION_PULL = "pull"
DIRECTION_PUSH = "push"
)
type Client struct {
NetStream
chunkSize int
pullCtx m7s.PullJob
pushCtx m7s.PushJob
direction string
u *url.URL
}
func (c *Client) GetPullJob() *m7s.PullJob {
return &c.pullCtx
}
func (c *Client) GetPushJob() *m7s.PushJob {
return &c.pushCtx
}
func NewPuller(_ config.Pull) m7s.IPuller {
ret := &Client{
direction: DIRECTION_PULL,
chunkSize: 4096,
}
ret.NetConnection = &NetConnection{}
ret.SetDescription(task.OwnerTypeKey, "RTMPPuller")
return ret
}
func NewPusher() m7s.IPusher {
ret := &Client{
direction: DIRECTION_PUSH,
chunkSize: 4096,
}
ret.NetConnection = &NetConnection{}
ret.SetDescription(task.OwnerTypeKey, "RTMPPusher")
return ret
}
func (c *Client) Run() (err error) {
func (c *Client) commonRun(handler func(commander Commander) error) (err error) {
if err = c.ClientHandshake(); err != nil {
return
}
@@ -171,6 +106,7 @@ func (c *Client) Run() (err error) {
return err
}
cmd := commander.GetCommand()
c.Debug(cmd.CommandName)
switch cmd.CommandName {
case Response_Result, Response_OnStatus:
switch response := commander.(type) {
@@ -185,66 +121,149 @@ func (c *Client) Run() (err error) {
}
case *ResponseCreateStreamMessage:
c.StreamID = response.StreamId
if c.direction == DIRECTION_PULL {
m := &PlayMessage{}
m.StreamId = response.StreamId
m.TransactionId = 4
m.CommandMessage.CommandName = "play"
URL, _ := url.Parse(c.pullCtx.Connection.RemoteURL)
ps := strings.Split(URL.Path, "/")
args := URL.Query()
m.StreamName = ps[len(ps)-1]
if len(args) > 0 {
m.StreamName += "?" + args.Encode()
if handler != nil {
if err = handler(commander); err != nil {
return err
}
if c.pullCtx.Publisher != nil {
c.Writers[response.StreamId] = &struct {
m7s.PublishWriter[*AudioFrame, *VideoFrame]
*m7s.Publisher
}{Publisher: c.pullCtx.Publisher}
}
err = c.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
// if response, ok := msg.MsgData.(*ResponsePlayMessage); ok {
// if response.Object["code"] == "NetStream.Play.Start" {
// } else if response.Object["level"] == Level_Error {
// return errors.New(response.Object["code"].(string))
// }
// } else {
// return errors.New("pull faild")
// }
} else {
err = c.pushCtx.Subscribe()
if err != nil {
return
}
URL, _ := url.Parse(c.pushCtx.Connection.RemoteURL)
_, streamPath, _ := strings.Cut(URL.Path, "/")
_, streamPath, _ = strings.Cut(streamPath, "/")
args := URL.Query()
if len(args) > 0 {
streamPath += "?" + args.Encode()
}
err = c.SendMessage(RTMP_MSG_AMF0_COMMAND, &PublishMessage{
CURDStreamMessage{
CommandMessage{
"publish",
1,
},
response.StreamId,
},
streamPath,
"live",
})
}
case *ResponsePublishMessage:
if response.Infomation["code"] == NetStream_Publish_Start {
c.Subscribe(c.pushCtx.Subscriber)
} else {
return errors.New(response.Infomation["code"].(string))
}
}
}
}
return
}
type Puller struct {
Client
pullCtx m7s.PullJob
}
func (p *Puller) GetPullJob() *m7s.PullJob {
return &p.pullCtx
}
func (p *Puller) Start() (err error) {
// Initialize progress tracking for pull operations
p.pullCtx.SetProgressStepsDefs(rtmpPullSteps)
addr := p.pullCtx.Connection.RemoteURL
err = p.pullCtx.Publish()
if err != nil {
p.pullCtx.Fail(err.Error())
return
}
p.pullCtx.GoToStepConst(pkg.StepURLParsing)
err = p.commonStart(addr)
if err != nil {
p.pullCtx.Fail(err.Error())
return
}
p.pullCtx.GoToStepConst(pkg.StepConnection)
p.pullCtx.GoToStepConst(pkg.StepHandshake)
p.pullCtx.GoToStepConst(pkg.StepStreaming)
return
}
func (p *Puller) Run() (err error) {
return p.commonRun(func(commander Commander) error {
switch response := commander.(type) {
case *ResponseCreateStreamMessage:
p.StreamID = response.StreamId
m := &PlayMessage{}
m.StreamId = response.StreamId
m.TransactionId = 4
m.CommandMessage.CommandName = "play"
URL, _ := url.Parse(p.pullCtx.Connection.RemoteURL)
ps := strings.Split(URL.Path, "/")
args := URL.Query()
m.StreamName = ps[len(ps)-1]
if len(args) > 0 {
m.StreamName += "?" + args.Encode()
}
if p.pullCtx.Publisher != nil {
p.Writers[response.StreamId] = &struct {
m7s.PublishWriter[*AudioFrame, *VideoFrame]
*m7s.Publisher
}{Publisher: p.pullCtx.Publisher}
}
return p.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
}
return nil
})
}
type Pusher struct {
Client
pushCtx m7s.PushJob
}
func (p *Pusher) GetPushJob() *m7s.PushJob {
return &p.pushCtx
}
func (p *Pusher) Start() (err error) {
return p.commonStart(p.pushCtx.Connection.RemoteURL)
}
func (p *Pusher) Run() (err error) {
return p.commonRun(func(commander Commander) error {
switch response := commander.(type) {
case *ResponseCreateStreamMessage:
p.StreamID = response.StreamId
err = p.pushCtx.Subscribe()
if err != nil {
return err
}
URL, _ := url.Parse(p.pushCtx.Connection.RemoteURL)
_, streamPath, _ := strings.Cut(URL.Path, "/")
_, streamPath, _ = strings.Cut(streamPath, "/")
args := URL.Query()
if len(args) > 0 {
streamPath += "?" + args.Encode()
}
return p.SendMessage(RTMP_MSG_AMF0_COMMAND, &PublishMessage{
CURDStreamMessage{
CommandMessage{
"publish",
1,
},
response.StreamId,
},
streamPath,
"live",
})
case *ResponsePublishMessage:
if response.Infomation["code"] == NetStream_Publish_Start {
p.Subscribe(p.pushCtx.Subscriber)
} else {
return errors.New(response.Infomation["code"].(string))
}
}
return nil
})
}
func NewPuller(_ config.Pull) m7s.IPuller {
ret := &Puller{
Client: Client{
chunkSize: 4096,
},
}
ret.NetConnection = &NetConnection{}
ret.SetDescription(task.OwnerTypeKey, "RTMPPuller")
return ret
}
func NewPusher() m7s.IPusher {
ret := &Pusher{
Client: Client{
chunkSize: 4096,
},
}
ret.NetConnection = &NetConnection{}
ret.SetDescription(task.OwnerTypeKey, "RTMPPusher")
return ret
}

View File

@@ -91,7 +91,8 @@ func NewNetConnection(conn net.Conn) (ret *NetConnection) {
func (nc *NetConnection) Init(conn net.Conn) {
nc.Conn = conn
nc.BufReader = util.NewBufReaderWithTimeout(conn, 30*time.Second)
nc.BufReader = util.NewBufReader(conn)
nc.BufReader.SetTimeout(time.Second * 30)
nc.bandwidth = RTMP_MAX_CHUNK_SIZE << 3
nc.ReadChunkSize = RTMP_DEFAULT_CHUNK_SIZE
nc.WriteChunkSize = RTMP_DEFAULT_CHUNK_SIZE

View File

@@ -22,12 +22,14 @@ import (
const Timeout = time.Second * 10
func NewNetConnection(conn net.Conn) *NetConnection {
return &NetConnection{
c := &NetConnection{
Conn: conn,
BufReader: util.NewBufReaderWithTimeout(conn, Timeout),
BufReader: util.NewBufReader(conn),
MemoryAllocator: gomem.NewScalableMemoryAllocator(1 << 12),
UserAgent: "monibuca" + m7s.Version,
}
c.BufReader.SetTimeout(Timeout)
return c
}
type NetConnection struct {
@@ -143,6 +145,7 @@ func (c *NetConnection) Connect(remoteURL string) (err error) {
}
c.Conn = conn
c.BufReader = util.NewBufReader(conn)
c.BufReader.SetTimeout(Timeout)
c.UserAgent = "monibuca" + m7s.Version
c.Session = ""
c.Auth = util.NewAuth(rtspURL.User)
@@ -255,9 +258,7 @@ func (c *NetConnection) Receive(sendMode bool, onReceive func(byte, []byte) erro
return
}
ts := time.Now()
if err = c.Conn.SetReadDeadline(ts.Add(util.Conditional(sendMode, time.Second*60, time.Second*15))); err != nil {
return
}
var magic []byte
// we can read:
// 1. RTP interleaved: `$` + 1B channel number + 2B size

View File

@@ -378,7 +378,7 @@ func (s *Sender) Send() (err error) {
}
}()
}
s.BufReader.SetTimeout(60 * time.Second)
// 接收处理(处理客户端发来的消息)
return s.NetConnection.Receive(true, nil, nil)
}

View File

@@ -152,8 +152,8 @@ func (task *RTSPServer) Go() (err error) {
Request: req,
}
// TCP传输模式
const tcpTransport = "RTP/AVP/TCP;unicast;interleaved="
// TCP传输模式 - 适配包含mode字段的格式
const tcpTransport = "RTP/AVP/TCP"
// UDP传输模式前缀
const udpTransport = "RTP/AVP"
@@ -164,13 +164,13 @@ func (task *RTSPServer) Go() (err error) {
if sendMode {
if i := reqTrackID(req); i >= 0 {
tr = fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", i*2, i*2+1)
tr = fmt.Sprintf("RTP/AVP/TCP;unicast;mode=record;interleaved=%d-%d", i*2, i*2+1)
res.Header.Set("Transport", tr)
} else {
res.Status = "400 Bad Request"
}
} else {
res.Header.Set("Transport", tr[:len(tcpTransport)+3])
res.Header.Set("Transport", tr)
}
} else if strings.HasPrefix(tr, udpTransport) && strings.Contains(tr, "unicast") && strings.Contains(tr, "client_port=") {
task.Debug("into udp play")