fix: rtsp pull and push

This commit is contained in:
langhuihui
2024-08-13 14:44:35 +08:00
parent 43c0fc7be4
commit 78e8d74fec
18 changed files with 213 additions and 236 deletions

View File

@@ -18,27 +18,6 @@ import (
_ "m7s.live/m7s/v5/plugin/webrtc" _ "m7s.live/m7s/v5/plugin/webrtc"
) )
// func init() {
// //全局推流鉴权
// m7s.DefaultServer.OnAuthPubs["RTMP"] = func(p *util.Promise[*m7s.Publisher]) {
// var pub = p.Value
// if strings.Contains(pub.StreamPath, "20A222800207-2") {
// p.Fulfill(nil)
// } else {
// p.Fulfill(errors.New("auth failed"))
// }
// }
// //全局播放鉴权
// m7s.DefaultServer.OnAuthSubs["RTMP"] = func(p *util.Promise[*m7s.Subscriber]) {
// var sub = p.Value
// if strings.Contains(sub.StreamPath, "20A222800207-22") {
// p.Fulfill(nil)
// } else {
// p.Fulfill(errors.New("auth failed"))
// }
// }
// }
func main() { func main() {
conf := flag.String("c", "config.yaml", "config file") conf := flag.String("c", "config.yaml", "config file")
flag.Parse() flag.Parse()

View File

@@ -0,0 +1,8 @@
global:
loglevel: debug
flv:
record:
enableregexp: true
fragment: 10s
recordlist:
.+: record/$0.flv

View File

@@ -1,6 +1,2 @@
global: global:
loglevel: info loglevel: info
tcp:
listenaddr: :50050
rtmp:
chunksize: 2048

View File

@@ -6,7 +6,6 @@ global:
listenaddr: :8081 listenaddr: :8081
listenaddrtls: :8555 listenaddrtls: :8555
rtmp: rtmp:
chunksize: 2048
tcp: tcp:
listenaddr: listenaddr:
push: push:

View File

@@ -1,10 +1,6 @@
global: global:
loglevel: info loglevel: info
tcp:
listenaddr: :50050
flv: flv:
publish:
pubaudio: false
pull: pull:
pullonstart: pullonstart:
live/test: /Users/dexter/Movies/jb-demo.flv live/test: /Users/dexter/Movies/jb-demo.flv

View File

@@ -1,6 +1,6 @@
global: global:
tcp: tcp:
listenaddr: :50051 listenaddr: :50050
http: http:
listenaddr: :8081 listenaddr: :8081
listenaddrtls: :8555 listenaddrtls: :8555

View File

@@ -1,4 +1,2 @@
global: global:
loglevel: info loglevel: info
tcp:
listenaddr: :50050

View File

@@ -1,6 +1,6 @@
global: global:
tcp: tcp:
listenaddr: :50051 listenaddr: :50050
http: http:
listenaddr: :8081 listenaddr: :8081
listenaddrtls: :8555 listenaddrtls: :8555

View File

@@ -147,12 +147,3 @@ func Test_Hooks(t *testing.T) {
}) })
mt.AddTask(&task).WaitStarted() mt.AddTask(&task).WaitStarted()
} }
func Test_GetID_IncrementsID(t *testing.T) {
mt := createMarcoTask()
id1 := mt.GetNextID()
id2 := mt.GetNextID()
if id1 == id2 {
t.Errorf("expected different IDs, got %d and %d", id1, id2)
}
}

View File

