diff --git a/flv/http_flv.go b/flv/http_flv.go index 2f39323..0d48abe 100644 --- a/flv/http_flv.go +++ b/flv/http_flv.go @@ -73,7 +73,7 @@ func (t *httpTransStream) Input(packet utils.AVPacket) error { var separatorSize int //新的合并写切片, 预留包长字节 - if t.mwBuffer.IsCompeted() { + if t.mwBuffer.IsCompleted() { separatorSize = HttpFlvBlockLengthSize //10字节描述flv包长, 前2个字节描述无效字节长度 n = HttpFlvBlockLengthSize @@ -186,8 +186,6 @@ func (t *httpTransStream) writeSeparator(dst []byte) { } func (t *httpTransStream) WriteHeader() error { - t.Init() - t.headerSize += t.muxer.WriteHeader(t.header[HttpFlvBlockLengthSize:]) for _, track := range t.BaseTransStream.Tracks { diff --git a/gb28181/filter.go b/gb28181/filter.go index d2ae5a8..a6d9a90 100644 --- a/gb28181/filter.go +++ b/gb28181/filter.go @@ -1,56 +1,9 @@ package gb28181 -import ( - "github.com/pion/rtp" - "github.com/yangjiechina/avformat/utils" - "github.com/yangjiechina/lkm/log" - "github.com/yangjiechina/lkm/stream" - "net" -) - type Filter interface { AddSource(ssrc uint32, source GBSource) bool RemoveSource(ssrc uint32) - Input(conn net.Conn, data []byte) GBSource - - ParseRtpPacket(conn net.Conn, data []byte) (*rtp.Packet, error) - - PreparePublishSource(conn net.Conn, ssrc uint32, source GBSource) -} - -type BaseFilter struct { -} - -func (r BaseFilter) ParseRtpPacket(conn net.Conn, data []byte) (*rtp.Packet, error) { - packet := rtp.Packet{} - err := packet.Unmarshal(data) - - if err != nil { - log.Sugar.Errorf("解析rtp失败 err:%s conn:%s", err.Error(), conn.RemoteAddr().String()) - return nil, err - } - - return &packet, err -} - -func (r BaseFilter) PreparePublishSource(conn net.Conn, ssrc uint32, source GBSource) { - source.SetConn(conn) - source.SetSSRC(ssrc) - - source.SetState(stream.SessionStateTransferring) - - if stream.AppConfig.Hook.EnablePublishEvent() { - go func() { - _, state := stream.HookPublishEvent(source) - if utils.HookStateOK != state { - log.Sugar.Errorf("GB28181 推流失败") - - if conn != nil { - conn.Close() - } - } - }() - } + FindSource(ssrc uint32) GBSource } diff --git a/gb28181/filter_single.go b/gb28181/filter_single.go index e57ae4d..dd3f0f1 100644 --- a/gb28181/filter_single.go +++ b/gb28181/filter_single.go @@ -1,42 +1,21 @@ package gb28181 -import ( - "github.com/yangjiechina/lkm/stream" - "net" -) - -type SingleFilter struct { - BaseFilter - +type singleFilter struct { source GBSource } -func NewSingleFilter(source GBSource) *SingleFilter { - return &SingleFilter{source: source} +func NewSingleFilter(source GBSource) Filter { + return &singleFilter{source: source} } -func (s *SingleFilter) AddSource(ssrc uint32, source GBSource) bool { +func (s *singleFilter) AddSource(ssrc uint32, source GBSource) bool { panic("implement me") } -func (s *SingleFilter) RemoveSource(ssrc uint32) { +func (s *singleFilter) RemoveSource(ssrc uint32) { panic("implement me") } -func (s *SingleFilter) Input(conn net.Conn, data []byte) GBSource { - packet, err := s.ParseRtpPacket(conn, data) - if err != nil { - return nil - } - - if s.source == nil { - return nil - } - - if stream.SessionStateHandshakeDone == s.source.State() { - s.PreparePublishSource(conn, packet.SSRC, s.source) - } - - s.source.InputRtp(packet) +func (s *singleFilter) FindSource(ssrc uint32) GBSource { return s.source } diff --git a/gb28181/filter_ssrc.go b/gb28181/filter_ssrc.go index fb3c22c..bdb808a 100644 --- a/gb28181/filter_ssrc.go +++ b/gb28181/filter_ssrc.go @@ -1,25 +1,21 @@ package gb28181 import ( - "github.com/yangjiechina/lkm/stream" - "net" "sync" ) -type SSRCFilter struct { - BaseFilter - +type ssrcFilter struct { sources map[uint32]GBSource mute sync.RWMutex } -func NewSharedFilter(guestCount int) *SSRCFilter { - return &SSRCFilter{sources: make(map[uint32]GBSource, guestCount)} +func NewSharedFilter(guestCount int) Filter { + return &ssrcFilter{sources: make(map[uint32]GBSource, guestCount)} } -func (r SSRCFilter) AddSource(ssrc uint32, source GBSource) bool { +func (r *ssrcFilter) AddSource(ssrc uint32, source GBSource) bool { r.mute.Lock() - defer r.mute.Lock() + defer r.mute.Unlock() if _, ok := r.sources[ssrc]; !ok { r.sources[ssrc] = source @@ -29,34 +25,14 @@ func (r SSRCFilter) AddSource(ssrc uint32, source GBSource) bool { return false } -func (r SSRCFilter) RemoveSource(ssrc uint32) { +func (r *ssrcFilter) RemoveSource(ssrc uint32) { r.mute.Lock() - defer r.mute.Lock() + defer r.mute.Unlock() delete(r.sources, ssrc) } -func (r SSRCFilter) Input(conn net.Conn, data []byte) GBSource { - packet, err := r.ParseRtpPacket(conn, data) - if err != nil { - return nil - } - - var source GBSource - var ok bool - { - r.mute.RLock() - source, ok = r.sources[packet.SSRC] - r.mute.RUnlock() - } - - if !ok { - return nil - } - - if stream.SessionStateHandshakeDone == source.State() { - r.PreparePublishSource(conn, packet.SSRC, source) - } - - source.InputRtp(packet) - return source +func (r *ssrcFilter) FindSource(ssrc uint32) GBSource { + r.mute.RLock() + defer r.mute.RUnlock() + return r.sources[ssrc] } diff --git a/gb28181/source.go b/gb28181/source.go index adef263..c5fb874 100644 --- a/gb28181/source.go +++ b/gb28181/source.go @@ -42,6 +42,8 @@ type GBSource interface { PrepareTransDeMuxer(id string, ssrc uint32) + PreparePublishSource(conn net.Conn, ssrc uint32, source GBSource) + SetConn(conn net.Conn) SetSSRC(ssrc uint32) @@ -54,8 +56,9 @@ type BaseGBSource struct { audioStream utils.AVStream videoStream utils.AVStream - ssrc uint32 - transport transport.ITransport + ssrc uint32 + transport transport.ITransport + receiveBuffer *stream.ReceiveBuffer } func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, uint16, error) { @@ -144,8 +147,15 @@ func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, uint1 return nil, 0, fmt.Errorf("error code %d", state) } + var bufferBlockCount int + if active || tcp { + bufferBlockCount = stream.ReceiveBufferTCPBlockCount + } else { + bufferBlockCount = stream.ReceiveBufferUdpBlockCount + } + source.SetType(stream.SourceType28181) - source.Init(source.Input, source.Close) + source.Init(source.Input, source.Close, bufferBlockCount) go source.LoopEvent() return source, port, err } @@ -334,3 +344,31 @@ func (source *BaseGBSource) SetConn(conn net.Conn) { func (source *BaseGBSource) SetSSRC(ssrc uint32) { source.ssrc = ssrc } + +func (source *BaseGBSource) SetReceiveBuffer(buffer *stream.ReceiveBuffer) { + source.receiveBuffer = buffer +} + +func (source *BaseGBSource) ReceiveBuffer() *stream.ReceiveBuffer { + return source.receiveBuffer +} + +func (source *BaseGBSource) PreparePublishSource(conn net.Conn, ssrc uint32, source_ GBSource) { + source.SetConn(conn) + source.SetSSRC(ssrc) + + source.SetState(stream.SessionStateTransferring) + + if stream.AppConfig.Hook.EnablePublishEvent() { + go func() { + _, state := stream.HookPublishEvent(source_) + if utils.HookStateOK != state { + log.Sugar.Errorf("GB28181 推流失败 source:%s", source.Id()) + + if conn != nil { + conn.Close() + } + } + }() + } +} diff --git a/gb28181/source_passive.go b/gb28181/source_passive.go index a5829a6..fd40bc3 100644 --- a/gb28181/source_passive.go +++ b/gb28181/source_passive.go @@ -17,6 +17,7 @@ func (t PassiveSource) TransportType() TransportType { } func (t PassiveSource) InputRtp(pkt *rtp.Packet) error { - t.PublishSource.Input(pkt.Payload) + //TCP收流, 解析rtp后直接送给ps解析 + t.Input(pkt.Payload) return nil } diff --git a/gb28181/source_udp.go b/gb28181/source_udp.go index b236b1c..5cabce0 100644 --- a/gb28181/source_udp.go +++ b/gb28181/source_udp.go @@ -1,24 +1,23 @@ package gb28181 import ( - "fmt" "github.com/pion/rtp" "github.com/yangjiechina/lkm/jitterbuffer" "github.com/yangjiechina/lkm/stream" ) +// UDPSource GB28181 UDP推流源 type UDPSource struct { BaseGBSource - rtpDeMuxer *jitterbuffer.JitterBuffer - - rtpBuffer stream.MemoryPool + jitterBuffer *jitterbuffer.JitterBuffer + receiveBuffer *stream.ReceiveBuffer } func NewUDPSource() *UDPSource { return &UDPSource{ - rtpDeMuxer: jitterbuffer.New(), - rtpBuffer: stream.NewDirectMemoryPool(JitterBufferSize), + jitterBuffer: jitterbuffer.New(), + receiveBuffer: stream.NewReceiveBuffer(1500, stream.ReceiveBufferUdpBlockCount+50), } } @@ -26,25 +25,20 @@ func (u UDPSource) TransportType() TransportType { return TransportTypeUDP } +// InputRtp UDP收流会先拷贝rtp包,交给jitter buffer处理后再发给source func (u UDPSource) InputRtp(pkt *rtp.Packet) error { - n := u.rtpBuffer.Capacity() - u.rtpBuffer.Size() - if n < len(pkt.Payload) { - return fmt.Errorf("RTP receive buffer overflow") - } + block := u.receiveBuffer.GetBlock() - allocate := u.rtpBuffer.Allocate(len(pkt.Payload)) - copy(allocate, pkt.Payload) - pkt.Payload = allocate - u.rtpDeMuxer.Push(pkt) + copy(block, pkt.Payload) + pkt.Payload = block[:len(pkt.Payload)] + u.jitterBuffer.Push(pkt) for { - pkt, _ := u.rtpDeMuxer.Pop() + pkt, _ := u.jitterBuffer.Pop() if pkt == nil { return nil } - u.rtpBuffer.FreeHead() - u.PublishSource.Input(pkt.Payload) } } diff --git a/gb28181/tcp_server.go b/gb28181/tcp_server.go index 09d9831..3491213 100644 --- a/gb28181/tcp_server.go +++ b/gb28181/tcp_server.go @@ -1,8 +1,10 @@ package gb28181 import ( + "github.com/pion/rtp" "github.com/yangjiechina/avformat/transport" "github.com/yangjiechina/lkm/log" + "github.com/yangjiechina/lkm/stream" "net" ) @@ -12,8 +14,9 @@ type TCPServer struct { } type TCPSession struct { - source GBSource - decoder *transport.LengthFieldFrameDecoder + source GBSource + decoder *transport.LengthFieldFrameDecoder + receiveBuffer *stream.ReceiveBuffer } func NewTCPServer(addr net.Addr, filter Filter) (*TCPServer, error) { @@ -31,33 +34,80 @@ func NewTCPServer(addr net.Addr, filter Filter) (*TCPServer, error) { return server, nil } -func (T *TCPServer) OnConnected(conn net.Conn) { +func (T *TCPServer) OnConnected(conn net.Conn) []byte { log.Sugar.Infof("GB28181连接 conn:%s", conn.RemoteAddr().String()) -} -func (T *TCPServer) OnPacket(conn net.Conn, data []byte) { con := conn.(*transport.Conn) - if con.Data == nil { - session := &TCPSession{} - session.decoder = transport.NewLengthFieldFrameDecoder(0xFFFF, 2, func(bytes []byte) { - source := T.filter.Input(con, bytes[2:]) - if source != nil && session.source == nil { - session.source = source - } - }) - - con.Data = session + session := &TCPSession{} + if stream.AppConfig.GB28181.IsMultiPort() { + session.source = T.filter.(*singleFilter).source + session.source.SetConn(con) + session.receiveBuffer = stream.NewTCPReceiveBuffer() } - con.Data.(*TCPSession).decoder.Input(data) + session.decoder = transport.NewLengthFieldFrameDecoder(0xFFFF, 2, func(bytes []byte) { + packet := rtp.Packet{} + err := packet.Unmarshal(bytes) + if err != nil { + log.Sugar.Errorf("解析rtp失败 err:%s conn:%s", err.Error(), conn.RemoteAddr().String()) + return + } + + //单端口模式,ssrc匹配source + if session.source == nil { + //匹配不到直接关闭链接 + source := T.filter.FindSource(packet.SSRC) + if source == nil { + conn.Close() + return + } + + session.source = source + session.receiveBuffer = stream.NewTCPReceiveBuffer() + session.source.SetConn(con) + + //直接丢给ps解析器, 虽然是非线程安全, 但是是阻塞执行的, 不会和后续走loop event的包冲突 + session.source.InputRtp(&packet) + } + + if stream.SessionStateHandshakeDone == session.source.State() { + session.source.PreparePublishSource(conn, packet.SSRC, session.source) + } + + session.source.InputRtp(&packet) + }) + + con.Data = session + + if session.source != nil { + return session.receiveBuffer.GetBlock() + } + + return nil +} + +func (T *TCPServer) OnPacket(conn net.Conn, data []byte) []byte { + session := conn.(*transport.Conn).Data.(*TCPSession) + + //单端口收流 + if session.source == nil { + //直接传给解码器, 先根据ssrc找到source. 后续还是会直接传给source + if err := session.decoder.Input(data); err != nil { + conn.Close() + } + } else { + session.source.Input(data) + } + + return session.receiveBuffer.GetBlock() } func (T *TCPServer) OnDisConnected(conn net.Conn, err error) { log.Sugar.Infof("GB28181断开连接 conn:%s", conn.RemoteAddr().String()) con := conn.(*transport.Conn) - if con.Data != nil { + if con.Data != nil && con.Data.(*TCPSession).source != nil { con.Data.(*TCPSession).source.Close() - con.Data = nil } + con.Data = nil } diff --git a/gb28181/udp_server.go b/gb28181/udp_server.go index c0918ad..7d9b203 100644 --- a/gb28181/udp_server.go +++ b/gb28181/udp_server.go @@ -1,12 +1,15 @@ package gb28181 import ( + "github.com/pion/rtp" "github.com/yangjiechina/avformat/transport" + "github.com/yangjiechina/lkm/log" + "github.com/yangjiechina/lkm/stream" "net" ) type UDPServer struct { - udp *transport.UDPTransport + udp *transport.UDPServer filter Filter } @@ -24,12 +27,31 @@ func NewUDPServer(addr net.Addr, filter Filter) (*UDPServer, error) { return server, nil } -func (U UDPServer) OnConnected(conn net.Conn) { - +func (U UDPServer) OnConnected(conn net.Conn) []byte { + return nil } -func (U UDPServer) OnPacket(conn net.Conn, data []byte) { - U.filter.Input(conn, data) +func (U UDPServer) OnPacket(conn net.Conn, data []byte) []byte { + packet := rtp.Packet{} + err := packet.Unmarshal(data) + + if err != nil { + log.Sugar.Errorf("解析rtp失败 err:%s conn:%s", err.Error(), conn.RemoteAddr().String()) + return nil + } + + source := U.filter.FindSource(packet.SSRC) + if source == nil { + log.Sugar.Errorf("ssrc匹配source失败 ssrc:%x conn:%s", packet.SSRC, conn.RemoteAddr().String()) + return nil + } + + if stream.SessionStateHandshakeDone == source.State() { + source.PreparePublishSource(conn, packet.SSRC, source) + } + + source.InputRtp(&packet) + return nil } func (U UDPServer) OnDisConnected(conn net.Conn, err error) { diff --git a/hls/hls_stream.go b/hls/hls_stream.go index a372316..5bf83d9 100644 --- a/hls/hls_stream.go +++ b/hls/hls_stream.go @@ -36,59 +36,6 @@ type transStream struct { m3u8Sinks map[stream.SinkId]stream.Sink } -// NewTransStream 创建HLS传输流 -// @url url前缀 -// @m3u8Name m3u8文件名 -// @tsFormat ts文件格式, 例如: %d.ts -// @parentDir 保存切片的绝对路径. mu38和ts切片放在同一目录下, 目录地址使用parentDir+urlPrefix -// @segmentDuration 单个切片时长 -// @playlistLength 缓存多少个切片 -func NewTransStream(url, m3u8Name, tsFormat, dir string, segmentDuration, playlistLength int) (stream.TransStream, error) { - //创建文件夹 - if err := os.MkdirAll(dir, 0666); err != nil { - return nil, err - } - - //创建m3u8文件 - m3u8Path := fmt.Sprintf("%s/%s", dir, m3u8Name) - file, err := os.OpenFile(m3u8Path, os.O_CREATE|os.O_RDWR, 0666) - if err != nil { - return nil, err - } - - stream_ := &transStream{ - url: url, - m3u8Name: m3u8Name, - tsFormat: tsFormat, - dir: dir, - duration: segmentDuration, - playlistLength: playlistLength, - } - - //创建TS封装器 - muxer := libmpeg.NewTSMuxer() - muxer.SetWriteHandler(stream_.onTSWrite) - muxer.SetAllocHandler(stream_.onTSAlloc) - - stream_.context = &tsContext{ - segmentSeq: 0, - writeBuffer: make([]byte, 1024*1024), - writeBufferSize: 0, - } - - stream_.muxer = muxer - stream_.m3u8 = NewM3U8Writer(playlistLength) - stream_.m3u8File = file - - stream_.m3u8Sinks = make(map[stream.SinkId]stream.Sink, 24) - return stream_, nil -} - -func TransStreamFactory(source stream.Source, protocol stream.Protocol, streams []utils.AVStream) (stream.TransStream, error) { - id := source.Id() - return NewTransStream("", stream.AppConfig.Hls.M3U8Format(id), stream.AppConfig.Hls.TSFormat(id, "%d"), stream.AppConfig.Hls.Dir, stream.AppConfig.Hls.Duration, stream.AppConfig.Hls.PlaylistLength) -} - func (t *transStream) Input(packet utils.AVPacket) error { if packet.Index() >= t.muxer.TrackCount() { return fmt.Errorf("track not available") @@ -137,8 +84,6 @@ func (t *transStream) AddTrack(stream utils.AVStream) error { } func (t *transStream) WriteHeader() error { - t.Init() - return t.createSegment() } @@ -246,3 +191,56 @@ func (t *transStream) Close() error { return err } + +// NewTransStream 创建HLS传输流 +// @url url前缀 +// @m3u8Name m3u8文件名 +// @tsFormat ts文件格式, 例如: %d.ts +// @parentDir 保存切片的绝对路径. mu38和ts切片放在同一目录下, 目录地址使用parentDir+urlPrefix +// @segmentDuration 单个切片时长 +// @playlistLength 缓存多少个切片 +func NewTransStream(url, m3u8Name, tsFormat, dir string, segmentDuration, playlistLength int) (stream.TransStream, error) { + //创建文件夹 + if err := os.MkdirAll(dir, 0666); err != nil { + return nil, err + } + + //创建m3u8文件 + m3u8Path := fmt.Sprintf("%s/%s", dir, m3u8Name) + file, err := os.OpenFile(m3u8Path, os.O_CREATE|os.O_RDWR, 0666) + if err != nil { + return nil, err + } + + stream_ := &transStream{ + url: url, + m3u8Name: m3u8Name, + tsFormat: tsFormat, + dir: dir, + duration: segmentDuration, + playlistLength: playlistLength, + } + + //创建TS封装器 + muxer := libmpeg.NewTSMuxer() + muxer.SetWriteHandler(stream_.onTSWrite) + muxer.SetAllocHandler(stream_.onTSAlloc) + + stream_.context = &tsContext{ + segmentSeq: 0, + writeBuffer: make([]byte, 1024*1024), + writeBufferSize: 0, + } + + stream_.muxer = muxer + stream_.m3u8 = NewM3U8Writer(playlistLength) + stream_.m3u8File = file + + stream_.m3u8Sinks = make(map[stream.SinkId]stream.Sink, 24) + return stream_, nil +} + +func TransStreamFactory(source stream.Source, protocol stream.Protocol, streams []utils.AVStream) (stream.TransStream, error) { + id := source.Id() + return NewTransStream("", stream.AppConfig.Hls.M3U8Format(id), stream.AppConfig.Hls.TSFormat(id, "%d"), stream.AppConfig.Hls.Dir, stream.AppConfig.Hls.Duration, stream.AppConfig.Hls.PlaylistLength) +} diff --git a/jt1078/jt_server.go b/jt1078/jt_server.go index 7b59ea8..ad137c9 100644 --- a/jt1078/jt_server.go +++ b/jt1078/jt_server.go @@ -21,15 +21,18 @@ func NewServer() Server { return &jtServer{} } -func (s jtServer) OnConnected(conn net.Conn) { +func (s jtServer) OnConnected(conn net.Conn) []byte { log.Sugar.Debugf("jtserver连接 conn:%s", conn.RemoteAddr().String()) t := conn.(*transport.Conn) t.Data = NewSession(conn) + + return t.Data.(*Session).receiveBuffer.GetBlock() } -func (s jtServer) OnPacket(conn net.Conn, data []byte) { - conn.(*transport.Conn).Data.(*Session).Input(data) +func (s jtServer) OnPacket(conn net.Conn, data []byte) []byte { + conn.(*transport.Conn).Data.(*Session).PublishSource.Input(data) + return conn.(*transport.Conn).Data.(*Session).receiveBuffer.GetBlock() } func (s jtServer) OnDisConnected(conn net.Conn, err error) { diff --git a/jt1078/jt_session.go b/jt1078/jt_session.go index 467ebaa..23f8f0b 100644 --- a/jt1078/jt_session.go +++ b/jt1078/jt_session.go @@ -37,13 +37,14 @@ type Session struct { phone string decoder *transport.DelimiterFrameDecoder - audioIndex int - videoIndex int - audioStream utils.AVStream - videoStream utils.AVStream - audioBuffer stream.MemoryPool - videoBuffer stream.MemoryPool - rtpPacket *RtpPacket + audioIndex int + videoIndex int + audioStream utils.AVStream + videoStream utils.AVStream + audioBuffer stream.MemoryPool + videoBuffer stream.MemoryPool + rtpPacket *RtpPacket + receiveBuffer *stream.ReceiveBuffer } type RtpPacket struct { @@ -298,8 +299,9 @@ func NewSession(conn net.Conn) *Session { } delimiter := [4]byte{0x30, 0x31, 0x63, 0x64} session.decoder = transport.NewDelimiterFrameDecoder(1024*1024*2, delimiter[:], session.OnJtPTPPacket) + session.receiveBuffer = stream.NewTCPReceiveBuffer() - session.Init(session.Input, session.Close) + session.Init(session.Input, session.Close, stream.ReceiveBufferTCPBlockCount) go session.LoopEvent() return &session } diff --git a/rtc/rtc_stream.go b/rtc/rtc_stream.go index e787270..55199b1 100644 --- a/rtc/rtc_stream.go +++ b/rtc/rtc_stream.go @@ -12,7 +12,6 @@ type transStream struct { func NewTransStream() stream.TransStream { t := &transStream{} - t.Init() return t } diff --git a/rtmp/rtmp_publisher.go b/rtmp/rtmp_publisher.go index 606da0c..a8cdaed 100644 --- a/rtmp/rtmp_publisher.go +++ b/rtmp/rtmp_publisher.go @@ -3,7 +3,6 @@ package rtmp import ( "github.com/yangjiechina/avformat/libflv" "github.com/yangjiechina/avformat/librtmp" - "github.com/yangjiechina/avformat/transport" "github.com/yangjiechina/avformat/utils" "github.com/yangjiechina/lkm/stream" "net" @@ -22,7 +21,7 @@ func NewPublisher(sourceId string, stack *librtmp.Stack, conn net.Conn) *Publish //设置回调,从flv解析出来的Stream和AVPacket都将统一回调到stream.PublishSource deMuxer.SetHandler(publisher_) //为推流方分配足够多的缓冲区 - conn.(*transport.Conn).ReallocateRecvBuffer(1024 * 1024) + //conn.(*transport.Conn).ReallocateRecvBuffer(1024 * 1024) return publisher_ } diff --git a/rtmp/rtmp_server.go b/rtmp/rtmp_server.go index 8506cd0..f545d5b 100644 --- a/rtmp/rtmp_server.go +++ b/rtmp/rtmp_server.go @@ -41,16 +41,20 @@ func (s *server) Close() { panic("implement me") } -func (s *server) OnConnected(conn net.Conn) { +func (s *server) OnConnected(conn net.Conn) []byte { log.Sugar.Debugf("rtmp连接 conn:%s", conn.RemoteAddr().String()) t := conn.(*transport.Conn) t.Data = NewSession(conn) + return nil } -func (s *server) OnPacket(conn net.Conn, data []byte) { +func (s *server) OnPacket(conn net.Conn, data []byte) []byte { + log.Sugar.Infof("rtmp包大小:%d", len(data)) + t := conn.(*transport.Conn) - err := t.Data.(*Session).Input(conn, data) + session := t.Data.(*Session) + err := session.Input(conn, data) if err != nil { log.Sugar.Errorf("处理rtmp包失败 err:%s conn:%s", err.Error(), conn.RemoteAddr().String()) @@ -59,6 +63,12 @@ func (s *server) OnPacket(conn net.Conn, data []byte) { t.Data.(*Session).Close() t.Data = nil } + + if session.isPublisher { + return session.receiveBuffer.GetBlock() + } + + return nil } func (s *server) OnDisConnected(conn net.Conn, err error) { diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index 1eb6c9b..d17986c 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -16,7 +16,8 @@ type Session struct { handle interface{} isPublisher bool - conn net.Conn + conn net.Conn + receiveBuffer *stream.ReceiveBuffer } func (s *Session) generateSourceId(app, stream_ string) string { @@ -37,6 +38,9 @@ func (s *Session) OnPublish(app, stream_ string, response chan utils.HookState) //设置推流的音视频回调 s.stack.SetOnPublishHandler(source) + //初始化放在add source前面, 以防add-init空窗期, 拉流队列空指针. + source.Init(source.Input, source.Close, stream.ReceiveBufferTCPBlockCount) + //推流事件Source统一处理, 是否已经存在, Hook回调.... _, state := stream.PreparePublishSource(source, true) if utils.HookStateOK != state { @@ -44,8 +48,8 @@ func (s *Session) OnPublish(app, stream_ string, response chan utils.HookState) } else { s.handle = source s.isPublisher = true + s.receiveBuffer = stream.NewTCPReceiveBuffer() - source.Init(source.Input, source.Close) go source.LoopEvent() } diff --git a/rtmp/rtmp_stream.go b/rtmp/rtmp_stream.go index f887e42..d1e246f 100644 --- a/rtmp/rtmp_stream.go +++ b/rtmp/rtmp_stream.go @@ -7,7 +7,7 @@ import ( "github.com/yangjiechina/lkm/stream" ) -type TransStream struct { +type transStream struct { stream.BaseTransStream chunkSize int @@ -21,7 +21,7 @@ type TransStream struct { mwBuffer stream.MergeWritingBuffer } -func (t *TransStream) Input(packet utils.AVPacket) error { +func (t *transStream) Input(packet utils.AVPacket) error { utils.Assert(t.BaseTransStream.Completed) var data []byte @@ -96,7 +96,7 @@ func (t *TransStream) Input(packet utils.AVPacket) error { return nil } -func (t *TransStream) AddSink(sink stream.Sink) error { +func (t *transStream) AddSink(sink stream.Sink) error { utils.Assert(t.headerSize > 0) t.BaseTransStream.AddSink(sink) @@ -113,12 +113,10 @@ func (t *TransStream) AddSink(sink stream.Sink) error { return nil } -func (t *TransStream) WriteHeader() error { +func (t *transStream) WriteHeader() error { utils.Assert(t.Tracks != nil) utils.Assert(!t.BaseTransStream.Completed) - t.Init() - var audioStream utils.AVStream var videoStream utils.AVStream var audioCodecId utils.AVCodecID @@ -181,7 +179,7 @@ func (t *TransStream) WriteHeader() error { return nil } -func (t *TransStream) Close() error { +func (t *transStream) Close() error { //发送剩余的流 segment := t.mwBuffer.PopSegment() if len(segment) > 0 { @@ -191,7 +189,7 @@ func (t *TransStream) Close() error { } func NewTransStream(chunkSize int) stream.TransStream { - return &TransStream{chunkSize: chunkSize} + return &transStream{chunkSize: chunkSize} } func TransStreamFactory(source stream.Source, protocol stream.Protocol, streams []utils.AVStream) (stream.TransStream, error) { diff --git a/rtsp/rtsp_server.go b/rtsp/rtsp_server.go index 147188c..8150ac8 100644 --- a/rtsp/rtsp_server.go +++ b/rtsp/rtsp_server.go @@ -51,21 +51,22 @@ func (s *server) Close() { } -func (s *server) OnConnected(conn net.Conn) { +func (s *server) OnConnected(conn net.Conn) []byte { log.Sugar.Debugf("rtsp连接 conn:%s", conn.RemoteAddr().String()) t := conn.(*transport.Conn) t.Data = NewSession(conn) + return nil } -func (s *server) OnPacket(conn net.Conn, data []byte) { +func (s *server) OnPacket(conn net.Conn, data []byte) []byte { t := conn.(*transport.Conn) method, url, header, err := parseMessage(data) if err != nil { log.Sugar.Errorf("failed to prase message:%s. err:%s conn:%s", string(data), err.Error(), conn.RemoteAddr().String()) _ = conn.Close() - return + return nil } err = s.handler.Process(t.Data.(*session), method, url, header) @@ -73,6 +74,9 @@ func (s *server) OnPacket(conn net.Conn, data []byte) { log.Sugar.Errorf("failed to process message of RTSP. err:%s conn:%s msg:%s", err.Error(), conn.RemoteAddr().String(), string(data)) _ = conn.Close() } + + //后续实现rtsp推流, 需要返回收流buffer + return nil } func (s *server) OnDisConnected(conn net.Conn, err error) { diff --git a/rtsp/rtsp_stream.go b/rtsp/rtsp_stream.go index 28a8013..02cadef 100644 --- a/rtsp/rtsp_stream.go +++ b/rtsp/rtsp_stream.go @@ -30,30 +30,6 @@ type tranStream struct { sdp string } -func NewTransStream(addr net.IPAddr, urlFormat string) stream.TransStream { - t := &tranStream{ - addr: addr, - urlFormat: urlFormat, - } - - if addr.IP.To4() != nil { - t.addrType = "IP4" - } else { - t.addrType = "IP6" - } - - t.Init() - return t -} - -func TransStreamFactory(source stream.Source, protocol stream.Protocol, streams []utils.AVStream) (stream.TransStream, error) { - trackFormat := source.Id() + "?track=%d" - return NewTransStream(net.IPAddr{ - IP: net.ParseIP(stream.AppConfig.PublicIP), - Zone: "", - }, trackFormat), nil -} - // rtpMuxer申请输出流内存的回调 // 无论是tcp/udp拉流, 均使用同一块内存, 并且给tcp预留4字节的包长. func (t *tranStream) onAllocBuffer(params interface{}) []byte { @@ -274,3 +250,26 @@ func (t *tranStream) WriteHeader() error { t.sdp = string(marshal) return nil } + +func NewTransStream(addr net.IPAddr, urlFormat string) stream.TransStream { + t := &tranStream{ + addr: addr, + urlFormat: urlFormat, + } + + if addr.IP.To4() != nil { + t.addrType = "IP4" + } else { + t.addrType = "IP6" + } + + return t +} + +func TransStreamFactory(source stream.Source, protocol stream.Protocol, streams []utils.AVStream) (stream.TransStream, error) { + trackFormat := source.Id() + "?track=%d" + return NewTransStream(net.IPAddr{ + IP: net.ParseIP(stream.AppConfig.PublicIP), + Zone: "", + }, trackFormat), nil +} diff --git a/stream/mw_buffer.go b/stream/mw_buffer.go index c87d059..034f84c 100644 --- a/stream/mw_buffer.go +++ b/stream/mw_buffer.go @@ -17,7 +17,7 @@ type MergeWritingBuffer interface { IsFull(ts int64) bool - IsCompeted() bool + IsCompleted() bool IsEmpty() bool @@ -72,7 +72,7 @@ func (m *mergeWritingBuffer) IsFull(ts int64) bool { return int(ts-m.prePacketTS) >= AppConfig.MergeWriteLatency } -func (m *mergeWritingBuffer) IsCompeted() bool { +func (m *mergeWritingBuffer) IsCompleted() bool { data, _ := m.transStreamBuffer.Data() return m.segmentOffset == len(data) } diff --git a/stream/receive_buffer.go b/stream/receive_buffer.go new file mode 100644 index 0000000..101dcb6 --- /dev/null +++ b/stream/receive_buffer.go @@ -0,0 +1,36 @@ +package stream + +const ( + ReceiveBufferUdpBlockCount = 200 + + ReceiveBufferTCPBlockCount = 100 +) + +// ReceiveBuffer 收流缓冲区. 网络收流->解析流->封装流->发送流是同步的,从解析到发送可能耗时,从而影响读取网络流. 使用收流缓冲区,可有效降低出现此问题的概率. +// 从网络IO读取数据->送给解复用器, 此过程需做到无内存拷贝 +// rtmp和1078推流直接使用ReceiveBuffer +// 国标推流,UDP收流都要经过jitter buffer处理, 还是需要拷贝一次, 没必要使用ReceiveBuffer. TCP全都使用ReceiveBuffer, 区别在于多端口模式, 第一包传给source, 单端口模式先解析出ssrc, 找到source. 后续再传给source. +type ReceiveBuffer struct { + blockSize int //单个缓存块大小 + blockCount int //缓存块数据流. 应当和Source的数据输入管道容量保持一致. + data []byte + index int +} + +func (r *ReceiveBuffer) GetBlock() []byte { + bytes := r.data[r.index*r.blockSize:] + r.index = r.index + 1%r.blockCount + return bytes[:r.blockSize] +} + +func NewReceiveBuffer(blockSize, blockCount int) *ReceiveBuffer { + return &ReceiveBuffer{blockSize: blockSize, blockCount: blockCount, data: make([]byte, blockSize*blockCount), index: 0} +} + +func NewUDPReceiveBuffer() *ReceiveBuffer { + return NewReceiveBuffer(1500, ReceiveBufferUdpBlockCount) +} + +func NewTCPReceiveBuffer() *ReceiveBuffer { + return NewReceiveBuffer(4096*20, ReceiveBufferTCPBlockCount) +} diff --git a/stream/source.go b/stream/source.go index 4099d32..f987321 100644 --- a/stream/source.go +++ b/stream/source.go @@ -108,7 +108,7 @@ type Source interface { // OnDeMuxDone 所有流解析完毕回调 OnDeMuxDone() - Init(inputCB func(data []byte) error, closeCB func()) + Init(inputCB func(data []byte) error, closeCB func(), receiveQueueSize int) LoopEvent() @@ -123,6 +123,8 @@ type Source interface { StartIdleTimer() State() SessionState + + SetInputCb(func(data []byte) error) } type PublishSource struct { @@ -154,7 +156,6 @@ type PublishSource struct { //sink的拉流和断开拉流事件,都通过管道交给Source处理. 意味着Source内部解析流、封装流、传输流都可以做到无锁操作 //golang的管道是有锁的(https://github.com/golang/go/blob/d38f1d13fa413436d38d86fe86d6a146be44bb84/src/runtime/chan.go#L202), 后面使用cas队列传输事件, 并且可以做到一次读取多个事件 inputDataEvent chan []byte - dataConsumedEvent chan byte //解析完input的数据后,才能继续从网络io中读取流 closedEvent chan byte playingEventQueue chan Sink playingDoneEventQueue chan Sink @@ -172,15 +173,15 @@ func (s *PublishSource) Id() string { return s.Id_ } -func (s *PublishSource) Init(inputCB func(data []byte) error, closeCB func()) { +func (s *PublishSource) Init(inputCB func(data []byte) error, closeCB func(), receiveQueueSize int) { s.inputCB = inputCB s.closeCB = closeCB s.SetState(SessionStateHandshakeDone) //初始化事件接收缓冲区 //收流和网络断开的chan都阻塞执行 - s.inputDataEvent = make(chan []byte) - s.dataConsumedEvent = make(chan byte) + //-1是为了保证从管道取到流, 到解析流是安全的, 不会被覆盖 + s.inputDataEvent = make(chan []byte, receiveQueueSize-1) s.closedEvent = make(chan byte) s.playingEventQueue = make(chan Sink, 128) s.playingDoneEventQueue = make(chan Sink, 128) @@ -234,22 +235,21 @@ func (s *PublishSource) LoopEvent() { select { case data := <-s.inputDataEvent: if !s.closed { - if AppConfig.ReceiveTimeout > 0 { - s.lastPacketTime = time.Now() - } - - if s.state == SessionStateHandshakeDone { - s.state = SessionStateTransferring - //不在父类处理hook和prepare事情 - } - - if err := s.inputCB(data); err != nil { - log.Sugar.Errorf("处理输入流失败 释放source:%s err:%s", s.Id_, err.Error()) - s.Close() - } + break } - s.dataConsumedEvent <- 0 + if AppConfig.ReceiveTimeout > 0 { + s.lastPacketTime = time.Now() + } + + if s.state == SessionStateHandshakeDone { + s.state = SessionStateTransferring + } + + if err := s.inputCB(data); err != nil { + log.Sugar.Errorf("处理输入流失败 释放source:%s err:%s", s.Id_, err.Error()) + s.Close() + } break case sink := <-s.playingEventQueue: if !s.completed { @@ -363,6 +363,7 @@ func (s *PublishSource) AddSink(sink Sink) bool { transStream.AddTrack(streams[i]) } + transStream.Init() _ = transStream.WriteHeader() } @@ -413,7 +414,6 @@ func (s *PublishSource) RemoveSink(sink Sink) bool { func (s *PublishSource) AddEvent(event SourceEvent, data interface{}) { if SourceEventInput == event { s.inputDataEvent <- data.([]byte) - <-s.dataConsumedEvent } else if SourceEventPlay == event { s.playingEventQueue <- data.(Sink) } else if SourceEventPlayDone == event { @@ -650,3 +650,7 @@ func (s *PublishSource) StartIdleTimer() { func (s *PublishSource) State() SessionState { return s.state } + +func (s *PublishSource) SetInputCb(cb func(data []byte) error) { + s.inputCB = cb +}