diff --git a/example/default/main.go b/example/default/main.go index d13174f..e161222 100644 --- a/example/default/main.go +++ b/example/default/main.go @@ -18,27 +18,6 @@ import ( _ "m7s.live/m7s/v5/plugin/webrtc" ) -// func init() { -// //全局推流鉴权 -// m7s.DefaultServer.OnAuthPubs["RTMP"] = func(p *util.Promise[*m7s.Publisher]) { -// var pub = p.Value -// if strings.Contains(pub.StreamPath, "20A222800207-2") { -// p.Fulfill(nil) -// } else { -// p.Fulfill(errors.New("auth failed")) -// } -// } -// //全局播放鉴权 -// m7s.DefaultServer.OnAuthSubs["RTMP"] = func(p *util.Promise[*m7s.Subscriber]) { -// var sub = p.Value -// if strings.Contains(sub.StreamPath, "20A222800207-22") { -// p.Fulfill(nil) -// } else { -// p.Fulfill(errors.New("auth failed")) -// } -// } -// } - func main() { conf := flag.String("c", "config.yaml", "config file") flag.Parse() diff --git a/example/default/recordflv.yaml b/example/default/recordflv.yaml new file mode 100644 index 0000000..a9eb186 --- /dev/null +++ b/example/default/recordflv.yaml @@ -0,0 +1,8 @@ +global: + loglevel: debug +flv: + record: + enableregexp: true + fragment: 10s + recordlist: + .+: record/$0.flv \ No newline at end of file diff --git a/example/rtmp-push/config1.yaml b/example/rtmp-push/config1.yaml index 2414201..480fcad 100644 --- a/example/rtmp-push/config1.yaml +++ b/example/rtmp-push/config1.yaml @@ -1,6 +1,2 @@ global: loglevel: info - tcp: - listenaddr: :50050 -rtmp: - chunksize: 2048 diff --git a/example/rtmp-push/config2.yaml b/example/rtmp-push/config2.yaml index 704ce9d..8905ddb 100644 --- a/example/rtmp-push/config2.yaml +++ b/example/rtmp-push/config2.yaml @@ -6,7 +6,6 @@ global: listenaddr: :8081 listenaddrtls: :8555 rtmp: - chunksize: 2048 tcp: listenaddr: push: diff --git a/example/rtsp-pull/config1.yaml b/example/rtsp-pull/config1.yaml index 9742515..dfdfe0f 100644 --- a/example/rtsp-pull/config1.yaml +++ b/example/rtsp-pull/config1.yaml @@ -1,10 +1,6 @@ global: loglevel: info - tcp: - listenaddr: :50050 flv: - publish: - pubaudio: false pull: pullonstart: live/test: /Users/dexter/Movies/jb-demo.flv \ No newline at end of file diff --git a/example/rtsp-pull/config2.yaml b/example/rtsp-pull/config2.yaml index 8b31500..f72ab68 100644 --- a/example/rtsp-pull/config2.yaml +++ b/example/rtsp-pull/config2.yaml @@ -1,6 +1,6 @@ global: tcp: - listenaddr: :50051 + listenaddr: :50050 http: listenaddr: :8081 listenaddrtls: :8555 diff --git a/example/rtsp-push/config1.yaml b/example/rtsp-push/config1.yaml index e1ac262..480fcad 100644 --- a/example/rtsp-push/config1.yaml +++ b/example/rtsp-push/config1.yaml @@ -1,4 +1,2 @@ global: loglevel: info - tcp: - listenaddr: :50050 diff --git a/example/rtsp-push/config2.yaml b/example/rtsp-push/config2.yaml index 7e12420..46607fc 100644 --- a/example/rtsp-push/config2.yaml +++ b/example/rtsp-push/config2.yaml @@ -1,6 +1,6 @@ global: tcp: - listenaddr: :50051 + listenaddr: :50050 http: listenaddr: :8081 listenaddrtls: :8555 diff --git a/pkg/task_test.go b/pkg/task_test.go index f825fec..de88568 100644 --- a/pkg/task_test.go +++ b/pkg/task_test.go @@ -147,12 +147,3 @@ func Test_Hooks(t *testing.T) { }) mt.AddTask(&task).WaitStarted() } - -func Test_GetID_IncrementsID(t *testing.T) { - mt := createMarcoTask() - id1 := mt.GetNextID() - id2 := mt.GetNextID() - if id1 == id2 { - t.Errorf("expected different IDs, got %d and %d", id1, id2) - } -} diff --git a/plugin/flv/pkg/record.go b/plugin/flv/pkg/record.go index 9de9c55..bf02626 100644 --- a/plugin/flv/pkg/record.go +++ b/plugin/flv/pkg/record.go @@ -11,6 +11,62 @@ import ( "time" ) +var writeMetaTagQueueTask pkg.MarcoLongTask + +func init() { + pkg.RootTask.AddTask(&writeMetaTagQueueTask) +} + +type writeMetaTagTask struct { + pkg.Task + file *os.File + flags byte + metaData []byte +} + +func (task *writeMetaTagTask) Start() (err error) { + defer func() { + err = task.file.Close() + if info, err := task.file.Stat(); err == nil && info.Size() == 0 { + err = os.Remove(info.Name()) + } + }() + var tempFile *os.File + if tempFile, err = os.CreateTemp("", "*.flv"); err != nil { + task.Error("create temp file failed", "err", err) + return + } else { + defer func() { + err = tempFile.Close() + err = os.Remove(tempFile.Name()) + task.Info("writeMetaData success") + }() + _, err = tempFile.Write([]byte{'F', 'L', 'V', 0x01, task.flags, 0, 0, 0, 9, 0, 0, 0, 0}) + if err != nil { + task.Error(err.Error()) + return + } + err = WriteFLVTag(tempFile, FLV_TAG_TYPE_SCRIPT, 0, task.metaData) + _, err = task.file.Seek(13, io.SeekStart) + if err != nil { + task.Error("writeMetaData Seek failed", "err", err) + return + } + _, err = io.Copy(tempFile, task.file) + if err != nil { + task.Error("writeMetaData Copy failed", "err", err) + return + } + _, err = tempFile.Seek(0, io.SeekStart) + _, err = task.file.Seek(0, io.SeekStart) + _, err = io.Copy(task.file, tempFile) + if err != nil { + task.Error("writeMetaData Copy failed", "err", err) + } + return + } +} + func RecordFlv(ctx *m7s.RecordContext) (err error) { var file *os.File var filepositions []uint64 @@ -23,13 +79,7 @@ func RecordFlv(ctx *m7s.RecordContext) (err error) { suber := ctx.Subscriber ar, vr := suber.AudioReader, suber.VideoReader hasAudio, hasVideo := ar != nil, vr != nil - writeMetaTag := func() { - defer func() { - err = file.Close() - if info, err := file.Stat(); err == nil && info.Size() == 0 { - os.Remove(file.Name()) - } - }() + writeMetaTag := func(file *os.File, filepositions []uint64, times []float64) { var amf rtmp.AMF metaData := rtmp.EcmaArray{ "MetaDataCreator": "m7s/" + m7s.Version, @@ -62,10 +112,6 @@ func RecordFlv(ctx *m7s.RecordContext) (err error) { "filepositions": filepositions, "times": times, } - defer func() { - filepositions = []uint64{0} - times = []float64{0} - }() } amf.Marshals("onMetaData", metaData) offset := amf.Len() + 13 + 15 @@ -79,42 +125,13 @@ func RecordFlv(ctx *m7s.RecordContext) (err error) { "times": times, } } - - if tempFile, err := os.CreateTemp("", "*.flv"); err != nil { - ctx.Error("create temp file failed", "err", err) - return - } else { - defer func() { - tempFile.Close() - os.Remove(tempFile.Name()) - ctx.Info("writeMetaData success") - }() - _, err := tempFile.Write([]byte{'F', 'L', 'V', 0x01, flags, 0, 0, 0, 9, 0, 0, 0, 0}) - if err != nil { - ctx.Error(err.Error()) - return - } - amf.Reset() - marshals := amf.Marshals("onMetaData", metaData) - WriteFLVTag(tempFile, FLV_TAG_TYPE_SCRIPT, 0, marshals) - _, err = file.Seek(13, io.SeekStart) - if err != nil { - ctx.Error("writeMetaData Seek failed", "err", err) - return - } - _, err = io.Copy(tempFile, file) - if err != nil { - ctx.Error("writeMetaData Copy failed", "err", err) - return - } - _, err = tempFile.Seek(0, io.SeekStart) - _, err = file.Seek(0, io.SeekStart) - _, err = io.Copy(file, tempFile) - if err != nil { - ctx.Error("writeMetaData Copy failed", "err", err) - return - } - } + amf.Reset() + marshals := amf.Marshals("onMetaData", metaData) + writeMetaTagQueueTask.AddTask(&writeMetaTagTask{ + file: file, + flags: flags, + metaData: marshals, + }) } if ctx.Append { var metaData rtmp.EcmaArray @@ -151,14 +168,16 @@ func RecordFlv(ctx *m7s.RecordContext) (err error) { file.Write(FLVHead) } if ctx.Fragment == 0 { - defer writeMetaTag() + defer writeMetaTag(file, filepositions, times) } checkFragment := func(absTime uint32) { if ctx.Fragment == 0 { return } if duration = int64(absTime); time.Duration(duration)*time.Millisecond >= ctx.Fragment { - writeMetaTag() + writeMetaTag(file, filepositions, times) + filepositions = []uint64{0} + times = []float64{0} offset = 0 if file, err = os.OpenFile(ctx.FilePath, os.O_CREATE|os.O_RDWR, 0666); err != nil { return diff --git a/plugin/rtmp/index.go b/plugin/rtmp/index.go index a84125b..583d9bd 100644 --- a/plugin/rtmp/index.go +++ b/plugin/rtmp/index.go @@ -3,7 +3,6 @@ package plugin_rtmp import ( "errors" "io" - "m7s.live/m7s/v5/pkg" "maps" "net" "slices" @@ -36,7 +35,6 @@ func (p *RTMPPlugin) GetPullableList() []string { } func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { - receivers := make(map[uint32]*Receiver) var err error nc := NewNetConnection(conn) nc.Logger = p.With("remote", conn.RemoteAddr().String()) @@ -57,8 +55,6 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { continue } switch msg.MessageTypeID { - case RTMP_MSG_CHUNK_SIZE: - nc.Info("msg read chunk size", "readChunkSize", nc.ReadChunkSize) case RTMP_MSG_AMF0_COMMAND: if msg.MsgData == nil { err = errors.New("msg.MsgData is nil") @@ -145,26 +141,24 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { // } // err = nc.SendMessage(RTMP_MSG_AMF0_COMMAND, m) case *PublishMessage: - receiver := &Receiver{ - NetStream: NetStream{ - NetConnection: nc, - StreamID: cmd.StreamId, - }, + ns := NetStream{ + NetConnection: nc, + StreamID: cmd.StreamId, } - receiver.Publisher, err = p.Publish(nc.Context, nc.AppName+"/"+cmd.PublishingName) - receiver.Publisher.Description = nc.Description + var publisher *m7s.Publisher + publisher, err = p.Publish(nc.Context, nc.AppName+"/"+cmd.PublishingName) + publisher.Description = nc.Description if err != nil { - delete(receivers, cmd.StreamId) - err = receiver.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error) + err = ns.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error) } else { - receivers[cmd.StreamId] = receiver - err = receiver.BeginPublish(cmd.TransactionId) + ns.Receivers[cmd.StreamId] = publisher + err = ns.BeginPublish(cmd.TransactionId) } if err != nil { nc.Error("sendMessage publish", "error", err) } else { - receiver.Publisher.OnDispose(func() { - nc.Stop(receiver.StopReason()) + publisher.OnDispose(func() { + nc.Stop(publisher.StopReason()) }) } case *PlayMessage: @@ -181,29 +175,12 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error) } else { err = ns.BeginPlay(cmd.TransactionId) - nc.AddCall(func(task *pkg.Task) error { - audio, video := ns.CreateSender(false) - return m7s.PlayBlock(suber, audio.HandleAudio, video.HandleVideo) - }, nil) + ns.Subscribe(suber) } if err != nil { nc.Error("sendMessage play", "error", err) } } - case RTMP_MSG_AUDIO: - if r, ok := receivers[msg.MessageStreamID]; ok && r.PubAudio { - err = r.WriteAudio(msg.AVData.WrapAudio()) - } else { - msg.AVData.Recycle() - nc.Warn("ReceiveAudio", "MessageStreamID", msg.MessageStreamID) - } - case RTMP_MSG_VIDEO: - if r, ok := receivers[msg.MessageStreamID]; ok && r.PubVideo { - err = r.WriteVideo(msg.AVData.WrapVideo()) - } else { - msg.AVData.Recycle() - nc.Warn("ReceiveVideo", "MessageStreamID", msg.MessageStreamID) - } } } else if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { nc.Info("rtmp client closed") diff --git a/plugin/rtmp/pkg/client.go b/plugin/rtmp/pkg/client.go index b55e6e4..c71f57d 100644 --- a/plugin/rtmp/pkg/client.go +++ b/plugin/rtmp/pkg/client.go @@ -115,22 +115,10 @@ func Pull(p *m7s.PullContext) (err error) { return err } switch msg.MessageTypeID { - case RTMP_MSG_AUDIO: - if p.Publisher.PubAudio { - err = p.Publisher.WriteAudio(msg.AVData.WrapAudio()) - } else { - msg.AVData.Recycle() - } - case RTMP_MSG_VIDEO: - if p.Publisher.PubVideo { - err = p.Publisher.WriteVideo(msg.AVData.WrapVideo()) - } else { - msg.AVData.Recycle() - } case RTMP_MSG_AMF0_COMMAND: cmd := msg.MsgData.(Commander).GetCommand() switch cmd.CommandName { - case "_result": + case Response_Result: if response, ok := msg.MsgData.(*ResponseCreateStreamMessage); ok { connection.StreamID = response.StreamId m := &PlayMessage{} @@ -144,6 +132,7 @@ func Pull(p *m7s.PullContext) (err error) { if len(args) > 0 { m.StreamName += "?" + args.Encode() } + connection.Receivers[response.StreamId] = p.Publisher connection.SendMessage(RTMP_MSG_AMF0_COMMAND, m) // if response, ok := msg.MsgData.(*ResponsePlayMessage); ok { // if response.Object["code"] == "NetStream.Play.Start" { @@ -200,14 +189,7 @@ func Push(p *m7s.PushContext) (err error) { }) } else if response, ok := msg.MsgData.(*ResponsePublishMessage); ok { if response.Infomation["code"] == NetStream_Publish_Start { - audio, video := connection.CreateSender(true) - go func() { - for err == nil { - msg, err = connection.RecvMessage() - } - p.Subscriber.Stop(err) - }() - return m7s.PlayBlock(p.Subscriber, audio.HandleAudio, video.HandleVideo) + connection.Subscribe(p.Subscriber) } else { return errors.New(response.Infomation["code"].(string)) } diff --git a/plugin/rtmp/pkg/handshake.go b/plugin/rtmp/pkg/handshake.go index eadb71c..fb55b49 100644 --- a/plugin/rtmp/pkg/handshake.go +++ b/plugin/rtmp/pkg/handshake.go @@ -89,18 +89,18 @@ func (nc *NetConnection) Handshake(checkC2 bool) (err error) { return nc.complex_handshake(C1) } -func (client *NetConnection) ClientHandshake() (err error) { - C0C1 := client.mediaDataPool.NextN(C1S1_SIZE + 1) - defer client.mediaDataPool.Recycle() +func (nc *NetConnection) ClientHandshake() (err error) { + C0C1 := nc.mediaDataPool.NextN(C1S1_SIZE + 1) + defer nc.mediaDataPool.Recycle() C0C1[0] = RTMP_HANDSHAKE_VERSION - if _, err = client.Write(C0C1); err == nil { + if _, err = nc.Write(C0C1); err == nil { // read S0 S1 - if _, err = io.ReadFull(client.Conn, C0C1); err == nil { + if _, err = io.ReadFull(nc.Conn, C0C1); err == nil { if C0C1[0] != RTMP_HANDSHAKE_VERSION { err = errors.New("S1 C1 Error") // C2 - } else if _, err = client.Write(C0C1[1:]); err == nil { - _, err = io.ReadFull(client.Conn, C0C1[1:]) // S2 + } else if _, err = nc.Write(C0C1[1:]); err == nil { + _, err = io.ReadFull(nc.Conn, C0C1[1:]) // S2 } } } diff --git a/plugin/rtmp/pkg/net-connection.go b/plugin/rtmp/pkg/net-connection.go index fc1d1dd..ff60389 100644 --- a/plugin/rtmp/pkg/net-connection.go +++ b/plugin/rtmp/pkg/net-connection.go @@ -2,6 +2,7 @@ package rtmp import ( "errors" + "m7s.live/m7s/v5" "net" "runtime" "sync/atomic" @@ -57,6 +58,7 @@ type NetConnection struct { chunkHeaderBuf util.Buffer mediaDataPool util.RecyclableMemory writing atomic.Bool // false 可写,true 不可写 + Receivers map[uint32]*m7s.Publisher } func NewNetConnection(conn net.Conn) (ret *NetConnection) { @@ -69,30 +71,31 @@ func NewNetConnection(conn net.Conn) (ret *NetConnection) { bandwidth: RTMP_MAX_CHUNK_SIZE << 3, tmpBuf: make(util.Buffer, 4), chunkHeaderBuf: make(util.Buffer, 0, 20), + Receivers: make(map[uint32]*m7s.Publisher), } ret.mediaDataPool.SetAllocator(util.NewScalableMemoryAllocator(1 << util.MinPowerOf2)) return } -func (conn *NetConnection) Dispose() { - conn.Conn.Close() - conn.BufReader.Recycle() - conn.mediaDataPool.Recycle() +func (nc *NetConnection) Dispose() { + nc.Conn.Close() + nc.BufReader.Recycle() + nc.mediaDataPool.Recycle() } -func (conn *NetConnection) SendStreamID(eventType uint16, streamID uint32) (err error) { - return conn.SendMessage(RTMP_MSG_USER_CONTROL, &StreamIDMessage{UserControlMessage{EventType: eventType}, streamID}) +func (nc *NetConnection) SendStreamID(eventType uint16, streamID uint32) (err error) { + return nc.SendMessage(RTMP_MSG_USER_CONTROL, &StreamIDMessage{UserControlMessage{EventType: eventType}, streamID}) } -func (conn *NetConnection) SendUserControl(eventType uint16) error { - return conn.SendMessage(RTMP_MSG_USER_CONTROL, &UserControlMessage{EventType: eventType}) +func (nc *NetConnection) SendUserControl(eventType uint16) error { + return nc.SendMessage(RTMP_MSG_USER_CONTROL, &UserControlMessage{EventType: eventType}) } -func (conn *NetConnection) ResponseCreateStream(tid uint64, streamID uint32) error { +func (nc *NetConnection) ResponseCreateStream(tid uint64, streamID uint32) error { m := &ResponseCreateStreamMessage{} m.CommandName = Response_Result m.TransactionId = tid m.StreamId = streamID - return conn.SendMessage(RTMP_MSG_AMF0_COMMAND, m) + return nc.SendMessage(RTMP_MSG_AMF0_COMMAND, m) } // func (conn *NetConnection) SendCommand(message string, args any) error { @@ -110,21 +113,21 @@ func (conn *NetConnection) ResponseCreateStream(tid uint64, streamID uint32) err // return errors.New("send message no exist") // } -func (conn *NetConnection) readChunk() (msg *Chunk, err error) { - head, err := conn.ReadByte() +func (nc *NetConnection) readChunk() (msg *Chunk, err error) { + head, err := nc.ReadByte() if err != nil { return nil, err } - conn.readSeqNum++ + nc.readSeqNum++ ChunkStreamID := uint32(head & 0x3f) // 0011 1111 ChunkType := head >> 6 // 1100 0000 // 如果块流ID为0,1的话,就需要计算. - ChunkStreamID, err = conn.readChunkStreamID(ChunkStreamID) + ChunkStreamID, err = nc.readChunkStreamID(ChunkStreamID) if err != nil { return nil, errors.New("get chunk stream id error :" + err.Error()) } //println("ChunkStreamID:", ChunkStreamID, "ChunkType:", ChunkType) - chunk, ok := conn.incommingChunks[ChunkStreamID] + chunk, ok := nc.incommingChunks[ChunkStreamID] if ChunkType != 3 && ok && chunk.bufLen > 0 { // 如果块类型不为3,那么这个rtmp的body应该为空. @@ -132,10 +135,10 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) { } if !ok { chunk = &Chunk{} - conn.incommingChunks[ChunkStreamID] = chunk + nc.incommingChunks[ChunkStreamID] = chunk } - if err = conn.readChunkType(&chunk.ChunkHeader, ChunkType); err != nil { + if err = nc.readChunkType(&chunk.ChunkHeader, ChunkType); err != nil { return nil, errors.New("get chunk type error :" + err.Error()) } msgLen := int(chunk.MessageLength) @@ -143,19 +146,19 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) { return nil, nil } var bufSize = 0 - if unRead := msgLen - chunk.bufLen; unRead < conn.ReadChunkSize { + if unRead := msgLen - chunk.bufLen; unRead < nc.ReadChunkSize { bufSize = unRead } else { - bufSize = conn.ReadChunkSize + bufSize = nc.ReadChunkSize } - conn.readSeqNum += uint32(bufSize) + nc.readSeqNum += uint32(bufSize) if chunk.bufLen == 0 { chunk.AVData.RecyclableMemory = util.RecyclableMemory{} - chunk.AVData.SetAllocator(conn.mediaDataPool.GetAllocator()) + chunk.AVData.SetAllocator(nc.mediaDataPool.GetAllocator()) chunk.AVData.NextN(msgLen) } buffer := chunk.AVData.Buffers[0] - err = conn.ReadRange(bufSize, func(buf []byte) { + err = nc.ReadRange(bufSize, func(buf []byte) { copy(buffer[chunk.bufLen:], buf) chunk.bufLen += len(buf) }) @@ -176,14 +179,14 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) { return } -func (conn *NetConnection) readChunkStreamID(csid uint32) (chunkStreamID uint32, err error) { +func (nc *NetConnection) readChunkStreamID(csid uint32) (chunkStreamID uint32, err error) { chunkStreamID = csid switch csid { case 0: { - u8, err := conn.ReadByte() - conn.readSeqNum++ + u8, err := nc.ReadByte() + nc.readSeqNum++ if err != nil { return 0, err } @@ -192,15 +195,15 @@ func (conn *NetConnection) readChunkStreamID(csid uint32) (chunkStreamID uint32, } case 1: { - u16_0, err1 := conn.ReadByte() + u16_0, err1 := nc.ReadByte() if err1 != nil { return 0, err1 } - u16_1, err1 := conn.ReadByte() + u16_1, err1 := nc.ReadByte() if err1 != nil { return 0, err1 } - conn.readSeqNum += 2 + nc.readSeqNum += 2 chunkStreamID = 64 + uint32(u16_0) + (uint32(u16_1) << 8) } } @@ -208,27 +211,27 @@ func (conn *NetConnection) readChunkStreamID(csid uint32) (chunkStreamID uint32, return chunkStreamID, nil } -func (conn *NetConnection) readChunkType(h *ChunkHeader, chunkType byte) (err error) { +func (nc *NetConnection) readChunkType(h *ChunkHeader, chunkType byte) (err error) { if chunkType == 3 { // 3个字节的时间戳 } else { // Timestamp 3 bytes - if h.Timestamp, err = conn.ReadBE32(3); err != nil { + if h.Timestamp, err = nc.ReadBE32(3); err != nil { return err } if chunkType != 2 { - if h.MessageLength, err = conn.ReadBE32(3); err != nil { + if h.MessageLength, err = nc.ReadBE32(3); err != nil { return err } // Message Type ID 1 bytes - if h.MessageTypeID, err = conn.ReadByte(); err != nil { + if h.MessageTypeID, err = nc.ReadByte(); err != nil { return err } - conn.readSeqNum++ + nc.readSeqNum++ if chunkType == 0 { // Message Stream ID 4bytes - if h.MessageStreamID, err = conn.ReadBE32(4); err != nil { // 读取Message Stream ID + if h.MessageStreamID, err = nc.ReadBE32(4); err != nil { // 读取Message Stream ID return err } } @@ -237,7 +240,7 @@ func (conn *NetConnection) readChunkType(h *ChunkHeader, chunkType byte) (err er // ExtendTimestamp 4 bytes if h.Timestamp >= 0xffffff { // 对于type 0的chunk,绝对时间戳在这里表示,如果时间戳值大于等于0xffffff(16777215),该值必须是0xffffff,且时间戳扩展字段必须发送,其他情况没有要求 - if h.Timestamp, err = conn.ReadBE32(4); err != nil { + if h.Timestamp, err = nc.ReadBE32(4); err != nil { return err } switch chunkType { @@ -258,75 +261,90 @@ func (conn *NetConnection) readChunkType(h *ChunkHeader, chunkType byte) (err er return nil } -func (conn *NetConnection) RecvMessage() (msg *Chunk, err error) { - if conn.readSeqNum >= conn.bandwidth { - conn.totalRead += conn.readSeqNum - conn.readSeqNum = 0 - err = conn.SendMessage(RTMP_MSG_ACK, Uint32Message(conn.totalRead)) +func (nc *NetConnection) RecvMessage() (msg *Chunk, err error) { + if nc.readSeqNum >= nc.bandwidth { + nc.totalRead += nc.readSeqNum + nc.readSeqNum = 0 + err = nc.SendMessage(RTMP_MSG_ACK, Uint32Message(nc.totalRead)) } for msg == nil && err == nil { - if msg, err = conn.readChunk(); msg != nil && err == nil { + if msg, err = nc.readChunk(); msg != nil && err == nil { switch msg.MessageTypeID { case RTMP_MSG_CHUNK_SIZE: - conn.ReadChunkSize = int(msg.MsgData.(Uint32Message)) + nc.ReadChunkSize = int(msg.MsgData.(Uint32Message)) + nc.Info("msg read chunk size", "readChunkSize", nc.ReadChunkSize) case RTMP_MSG_ABORT: - delete(conn.incommingChunks, uint32(msg.MsgData.(Uint32Message))) + delete(nc.incommingChunks, uint32(msg.MsgData.(Uint32Message))) case RTMP_MSG_ACK, RTMP_MSG_EDGE: case RTMP_MSG_USER_CONTROL: if _, ok := msg.MsgData.(*PingRequestMessage); ok { - conn.SendUserControl(RTMP_USER_PING_RESPONSE) + nc.SendUserControl(RTMP_USER_PING_RESPONSE) } case RTMP_MSG_ACK_SIZE: - conn.bandwidth = uint32(msg.MsgData.(Uint32Message)) + nc.bandwidth = uint32(msg.MsgData.(Uint32Message)) case RTMP_MSG_BANDWIDTH: - conn.bandwidth = msg.MsgData.(*SetPeerBandwidthMessage).AcknowledgementWindowsize - case RTMP_MSG_AMF0_COMMAND, RTMP_MSG_AUDIO, RTMP_MSG_VIDEO: + nc.bandwidth = msg.MsgData.(*SetPeerBandwidthMessage).AcknowledgementWindowsize + case RTMP_MSG_AMF0_COMMAND: return msg, err + case RTMP_MSG_AUDIO: + if r, ok := nc.Receivers[msg.MessageStreamID]; ok && r.PubAudio { + err = r.WriteAudio(msg.AVData.WrapAudio()) + } else { + msg.AVData.Recycle() + nc.Warn("ReceiveAudio", "MessageStreamID", msg.MessageStreamID) + } + case RTMP_MSG_VIDEO: + if r, ok := nc.Receivers[msg.MessageStreamID]; ok && r.PubVideo { + err = r.WriteVideo(msg.AVData.WrapVideo()) + } else { + msg.AVData.Recycle() + nc.Warn("ReceiveVideo", "MessageStreamID", msg.MessageStreamID) + } } } } return } -func (conn *NetConnection) SendMessage(t byte, msg RtmpMessage) (err error) { - if conn == nil { +func (nc *NetConnection) SendMessage(t byte, msg RtmpMessage) (err error) { + if nc == nil { return errors.New("connection is nil") } - if conn.writeSeqNum > conn.bandwidth { - conn.totalWrite += conn.writeSeqNum - conn.writeSeqNum = 0 - err = conn.SendMessage(RTMP_MSG_ACK, Uint32Message(conn.totalWrite)) - err = conn.SendStreamID(RTMP_USER_PING_REQUEST, 0) + if nc.writeSeqNum > nc.bandwidth { + nc.totalWrite += nc.writeSeqNum + nc.writeSeqNum = 0 + err = nc.SendMessage(RTMP_MSG_ACK, Uint32Message(nc.totalWrite)) + err = nc.SendStreamID(RTMP_USER_PING_REQUEST, 0) } - for !conn.writing.CompareAndSwap(false, true) { + for !nc.writing.CompareAndSwap(false, true) { runtime.Gosched() } - defer conn.writing.Store(false) - conn.tmpBuf.Reset() - amf := AMF{conn.tmpBuf} - if conn.ObjectEncoding == 0 { + defer nc.writing.Store(false) + nc.tmpBuf.Reset() + amf := AMF{nc.tmpBuf} + if nc.ObjectEncoding == 0 { msg.Encode(&amf) } else { amf := AMF3{AMF: amf} msg.Encode(&amf) } - conn.tmpBuf = amf.Buffer + nc.tmpBuf = amf.Buffer head := newChunkHeader(t) - head.MessageLength = uint32(conn.tmpBuf.Len()) + head.MessageLength = uint32(nc.tmpBuf.Len()) if sid, ok := msg.(HaveStreamID); ok { head.MessageStreamID = sid.GetStreamID() } - return conn.sendChunk(net.Buffers{conn.tmpBuf}, head, RTMP_CHUNK_HEAD_12) + return nc.sendChunk(net.Buffers{nc.tmpBuf}, head, RTMP_CHUNK_HEAD_12) } -func (conn *NetConnection) sendChunk(data net.Buffers, head *ChunkHeader, headType byte) (err error) { - conn.chunkHeaderBuf.Reset() - head.WriteTo(headType, &conn.chunkHeaderBuf) - chunks := net.Buffers{conn.chunkHeaderBuf} - var chunk3 util.Buffer = conn.chunkHeaderBuf[conn.chunkHeaderBuf.Len():20] +func (nc *NetConnection) sendChunk(data net.Buffers, head *ChunkHeader, headType byte) (err error) { + nc.chunkHeaderBuf.Reset() + head.WriteTo(headType, &nc.chunkHeaderBuf) + chunks := net.Buffers{nc.chunkHeaderBuf} + var chunk3 util.Buffer = nc.chunkHeaderBuf[nc.chunkHeaderBuf.Len():20] head.WriteTo(RTMP_CHUNK_HEAD_1, &chunk3) r := util.NewReadableBuffersFromBytes(data...) for { - r.RangeN(conn.WriteChunkSize, func(buf []byte) { + r.RangeN(nc.WriteChunkSize, func(buf []byte) { chunks = append(chunks, buf) }) if r.Length <= 0 { @@ -336,7 +354,7 @@ func (conn *NetConnection) sendChunk(data net.Buffers, head *ChunkHeader, headTy chunks = append(chunks, chunk3) } var nw int64 - nw, err = chunks.WriteTo(conn.Conn) - conn.writeSeqNum += uint32(nw) + nw, err = chunks.WriteTo(nc.Conn) + nc.writeSeqNum += uint32(nw) return err } diff --git a/plugin/rtmp/pkg/net-stream.go b/plugin/rtmp/pkg/net-stream.go index 4d7c61a..fb4cf2c 100644 --- a/plugin/rtmp/pkg/net-stream.go +++ b/plugin/rtmp/pkg/net-stream.go @@ -1,5 +1,10 @@ package rtmp +import ( + "m7s.live/m7s/v5" + "m7s.live/m7s/v5/pkg" +) + type NetStream struct { *NetConnection StreamID uint32 @@ -62,3 +67,10 @@ func (ns *NetStream) BeginPlay(tid uint64) (err error) { err = ns.Response(tid, NetStream_Play_Start, Level_Status) return } + +func (ns *NetStream) Subscribe(suber *m7s.Subscriber) { + ns.AddCall(func(task *pkg.Task) error { + audio, video := ns.CreateSender(false) + return m7s.PlayBlock(suber, audio.HandleAudio, video.HandleVideo) + }, nil) +} diff --git a/plugin/rtsp/index.go b/plugin/rtsp/index.go index eb09a62..d9ceea8 100644 --- a/plugin/rtsp/index.go +++ b/plugin/rtsp/index.go @@ -43,15 +43,13 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) { var err error nc := NewNetConnection(conn) nc.Logger = logger + p.AddTask(nc).WaitStarted() defer func() { - nc.Destroy() + nc.Stop(err) if p := recover(); p != nil { err = p.(error) logger.Error(err.Error(), "stack", string(debug.Stack())) } - if receiver != nil { - receiver.Stop(err) - } }() var req *util.Request var sendMode bool @@ -106,8 +104,9 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) { return } - receiver = &Receiver{} - receiver.NetConnection = nc + receiver = &Receiver{ + Stream: &Stream{NetConnection: nc}, + } if receiver.Publisher, err = p.Publish(nc, strings.TrimPrefix(nc.URL.Path, "/")); err != nil { receiver = nil err = nc.WriteResponse(&util.Response{ @@ -122,11 +121,14 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) { if err = nc.WriteResponse(res); err != nil { return } - + receiver.Publisher.OnDispose(func() { + nc.Stop(receiver.Publisher.StopReason()) + }) case MethodDescribe: sendMode = true - sender = &Sender{} - sender.NetConnection = nc + sender = &Sender{ + Stream: &Stream{NetConnection: nc}, + } sender.Subscriber, err = p.Subscribe(nc, strings.TrimPrefix(nc.URL.Path, "/")) if err != nil { res := &util.Response{ diff --git a/plugin/rtsp/pkg/connection.go b/plugin/rtsp/pkg/connection.go index 4b52506..5c62a12 100644 --- a/plugin/rtsp/pkg/connection.go +++ b/plugin/rtsp/pkg/connection.go @@ -64,7 +64,7 @@ func (c *NetConnection) StopWrite() { c.writing.Store(false) } -func (c *NetConnection) Destroy() { +func (c *NetConnection) Dispose() { c.conn.Close() c.BufReader.Recycle() c.MemoryAllocator.Recycle() diff --git a/plugin/rtsp/pkg/net-stream.go b/plugin/rtsp/pkg/net-stream.go index 9b20de7..22b5add 100644 --- a/plugin/rtsp/pkg/net-stream.go +++ b/plugin/rtsp/pkg/net-stream.go @@ -238,6 +238,6 @@ func (c *Stream) Teardown() (err error) { func (ns *Stream) disconnect() { if ns != nil && ns.NetConnection != nil { _ = ns.Teardown() - ns.NetConnection.Destroy() + ns.NetConnection.Dispose() } }