@@ -11,6 +11,62 @@ import (
"time" "time"
) )
var writeMetaTagQueueTask pkg.MarcoLongTask
func init() {
pkg.RootTask.AddTask(&writeMetaTagQueueTask)
}
type writeMetaTagTask struct {
pkg.Task
file *os.File
flags byte
metaData []byte
}
func (task *writeMetaTagTask) Start() (err error) {
defer func() {
err = task.file.Close()
if info, err := task.file.Stat(); err == nil && info.Size() == 0 {
err = os.Remove(info.Name())
}
}()
var tempFile *os.File
if tempFile, err = os.CreateTemp("", "*.flv"); err != nil {
task.Error("create temp file failed", "err", err)
return
} else {
defer func() {
err = tempFile.Close()
err = os.Remove(tempFile.Name())
task.Info("writeMetaData success")
}()
_, err = tempFile.Write([]byte{'F', 'L', 'V', 0x01, task.flags, 0, 0, 0, 9, 0, 0, 0, 0})
if err != nil {
task.Error(err.Error())
return
}
err = WriteFLVTag(tempFile, FLV_TAG_TYPE_SCRIPT, 0, task.metaData)
_, err = task.file.Seek(13, io.SeekStart)
if err != nil {
task.Error("writeMetaData Seek failed", "err", err)
return
}
_, err = io.Copy(tempFile, task.file)
if err != nil {
task.Error("writeMetaData Copy failed", "err", err)
return
}
_, err = tempFile.Seek(0, io.SeekStart)
_, err = task.file.Seek(0, io.SeekStart)
_, err = io.Copy(task.file, tempFile)
if err != nil {
task.Error("writeMetaData Copy failed", "err", err)
}
return
}
}
func RecordFlv(ctx *m7s.RecordContext) (err error) { func RecordFlv(ctx *m7s.RecordContext) (err error) {
var file *os.File var file *os.File
var filepositions []uint64 var filepositions []uint64
@@ -23,13 +79,7 @@ func RecordFlv(ctx *m7s.RecordContext) (err error) {
suber := ctx.Subscriber suber := ctx.Subscriber
ar, vr := suber.AudioReader, suber.VideoReader ar, vr := suber.AudioReader, suber.VideoReader
hasAudio, hasVideo := ar != nil, vr != nil hasAudio, hasVideo := ar != nil, vr != nil
writeMetaTag := func() { writeMetaTag := func(file *os.File, filepositions []uint64, times []float64) {
defer func() {
err = file.Close()
if info, err := file.Stat(); err == nil && info.Size() == 0 {
os.Remove(file.Name())
}
}()
var amf rtmp.AMF var amf rtmp.AMF
metaData := rtmp.EcmaArray{ metaData := rtmp.EcmaArray{
"MetaDataCreator": "m7s/" + m7s.Version, "MetaDataCreator": "m7s/" + m7s.Version,
@@ -62,10 +112,6 @@ func RecordFlv(ctx *m7s.RecordContext) (err error) {
"filepositions": filepositions, "filepositions": filepositions,
"times": times, "times": times,
} }
defer func() {
filepositions = []uint64{0}
times = []float64{0}
}()
} }
amf.Marshals("onMetaData", metaData) amf.Marshals("onMetaData", metaData)
offset := amf.Len() + 13 + 15 offset := amf.Len() + 13 + 15
@@ -79,42 +125,13 @@ func RecordFlv(ctx *m7s.RecordContext) (err error) {
"times": times, "times": times,
} }
} }
amf.Reset()
if tempFile, err := os.CreateTemp("", "*.flv"); err != nil { marshals := amf.Marshals("onMetaData", metaData)
ctx.Error("create temp file failed", "err", err) writeMetaTagQueueTask.AddTask(&writeMetaTagTask{
return file: file,
} else { flags: flags,
defer func() { metaData: marshals,
tempFile.Close() })
os.Remove(tempFile.Name())
ctx.Info("writeMetaData success")
}()
_, err := tempFile.Write([]byte{'F', 'L', 'V', 0x01, flags, 0, 0, 0, 9, 0, 0, 0, 0})
if err != nil {
ctx.Error(err.Error())
return
}
amf.Reset()
marshals := amf.Marshals("onMetaData", metaData)
WriteFLVTag(tempFile, FLV_TAG_TYPE_SCRIPT, 0, marshals)
_, err = file.Seek(13, io.SeekStart)
if err != nil {
ctx.Error("writeMetaData Seek failed", "err", err)
return
}
_, err = io.Copy(tempFile, file)
if err != nil {
ctx.Error("writeMetaData Copy failed", "err", err)
return
}
_, err = tempFile.Seek(0, io.SeekStart)
_, err = file.Seek(0, io.SeekStart)
_, err = io.Copy(file, tempFile)
if err != nil {
ctx.Error("writeMetaData Copy failed", "err", err)
return
}
}
} }
if ctx.Append { if ctx.Append {
var metaData rtmp.EcmaArray var metaData rtmp.EcmaArray
@@ -151,14 +168,16 @@ func RecordFlv(ctx *m7s.RecordContext) (err error) {
file.Write(FLVHead) file.Write(FLVHead)
} }
if ctx.Fragment == 0 { if ctx.Fragment == 0 {
defer writeMetaTag() defer writeMetaTag(file, filepositions, times)
} }
checkFragment := func(absTime uint32) { checkFragment := func(absTime uint32) {
if ctx.Fragment == 0 { if ctx.Fragment == 0 {
return return
} }
if duration = int64(absTime); time.Duration(duration)*time.Millisecond >= ctx.Fragment { if duration = int64(absTime); time.Duration(duration)*time.Millisecond >= ctx.Fragment {
writeMetaTag() writeMetaTag(file, filepositions, times)
filepositions = []uint64{0}
times = []float64{0}
offset = 0 offset = 0
if file, err = os.OpenFile(ctx.FilePath, os.O_CREATE|os.O_RDWR, 0666); err != nil { if file, err = os.OpenFile(ctx.FilePath, os.O_CREATE|os.O_RDWR, 0666); err != nil {
return return

View File

@@ -3,7 +3,6 @@ package plugin_rtmp
import ( import (
"errors" "errors"
"io" "io"
"m7s.live/m7s/v5/pkg"
"maps" "maps"
"net" "net"
"slices" "slices"
@@ -36,7 +35,6 @@ func (p *RTMPPlugin) GetPullableList() []string {
} }
func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
receivers := make(map[uint32]*Receiver)
var err error var err error
nc := NewNetConnection(conn) nc := NewNetConnection(conn)
nc.Logger = p.With("remote", conn.RemoteAddr().String()) nc.Logger = p.With("remote", conn.RemoteAddr().String())
@@ -57,8 +55,6 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
continue continue
} }
switch msg.MessageTypeID { switch msg.MessageTypeID {
case RTMP_MSG_CHUNK_SIZE:
nc.Info("msg read chunk size", "readChunkSize", nc.ReadChunkSize)
case RTMP_MSG_AMF0_COMMAND: case RTMP_MSG_AMF0_COMMAND:
if msg.MsgData == nil { if msg.MsgData == nil {
err = errors.New("msg.MsgData is nil") err = errors.New("msg.MsgData is nil")
@@ -145,26 +141,24 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
// } // }
// err = nc.SendMessage(RTMP_MSG_AMF0_COMMAND, m) // err = nc.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
case *PublishMessage: case *PublishMessage:
receiver := &Receiver{ ns := NetStream{
NetStream: NetStream{ NetConnection: nc,
NetConnection: nc, StreamID: cmd.StreamId,
StreamID: cmd.StreamId,
},
} }
receiver.Publisher, err = p.Publish(nc.Context, nc.AppName+"/"+cmd.PublishingName) var publisher *m7s.Publisher
receiver.Publisher.Description = nc.Description publisher, err = p.Publish(nc.Context, nc.AppName+"/"+cmd.PublishingName)
publisher.Description = nc.Description
if err != nil { if err != nil {
delete(receivers, cmd.StreamId) err = ns.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error)
err = receiver.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error)
} else { } else {
receivers[cmd.StreamId] = receiver ns.Receivers[cmd.StreamId] = publisher
err = receiver.BeginPublish(cmd.TransactionId) err = ns.BeginPublish(cmd.TransactionId)
} }
if err != nil { if err != nil {
nc.Error("sendMessage publish", "error", err) nc.Error("sendMessage publish", "error", err)
} else { } else {
receiver.Publisher.OnDispose(func() { publisher.OnDispose(func() {
nc.Stop(receiver.StopReason()) nc.Stop(publisher.StopReason())
}) })
} }
case *PlayMessage: case *PlayMessage:
@@ -181,29 +175,12 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error) err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error)
} else { } else {
err = ns.BeginPlay(cmd.TransactionId) err = ns.BeginPlay(cmd.TransactionId)
nc.AddCall(func(task *pkg.Task) error { ns.Subscribe(suber)
audio, video := ns.CreateSender(false)
return m7s.PlayBlock(suber, audio.HandleAudio, video.HandleVideo)
}, nil)
} }
if err != nil { if err != nil {
nc.Error("sendMessage play", "error", err) nc.Error("sendMessage play", "error", err)
} }
} }
case RTMP_MSG_AUDIO:
if r, ok := receivers[msg.MessageStreamID]; ok && r.PubAudio {
err = r.WriteAudio(msg.AVData.WrapAudio())
} else {
msg.AVData.Recycle()
nc.Warn("ReceiveAudio", "MessageStreamID", msg.MessageStreamID)
}
case RTMP_MSG_VIDEO:
if r, ok := receivers[msg.MessageStreamID]; ok && r.PubVideo {
err = r.WriteVideo(msg.AVData.WrapVideo())
} else {
msg.AVData.Recycle()
nc.Warn("ReceiveVideo", "MessageStreamID", msg.MessageStreamID)
}
} }
} else if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { } else if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
nc.Info("rtmp client closed") nc.Info("rtmp client closed")

