diff --git a/pkg/util/buf_reader.go b/pkg/util/buf_reader.go index 34cb4d6..9de767b 100644 --- a/pkg/util/buf_reader.go +++ b/pkg/util/buf_reader.go @@ -6,6 +6,7 @@ import ( "net/textproto" "strings" "time" + . "github.com/langhuihui/gomem" ) @@ -18,13 +19,24 @@ type BufReader struct { BufLen int Mouth chan []byte feedData func() error + timeout time.Duration +} + +func (r *BufReader) SetTimeout(timeout time.Duration) { + r.timeout = timeout } func NewBufReaderWithBufLen(reader io.Reader, bufLen int) (r *BufReader) { + conn, _ := reader.(net.Conn) r = &BufReader{ Allocator: NewScalableMemoryAllocator(bufLen), BufLen: bufLen, feedData: func() error { + if conn != nil && r.timeout > 0 { + if err := conn.SetReadDeadline(time.Now().Add(r.timeout)); err != nil { + return err + } + } buf, err := r.Allocator.Read(reader, r.BufLen) if err != nil { return err @@ -42,34 +54,6 @@ func NewBufReaderWithBufLen(reader io.Reader, bufLen int) (r *BufReader) { return } -// NewBufReaderWithTimeout 创建一个具有指定读取超时时间的 BufReader -func NewBufReaderWithTimeout(conn net.Conn, timeout time.Duration) (r *BufReader) { - r = &BufReader{ - Allocator: NewScalableMemoryAllocator(defaultBufSize), - BufLen: defaultBufSize, - feedData: func() error { - // 设置读取超时 - if conn != nil && timeout > 0 { - if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { - return err - } - } - buf, err := r.Allocator.Read(conn, r.BufLen) - if err != nil { - return err - } - n := len(buf) - r.totalRead += n - r.buf.Buffers = append(r.buf.Buffers, buf) - r.buf.Size += n - r.buf.Length += n - return nil - }, - } - r.buf.Memory = &Memory{} - return -} - func NewBufReaderBuffersChan(feedChan chan net.Buffers) (r *BufReader) { r = &BufReader{ feedData: func() error { diff --git a/plugin/rtmp/pkg/client.go b/plugin/rtmp/pkg/client.go index cb1428a..bc5fd3f 100644 --- a/plugin/rtmp/pkg/client.go +++ b/plugin/rtmp/pkg/client.go @@ -23,42 +23,30 @@ var rtmpPullSteps = []pkg.StepDef{ {Name: pkg.StepStreaming, Description: "Receiving media stream"}, } -func (c *Client) Start() (err error) { - var addr string - if c.direction == DIRECTION_PULL { - // Initialize progress tracking for pull operations - c.pullCtx.SetProgressStepsDefs(rtmpPullSteps) +type Client struct { + NetStream + chunkSize int + u *url.URL +} - addr = c.pullCtx.Connection.RemoteURL - err = c.pullCtx.Publish() - if err != nil { - c.pullCtx.Fail(err.Error()) - return - } +func (c *Client) GetPullJob() *m7s.PullJob { + return nil +} - c.pullCtx.GoToStepConst(pkg.StepURLParsing) - } else { - addr = c.pushCtx.Connection.RemoteURL - } +func (c *Client) GetPushJob() *m7s.PushJob { + return nil +} + +func (c *Client) commonStart(addr string) (err error) { c.u, err = url.Parse(addr) if err != nil { - if c.direction == DIRECTION_PULL { - c.pullCtx.Fail(err.Error()) - } return } ps := strings.Split(c.u.Path, "/") if len(ps) < 2 { - if c.direction == DIRECTION_PULL { - c.pullCtx.Fail("illegal rtmp url") - } return errors.New("illegal rtmp url") } - if c.direction == DIRECTION_PULL { - c.pullCtx.GoToStepConst(pkg.StepConnection) - } - isRtmps := c.u.Scheme == "rtmps" if strings.Count(c.u.Host, ":") == 0 { if isRtmps { @@ -78,72 +66,19 @@ func (c *Client) Start() (err error) { conn, err = net.Dial("tcp", c.u.Host) } if err != nil { - if c.direction == DIRECTION_PULL { - c.pullCtx.Fail(err.Error()) - } return err } - if c.direction == DIRECTION_PULL { - c.pullCtx.GoToStepConst(pkg.StepHandshake) - } - c.Init(conn) c.SetDescription("local", conn.LocalAddr().String()) c.Info("connect") c.WriteChunkSize = c.chunkSize c.AppName = strings.Join(ps[1:len(ps)-1], "/") - if c.direction == DIRECTION_PULL { - c.pullCtx.GoToStepConst(pkg.StepStreaming) - } - return err } -const ( - DIRECTION_PULL = "pull" - DIRECTION_PUSH = "push" -) - -type Client struct { - NetStream - chunkSize int - pullCtx m7s.PullJob - pushCtx m7s.PushJob - direction string - u *url.URL -} - -func (c *Client) GetPullJob() *m7s.PullJob { - return &c.pullCtx -} - -func (c *Client) GetPushJob() *m7s.PushJob { - return &c.pushCtx -} - -func NewPuller(_ config.Pull) m7s.IPuller { - ret := &Client{ - direction: DIRECTION_PULL, - chunkSize: 4096, - } - ret.NetConnection = &NetConnection{} - ret.SetDescription(task.OwnerTypeKey, "RTMPPuller") - return ret -} - -func NewPusher() m7s.IPusher { - ret := &Client{ - direction: DIRECTION_PUSH, - chunkSize: 4096, - } - ret.NetConnection = &NetConnection{} - ret.SetDescription(task.OwnerTypeKey, "RTMPPusher") - return ret -} - -func (c *Client) Run() (err error) { +func (c *Client) commonRun(handler func(commander Commander) error) (err error) { if err = c.ClientHandshake(); err != nil { return } @@ -171,6 +106,7 @@ func (c *Client) Run() (err error) { return err } cmd := commander.GetCommand() + c.Debug(cmd.CommandName) switch cmd.CommandName { case Response_Result, Response_OnStatus: switch response := commander.(type) { @@ -185,66 +121,149 @@ func (c *Client) Run() (err error) { } case *ResponseCreateStreamMessage: c.StreamID = response.StreamId - if c.direction == DIRECTION_PULL { - m := &PlayMessage{} - m.StreamId = response.StreamId - m.TransactionId = 4 - m.CommandMessage.CommandName = "play" - URL, _ := url.Parse(c.pullCtx.Connection.RemoteURL) - ps := strings.Split(URL.Path, "/") - args := URL.Query() - m.StreamName = ps[len(ps)-1] - if len(args) > 0 { - m.StreamName += "?" + args.Encode() + if handler != nil { + if err = handler(commander); err != nil { + return err } - if c.pullCtx.Publisher != nil { - c.Writers[response.StreamId] = &struct { - m7s.PublishWriter[*AudioFrame, *VideoFrame] - *m7s.Publisher - }{Publisher: c.pullCtx.Publisher} - } - err = c.SendMessage(RTMP_MSG_AMF0_COMMAND, m) - // if response, ok := msg.MsgData.(*ResponsePlayMessage); ok { - // if response.Object["code"] == "NetStream.Play.Start" { - - // } else if response.Object["level"] == Level_Error { - // return errors.New(response.Object["code"].(string)) - // } - // } else { - // return errors.New("pull faild") - // } - } else { - err = c.pushCtx.Subscribe() - if err != nil { - return - } - URL, _ := url.Parse(c.pushCtx.Connection.RemoteURL) - _, streamPath, _ := strings.Cut(URL.Path, "/") - _, streamPath, _ = strings.Cut(streamPath, "/") - args := URL.Query() - if len(args) > 0 { - streamPath += "?" + args.Encode() - } - err = c.SendMessage(RTMP_MSG_AMF0_COMMAND, &PublishMessage{ - CURDStreamMessage{ - CommandMessage{ - "publish", - 1, - }, - response.StreamId, - }, - streamPath, - "live", - }) - } - case *ResponsePublishMessage: - if response.Infomation["code"] == NetStream_Publish_Start { - c.Subscribe(c.pushCtx.Subscriber) - } else { - return errors.New(response.Infomation["code"].(string)) } } } } return } + +type Puller struct { + Client + pullCtx m7s.PullJob +} + +func (p *Puller) GetPullJob() *m7s.PullJob { + return &p.pullCtx +} + +func (p *Puller) Start() (err error) { + // Initialize progress tracking for pull operations + p.pullCtx.SetProgressStepsDefs(rtmpPullSteps) + + addr := p.pullCtx.Connection.RemoteURL + err = p.pullCtx.Publish() + if err != nil { + p.pullCtx.Fail(err.Error()) + return + } + + p.pullCtx.GoToStepConst(pkg.StepURLParsing) + + err = p.commonStart(addr) + if err != nil { + p.pullCtx.Fail(err.Error()) + return + } + + p.pullCtx.GoToStepConst(pkg.StepConnection) + p.pullCtx.GoToStepConst(pkg.StepHandshake) + p.pullCtx.GoToStepConst(pkg.StepStreaming) + + return +} + +func (p *Puller) Run() (err error) { + return p.commonRun(func(commander Commander) error { + switch response := commander.(type) { + case *ResponseCreateStreamMessage: + p.StreamID = response.StreamId + m := &PlayMessage{} + m.StreamId = response.StreamId + m.TransactionId = 4 + m.CommandMessage.CommandName = "play" + URL, _ := url.Parse(p.pullCtx.Connection.RemoteURL) + ps := strings.Split(URL.Path, "/") + args := URL.Query() + m.StreamName = ps[len(ps)-1] + if len(args) > 0 { + m.StreamName += "?" + args.Encode() + } + if p.pullCtx.Publisher != nil { + p.Writers[response.StreamId] = &struct { + m7s.PublishWriter[*AudioFrame, *VideoFrame] + *m7s.Publisher + }{Publisher: p.pullCtx.Publisher} + } + return p.SendMessage(RTMP_MSG_AMF0_COMMAND, m) + } + return nil + }) +} + +type Pusher struct { + Client + pushCtx m7s.PushJob +} + +func (p *Pusher) GetPushJob() *m7s.PushJob { + return &p.pushCtx +} + +func (p *Pusher) Start() (err error) { + return p.commonStart(p.pushCtx.Connection.RemoteURL) +} + +func (p *Pusher) Run() (err error) { + return p.commonRun(func(commander Commander) error { + switch response := commander.(type) { + case *ResponseCreateStreamMessage: + p.StreamID = response.StreamId + err = p.pushCtx.Subscribe() + if err != nil { + return err + } + URL, _ := url.Parse(p.pushCtx.Connection.RemoteURL) + _, streamPath, _ := strings.Cut(URL.Path, "/") + _, streamPath, _ = strings.Cut(streamPath, "/") + args := URL.Query() + if len(args) > 0 { + streamPath += "?" + args.Encode() + } + return p.SendMessage(RTMP_MSG_AMF0_COMMAND, &PublishMessage{ + CURDStreamMessage{ + CommandMessage{ + "publish", + 1, + }, + response.StreamId, + }, + streamPath, + "live", + }) + case *ResponsePublishMessage: + if response.Infomation["code"] == NetStream_Publish_Start { + p.Subscribe(p.pushCtx.Subscriber) + } else { + return errors.New(response.Infomation["code"].(string)) + } + } + return nil + }) +} + +func NewPuller(_ config.Pull) m7s.IPuller { + ret := &Puller{ + Client: Client{ + chunkSize: 4096, + }, + } + ret.NetConnection = &NetConnection{} + ret.SetDescription(task.OwnerTypeKey, "RTMPPuller") + return ret +} + +func NewPusher() m7s.IPusher { + ret := &Pusher{ + Client: Client{ + chunkSize: 4096, + }, + } + ret.NetConnection = &NetConnection{} + ret.SetDescription(task.OwnerTypeKey, "RTMPPusher") + return ret +} diff --git a/plugin/rtmp/pkg/net-connection.go b/plugin/rtmp/pkg/net-connection.go index d8df9b7..8d6b9c8 100644 --- a/plugin/rtmp/pkg/net-connection.go +++ b/plugin/rtmp/pkg/net-connection.go @@ -91,7 +91,8 @@ func NewNetConnection(conn net.Conn) (ret *NetConnection) { func (nc *NetConnection) Init(conn net.Conn) { nc.Conn = conn - nc.BufReader = util.NewBufReaderWithTimeout(conn, 30*time.Second) + nc.BufReader = util.NewBufReader(conn) + nc.BufReader.SetTimeout(time.Second * 30) nc.bandwidth = RTMP_MAX_CHUNK_SIZE << 3 nc.ReadChunkSize = RTMP_DEFAULT_CHUNK_SIZE nc.WriteChunkSize = RTMP_DEFAULT_CHUNK_SIZE diff --git a/plugin/rtsp/pkg/connection.go b/plugin/rtsp/pkg/connection.go index 13f1275..3d87244 100644 --- a/plugin/rtsp/pkg/connection.go +++ b/plugin/rtsp/pkg/connection.go @@ -22,12 +22,14 @@ import ( const Timeout = time.Second * 10 func NewNetConnection(conn net.Conn) *NetConnection { - return &NetConnection{ + c := &NetConnection{ Conn: conn, - BufReader: util.NewBufReaderWithTimeout(conn, Timeout), + BufReader: util.NewBufReader(conn), MemoryAllocator: gomem.NewScalableMemoryAllocator(1 << 12), UserAgent: "monibuca" + m7s.Version, } + c.BufReader.SetTimeout(Timeout) + return c } type NetConnection struct { @@ -143,6 +145,7 @@ func (c *NetConnection) Connect(remoteURL string) (err error) { } c.Conn = conn c.BufReader = util.NewBufReader(conn) + c.BufReader.SetTimeout(Timeout) c.UserAgent = "monibuca" + m7s.Version c.Session = "" c.Auth = util.NewAuth(rtspURL.User) @@ -255,9 +258,7 @@ func (c *NetConnection) Receive(sendMode bool, onReceive func(byte, []byte) erro return } ts := time.Now() - if err = c.Conn.SetReadDeadline(ts.Add(util.Conditional(sendMode, time.Second*60, time.Second*15))); err != nil { - return - } + var magic []byte // we can read: // 1. RTP interleaved: `$` + 1B channel number + 2B size diff --git a/plugin/rtsp/pkg/transceiver.go b/plugin/rtsp/pkg/transceiver.go index b9eb244..e6a0c48 100644 --- a/plugin/rtsp/pkg/transceiver.go +++ b/plugin/rtsp/pkg/transceiver.go @@ -378,7 +378,7 @@ func (s *Sender) Send() (err error) { } }() } - + s.BufReader.SetTimeout(60 * time.Second) // 接收处理(处理客户端发来的消息) return s.NetConnection.Receive(true, nil, nil) } diff --git a/plugin/rtsp/server.go b/plugin/rtsp/server.go index 3922e6b..e9b560c 100644 --- a/plugin/rtsp/server.go +++ b/plugin/rtsp/server.go @@ -152,8 +152,8 @@ func (task *RTSPServer) Go() (err error) { Request: req, } - // TCP传输模式 - const tcpTransport = "RTP/AVP/TCP;unicast;interleaved=" + // TCP传输模式 - 适配包含mode字段的格式 + const tcpTransport = "RTP/AVP/TCP" // UDP传输模式前缀 const udpTransport = "RTP/AVP" @@ -164,13 +164,13 @@ func (task *RTSPServer) Go() (err error) { if sendMode { if i := reqTrackID(req); i >= 0 { - tr = fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", i*2, i*2+1) + tr = fmt.Sprintf("RTP/AVP/TCP;unicast;mode=record;interleaved=%d-%d", i*2, i*2+1) res.Header.Set("Transport", tr) } else { res.Status = "400 Bad Request" } } else { - res.Header.Set("Transport", tr[:len(tcpTransport)+3]) + res.Header.Set("Transport", tr) } } else if strings.HasPrefix(tr, udpTransport) && strings.Contains(tr, "unicast") && strings.Contains(tr, "client_port=") { task.Debug("into udp play")