View File

@@ -115,22 +115,10 @@ func Pull(p *m7s.PullContext) (err error) {
return err return err
} }
switch msg.MessageTypeID { switch msg.MessageTypeID {
case RTMP_MSG_AUDIO:
if p.Publisher.PubAudio {
err = p.Publisher.WriteAudio(msg.AVData.WrapAudio())
} else {
msg.AVData.Recycle()
}
case RTMP_MSG_VIDEO:
if p.Publisher.PubVideo {
err = p.Publisher.WriteVideo(msg.AVData.WrapVideo())
} else {
msg.AVData.Recycle()
}
case RTMP_MSG_AMF0_COMMAND: case RTMP_MSG_AMF0_COMMAND:
cmd := msg.MsgData.(Commander).GetCommand() cmd := msg.MsgData.(Commander).GetCommand()
switch cmd.CommandName { switch cmd.CommandName {
case "_result": case Response_Result:
if response, ok := msg.MsgData.(*ResponseCreateStreamMessage); ok { if response, ok := msg.MsgData.(*ResponseCreateStreamMessage); ok {
connection.StreamID = response.StreamId connection.StreamID = response.StreamId
m := &PlayMessage{} m := &PlayMessage{}
@@ -144,6 +132,7 @@ func Pull(p *m7s.PullContext) (err error) {
if len(args) > 0 { if len(args) > 0 {
m.StreamName += "?" + args.Encode() m.StreamName += "?" + args.Encode()
} }
connection.Receivers[response.StreamId] = p.Publisher
connection.SendMessage(RTMP_MSG_AMF0_COMMAND, m) connection.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
// if response, ok := msg.MsgData.(*ResponsePlayMessage); ok { // if response, ok := msg.MsgData.(*ResponsePlayMessage); ok {
// if response.Object["code"] == "NetStream.Play.Start" { // if response.Object["code"] == "NetStream.Play.Start" {
@@ -200,14 +189,7 @@ func Push(p *m7s.PushContext) (err error) {
}) })
} else if response, ok := msg.MsgData.(*ResponsePublishMessage); ok { } else if response, ok := msg.MsgData.(*ResponsePublishMessage); ok {
if response.Infomation["code"] == NetStream_Publish_Start { if response.Infomation["code"] == NetStream_Publish_Start {
audio, video := connection.CreateSender(true) connection.Subscribe(p.Subscriber)
go func() {
for err == nil {
msg, err = connection.RecvMessage()
}
p.Subscriber.Stop(err)
}()
return m7s.PlayBlock(p.Subscriber, audio.HandleAudio, video.HandleVideo)
} else { } else {
return errors.New(response.Infomation["code"].(string)) return errors.New(response.Infomation["code"].(string))
} }

View File

@@ -89,18 +89,18 @@ func (nc *NetConnection) Handshake(checkC2 bool) (err error) {
return nc.complex_handshake(C1) return nc.complex_handshake(C1)
} }
func (client *NetConnection) ClientHandshake() (err error) { func (nc *NetConnection) ClientHandshake() (err error) {
C0C1 := client.mediaDataPool.NextN(C1S1_SIZE + 1) C0C1 := nc.mediaDataPool.NextN(C1S1_SIZE + 1)
defer client.mediaDataPool.Recycle() defer nc.mediaDataPool.Recycle()
C0C1[0] = RTMP_HANDSHAKE_VERSION C0C1[0] = RTMP_HANDSHAKE_VERSION
if _, err = client.Write(C0C1); err == nil { if _, err = nc.Write(C0C1); err == nil {
// read S0 S1 // read S0 S1
if _, err = io.ReadFull(client.Conn, C0C1); err == nil { if _, err = io.ReadFull(nc.Conn, C0C1); err == nil {
if C0C1[0] != RTMP_HANDSHAKE_VERSION { if C0C1[0] != RTMP_HANDSHAKE_VERSION {
err = errors.New("S1 C1 Error") err = errors.New("S1 C1 Error")
// C2 // C2
} else if _, err = client.Write(C0C1[1:]); err == nil { } else if _, err = nc.Write(C0C1[1:]); err == nil {
_, err = io.ReadFull(client.Conn, C0C1[1:]) // S2 _, err = io.ReadFull(nc.Conn, C0C1[1:]) // S2
} }
} }
} }

View File

@@ -2,6 +2,7 @@ package rtmp
import ( import (
"errors" "errors"
"m7s.live/m7s/v5"
"net" "net"
"runtime" "runtime"
"sync/atomic" "sync/atomic"
@@ -57,6 +58,7 @@ type NetConnection struct {
chunkHeaderBuf util.Buffer chunkHeaderBuf util.Buffer
mediaDataPool util.RecyclableMemory mediaDataPool util.RecyclableMemory
writing atomic.Bool // false 可写true 不可写 writing atomic.Bool // false 可写true 不可写
Receivers map[uint32]*m7s.Publisher
} }
func NewNetConnection(conn net.Conn) (ret *NetConnection) { func NewNetConnection(conn net.Conn) (ret *NetConnection) {
@@ -69,30 +71,31 @@ func NewNetConnection(conn net.Conn) (ret *NetConnection) {
bandwidth: RTMP_MAX_CHUNK_SIZE << 3, bandwidth: RTMP_MAX_CHUNK_SIZE << 3,
tmpBuf: make(util.Buffer, 4), tmpBuf: make(util.Buffer, 4),
chunkHeaderBuf: make(util.Buffer, 0, 20), chunkHeaderBuf: make(util.Buffer, 0, 20),
Receivers: make(map[uint32]*m7s.Publisher),
} }
ret.mediaDataPool.SetAllocator(util.NewScalableMemoryAllocator(1 << util.MinPowerOf2)) ret.mediaDataPool.SetAllocator(util.NewScalableMemoryAllocator(1 << util.MinPowerOf2))
return return
} }
func (conn *NetConnection) Dispose() { func (nc *NetConnection) Dispose() {
conn.Conn.Close() nc.Conn.Close()
conn.BufReader.Recycle() nc.BufReader.Recycle()
conn.mediaDataPool.Recycle() nc.mediaDataPool.Recycle()
} }
func (conn *NetConnection) SendStreamID(eventType uint16, streamID uint32) (err error) { func (nc *NetConnection) SendStreamID(eventType uint16, streamID uint32) (err error) {
return conn.SendMessage(RTMP_MSG_USER_CONTROL, &StreamIDMessage{UserControlMessage{EventType: eventType}, streamID}) return nc.SendMessage(RTMP_MSG_USER_CONTROL, &StreamIDMessage{UserControlMessage{EventType: eventType}, streamID})
} }
func (conn *NetConnection) SendUserControl(eventType uint16) error { func (nc *NetConnection) SendUserControl(eventType uint16) error {
return conn.SendMessage(RTMP_MSG_USER_CONTROL, &UserControlMessage{EventType: eventType}) return nc.SendMessage(RTMP_MSG_USER_CONTROL, &UserControlMessage{EventType: eventType})
} }
func (conn *NetConnection) ResponseCreateStream(tid uint64, streamID uint32) error { func (nc *NetConnection) ResponseCreateStream(tid uint64, streamID uint32) error {
m := &ResponseCreateStreamMessage{} m := &ResponseCreateStreamMessage{}
m.CommandName = Response_Result m.CommandName = Response_Result
m.TransactionId = tid m.TransactionId = tid
m.StreamId = streamID m.StreamId = streamID
return conn.SendMessage(RTMP_MSG_AMF0_COMMAND, m) return nc.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
} }
// func (conn *NetConnection) SendCommand(message string, args any) error { // func (conn *NetConnection) SendCommand(message string, args any) error {
@@ -110,21 +113,21 @@ func (conn *NetConnection) ResponseCreateStream(tid uint64, streamID uint32) err
// return errors.New("send message no exist") // return errors.New("send message no exist")
// } // }
func (conn *NetConnection) readChunk() (msg *Chunk, err error) { func (nc *NetConnection) readChunk() (msg *Chunk, err error) {
head, err := conn.ReadByte() head, err := nc.ReadByte()
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn.readSeqNum++ nc.readSeqNum++
ChunkStreamID := uint32(head & 0x3f) // 0011 1111 ChunkStreamID := uint32(head & 0x3f) // 0011 1111
ChunkType := head >> 6 // 1100 0000 ChunkType := head >> 6 // 1100 0000
// 如果块流ID为0,1的话,就需要计算. // 如果块流ID为0,1的话,就需要计算.
ChunkStreamID, err = conn.readChunkStreamID(ChunkStreamID) ChunkStreamID, err = nc.readChunkStreamID(ChunkStreamID)
if err != nil { if err != nil {
return nil, errors.New("get chunk stream id error :" + err.Error()) return nil, errors.New("get chunk stream id error :" + err.Error())
} }
//println("ChunkStreamID:", ChunkStreamID, "ChunkType:", ChunkType) //println("ChunkStreamID:", ChunkStreamID, "ChunkType:", ChunkType)
chunk, ok := conn.incommingChunks[ChunkStreamID] chunk, ok := nc.incommingChunks[ChunkStreamID]
if ChunkType != 3 && ok && chunk.bufLen > 0 { if ChunkType != 3 && ok && chunk.bufLen > 0 {
// 如果块类型不为3,那么这个rtmp的body应该为空. // 如果块类型不为3,那么这个rtmp的body应该为空.
@@ -132,10 +135,10 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) {
} }
if !ok { if !ok {
chunk = &Chunk{} chunk = &Chunk{}
conn.incommingChunks[ChunkStreamID] = chunk nc.incommingChunks[ChunkStreamID] = chunk
} }
if err = conn.readChunkType(&chunk.ChunkHeader, ChunkType); err != nil { if err = nc.readChunkType(&chunk.ChunkHeader, ChunkType); err != nil {
return nil, errors.New("get chunk type error :" + err.Error()) return nil, errors.New("get chunk type error :" + err.Error())
} }
msgLen := int(chunk.MessageLength) msgLen := int(chunk.MessageLength)
@@ -143,19 +146,19 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) {
return nil, nil return nil, nil
} }
var bufSize = 0 var bufSize = 0
if unRead := msgLen - chunk.bufLen; unRead < conn.ReadChunkSize { if unRead := msgLen - chunk.bufLen; unRead < nc.ReadChunkSize {
bufSize = unRead bufSize = unRead
} else { } else {
bufSize = conn.ReadChunkSize bufSize = nc.ReadChunkSize
} }
conn.readSeqNum += uint32(bufSize) nc.readSeqNum += uint32(bufSize)
if chunk.bufLen == 0 { if chunk.bufLen == 0 {
chunk.AVData.RecyclableMemory = util.RecyclableMemory{} chunk.AVData.RecyclableMemory = util.RecyclableMemory{}
chunk.AVData.SetAllocator(conn.mediaDataPool.GetAllocator()) chunk.AVData.SetAllocator(nc.mediaDataPool.GetAllocator())
chunk.AVData.NextN(msgLen) chunk.AVData.NextN(msgLen)
} }
buffer := chunk.AVData.Buffers[0] buffer := chunk.AVData.Buffers[0]
err = conn.ReadRange(bufSize, func(buf []byte) { err = nc.ReadRange(bufSize, func(buf []byte) {
copy(buffer[chunk.bufLen:], buf) copy(buffer[chunk.bufLen:], buf)
chunk.bufLen += len(buf) chunk.bufLen += len(buf)
}) })
@@ -176,14 +179,14 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) {
return return
} }
func (conn *NetConnection) readChunkStreamID(csid uint32) (chunkStreamID uint32, err error) { func (nc *NetConnection) readChunkStreamID(csid uint32) (chunkStreamID uint32, err error) {
chunkStreamID = csid chunkStreamID = csid
switch csid { switch csid {
case 0: case 0:
{ {
u8, err := conn.ReadByte() u8, err := nc.ReadByte()
conn.readSeqNum++ nc.readSeqNum++
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -192,15 +195,15 @@ func (conn *NetConnection) readChunkStreamID(csid uint32) (chunkStreamID uint32,
} }
case 1: case 1:
{ {
u16_0, err1 := conn.ReadByte() u16_0, err1 := nc.ReadByte()
if err1 != nil { if err1 != nil {
return 0, err1 return 0, err1
} }
u16_1, err1 := conn.ReadByte() u16_1, err1 := nc.ReadByte()
if err1 != nil { if err1 != nil {
return 0, err1 return 0, err1
} }
conn.readSeqNum += 2 nc.readSeqNum += 2
chunkStreamID = 64 + uint32(u16_0) + (uint32(u16_1) << 8) chunkStreamID = 64 + uint32(u16_0) + (uint32(u16_1) << 8)
} }
} }
@@ -208,27 +211,27 @@ func (conn *NetConnection) readChunkStreamID(csid uint32) (chunkStreamID uint32,
return chunkStreamID, nil return chunkStreamID, nil
} }
func (conn *NetConnection) readChunkType(h *ChunkHeader, chunkType byte) (err error) { func (nc *NetConnection) readChunkType(h *ChunkHeader, chunkType byte) (err error) {
if chunkType == 3 { if chunkType == 3 {
// 3个字节的时间戳 // 3个字节的时间戳
} else { } else {
// Timestamp 3 bytes // Timestamp 3 bytes
if h.Timestamp, err = conn.ReadBE32(3); err != nil { if h.Timestamp, err = nc.ReadBE32(3); err != nil {
return err return err
} }
if chunkType != 2 { if chunkType != 2 {
if h.MessageLength, err = conn.ReadBE32(3); err != nil { if h.MessageLength, err = nc.ReadBE32(3); err != nil {
return err return err
} }
// Message Type ID 1 bytes // Message Type ID 1 bytes
if h.MessageTypeID, err = conn.ReadByte(); err != nil { if h.MessageTypeID, err = nc.ReadByte(); err != nil {
return err return err
} }
conn.readSeqNum++ nc.readSeqNum++
if chunkType == 0 { if chunkType == 0 {
// Message Stream ID 4bytes // Message Stream ID 4bytes
if h.MessageStreamID, err = conn.ReadBE32(4); err != nil { // 读取Message Stream ID if h.MessageStreamID, err = nc.ReadBE32(4); err != nil { // 读取Message Stream ID
return err return err
} }
} }
@@ -237,7 +240,7 @@ func (conn *NetConnection) readChunkType(h *ChunkHeader, chunkType byte) (err er
// ExtendTimestamp 4 bytes // ExtendTimestamp 4 bytes
if h.Timestamp >= 0xffffff { // 对于type 0的chunk,绝对时间戳在这里表示,如果时间戳值大于等于0xffffff(16777215),该值必须是0xffffff,且时间戳扩展字段必须发送,其他情况没有要求 if h.Timestamp >= 0xffffff { // 对于type 0的chunk,绝对时间戳在这里表示,如果时间戳值大于等于0xffffff(16777215),该值必须是0xffffff,且时间戳扩展字段必须发送,其他情况没有要求
if h.Timestamp, err = conn.ReadBE32(4); err != nil { if h.Timestamp, err = nc.ReadBE32(4); err != nil {
return err return err
} }
switch chunkType { switch chunkType {
@@ -258,75 +261,90 @@ func (conn *NetConnection) readChunkType(h *ChunkHeader, chunkType byte) (err er
return nil return nil
} }
func (conn *NetConnection) RecvMessage() (msg *Chunk, err error) { func (nc *NetConnection) RecvMessage() (msg *Chunk, err error) {
if conn.readSeqNum >= conn.bandwidth { if nc.readSeqNum >= nc.bandwidth {
conn.totalRead += conn.readSeqNum nc.totalRead += nc.readSeqNum
conn.readSeqNum = 0 nc.readSeqNum = 0
err = conn.SendMessage(RTMP_MSG_ACK, Uint32Message(conn.totalRead)) err = nc.SendMessage(RTMP_MSG_ACK, Uint32Message(nc.totalRead))
} }
for msg == nil && err == nil { for msg == nil && err == nil {
if msg, err = conn.readChunk(); msg != nil && err == nil { if msg, err = nc.readChunk(); msg != nil && err == nil {
switch msg.MessageTypeID { switch msg.MessageTypeID {
case RTMP_MSG_CHUNK_SIZE: case RTMP_MSG_CHUNK_SIZE:
conn.ReadChunkSize = int(msg.MsgData.(Uint32Message)) nc.ReadChunkSize = int(msg.MsgData.(Uint32Message))
nc.Info("msg read chunk size", "readChunkSize", nc.ReadChunkSize)
case RTMP_MSG_ABORT: case RTMP_MSG_ABORT:
delete(conn.incommingChunks, uint32(msg.MsgData.(Uint32Message))) delete(nc.incommingChunks, uint32(msg.MsgData.(Uint32Message)))
case RTMP_MSG_ACK, RTMP_MSG_EDGE: case RTMP_MSG_ACK, RTMP_MSG_EDGE:
case RTMP_MSG_USER_CONTROL: case RTMP_MSG_USER_CONTROL:
if _, ok := msg.MsgData.(*PingRequestMessage); ok { if _, ok := msg.MsgData.(*PingRequestMessage); ok {
conn.SendUserControl(RTMP_USER_PING_RESPONSE) nc.SendUserControl(RTMP_USER_PING_RESPONSE)
} }
case RTMP_MSG_ACK_SIZE: case RTMP_MSG_ACK_SIZE:
conn.bandwidth = uint32(msg.MsgData.(Uint32Message)) nc.bandwidth = uint32(msg.MsgData.(Uint32Message))
case RTMP_MSG_BANDWIDTH: case RTMP_MSG_BANDWIDTH:
conn.bandwidth = msg.MsgData.(*SetPeerBandwidthMessage).AcknowledgementWindowsize nc.bandwidth = msg.MsgData.(*SetPeerBandwidthMessage).AcknowledgementWindowsize
case RTMP_MSG_AMF0_COMMAND, RTMP_MSG_AUDIO, RTMP_MSG_VIDEO: case RTMP_MSG_AMF0_COMMAND:
return msg, err return msg, err
case RTMP_MSG_AUDIO:
if r, ok := nc.Receivers[msg.MessageStreamID]; ok && r.PubAudio {
err = r.WriteAudio(msg.AVData.WrapAudio())
} else {
msg.AVData.Recycle()
nc.Warn("ReceiveAudio", "MessageStreamID", msg.MessageStreamID)
}
case RTMP_MSG_VIDEO:
if r, ok := nc.Receivers[msg.MessageStreamID]; ok && r.PubVideo {
err = r.WriteVideo(msg.AVData.WrapVideo())
} else {
msg.AVData.Recycle()
nc.Warn("ReceiveVideo", "MessageStreamID", msg.MessageStreamID)
}
} }
} }
} }
return return
} }
func (conn *NetConnection) SendMessage(t byte, msg RtmpMessage) (err error) { func (nc *NetConnection) SendMessage(t byte, msg RtmpMessage) (err error) {
if conn == nil { if nc == nil {
return errors.New("connection is nil") return errors.New("connection is nil")
} }
if conn.writeSeqNum > conn.bandwidth { if nc.writeSeqNum > nc.bandwidth {
conn.totalWrite += conn.writeSeqNum nc.totalWrite += nc.writeSeqNum
conn.writeSeqNum = 0 nc.writeSeqNum = 0
err = conn.SendMessage(RTMP_MSG_ACK, Uint32Message(conn.totalWrite)) err = nc.SendMessage(RTMP_MSG_ACK, Uint32Message(nc.totalWrite))
err = conn.SendStreamID(RTMP_USER_PING_REQUEST, 0) err = nc.SendStreamID(RTMP_USER_PING_REQUEST, 0)
} }
for !conn.writing.CompareAndSwap(false, true) { for !nc.writing.CompareAndSwap(false, true) {
runtime.Gosched() runtime.Gosched()
} }
defer conn.writing.Store(false) defer nc.writing.Store(false)
conn.tmpBuf.Reset() nc.tmpBuf.Reset()
amf := AMF{conn.tmpBuf} amf := AMF{nc.tmpBuf}
if conn.ObjectEncoding == 0 { if nc.ObjectEncoding == 0 {
msg.Encode(&amf) msg.Encode(&amf)
} else { } else {
amf := AMF3{AMF: amf} amf := AMF3{AMF: amf}
msg.Encode(&amf) msg.Encode(&amf)
} }
conn.tmpBuf = amf.Buffer nc.tmpBuf = amf.Buffer
head := newChunkHeader(t) head := newChunkHeader(t)
head.MessageLength = uint32(conn.tmpBuf.Len()) head.MessageLength = uint32(nc.tmpBuf.Len())
if sid, ok := msg.(HaveStreamID); ok { if sid, ok := msg.(HaveStreamID); ok {
head.MessageStreamID = sid.GetStreamID() head.MessageStreamID = sid.GetStreamID()
} }
return conn.sendChunk(net.Buffers{conn.tmpBuf}, head, RTMP_CHUNK_HEAD_12) return nc.sendChunk(net.Buffers{nc.tmpBuf}, head, RTMP_CHUNK_HEAD_12)
} }
func (conn *NetConnection) sendChunk(data net.Buffers, head *ChunkHeader, headType byte) (err error) { func (nc *NetConnection) sendChunk(data net.Buffers, head *ChunkHeader, headType byte) (err error) {
conn.chunkHeaderBuf.Reset() nc.chunkHeaderBuf.Reset()
head.WriteTo(headType, &conn.chunkHeaderBuf) head.WriteTo(headType, &nc.chunkHeaderBuf)
chunks := net.Buffers{conn.chunkHeaderBuf} chunks := net.Buffers{nc.chunkHeaderBuf}
var chunk3 util.Buffer = conn.chunkHeaderBuf[conn.chunkHeaderBuf.Len():20] var chunk3 util.Buffer = nc.chunkHeaderBuf[nc.chunkHeaderBuf.Len():20]
head.WriteTo(RTMP_CHUNK_HEAD_1, &chunk3) head.WriteTo(RTMP_CHUNK_HEAD_1, &chunk3)
r := util.NewReadableBuffersFromBytes(data...) r := util.NewReadableBuffersFromBytes(data...)
for { for {
r.RangeN(conn.WriteChunkSize, func(buf []byte) { r.RangeN(nc.WriteChunkSize, func(buf []byte) {
chunks = append(chunks, buf) chunks = append(chunks, buf)
}) })
if r.Length <= 0 { if r.Length <= 0 {
@@ -336,7 +354,7 @@ func (conn *NetConnection) sendChunk(data net.Buffers, head *ChunkHeader, headTy
chunks = append(chunks, chunk3) chunks = append(chunks, chunk3)
} }
var nw int64 var nw int64
nw, err = chunks.WriteTo(conn.Conn) nw, err = chunks.WriteTo(nc.Conn)
conn.writeSeqNum += uint32(nw) nc.writeSeqNum += uint32(nw)
return err return err
} }

View File

@@ -1,5 +1,10 @@
package rtmp package rtmp
import (
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg"
)
type NetStream struct { type NetStream struct {
*NetConnection *NetConnection
StreamID uint32 StreamID uint32
@@ -62,3 +67,10 @@ func (ns *NetStream) BeginPlay(tid uint64) (err error) {
err = ns.Response(tid, NetStream_Play_Start, Level_Status) err = ns.Response(tid, NetStream_Play_Start, Level_Status)
return return
} }
func (ns *NetStream) Subscribe(suber *m7s.Subscriber) {
ns.AddCall(func(task *pkg.Task) error {
audio, video := ns.CreateSender(false)
return m7s.PlayBlock(suber, audio.HandleAudio, video.HandleVideo)
}, nil)
}

View File

@@ -43,15 +43,13 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) {
var err error var err error
nc := NewNetConnection(conn) nc := NewNetConnection(conn)
nc.Logger = logger nc.Logger = logger
p.AddTask(nc).WaitStarted()
defer func() { defer func() {
nc.Destroy() nc.Stop(err)
if p := recover(); p != nil { if p := recover(); p != nil {
err = p.(error) err = p.(error)
logger.Error(err.Error(), "stack", string(debug.Stack())) logger.Error(err.Error(), "stack", string(debug.Stack()))
} }
if receiver != nil {
receiver.Stop(err)
}
}() }()
var req *util.Request var req *util.Request
var sendMode bool var sendMode bool
@@ -106,8 +104,9 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) {
return return
} }
receiver = &Receiver{} receiver = &Receiver{
receiver.NetConnection = nc Stream: &Stream{NetConnection: nc},
}
if receiver.Publisher, err = p.Publish(nc, strings.TrimPrefix(nc.URL.Path, "/")); err != nil { if receiver.Publisher, err = p.Publish(nc, strings.TrimPrefix(nc.URL.Path, "/")); err != nil {
receiver = nil receiver = nil
err = nc.WriteResponse(&util.Response{ err = nc.WriteResponse(&util.Response{
@@ -122,11 +121,14 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) {
if err = nc.WriteResponse(res); err != nil { if err = nc.WriteResponse(res); err != nil {
return return
} }
receiver.Publisher.OnDispose(func() {
nc.Stop(receiver.Publisher.StopReason())
})
case MethodDescribe: case MethodDescribe:
sendMode = true sendMode = true
sender = &Sender{} sender = &Sender{
sender.NetConnection = nc Stream: &Stream{NetConnection: nc},
}
sender.Subscriber, err = p.Subscribe(nc, strings.TrimPrefix(nc.URL.Path, "/")) sender.Subscriber, err = p.Subscribe(nc, strings.TrimPrefix(nc.URL.Path, "/"))
if err != nil { if err != nil {
res := &util.Response{ res := &util.Response{

View File

@@ -64,7 +64,7 @@ func (c *NetConnection) StopWrite() {
c.writing.Store(false) c.writing.Store(false)
} }
func (c *NetConnection) Destroy() { func (c *NetConnection) Dispose() {
c.conn.Close() c.conn.Close()
c.BufReader.Recycle() c.BufReader.Recycle()
c.MemoryAllocator.Recycle() c.MemoryAllocator.Recycle()

View File

@@ -238,6 +238,6 @@ func (c *Stream) Teardown() (err error) {
func (ns *Stream) disconnect() { func (ns *Stream) disconnect() {
if ns != nil && ns.NetConnection != nil { if ns != nil && ns.NetConnection != nil {
_ = ns.Teardown() _ = ns.Teardown()
ns.NetConnection.Destroy() ns.NetConnection.Dispose()
} }
} }