diff --git a/api.go b/api.go index 90ac493..352e7f1 100644 --- a/api.go +++ b/api.go @@ -2,10 +2,12 @@ package main import ( "encoding/json" + "fmt" "github.com/gorilla/mux" "github.com/gorilla/websocket" "github.com/yangjiechina/avformat/utils" "github.com/yangjiechina/live-server/flv" + "github.com/yangjiechina/live-server/gb28181" "github.com/yangjiechina/live-server/hls" "github.com/yangjiechina/live-server/log" "github.com/yangjiechina/live-server/rtc" @@ -40,6 +42,12 @@ func init() { func startApiServer(addr string) { apiServer.router.HandleFunc("/live/{source}", apiServer.filterLive) + + apiServer.router.HandleFunc("/v1/gb28181/source/create", apiServer.createGBSource) + //TCP主动,设置连接地址 + apiServer.router.HandleFunc("/v1/gb28181/source/connect", apiServer.connectGBSource) + apiServer.router.HandleFunc("/v1/gb28181/source/close", apiServer.closeGBSource) + apiServer.router.HandleFunc("/rtc.html", func(writer http.ResponseWriter, request *http.Request) { http.ServeFile(writer, request, "./rtc.html") }) @@ -60,6 +68,147 @@ func startApiServer(addr string) { } } +func (api *ApiServer) createGBSource(w http.ResponseWriter, r *http.Request) { + //请求参数 + v := &struct { + Source string `json:"source"` //SourceId + Transport string `json:"transport,omitempty"` + Setup string `json:"setup"` //active/passive + SSRC uint32 `json:"ssrc,omitempty"` + }{} + + //返回监听的端口 + response := &struct { + Port uint16 `json:"port,omitempty"` + }{} + + var err error + defer func() { + if err != nil { + log.Sugar.Errorf(err.Error()) + httpResponse2(w, err) + } + }() + + if err = HttpDecodeJSONBody(w, r, v); err != nil { + return + } + + log.Sugar.Infof("gb create:%v", v) + + source := stream.SourceManager.Find(v.Source) + if source != nil { + err = &MalformedRequest{Code: http.StatusBadRequest, Msg: "gbsource 已经存在"} + return + } + + tcp := strings.Contains(v.Transport, "tcp") + var active bool + if tcp && "active" == v.Setup { + if !stream.AppConfig.GB28181.IsMultiPort() { + err = &MalformedRequest{Code: http.StatusBadRequest, Msg: "创建GB28181 Source失败, 单端口模式下不能主动拉流"} + } else if !tcp { + err = &MalformedRequest{Code: http.StatusBadRequest, Msg: "创建GB28181 Source失败, UDP不能主动拉流"} + } else if !stream.AppConfig.GB28181.EnableTCP() { + err = &MalformedRequest{Code: http.StatusBadRequest, Msg: "创建GB28181 Source失败, 未开启TCP, UDP不能主动拉流"} + } + + if err != nil { + return + } + + active = true + } + + _, port, err := gb28181.NewGBSource(v.Source, v.SSRC, tcp, active) + if err != nil { + err = &MalformedRequest{Code: http.StatusInternalServerError, Msg: fmt.Sprintf("创建GB28181 Source失败 err:%s", err.Error())} + return + } + + response.Port = port + httpResponseOk(w, response) +} + +func (api *ApiServer) connectGBSource(w http.ResponseWriter, r *http.Request) { + //请求参数 + v := &struct { + Source string `json:"source"` //SourceId + RemoteAddr string `json:"remote_addr"` + }{} + + var err error + defer func() { + if err != nil { + log.Sugar.Errorf(err.Error()) + httpResponse2(w, err) + } + }() + + if err = HttpDecodeJSONBody(w, r, v); err != nil { + return + } + + log.Sugar.Infof("gb connect:%v", v) + + source := stream.SourceManager.Find(v.Source) + if source == nil { + err = &MalformedRequest{Code: http.StatusBadRequest, Msg: "gb28181 source 不存在"} + return + } + + activeSource, ok := source.(*gb28181.ActiveSource) + if !ok { + err = &MalformedRequest{Code: http.StatusBadRequest, Msg: "gbsource 不能转为active source"} + return + } + + addr, err := net.ResolveTCPAddr("tcp", v.RemoteAddr) + if err != nil { + err = &MalformedRequest{Code: http.StatusBadRequest, Msg: "解析连接地址失败"} + return + } + + err = activeSource.Connect(addr) + if err != nil { + err = &MalformedRequest{Code: http.StatusBadRequest, Msg: fmt.Sprintf("连接Server失败 err:%s", err.Error())} + return + } + + httpResponseOk(w, nil) +} + +func (api *ApiServer) closeGBSource(w http.ResponseWriter, r *http.Request) { + //请求参数 + v := &struct { + Source string `json:"source"` //SourceId + }{} + + var err error + defer func() { + if err != nil { + log.Sugar.Errorf(err.Error()) + httpResponse2(w, err) + } + }() + + if err = HttpDecodeJSONBody(w, r, v); err != nil { + httpResponse2(w, err) + return + } + + log.Sugar.Infof("gb close:%v", v) + + source := stream.SourceManager.Find(v.Source) + if source == nil { + err = &MalformedRequest{Code: http.StatusBadRequest, Msg: "gb28181 source 不存在"} + return + } + + source.Close() + httpResponseOk(w, nil) +} + func (api *ApiServer) generateSinkId(remoteAddr string) stream.SinkId { tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr) if err != nil { diff --git a/flv/http_flv.go b/flv/http_flv.go index 2671692..3f0aeb4 100644 --- a/flv/http_flv.go +++ b/flv/http_flv.go @@ -42,6 +42,10 @@ func NewHttpTransStream() stream.ITransStream { } } +func TransStreamFactory(source stream.ISource, protocol stream.Protocol, streams []utils.AVStream) (stream.ITransStream, error) { + return NewHttpTransStream(), nil +} + func (t *httpTransStream) Input(packet utils.AVPacket) error { var flvSize int var data []byte @@ -210,7 +214,7 @@ func (t *httpTransStream) WriteHeader() error { if utils.AVMediaTypeAudio == track.Type() { data = track.Extra() } else if utils.AVMediaTypeVideo == track.Type() { - data, _ = track.M4VCExtraData() + data = track.CodecParameters().DecoderConfRecord().ToMP4VC() } t.headerSize += t.muxer.Input(t.header[t.headerSize:], track.Type(), len(data), 0, 0, false, true) diff --git a/gb28181/filter.go b/gb28181/filter.go new file mode 100644 index 0000000..a90d37f --- /dev/null +++ b/gb28181/filter.go @@ -0,0 +1,30 @@ +package gb28181 + +import ( + "github.com/pion/rtp" + "github.com/yangjiechina/live-server/log" + "net" +) + +type Filter interface { + AddSource(ssrc uint32, source GBSource) bool + + Input(conn net.Conn, data []byte) GBSource + + ParseRtpPacket(conn net.Conn, data []byte) (*rtp.Packet, error) +} + +type FilterImpl struct { +} + +func (r FilterImpl) ParseRtpPacket(conn net.Conn, data []byte) (*rtp.Packet, error) { + packet := rtp.Packet{} + err := packet.Unmarshal(data) + + if err != nil { + log.Sugar.Errorf("解析rtp失败 err:%s conn:%s", err.Error(), conn.RemoteAddr().String()) + return nil, err + } + + return &packet, err +} diff --git a/gb28181/filter_single.go b/gb28181/filter_single.go new file mode 100644 index 0000000..003e666 --- /dev/null +++ b/gb28181/filter_single.go @@ -0,0 +1,36 @@ +package gb28181 + +import ( + "net" +) + +type SingleFilter struct { + FilterImpl + + source GBSource +} + +func NewSingleFilter(source GBSource) *SingleFilter { + return &SingleFilter{source: source} +} + +func (s *SingleFilter) AddSource(ssrc uint32, source GBSource) bool { + panic("implement me") + /* utils.Assert(s.source == nil) + s.source = source + return true*/ +} + +func (s *SingleFilter) Input(conn net.Conn, data []byte) GBSource { + packet, err := s.ParseRtpPacket(conn, data) + if err != nil { + return nil + } + + if s.source == nil { + return nil + } + + s.source.InputRtp(packet) + return s.source +} diff --git a/gb28181/filter_ssrc.go b/gb28181/filter_ssrc.go new file mode 100644 index 0000000..8e7a556 --- /dev/null +++ b/gb28181/filter_ssrc.go @@ -0,0 +1,40 @@ +package gb28181 + +import ( + "net" +) + +type SSRCFilter struct { + FilterImpl + + sources map[uint32]GBSource +} + +func NewSharedFilter(guestCount int) *SSRCFilter { + return &SSRCFilter{sources: make(map[uint32]GBSource, guestCount)} +} + +func (r SSRCFilter) AddSource(ssrc uint32, source GBSource) bool { + _, ok := r.sources[ssrc] + if ok { + return false + } + + r.sources[ssrc] = source + return true +} + +func (r SSRCFilter) Input(conn net.Conn, data []byte) GBSource { + packet, err := r.ParseRtpPacket(conn, data) + if err != nil { + return nil + } + + source, ok := r.sources[packet.SSRC] + if !ok { + return nil + } + + source.InputRtp(packet) + return source +} diff --git a/gb28181/gb28181_test.go b/gb28181/gb28181_test.go new file mode 100644 index 0000000..90bc18d --- /dev/null +++ b/gb28181/gb28181_test.go @@ -0,0 +1,198 @@ +package gb28181 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" + "github.com/yangjiechina/avformat/transport" + "io" + "net" + "net/http" + "os" + "testing" + "time" +) + +// 输入rtp负载的ps流文件路径, 根据ssrc解析, rtp头不要带扩展 +func readRtp(path string, ssrc uint32, tcp bool, cb func([]byte)) { + file, err := os.ReadFile(path) + if err != nil { + panic(err) + } + + var offset int + tcpRtp := make([]byte, 1500) + + for i := 0; i < len(file)-4; i++ { + if ssrc != binary.BigEndian.Uint32(file[i:]) { + continue + } + + if i-8 != 0 { + var err error + rtp := file[offset : i-8] + + if tcp { + binary.BigEndian.PutUint16(tcpRtp, uint16(len(rtp))) + copy(tcpRtp[2:], rtp) + cb(tcpRtp[:2+len(rtp)]) + } else { + cb(rtp) + } + + if err != nil { + panic(err.Error()) + } + } + offset = i - 8 + } +} + +func connectSource(source string, addr string) { + v := &struct { + Source string `json:"source"` //SourceId + RemoteAddr string `json:"remote_addr"` + }{ + Source: source, + RemoteAddr: addr, + } + + marshal, err := json.Marshal(v) + if err != nil { + panic(err) + } + + request, err := http.NewRequest("POST", "http://localhost:8080/v1/gb28181/source/connect", bytes.NewBuffer(marshal)) + if err != nil { + panic(err) + } + + client := http.Client{} + response, err := client.Do(request) + if err != nil { + panic(err) + } + + _, err = io.ReadAll(response.Body) + if err != nil { + panic(err) + } +} + +func createSource(source, transport, setup string, ssrc uint32) int { + v := struct { + Source string `json:"source"` //SourceId + Transport string `json:"transport,omitempty"` + Setup string `json:"setup"` //active/passive + SSRC uint32 `json:"ssrc,omitempty"` + }{ + Source: source, + Transport: transport, + Setup: setup, + SSRC: ssrc, + } + + marshal, err := json.Marshal(v) + if err != nil { + panic(err) + } + + request, err := http.NewRequest("POST", "http://localhost:8080/v1/gb28181/source/create", bytes.NewBuffer(marshal)) + if err != nil { + panic(err) + } + + client := http.Client{} + response, err := client.Do(request) + if err != nil { + panic(err) + } + + all, err := io.ReadAll(response.Body) + if err != nil { + panic(err) + } + + resposne := &struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + Port int `json:"port"` + } `json:"data"` + }{} + + err = json.Unmarshal(all, resposne) + if err != nil { + panic(err) + } + + if resposne.Code != http.StatusOK { + panic("") + } + + return resposne.Data.Port +} + +func TestUDPRecv(t *testing.T) { + path := "D:\\GOProjects\\avformat\\gb28181_h264.rtp" + ssrc := 0xBEBC201 + ip := "192.168.31.112" + localAddr := "0.0.0.0:20001" + network := "tcp" + setup := "active" + id := "hls_mystream" + + port := createSource(id, network, setup, uint32(ssrc)) + + if network == "udp" { + addr, _ := net.ResolveUDPAddr(network, localAddr) + remoteAddr, _ := net.ResolveUDPAddr(network, fmt.Sprintf("%s:%d", ip, port)) + + client := &transport.UDPClient{} + err := client.Connect(addr, remoteAddr) + if err != nil { + panic(err) + } + + readRtp(path, uint32(ssrc), false, func(data []byte) { + client.Write(data) + time.Sleep(1 * time.Millisecond) + }) + } else if !(setup == "active") { + addr, _ := net.ResolveTCPAddr(network, localAddr) + remoteAddr, _ := net.ResolveTCPAddr(network, fmt.Sprintf("%s:%d", ip, port)) + + client := transport.TCPClient{} + err := client.Connect(addr, remoteAddr) + + if err != nil { + panic(err) + } + + readRtp(path, uint32(ssrc), true, func(data []byte) { + client.Write(data) + time.Sleep(1 * time.Millisecond) + }) + } else { + addr, _ := net.ResolveTCPAddr(network, localAddr) + server := transport.TCPServer{} + + server.SetHandler2(func(conn net.Conn) { + readRtp(path, uint32(ssrc), true, func(data []byte) { + conn.Write(data) + time.Sleep(1 * time.Millisecond) + }) + }, nil, nil) + + err := server.Bind(addr) + if err != nil { + panic(err) + } + + connectSource(id, "192.168.31.112:20001") + // + } + + select {} +} diff --git a/gb28181/source.go b/gb28181/source.go new file mode 100644 index 0000000..9ec5a2d --- /dev/null +++ b/gb28181/source.go @@ -0,0 +1,312 @@ +package gb28181 + +import ( + "encoding/hex" + "fmt" + "github.com/pion/rtp" + "github.com/yangjiechina/avformat/libavc" + "github.com/yangjiechina/avformat/libhevc" + "github.com/yangjiechina/avformat/libmpeg" + "github.com/yangjiechina/avformat/transport" + "github.com/yangjiechina/avformat/utils" + "github.com/yangjiechina/live-server/log" + "github.com/yangjiechina/live-server/stream" + "net" +) + +type Transport int + +const ( + TransportUDP = Transport(0) + TransportTCPPassive = Transport(1) + TransportTCPActive = Transport(2) + + PsProbeBufferSize = 1024 * 1024 * 2 + JitterBufferSize = 1024 * 1024 +) + +var ( + TransportManger stream.TransportManager + SharedUDPServer *UDPServer + SharedTCPServer *TCPServer +) + +type GBSource interface { + stream.ISource + + InputRtp(pkt *rtp.Packet) error + + Transport() Transport + + PrepareTransDeMuxer(id string, ssrc uint32) +} + +// GBSourceImpl GB28181推流Source +// 负责解析生成AVStream和AVPacket, 后续全权交给父类Source处理. +type GBSourceImpl struct { + stream.SourceImpl + + deMuxerCtx *libmpeg.PSDeMuxerContext + + audioStream utils.AVStream + videoStream utils.AVStream + + ssrc uint32 + + transport transport.ITransport +} + +func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, uint16, error) { + if tcp { + utils.Assert(stream.AppConfig.GB28181.EnableTCP()) + } else { + utils.Assert(stream.AppConfig.GB28181.EnableUDP()) + } + + if active { + utils.Assert(tcp && stream.AppConfig.GB28181.EnableTCP() && stream.AppConfig.GB28181.IsMultiPort()) + } + + var source GBSource + var port uint16 + var err error + + if active { + source, port, err = NewActiveSource() + } else if tcp { + source = NewPassiveSource() + } else { + source = NewUDPSource() + } + + if err != nil { + return nil, 0, err + } + + //单端口模式,绑定ssrc + if !stream.AppConfig.GB28181.IsMultiPort() { + var success bool + if tcp { + success = SharedTCPServer.filter.AddSource(ssrc, source) + } else { + success = SharedUDPServer.filter.AddSource(ssrc, source) + } + + if !success { + return nil, 0, fmt.Errorf("source existing") + } + + port = stream.AppConfig.GB28181.Port[0] + } else if !active { + if tcp { + err := TransportManger.AllocTransport(true, func(port_ uint16) error { + + addr, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", stream.AppConfig.GB28181.Addr, port_)) + server, err := NewTCPServer(addr, NewSingleFilter(source)) + if err != nil { + + return err + } + + source.(*PassiveSource).transport = server.tcp + port = port_ + return nil + }) + + if err != nil { + return nil, 0, err + } + } else { + err := TransportManger.AllocTransport(false, func(port_ uint16) error { + + addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", stream.AppConfig.GB28181.Addr, port_)) + server, err := NewUDPServer(addr, NewSingleFilter(source)) + if err != nil { + return err + } + + source.(*UDPSource).transport = server.udp + port = port_ + return nil + }) + + if err != nil { + return nil, 0, err + } + } + } + + source.PrepareTransDeMuxer(id, ssrc) + + if err = stream.SourceManager.Add(source); err != nil { + source.Close() + return nil, 0, err + } + + source.Init(source.Input) + go source.LoopEvent() + return source, port, err +} + +func (source *GBSourceImpl) InputRtp(pkt *rtp.Packet) error { + panic("implement me") +} + +func (source *GBSourceImpl) Transport() Transport { + panic("implement me") +} + +func (source *GBSourceImpl) PrepareTransDeMuxer(id string, ssrc uint32) { + source.Id_ = id + source.ssrc = ssrc + source.deMuxerCtx = libmpeg.NewPSDeMuxerContext(make([]byte, PsProbeBufferSize)) + source.deMuxerCtx.SetHandler(source) +} + +// Input 输入PS流 +func (source *GBSourceImpl) Input(data []byte) error { + return source.deMuxerCtx.Input(data) +} + +// OnPartPacket 部分es流回调 +func (source *GBSourceImpl) OnPartPacket(index int, mediaType utils.AVMediaType, codec utils.AVCodecID, data []byte, first bool) { + buffer := source.FindOrCreatePacketBuffer(index, mediaType) + + //第一个es包, 标记内存起始位置 + if first { + buffer.Mark() + } + + buffer.Write(data) +} + +// OnLossPacket 非完整es包丢弃回调, 直接释放内存块 +func (source *GBSourceImpl) OnLossPacket(index int, mediaType utils.AVMediaType, codec utils.AVCodecID) { + buffer := source.FindOrCreatePacketBuffer(index, mediaType) + + buffer.Fetch() + buffer.FreeTail() +} + +// OnCompletePacket 完整帧回调 +func (source *GBSourceImpl) OnCompletePacket(index int, mediaType utils.AVMediaType, codec utils.AVCodecID, dts int64, pts int64, key bool) error { + buffer := source.FindOrCreatePacketBuffer(index, mediaType) + + data := buffer.Fetch() + var packet utils.AVPacket + var stream_ utils.AVStream + defer func() { + if packet == nil { + buffer.FreeTail() + } + }() + + if utils.AVCodecIdH264 == codec { + //从关键帧中解析出sps和pps + if source.videoStream == nil { + sps, pps, err := libavc.ParseExtraDataFromKeyNALU(data) + if err != nil { + log.Sugar.Errorf("从关键帧中解析sps pps失败 source:%s data:%s", source.Id_, hex.EncodeToString(data)) + return err + } + + codecData, err := utils.NewAVCCodecData(sps, pps) + if err != nil { + log.Sugar.Errorf("解析sps pps失败 source:%s data:%s sps:%s, pps:%s", source.Id_, hex.EncodeToString(data), hex.EncodeToString(sps), hex.EncodeToString(pps)) + return err + } + + source.videoStream = utils.NewAVStream(utils.AVMediaTypeVideo, 0, codec, codecData.Record(), codecData) + stream_ = source.videoStream + } + + packet = utils.NewVideoPacket(data, dts, pts, key, utils.PacketTypeAnnexB, codec, index, 90000) + } else if utils.AVCodecIdH265 == codec { + if source.videoStream == nil { + vps, sps, pps, err := libhevc.ParseExtraDataFromKeyNALU(data) + if err != nil { + log.Sugar.Errorf("从关键帧中解析vps sps pps失败 source:%s data:%s", source.Id_, hex.EncodeToString(data)) + return err + } + + codecData, err := utils.NewHevcCodecData(vps, sps, pps) + if err != nil { + log.Sugar.Errorf("解析sps pps失败 source:%s data:%s vps:%s sps:%s, pps:%s", source.Id_, hex.EncodeToString(data), hex.EncodeToString(vps), hex.EncodeToString(sps), hex.EncodeToString(pps)) + return err + } + + source.videoStream = utils.NewAVStream(utils.AVMediaTypeVideo, 0, codec, codecData.Record(), codecData) + stream_ = source.videoStream + } + + packet = utils.NewVideoPacket(data, dts, pts, key, utils.PacketTypeAnnexB, codec, index, 90000) + } else if utils.AVCodecIdAAC == codec { + //必须包含ADTSHeader + if len(data) < 7 { + log.Sugar.Warnf("need more data...") + return nil + } + + var skip int + header, err := utils.ReadADtsFixedHeader(data) + if err != nil { + log.Sugar.Errorf("读取ADTSHeader失败 suorce:%s data:%s", source.Id_, hex.EncodeToString(data[:7])) + return nil + } else { + skip = 7 + //跳过ADtsHeader长度 + if header.ProtectionAbsent() == 0 { + skip += 2 + } + } + + if source.audioStream == nil { + if source.IsCompleted() { + return nil + } + + configData, err := utils.ADtsHeader2MpegAudioConfigData(header) + config, err := utils.ParseMpeg4AudioConfig(configData) + println(config) + if err != nil { + log.Sugar.Errorf("adt头转m4ac失败 suorce:%s data:%s", source.Id_, hex.EncodeToString(data[:7])) + return nil + } + + source.audioStream = utils.NewAVStream(utils.AVMediaTypeAudio, index, codec, configData, nil) + stream_ = source.audioStream + } + + packet = utils.NewAudioPacket(data[skip:], dts, pts, codec, index, 90000) + } else if utils.AVCodecIdPCMALAW == codec || utils.AVCodecIdPCMMULAW == codec { + if source.audioStream == nil { + source.audioStream = utils.NewAVStream(utils.AVMediaTypeAudio, index, codec, nil, nil) + stream_ = source.audioStream + } + + packet = utils.NewAudioPacket(data, dts, pts, codec, index, 90000) + } else { + log.Sugar.Errorf("the codec %d is not implemented.", codec) + return nil + } + + if stream_ != nil { + source.OnDeMuxStream(stream_) + if len(source.OriginStreams()) >= source.deMuxerCtx.TrackCount() { + source.OnDeMuxStreamDone() + } + } + + source.OnDeMuxPacket(packet) + + return nil +} + +func (source *GBSourceImpl) Close() { + if source.transport != nil { + source.transport.Close() + source.transport = nil + } + + source.SourceImpl.Close() +} diff --git a/gb28181/source_active.go b/gb28181/source_active.go new file mode 100644 index 0000000..a305b18 --- /dev/null +++ b/gb28181/source_active.go @@ -0,0 +1,37 @@ +package gb28181 + +import ( + "net" +) + +type ActiveSource struct { + PassiveSource + + port uint16 + remoteAddr net.TCPAddr + tcp *TCPClient +} + +func NewActiveSource() (*ActiveSource, uint16, error) { + var port uint16 + TransportManger.AllocTransport(true, func(port_ uint16) error { + port = port_ + return nil + }) + + return &ActiveSource{port: port}, port, nil +} + +func (a ActiveSource) Connect(remoteAddr *net.TCPAddr) error { + client, err := NewTCPClient(a.port, remoteAddr, &a) + if err != nil { + return err + } + + a.tcp = client + return nil +} + +func (a ActiveSource) Transport() Transport { + return TransportTCPActive +} diff --git a/gb28181/source_passive.go b/gb28181/source_passive.go new file mode 100644 index 0000000..78902e3 --- /dev/null +++ b/gb28181/source_passive.go @@ -0,0 +1,23 @@ +package gb28181 + +import ( + "github.com/pion/rtp" + "github.com/yangjiechina/live-server/stream" +) + +type PassiveSource struct { + GBSourceImpl +} + +func NewPassiveSource() *PassiveSource { + return &PassiveSource{} +} + +func (t PassiveSource) Transport() Transport { + return TransportTCPPassive +} + +func (t PassiveSource) InputRtp(pkt *rtp.Packet) error { + t.SourceImpl.AddEvent(stream.SourceEventInput, pkt.Payload) + return nil +} diff --git a/gb28181/source_udp.go b/gb28181/source_udp.go new file mode 100644 index 0000000..8c4c26d --- /dev/null +++ b/gb28181/source_udp.go @@ -0,0 +1,50 @@ +package gb28181 + +import ( + "fmt" + "github.com/pion/rtp" + "github.com/yangjiechina/live-server/jitterbuffer" + "github.com/yangjiechina/live-server/stream" +) + +type UDPSource struct { + GBSourceImpl + + rtpDeMuxer *jitterbuffer.JitterBuffer + + rtpBuffer stream.MemoryPool +} + +func NewUDPSource() *UDPSource { + return &UDPSource{ + rtpDeMuxer: jitterbuffer.New(), + rtpBuffer: stream.NewMemoryPoolWithDirect(JitterBufferSize, true), + } +} + +func (u UDPSource) Transport() Transport { + return TransportUDP +} + +func (u UDPSource) InputRtp(pkt *rtp.Packet) error { + n := u.rtpBuffer.Capacity() - u.rtpBuffer.Size() + if n < len(pkt.Payload) { + return fmt.Errorf("RTP receive buffer overflow") + } + + allocate := u.rtpBuffer.Allocate(len(pkt.Payload)) + copy(allocate, pkt.Payload) + pkt.Payload = allocate + u.rtpDeMuxer.Push(pkt) + + for { + pkt, _ := u.rtpDeMuxer.Pop() + if pkt == nil { + return nil + } + + u.rtpBuffer.FreeHead() + + u.SourceImpl.AddEvent(stream.SourceEventInput, pkt.Payload) + } +} diff --git a/gb28181/tcp_client.go b/gb28181/tcp_client.go new file mode 100644 index 0000000..66336c3 --- /dev/null +++ b/gb28181/tcp_client.go @@ -0,0 +1,28 @@ +package gb28181 + +import ( + "fmt" + "github.com/yangjiechina/avformat/transport" + "github.com/yangjiechina/live-server/stream" + "net" +) + +type TCPClient struct { + TCPServer +} + +func NewTCPClient(listenPort uint16, remoteAddr *net.TCPAddr, source GBSource) (*TCPClient, error) { + client := &TCPClient{ + TCPServer{filter: NewSingleFilter(source)}, + } + tcp := transport.TCPClient{} + tcp.SetHandler(client) + + addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", stream.AppConfig.GB28181.Addr, listenPort)) + if err != nil { + return client, err + } + + err = tcp.Connect(addr, remoteAddr) + return client, err +} diff --git a/gb28181/tcp_server.go b/gb28181/tcp_server.go new file mode 100644 index 0000000..766b5f1 --- /dev/null +++ b/gb28181/tcp_server.go @@ -0,0 +1,63 @@ +package gb28181 + +import ( + "github.com/yangjiechina/avformat/transport" + "github.com/yangjiechina/live-server/log" + "net" +) + +type TCPServer struct { + tcp *transport.TCPServer + filter Filter +} + +type TCPSession struct { + source GBSource + decoder *transport.LengthFieldFrameDecoder +} + +func NewTCPServer(addr net.Addr, filter Filter) (*TCPServer, error) { + server := &TCPServer{ + filter: filter, + } + + tcp := &transport.TCPServer{} + tcp.SetHandler(server) + if err := tcp.Bind(addr); err != nil { + return server, err + } + + server.tcp = tcp + return server, nil +} + +func (T *TCPServer) OnConnected(conn net.Conn) { + log.Sugar.Infof("客户端链接 conn:%s", conn.RemoteAddr().String()) +} + +func (T *TCPServer) OnPacket(conn net.Conn, data []byte) { + con := conn.(*transport.Conn) + if con.Data == nil { + session := &TCPSession{} + session.decoder = transport.NewLengthFieldFrameDecoder(0xFFFF, 2, func(bytes []byte) { + source := T.filter.Input(con, bytes[2:]) + if source != nil && session.source == nil { + session.source = source + } + }) + + con.Data = session + } + + con.Data.(*TCPSession).decoder.Input(data) +} + +func (T *TCPServer) OnDisConnected(conn net.Conn, err error) { + log.Sugar.Infof("客户端断开链接 conn:%s", conn.RemoteAddr().String()) + + con := conn.(*transport.Conn) + if con.Data != nil { + con.Data.(*TCPSession).source.Close() + con.Data = nil + } +} diff --git a/gb28181/udp_server.go b/gb28181/udp_server.go new file mode 100644 index 0000000..c0918ad --- /dev/null +++ b/gb28181/udp_server.go @@ -0,0 +1,37 @@ +package gb28181 + +import ( + "github.com/yangjiechina/avformat/transport" + "net" +) + +type UDPServer struct { + udp *transport.UDPTransport + filter Filter +} + +func NewUDPServer(addr net.Addr, filter Filter) (*UDPServer, error) { + server := &UDPServer{ + filter: filter, + } + + udp, err := transport.NewUDPServer(addr, server) + if err != nil { + return nil, err + } + + server.udp = udp + return server, nil +} + +func (U UDPServer) OnConnected(conn net.Conn) { + +} + +func (U UDPServer) OnPacket(conn net.Conn, data []byte) { + U.filter.Input(conn, data) +} + +func (U UDPServer) OnDisConnected(conn net.Conn, err error) { + +} diff --git a/hls/hls_stream.go b/hls/hls_stream.go index 8c8e914..69dcea3 100644 --- a/hls/hls_stream.go +++ b/hls/hls_stream.go @@ -84,6 +84,11 @@ func NewTransStream(url, m3u8Name, tsFormat, dir string, segmentDuration, playli return stream_, nil } +func TransStreamFactory(source stream.ISource, protocol stream.Protocol, streams []utils.AVStream) (stream.ITransStream, error) { + id := source.Id() + return NewTransStream("", stream.AppConfig.Hls.M3U8Format(id), stream.AppConfig.Hls.TSFormat(id, "%d"), stream.AppConfig.Hls.Dir, stream.AppConfig.Hls.Duration, stream.AppConfig.Hls.PlaylistLength) +} + func (t *transStream) Input(packet utils.AVPacket) error { if packet.Index() >= t.muxer.TrackCount() { return fmt.Errorf("track not available") @@ -112,11 +117,7 @@ func (t *transStream) AddTrack(stream utils.AVStream) error { } if stream.CodecId() == utils.AVCodecIdH264 { - data, err := stream.AnnexBExtraData() - if err != nil { - return err - } - + data := stream.CodecParameters().DecoderConfRecord().ToAnnexB() _, err = t.muxer.AddTrack(stream.Type(), stream.CodecId(), data) } else { _, err = t.muxer.AddTrack(stream.Type(), stream.CodecId(), stream.Extra()) diff --git a/http_json_body_decode.go b/http_json_body_decode.go new file mode 100644 index 0000000..eb7a818 --- /dev/null +++ b/http_json_body_decode.go @@ -0,0 +1,82 @@ +package main + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" +) + +const ( + maxHttpJsonBodySize = 256 * 1024 +) + +const ( + EmptyRequestBody = "Request body must not be empty" +) + +type MalformedRequest struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data interface{} `json:"data,omitempty"` +} + +func (mr *MalformedRequest) Error() string { + return mr.Msg +} + +func HttpDecodeJSONBody(w http.ResponseWriter, r *http.Request, dst interface{}) error { + // Use http.MaxBytesReader to enforce a maximum read of 1MB from the + // response body. A request body larger than that will now result in + // Decode() returning a "http: request body too large" error. + r.Body = http.MaxBytesReader(w, r.Body, maxHttpJsonBodySize) + + dec := json.NewDecoder(r.Body) + //dec.DisallowUnknownFields() + + err := dec.Decode(&dst) + if err != nil { + var syntaxError *json.SyntaxError + var unmarshalTypeError *json.UnmarshalTypeError + + switch { + case errors.As(err, &syntaxError): + msg := fmt.Sprintf("Request body contains badly-formed JSON (at position %d)", syntaxError.Offset) + return &MalformedRequest{Code: http.StatusBadRequest, Msg: msg} + + case errors.Is(err, io.ErrUnexpectedEOF): + msg := fmt.Sprintf("Request body contains badly-formed JSON") + return &MalformedRequest{Code: http.StatusBadRequest, Msg: msg} + + case errors.As(err, &unmarshalTypeError): + msg := fmt.Sprintf("Request body contains an invalid value for the %q field (at position %d)", unmarshalTypeError.Field, unmarshalTypeError.Offset) + return &MalformedRequest{Code: http.StatusBadRequest, Msg: msg} + + case strings.HasPrefix(err.Error(), "json: unknown field "): + fieldName := strings.TrimPrefix(err.Error(), "json: unknown field ") + msg := fmt.Sprintf("Request body contains unknown field %s", fieldName) + return &MalformedRequest{Code: http.StatusBadRequest, Msg: msg} + + case errors.Is(err, io.EOF): + msg := "Request body must not be empty" + return &MalformedRequest{Code: http.StatusBadRequest, Msg: msg} + + case err.Error() == "http: request body too large": + msg := "Request body must not be larger than 1MB" + return &MalformedRequest{Code: http.StatusRequestEntityTooLarge, Msg: msg} + + default: + return err + } + } + + err = dec.Decode(&struct{}{}) + if err != io.EOF { + msg := "Request body must only contain a single JSON object" + return &MalformedRequest{Code: http.StatusBadRequest, Msg: msg} + } + + return nil +} diff --git a/http_response.go b/http_response.go new file mode 100644 index 0000000..bc02d4d --- /dev/null +++ b/http_response.go @@ -0,0 +1,29 @@ +package main + +import ( + "encoding/json" + "net/http" +) + +func httpResponse(w http.ResponseWriter, code int, msg string) { + httpResponse2(w, MalformedRequest{ + Code: code, + Msg: msg, + }) +} + +func httpResponseOk(w http.ResponseWriter, data interface{}) { + httpResponse2(w, MalformedRequest{ + Code: http.StatusOK, + Msg: "ok", + Data: data, + }) +} + +func httpResponse2(w http.ResponseWriter, payload interface{}) { + body, _ := json.Marshal(payload) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT") + w.Write(body) +} diff --git a/jitterbuffer/jitter_buffer.go b/jitterbuffer/jitter_buffer.go new file mode 100644 index 0000000..976ea76 --- /dev/null +++ b/jitterbuffer/jitter_buffer.go @@ -0,0 +1,282 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package jitterbuffer implements a buffer for RTP packets designed to help +// counteract non-deterministic sources of latency +package jitterbuffer + +import ( + "errors" + "math" + "sync" + + "github.com/pion/rtp" +) + +// State tracks a JitterBuffer as either Buffering or Emitting +type State uint16 + +// Event represents all events a JitterBuffer can emit +type Event string + +var ( + // ErrBufferUnderrun is returned when the buffer has no items + ErrBufferUnderrun = errors.New("invalid Peek: Empty jitter buffer") + // ErrPopWhileBuffering is returned if a jitter buffer is not in a playback state + ErrPopWhileBuffering = errors.New("attempt to pop while buffering") +) + +const ( + // Buffering is the state when the jitter buffer has not started emitting yet, or has hit an underflow and needs to re-buffer packets + Buffering State = iota + // Emitting is the state when the jitter buffer is operating nominally + Emitting +) + +const ( + // StartBuffering is emitted when the buffer receives its first packet + StartBuffering Event = "startBuffering" + // BeginPlayback is emitted when the buffer has satisfied its buffer length + BeginPlayback = "playing" + // BufferUnderflow is emitted when the buffer does not have enough packets to Pop + BufferUnderflow = "underflow" + // BufferOverflow is emitted when the buffer has exceeded its limit + BufferOverflow = "overflow" +) + +func (jbs State) String() string { + switch jbs { + case Buffering: + return "Buffering" + case Emitting: + return "Emitting" + } + return "unknown" +} + +type ( + // Option will Override JitterBuffer's defaults + Option func(jb *JitterBuffer) + // EventListener will be called when the corresponding Event occurs + EventListener func(event Event, jb *JitterBuffer) +) + +// A JitterBuffer will accept Pushed packets, put them in sequence number +// order, and allows removing in either sequence number order or via a +// provided timestamp +type JitterBuffer struct { + packets *PriorityQueue + minStartCount uint16 + lastSequence uint16 + playoutHead uint16 + playoutReady bool + state State + stats Stats + listeners map[Event][]EventListener + mutex sync.Mutex +} + +// Stats Track interesting statistics for the life of this JitterBuffer +// outOfOrderCount will provide the number of times a packet was Pushed +// +// without its predecessor being present +// +// underflowCount will provide the count of attempts to Pop an empty buffer +// overflowCount will track the number of times the jitter buffer exceeds its limit +type Stats struct { + outOfOrderCount uint32 + underflowCount uint32 + overflowCount uint32 +} + +// New will initialize a jitter buffer and its associated statistics +func New(opts ...Option) *JitterBuffer { + jb := &JitterBuffer{ + state: Buffering, + stats: Stats{0, 0, 0}, + minStartCount: 50, + packets: NewQueue(), + listeners: make(map[Event][]EventListener), + } + + for _, o := range opts { + o(jb) + } + + return jb +} + +// WithMinimumPacketCount will set the required number of packets to be received before +// any attempt to pop a packet can succeed +func WithMinimumPacketCount(count uint16) Option { + return func(jb *JitterBuffer) { + jb.minStartCount = count + } +} + +// Listen will register an event listener +// The jitter buffer may emit events correspnding, interested listerns should +// look at Event for available events +func (jb *JitterBuffer) Listen(event Event, cb EventListener) { + jb.listeners[event] = append(jb.listeners[event], cb) +} + +// PlayoutHead returns the SequenceNumber that will be attempted to Pop next +func (jb *JitterBuffer) PlayoutHead() uint16 { + jb.mutex.Lock() + defer jb.mutex.Unlock() + + return jb.playoutHead +} + +// SetPlayoutHead allows you to manually specify the packet you wish to pop next +// If you have encountered a packet that hasn't resolved you can skip it +func (jb *JitterBuffer) SetPlayoutHead(playoutHead uint16) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + + jb.playoutHead = playoutHead +} + +func (jb *JitterBuffer) updateStats(lastPktSeqNo uint16) { + // If we have at least one packet, and the next packet being pushed in is not + // at the expected sequence number increment the out of order count + if jb.packets.Length() > 0 && lastPktSeqNo != ((jb.lastSequence+1)%math.MaxUint16) { + jb.stats.outOfOrderCount++ + } + jb.lastSequence = lastPktSeqNo +} + +// Push an RTP packet into the jitter buffer, this does not clone +// the data so if the memory is expected to be reused, the caller should +// take this in to account and pass a copy of the packet they wish to buffer +func (jb *JitterBuffer) Push(packet *rtp.Packet) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + if jb.packets.Length() == 0 { + jb.emit(StartBuffering) + } + if jb.packets.Length() > 100 { + jb.stats.overflowCount++ + jb.emit(BufferOverflow) + } + if !jb.playoutReady && jb.packets.Length() == 0 { + jb.playoutHead = packet.SequenceNumber + } + jb.updateStats(packet.SequenceNumber) + jb.packets.Push(packet, packet.SequenceNumber) + jb.updateState() +} + +func (jb *JitterBuffer) emit(event Event) { + for _, l := range jb.listeners[event] { + l(event, jb) + } +} + +func (jb *JitterBuffer) updateState() { + // For now, we only look at the number of packets captured in the play buffer + if jb.packets.Length() >= jb.minStartCount && jb.state == Buffering { + jb.state = Emitting + jb.playoutReady = true + jb.emit(BeginPlayback) + } +} + +// Peek at the packet which is either: +// +// At the playout head when we are emitting, and the playoutHead flag is true +// +// or else +// +// At the last sequence received +func (jb *JitterBuffer) Peek(playoutHead bool) (*rtp.Packet, error) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + if jb.packets.Length() < 1 { + return nil, ErrBufferUnderrun + } + if playoutHead && jb.state == Emitting { + return jb.packets.Find(jb.playoutHead) + } + return jb.packets.Find(jb.lastSequence) +} + +// Pop an RTP packet from the jitter buffer at the current playout head +func (jb *JitterBuffer) Pop() (*rtp.Packet, error) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + if jb.state != Emitting { + return nil, ErrPopWhileBuffering + } + packet, err := jb.packets.PopAt(jb.playoutHead) + if err != nil { + jb.stats.underflowCount++ + jb.emit(BufferUnderflow) + return nil, err + } + jb.playoutHead = (jb.playoutHead + 1) % math.MaxUint16 + jb.updateState() + return packet, nil +} + +// PopAtSequence will pop an RTP packet from the jitter buffer at the specified Sequence +func (jb *JitterBuffer) PopAtSequence(sq uint16) (*rtp.Packet, error) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + if jb.state != Emitting { + return nil, ErrPopWhileBuffering + } + packet, err := jb.packets.PopAt(sq) + if err != nil { + jb.stats.underflowCount++ + jb.emit(BufferUnderflow) + return nil, err + } + jb.playoutHead = (jb.playoutHead + 1) % math.MaxUint16 + jb.updateState() + return packet, nil +} + +// PeekAtSequence will return an RTP packet from the jitter buffer at the specified Sequence +// without removing it from the buffer +func (jb *JitterBuffer) PeekAtSequence(sq uint16) (*rtp.Packet, error) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + packet, err := jb.packets.Find(sq) + if err != nil { + return nil, err + } + return packet, nil +} + +// PopAtTimestamp pops an RTP packet from the jitter buffer with the provided timestamp +// Call this method repeatedly to drain the buffer at the timestamp +func (jb *JitterBuffer) PopAtTimestamp(ts uint32) (*rtp.Packet, error) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + if jb.state != Emitting { + return nil, ErrPopWhileBuffering + } + packet, err := jb.packets.PopAtTimestamp(ts) + if err != nil { + jb.stats.underflowCount++ + jb.emit(BufferUnderflow) + return nil, err + } + jb.updateState() + return packet, nil +} + +// Clear will empty the buffer and optionally reset the state +func (jb *JitterBuffer) Clear(resetState bool) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + jb.packets.Clear() + if resetState { + jb.lastSequence = 0 + jb.state = Buffering + jb.stats = Stats{0, 0, 0} + jb.minStartCount = 50 + } +} diff --git a/jitterbuffer/jitter_buffer_test.go b/jitterbuffer/jitter_buffer_test.go new file mode 100644 index 0000000..205e610 --- /dev/null +++ b/jitterbuffer/jitter_buffer_test.go @@ -0,0 +1,238 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package jitterbuffer + +import ( + "math" + "testing" + + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +func TestJitterBuffer(t *testing.T) { + assert := assert.New(t) + + t.Run("Appends packets in order", func(*testing.T) { + jb := New() + assert.Equal(jb.lastSequence, uint16(0)) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 501}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 502}, Payload: []byte{0x02}}) + + assert.Equal(jb.lastSequence, uint16(5002)) + + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5012, Timestamp: 512}, Payload: []byte{0x02}}) + + assert.Equal(jb.stats.outOfOrderCount, uint32(1)) + assert.Equal(jb.packets.Length(), uint16(4)) + assert.Equal(jb.lastSequence, uint16(5012)) + }) + + t.Run("Appends packets and begins playout", func(*testing.T) { + jb := New() + for i := 0; i < 100; i++ { + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + assert.Equal(jb.packets.Length(), uint16(100)) + assert.Equal(jb.state, Emitting) + assert.Equal(jb.playoutHead, uint16(5012)) + head, err := jb.Pop() + assert.Equal(head.SequenceNumber, uint16(5012)) + assert.Equal(err, nil) + }) + t.Run("Appends packets and begins playout", func(*testing.T) { + jb := New(WithMinimumPacketCount(1)) + events := make([]Event, 0) + jb.Listen(BeginPlayback, func(event Event, _ *JitterBuffer) { + events = append(events, event) + }) + for i := 0; i < 2; i++ { + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + assert.Equal(jb.packets.Length(), uint16(2)) + assert.Equal(jb.state, Emitting) + assert.Equal(jb.playoutHead, uint16(5012)) + head, err := jb.Pop() + assert.Equal(head.SequenceNumber, uint16(5012)) + assert.Equal(err, nil) + assert.Equal(1, len(events)) + assert.Equal(Event(BeginPlayback), events[0]) + }) + + t.Run("Wraps playout correctly", func(*testing.T) { + jb := New() + for i := 0; i < 100; i++ { + sqnum := uint16((math.MaxUint16 - 32 + i) % math.MaxUint16) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + assert.Equal(jb.packets.Length(), uint16(100)) + assert.Equal(jb.state, Emitting) + assert.Equal(jb.playoutHead, uint16(math.MaxUint16-32)) + head, err := jb.Pop() + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) + assert.Equal(err, nil) + for i := 0; i < 100; i++ { + head, err := jb.Pop() + if i < 99 { + assert.Equal(head.SequenceNumber, uint16((math.MaxUint16-31+i)%math.MaxUint16)) + assert.Equal(err, nil) + } else { + assert.Equal(head, (*rtp.Packet)(nil)) + } + } + }) + + t.Run("Pops at timestamp correctly", func(*testing.T) { + jb := New() + for i := 0; i < 100; i++ { + sqnum := uint16((math.MaxUint16 - 32 + i) % math.MaxUint16) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + assert.Equal(jb.packets.Length(), uint16(100)) + assert.Equal(jb.state, Emitting) + head, err := jb.PopAtTimestamp(uint32(513)) + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32+1)) + assert.Equal(err, nil) + head, err = jb.PopAtTimestamp(uint32(513)) + assert.Equal(head, (*rtp.Packet)(nil)) + assert.NotEqual(err, nil) + + head, err = jb.Pop() + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) + assert.Equal(err, nil) + }) + + t.Run("Can peek at a packet", func(*testing.T) { + jb := New() + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 501}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 502}, Payload: []byte{0x02}}) + pkt, err := jb.Peek(false) + assert.Equal(pkt.SequenceNumber, uint16(5002)) + assert.Equal(err, nil) + for i := 0; i < 100; i++ { + sqnum := uint16((math.MaxUint16 - 32 + i) % math.MaxUint16) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + pkt, err = jb.Peek(true) + assert.Equal(pkt.SequenceNumber, uint16(5000)) + assert.Equal(err, nil) + }) + + t.Run("Pops at sequence with an invalid sequence number", func(*testing.T) { + jb := New() + for i := 0; i < 50; i++ { + sqnum := uint16((math.MaxUint16 - 32 + i) % math.MaxUint16) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + assert.Equal(jb.packets.Length(), uint16(52)) + assert.Equal(jb.state, Emitting) + head, err := jb.PopAtSequence(uint16(9000)) + assert.Equal(head, (*rtp.Packet)(nil)) + assert.NotEqual(err, nil) + }) + + t.Run("Pops at timestamp with multiple packets", func(*testing.T) { + jb := New() + for i := 0; i < 50; i++ { + sqnum := uint16((math.MaxUint16 - 32 + i) % math.MaxUint16) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + assert.Equal(jb.packets.Length(), uint16(52)) + assert.Equal(jb.state, Emitting) + head, err := jb.PopAtTimestamp(uint32(9000)) + assert.Equal(head.SequenceNumber, uint16(1019)) + assert.Equal(err, nil) + head, err = jb.PopAtTimestamp(uint32(9000)) + assert.Equal(head.SequenceNumber, uint16(1020)) + assert.Equal(err, nil) + + head, err = jb.Pop() + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) + assert.Equal(err, nil) + }) + + t.Run("Peeks at timestamp with multiple packets", func(*testing.T) { + jb := New() + for i := 0; i < 50; i++ { + sqnum := uint16((math.MaxUint16 - 32 + i) % math.MaxUint16) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + assert.Equal(jb.packets.Length(), uint16(52)) + assert.Equal(jb.state, Emitting) + head, err := jb.PeekAtSequence(uint16(1019)) + assert.Equal(head.SequenceNumber, uint16(1019)) + assert.Equal(err, nil) + head, err = jb.PeekAtSequence(uint16(1020)) + assert.Equal(head.SequenceNumber, uint16(1020)) + assert.Equal(err, nil) + + head, err = jb.PopAtSequence(uint16(math.MaxUint16 - 32)) + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) + assert.Equal(err, nil) + }) + + t.Run("SetPlayoutHead", func(*testing.T) { + jb := New(WithMinimumPacketCount(1)) + + // Push packets 0-9, but no packet 4 + for i := uint16(0); i < 10; i++ { + if i == 4 { + continue + } + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: i, Timestamp: uint32(512 + i)}, Payload: []byte{0x00}}) + } + + // The first 3 packets will be able to popped + for i := 0; i < 4; i++ { + pkt, err := jb.Pop() + assert.NoError(err) + assert.NotNil(pkt) + } + + // The next pop will fail because of gap + pkt, err := jb.Pop() + assert.ErrorIs(err, ErrNotFound) + assert.Nil(pkt) + assert.Equal(jb.PlayoutHead(), uint16(4)) + + // Assert that PlayoutHead isn't modified with pushing/popping again + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 10, Timestamp: uint32(522)}, Payload: []byte{0x00}}) + pkt, err = jb.Pop() + assert.ErrorIs(err, ErrNotFound) + assert.Nil(pkt) + assert.Equal(jb.PlayoutHead(), uint16(4)) + + // Increment the PlayoutHead and popping will work again + jb.SetPlayoutHead(jb.PlayoutHead() + 1) + for i := 0; i < 6; i++ { + pkt, err := jb.Pop() + assert.NoError(err) + assert.NotNil(pkt) + } + }) + + t.Run("Allows clearing the buffer", func(*testing.T) { + jb := New() + jb.Clear(false) + + assert.Equal(jb.lastSequence, uint16(0)) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 501}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 502}, Payload: []byte{0x02}}) + + assert.Equal(jb.lastSequence, uint16(5002)) + jb.Clear(true) + assert.Equal(jb.lastSequence, uint16(0)) + assert.Equal(jb.stats.outOfOrderCount, uint32(0)) + assert.Equal(jb.packets.Length(), uint16(0)) + }) +} diff --git a/jitterbuffer/option.go b/jitterbuffer/option.go new file mode 100644 index 0000000..9a33c22 --- /dev/null +++ b/jitterbuffer/option.go @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package jitterbuffer + +import ( + "github.com/pion/logging" +) + +// ReceiverInterceptorOption can be used to configure ReceiverInterceptor +type ReceiverInterceptorOption func(d *ReceiverInterceptor) error + +// Log sets a logger for the interceptor +func Log(log logging.LeveledLogger) ReceiverInterceptorOption { + return func(d *ReceiverInterceptor) error { + d.log = log + return nil + } +} diff --git a/jitterbuffer/priority_queue.go b/jitterbuffer/priority_queue.go new file mode 100644 index 0000000..f6d7d93 --- /dev/null +++ b/jitterbuffer/priority_queue.go @@ -0,0 +1,189 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package jitterbuffer + +import ( + "errors" + + "github.com/pion/rtp" +) + +// PriorityQueue provides a linked list sorting of RTP packets by SequenceNumber +type PriorityQueue struct { + next *node + length uint16 +} + +type node struct { + val *rtp.Packet + next *node + prev *node + priority uint16 +} + +var ( + // ErrInvalidOperation may be returned if a Pop or Find operation is performed on an empty queue + ErrInvalidOperation = errors.New("attempt to find or pop on an empty list") + // ErrNotFound will be returned if the packet cannot be found in the queue + ErrNotFound = errors.New("priority not found") +) + +// NewQueue will create a new PriorityQueue whose order relies on monotonically +// increasing Sequence Number, wrapping at MaxUint16, so +// a packet with sequence number MaxUint16 - 1 will be after 0 +func NewQueue() *PriorityQueue { + return &PriorityQueue{ + next: nil, + length: 0, + } +} + +func newNode(val *rtp.Packet, priority uint16) *node { + return &node{ + val: val, + prev: nil, + next: nil, + priority: priority, + } +} + +// Find a packet in the queue with the provided sequence number, +// regardless of position (the packet is retained in the queue) +func (q *PriorityQueue) Find(sqNum uint16) (*rtp.Packet, error) { + next := q.next + for next != nil { + if next.priority == sqNum { + return next.val, nil + } + next = next.next + } + + return nil, ErrNotFound +} + +// Push will insert a packet in to the queue in order of sequence number +func (q *PriorityQueue) Push(val *rtp.Packet, priority uint16) { + newPq := newNode(val, priority) + if q.next == nil { + q.next = newPq + q.length++ + return + } + if priority < q.next.priority { + newPq.next = q.next + q.next.prev = newPq + q.next = newPq + q.length++ + return + } + head := q.next + prev := q.next + for head != nil { + if priority <= head.priority { + break + } + prev = head + head = head.next + } + if head == nil { + if prev != nil { + prev.next = newPq + } + newPq.prev = prev + } else { + newPq.next = head + newPq.prev = prev + if prev != nil { + prev.next = newPq + } + head.prev = newPq + } + q.length++ +} + +// Length will get the total length of the queue +func (q *PriorityQueue) Length() uint16 { + return q.length +} + +// Pop removes the first element from the queue, regardless +// sequence number +func (q *PriorityQueue) Pop() (*rtp.Packet, error) { + if q.next == nil { + return nil, ErrInvalidOperation + } + val := q.next.val + q.length-- + q.next = q.next.next + return val, nil +} + +// PopAt removes an element at the specified sequence number (priority) +func (q *PriorityQueue) PopAt(sqNum uint16) (*rtp.Packet, error) { + if q.next == nil { + return nil, ErrInvalidOperation + } + if q.next.priority == sqNum { + val := q.next.val + q.next = q.next.next + q.length-- + return val, nil + } + pos := q.next + prev := q.next.prev + for pos != nil { + if pos.priority == sqNum { + val := pos.val + prev.next = pos.next + if prev.next != nil { + prev.next.prev = prev + } + q.length-- + return val, nil + } + prev = pos + pos = pos.next + } + return nil, ErrNotFound +} + +// PopAtTimestamp removes and returns a packet at the given RTP Timestamp, regardless +// sequence number order +func (q *PriorityQueue) PopAtTimestamp(timestamp uint32) (*rtp.Packet, error) { + if q.next == nil { + return nil, ErrInvalidOperation + } + if q.next.val.Timestamp == timestamp { + val := q.next.val + q.next = q.next.next + q.length-- + return val, nil + } + pos := q.next + prev := q.next.prev + for pos != nil { + if pos.val.Timestamp == timestamp { + val := pos.val + prev.next = pos.next + if prev.next != nil { + prev.next.prev = prev + } + q.length-- + return val, nil + } + prev = pos + pos = pos.next + } + return nil, ErrNotFound +} + +// Clear will empty a PriorityQueue +func (q *PriorityQueue) Clear() { + next := q.next + q.length = 0 + for next != nil { + next.prev = nil + next = next.next + } +} diff --git a/jitterbuffer/priority_queue_test.go b/jitterbuffer/priority_queue_test.go new file mode 100644 index 0000000..7fb2a7a --- /dev/null +++ b/jitterbuffer/priority_queue_test.go @@ -0,0 +1,138 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package jitterbuffer + +import ( + "testing" + + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +func TestPriorityQueue(t *testing.T) { + assert := assert.New(t) + + t.Run("Appends packets in order", func(*testing.T) { + pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} + q := NewQueue() + q.Push(pkt, pkt.SequenceNumber) + pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} + q.Push(pkt2, pkt2.SequenceNumber) + assert.Equal(q.next.next.val, pkt2) + assert.Equal(q.next.priority, uint16(5000)) + assert.Equal(q.next.next.priority, uint16(5004)) + }) + + t.Run("Appends many in order", func(*testing.T) { + q := NewQueue() + for i := 0; i < 100; i++ { + q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + } + assert.Equal(uint16(100), q.Length()) + last := (*node)(nil) + cur := q.next + for cur != nil { + last = cur + cur = cur.next + if cur != nil { + assert.Equal(cur.priority, last.priority+1) + } + } + assert.Equal(q.next.priority, uint16(5012)) + assert.Equal(last.priority, uint16(5012+99)) + }) + + t.Run("Can remove an element", func(*testing.T) { + pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} + q := NewQueue() + q.Push(pkt, pkt.SequenceNumber) + pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} + q.Push(pkt2, pkt2.SequenceNumber) + for i := 0; i < 100; i++ { + q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + } + popped, _ := q.Pop() + assert.Equal(popped.SequenceNumber, uint16(5000)) + _, _ = q.Pop() + nextPop, _ := q.Pop() + assert.Equal(nextPop.SequenceNumber, uint16(5012)) + }) + + t.Run("Appends in order", func(*testing.T) { + q := NewQueue() + for i := 0; i < 100; i++ { + q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + } + assert.Equal(uint16(100), q.Length()) + pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} + q.Push(pkt, pkt.SequenceNumber) + assert.Equal(pkt, q.next.val) + assert.Equal(uint16(101), q.Length()) + assert.Equal(q.next.priority, uint16(5000)) + }) + + t.Run("Can find", func(*testing.T) { + q := NewQueue() + for i := 0; i < 100; i++ { + q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + } + pkt, err := q.Find(5012) + assert.Equal(pkt.SequenceNumber, uint16(5012)) + assert.Equal(err, nil) + }) + + t.Run("Updates the length when PopAt* are called", func(*testing.T) { + pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} + q := NewQueue() + q.Push(pkt, pkt.SequenceNumber) + pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} + q.Push(pkt2, pkt2.SequenceNumber) + for i := 0; i < 100; i++ { + q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + } + assert.Equal(uint16(102), q.Length()) + popped, _ := q.PopAt(uint16(5012)) + assert.Equal(popped.SequenceNumber, uint16(5012)) + assert.Equal(uint16(101), q.Length()) + + popped, err := q.PopAtTimestamp(uint32(500)) + assert.Equal(popped.SequenceNumber, uint16(5000)) + assert.Equal(uint16(100), q.Length()) + assert.Equal(err, nil) + }) +} + +func TestPriorityQueue_Find(t *testing.T) { + packets := NewQueue() + + packets.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: 1000, + Timestamp: 5, + SSRC: 5, + }, + Payload: []uint8{0xA}, + }, 1000) + + _, err := packets.PopAt(1000) + assert.NoError(t, err) + + _, err = packets.Find(1001) + assert.Error(t, err) +} + +func TestPriorityQueue_Clean(t *testing.T) { + packets := NewQueue() + packets.Clear() + packets.Push(&rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: 1000, + Timestamp: 5, + SSRC: 5, + }, + Payload: []uint8{0xA}, + }, 1000) + assert.EqualValues(t, 1, packets.Length()) + packets.Clear() +} diff --git a/jitterbuffer/receiver_interceptor.go b/jitterbuffer/receiver_interceptor.go new file mode 100644 index 0000000..b4c032b --- /dev/null +++ b/jitterbuffer/receiver_interceptor.go @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package jitterbuffer + +import ( + "sync" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtp" +) + +// InterceptorFactory is a interceptor.Factory for a GeneratorInterceptor +type InterceptorFactory struct { + opts []ReceiverInterceptorOption +} + +// NewInterceptor constructs a new ReceiverInterceptor +func (g *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + i := &ReceiverInterceptor{ + close: make(chan struct{}), + log: logging.NewDefaultLoggerFactory().NewLogger("jitterbuffer"), + buffer: New(), + } + + for _, opt := range g.opts { + if err := opt(i); err != nil { + return nil, err + } + } + + return i, nil +} + +// ReceiverInterceptor places a JitterBuffer in the chain to smooth packet arrival +// and allow for network jitter +// +// The Interceptor is designed to fit in a RemoteStream +// pipeline and buffer incoming packets for a short period (currently +// defaulting to 50 packets) before emitting packets to be consumed by the +// next step in the pipeline. +// +// The caller must ensure they are prepared to handle an +// ErrPopWhileBuffering in the case that insufficient packets have been +// received by the jitter buffer. The caller should retry the operation +// at some point later as the buffer may have been filled in the interim. +// +// The caller should also be aware that an ErrBufferUnderrun may be +// returned in the case that the initial buffering was sufficient and +// playback began but the caller is consuming packets (or they are not +// arriving) quickly enough. +type ReceiverInterceptor struct { + interceptor.NoOp + buffer *JitterBuffer + m sync.Mutex + wg sync.WaitGroup + close chan struct{} + log logging.LeveledLogger +} + +// NewInterceptor returns a new InterceptorFactory +func NewInterceptor(opts ...ReceiverInterceptorOption) (*InterceptorFactory, error) { + return &InterceptorFactory{opts}, nil +} + +// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method +// will be called once per rtp packet. +func (i *ReceiverInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + buf := make([]byte, len(b)) + n, attr, err := reader.Read(buf, a) + if err != nil { + return n, attr, err + } + packet := &rtp.Packet{} + if err := packet.Unmarshal(buf); err != nil { + return 0, nil, err + } + i.m.Lock() + defer i.m.Unlock() + i.buffer.Push(packet) + if i.buffer.state == Emitting { + newPkt, err := i.buffer.Pop() + if err != nil { + return 0, nil, err + } + nlen, err := newPkt.MarshalTo(b) + return nlen, attr, err + } + return n, attr, ErrPopWhileBuffering + }) +} + +// UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (i *ReceiverInterceptor) UnbindRemoteStream(_ *interceptor.StreamInfo) { + defer i.wg.Wait() + i.m.Lock() + defer i.m.Unlock() + i.buffer.Clear(true) +} + +// Close closes the interceptor +func (i *ReceiverInterceptor) Close() error { + defer i.wg.Wait() + i.m.Lock() + defer i.m.Unlock() + i.buffer.Clear(true) + return nil +} diff --git a/main.go b/main.go index f6e594f..1516746 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,9 @@ package main import ( + "fmt" "github.com/yangjiechina/live-server/flv" + "github.com/yangjiechina/live-server/gb28181" "github.com/yangjiechina/live-server/hls" "github.com/yangjiechina/live-server/log" "github.com/yangjiechina/live-server/rtc" @@ -12,89 +14,148 @@ import ( _ "net/http/pprof" - "github.com/yangjiechina/avformat/librtmp" - "github.com/yangjiechina/avformat/utils" "github.com/yangjiechina/live-server/rtmp" "github.com/yangjiechina/live-server/stream" ) -var rtspAddr *net.TCPAddr +func NewDefaultAppConfig() stream.AppConfig_ { + return stream.AppConfig_{ + GOPCache: true, + MergeWriteLatency: 350, -func CreateTransStream(source stream.ISource, protocol stream.Protocol, streams []utils.AVStream) stream.ITransStream { - if stream.ProtocolRtmp == protocol { - return rtmp.NewTransStream(librtmp.ChunkSize) - } else if stream.ProtocolHls == protocol { - id := source.Id() + Hls: stream.HlsConfig{ + Enable: true, + Dir: "../tmp", + Duration: 2, + PlaylistLength: 10, + }, - transStream, err := hls.NewTransStream("", stream.AppConfig.Hls.M3U8Format(id), stream.AppConfig.Hls.TSFormat(id, "%d"), stream.AppConfig.Hls.Dir, stream.AppConfig.Hls.Duration, stream.AppConfig.Hls.PlaylistLength) + Rtmp: stream.RtmpConfig{ + Enable: true, + Addr: "0.0.0.0:1935", + }, + + Rtsp: stream.RtmpConfig{ + Enable: true, + Addr: "0.0.0.0:554", + }, + + Log: stream.LogConfig{ + Level: int(zapcore.DebugLevel), + Name: "./logs/lkm.log", + MaxSize: 10, + MaxBackup: 100, + MaxAge: 7, + Compress: false, + }, + + Http: stream.HttpConfig{ + Enable: true, + Addr: "0.0.0.0:8080", + }, + + GB28181: stream.GB28181Config{ + Addr: "0.0.0.0", + Transport: "UDP|TCP", + Port: [2]uint16{20000, 30000}, + }, + } +} + +func init() { + stream.RegisterTransStreamFactory(stream.ProtocolRtmp, rtmp.TransStreamFactory) + stream.RegisterTransStreamFactory(stream.ProtocolHls, hls.TransStreamFactory) + stream.RegisterTransStreamFactory(stream.ProtocolFlv, flv.TransStreamFactory) + stream.RegisterTransStreamFactory(stream.ProtocolRtsp, rtsp.TransStreamFactory) + stream.RegisterTransStreamFactory(stream.ProtocolRtc, rtc.TransStreamFactory) + + stream.AppConfig = NewDefaultAppConfig() + + //初始化日志 + log.InitLogger(zapcore.Level(stream.AppConfig.Log.Level), stream.AppConfig.Log.Name, stream.AppConfig.Log.MaxSize, stream.AppConfig.Log.MaxBackup, stream.AppConfig.Log.MaxAge, stream.AppConfig.Log.Compress) + + if stream.AppConfig.GB28181.IsMultiPort() { + gb28181.TransportManger = stream.NewTransportManager(stream.AppConfig.GB28181.Port[0], stream.AppConfig.GB28181.Port[1]) + } +} + +func main() { + if stream.AppConfig.Rtmp.Enable { + rtmpAddr, err := net.ResolveTCPAddr("tcp", stream.AppConfig.Rtmp.Addr) if err != nil { panic(err) } - return transStream - } else if stream.ProtocolFlv == protocol { - return flv.NewHttpTransStream() - } else if stream.ProtocolRtsp == protocol { - trackFormat := source.Id() + "?track=%d" - return rtsp.NewTransStream(net.IPAddr{ - IP: rtspAddr.IP, - Zone: rtspAddr.Zone, - }, trackFormat) - } else if stream.ProtocolRtc == protocol { - return rtc.NewTransStream() + impl := rtmp.NewServer() + err = impl.Start(rtmpAddr) + if err != nil { + panic(err) + } + + log.Sugar.Info("启动rtmp服务成功 addr:", rtmpAddr.String()) } - return nil -} + if stream.AppConfig.Rtsp.Enable { + rtspAddr, err := net.ResolveTCPAddr("tcp", stream.AppConfig.Rtsp.Addr) + if err != nil { + panic(rtspAddr) + } -func init() { - stream.TransStreamFactory = CreateTransStream -} + rtspServer := rtsp.NewServer() + err = rtspServer.Start(rtspAddr) + if err != nil { + panic(err) + } -func main() { - //初始化日志 - log.InitLogger(zapcore.DebugLevel, "./logs/lkm.log", 10, 100, 7, false) - - stream.AppConfig.GOPCache = true - stream.AppConfig.MergeWriteLatency = 350 - - stream.AppConfig.Hls.Enable = true - stream.AppConfig.Hls.Dir = "../tmp" - stream.AppConfig.Hls.Duration = 2 - stream.AppConfig.Hls.PlaylistLength = 10 - - rtmpAddr, err := net.ResolveTCPAddr("tcp", "0.0.0.0:1935") - if err != nil { - panic(err) + log.Sugar.Info("启动rtsp服务成功 addr:", rtspAddr.String()) } - impl := rtmp.NewServer() - err = impl.Start(rtmpAddr) - if err != nil { - panic(err) + if stream.AppConfig.Http.Enable { + log.Sugar.Info("启动Http服务 addr:", stream.AppConfig.Http.Addr) + + go startApiServer(stream.AppConfig.Http.Addr) } - println("启动rtmp服务成功:" + rtmpAddr.String()) + //单端口模式下, 启动时就创建收流端口 + //多端口模式下, 创建GBSource时才创建收流端口 + if !stream.AppConfig.GB28181.IsMultiPort() { + if stream.AppConfig.GB28181.EnableUDP() { + addr := fmt.Sprintf("%s:%d", stream.AppConfig.GB28181.Addr, stream.AppConfig.GB28181.Port[0]) + gbAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + panic(err) + } - rtspAddr, err = net.ResolveTCPAddr("tcp", "0.0.0.0:554") - if err != nil { - panic(rtspAddr) + server, err := gb28181.NewUDPServer(gbAddr, gb28181.NewSharedFilter(128)) + if err != nil { + panic(err) + } + + gb28181.SharedUDPServer = server + log.Sugar.Info("启动GB28181 UDP收流端口成功:" + gbAddr.String()) + } + + if stream.AppConfig.GB28181.EnableTCP() { + addr := fmt.Sprintf("%s:%d", stream.AppConfig.GB28181.Addr, stream.AppConfig.GB28181.Port[0]) + gbAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + panic(err) + } + + server, err := gb28181.NewTCPServer(gbAddr, gb28181.NewSharedFilter(128)) + if err != nil { + panic(err) + } + + gb28181.SharedTCPServer = server + log.Sugar.Info("启动GB28181 TCP收流端口成功:" + gbAddr.String()) + } } - rtspServer := rtsp.NewServer() - err = rtspServer.Start(rtspAddr) - if err != nil { - panic(err) - } - - println("启动rtsp服务成功:" + rtspAddr.String()) - - apiAddr := "0.0.0.0:8080" - go startApiServer(apiAddr) - loadConfigError := http.ListenAndServe(":19999", nil) if loadConfigError != nil { panic(loadConfigError) } + select {} } diff --git a/rtc/rtc_stream.go b/rtc/rtc_stream.go index 841ba32..4b49257 100644 --- a/rtc/rtc_stream.go +++ b/rtc/rtc_stream.go @@ -16,6 +16,10 @@ func NewTransStream() stream.ITransStream { return t } +func TransStreamFactory(source stream.ISource, protocol stream.Protocol, streams []utils.AVStream) (stream.ITransStream, error) { + return NewTransStream(), nil +} + func (t *transStream) Input(packet utils.AVPacket) error { if utils.AVMediaTypeAudio == packet.MediaType() { @@ -28,10 +32,7 @@ func (t *transStream) Input(packet utils.AVPacket) error { } if packet.KeyFrame() { - extra, err := t.TransStreamImpl.Tracks[packet.Index()].AnnexBExtraData() - if err != nil { - return err - } + extra := t.TransStreamImpl.Tracks[packet.Index()].CodecParameters().DecoderConfRecord().ToAnnexB() sink_.input(packet.Index(), extra, 0) } diff --git a/rtmp/rtmp_publisher.go b/rtmp/rtmp_publisher.go index d78cb95..c155d8b 100644 --- a/rtmp/rtmp_publisher.go +++ b/rtmp/rtmp_publisher.go @@ -9,137 +9,50 @@ import ( "net" ) -type Publisher interface { - - // Init 初始化内存池 - Init() - - // OnDiscardPacket GOP缓存溢出的包 - OnDiscardPacket(pkt interface{}) - - // OnVideo 从rtmp chunk中解析出来的整个视频包, 还需要进步封装成AVPacket - OnVideo(data []byte, ts uint32) - - // OnAudio 从rtmp chunk中解析出来的整个音频包 - OnAudio(data []byte, ts uint32) - - // OnPartPacket 从rtmp chunk中解析出来的一部分音视频包 - OnPartPacket(index int, data []byte, first bool) -} - -type publisher struct { +// Publisher RTMP推流Source +type Publisher struct { stream.SourceImpl - stack *librtmp.Stack - audioMemoryPool stream.MemoryPool - videoMemoryPool stream.MemoryPool - - audioMark bool - videoMark bool + stack *librtmp.Stack } -func NewPublisher(sourceId string, stack *librtmp.Stack, conn net.Conn) Publisher { +func NewPublisher(sourceId string, stack *librtmp.Stack, conn net.Conn) *Publisher { deMuxer := libflv.NewDeMuxer(libflv.TSModeRelative) - publisher_ := &publisher{SourceImpl: stream.SourceImpl{Id_: sourceId, Type_: stream.SourceTypeRtmp, TransDeMuxer: deMuxer, Conn: conn}, stack: stack, audioMark: false, videoMark: false} + publisher_ := &Publisher{SourceImpl: stream.SourceImpl{Id_: sourceId, Type_: stream.SourceTypeRtmp, TransDeMuxer: deMuxer, Conn: conn}, stack: stack} //设置回调,从flv解析出来的Stream和AVPacket都将统一回调到stream.SourceImpl deMuxer.SetHandler(publisher_) - publisher_.Input_ = publisher_.Input //为推流方分配足够多的缓冲区 conn.(*transport.Conn).ReallocateRecvBuffer(1024 * 1024) return publisher_ } -func (p *publisher) Init() { - //创建内存池 - p.audioMemoryPool = stream.NewMemoryPool(48000 * 64) - if stream.AppConfig.GOPCache { - //以每秒钟4M码率大小创建内存池 - p.videoMemoryPool = stream.NewMemoryPool(4096 * 1000) - } else { - p.videoMemoryPool = stream.NewMemoryPool(4096 * 1000 / 8) - } - - p.SourceImpl.Init() - go p.SourceImpl.LoopEvent() +func (p *Publisher) Input(data []byte) error { + return p.stack.Input(nil, data) } -func (p *publisher) Input(data []byte) { - p.stack.Input(nil, data) -} - -func (p *publisher) OnDiscardPacket(pkt interface{}) { - packet := pkt.(utils.AVPacket) - if utils.AVMediaTypeAudio == packet.MediaType() { - p.audioMemoryPool.FreeHead() - } else if utils.AVMediaTypeVideo == packet.MediaType() { - p.videoMemoryPool.FreeHead() - } -} - -func (p *publisher) OnDeMuxStream(stream_ utils.AVStream) { - //释放掉内存池中最新分配的内存 - if utils.AVMediaTypeAudio == stream_.Type() { - p.audioMemoryPool.FreeTail() - } else if utils.AVMediaTypeVideo == stream_.Type() { - p.videoMemoryPool.FreeTail() - } - - if ret, buffer := p.SourceImpl.OnDeMuxStream(stream_); ret && buffer != nil { - buffer.SetDiscardHandler(p.OnDiscardPacket) - } -} - -func (p *publisher) OnDeMuxPacket(packet utils.AVPacket) { - p.SourceImpl.OnDeMuxPacket(packet) - - if stream.AppConfig.GOPCache { - return - } - - //未开启GOP缓存,释放掉内存 - if utils.AVMediaTypeAudio == packet.MediaType() { - p.audioMemoryPool.FreeTail() - } else if utils.AVMediaTypeVideo == packet.MediaType() { - p.videoMemoryPool.FreeTail() - } +func (p *Publisher) OnDeMuxStream(stream utils.AVStream) { + //AVStream的ExtraData已经拷贝, 释放掉内存池中最新分配的内存 + p.FindOrCreatePacketBuffer(stream.Index(), stream.Type()).FreeTail() } // OnVideo 解析出来的完整视频包 // @ts rtmp chunk的相对时间戳 -func (p *publisher) OnVideo(data []byte, ts uint32) { - if data == nil { - data = p.videoMemoryPool.Fetch() - p.videoMark = false - } - +func (p *Publisher) OnVideo(index int, data []byte, ts uint32) { + data = p.FindOrCreatePacketBuffer(index, utils.AVMediaTypeVideo).Fetch() + //交给flv解复用器, 解析回调出AVPacket p.SourceImpl.TransDeMuxer.(libflv.DeMuxer).InputVideo(data, ts) } -func (p *publisher) OnAudio(data []byte, ts uint32) { - if data == nil { - data = p.audioMemoryPool.Fetch() - p.audioMark = false - } - - _ = p.SourceImpl.TransDeMuxer.(libflv.DeMuxer).InputAudio(data, ts) +func (p *Publisher) OnAudio(index int, data []byte, ts uint32) { + data = p.FindOrCreatePacketBuffer(index, utils.AVMediaTypeAudio).Fetch() + p.SourceImpl.TransDeMuxer.(libflv.DeMuxer).InputAudio(data, ts) } -func (p *publisher) OnPartPacket(index int, data []byte, first bool) { - //audio - if index == 0 { - if !p.audioMark { - p.audioMemoryPool.Mark() - p.audioMark = true - } - - p.audioMemoryPool.Write(data) - //video - } else if index == 1 { - if !p.videoMark { - p.videoMemoryPool.Mark() - p.videoMark = true - } - - p.videoMemoryPool.Write(data) +func (p *Publisher) OnPartPacket(index int, mediaType utils.AVMediaType, data []byte, first bool) { + buffer := p.FindOrCreatePacketBuffer(index, mediaType) + if first { + buffer.Mark() } + + buffer.Write(data) } diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index f847980..842dd0a 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -27,7 +27,7 @@ func NewSession(conn net.Conn) Session { type sessionImpl struct { //解析rtmp协议栈 stack *librtmp.Stack - //publisher/sink, 在publish或play成功后赋值 + //Publisher/sink, 在publish或play成功后赋值 handle interface{} isPublisher bool @@ -39,14 +39,16 @@ func (s *sessionImpl) OnPublish(app, stream_ string, response chan utils.HookSta sourceId := app + "_" + stream_ source := NewPublisher(sourceId, s.stack, s.conn) + //设置推流的音视频回调 s.stack.SetOnPublishHandler(source) - s.stack.SetOnTransDeMuxerHandler(source) //推流事件Source统一处理, 是否已经存在, Hook回调.... - source.(*publisher).Publish(source.(*publisher), func() { + source.Publish(source, func() { s.handle = source s.isPublisher = true - source.Init() + + source.Init(source.Input) + go source.LoopEvent() response <- utils.HookStateOK }, func(state utils.HookState) { @@ -72,7 +74,7 @@ func (s *sessionImpl) OnPlay(app, stream_ string, response chan utils.HookState) func (s *sessionImpl) Input(conn net.Conn, data []byte) error { //如果是推流,并且握手成功,后续收到的包,都将发送给LoopEvent处理 if s.isPublisher { - s.handle.(*publisher).AddEvent(stream.SourceEventInput, data) + s.handle.(*Publisher).AddEvent(stream.SourceEventInput, data) return nil } else { return s.stack.Input(conn, data) @@ -92,10 +94,10 @@ func (s *sessionImpl) Close() { return } - _, ok := s.handle.(*publisher) + _, ok := s.handle.(*Publisher) if ok { if s.isPublisher { - s.handle.(*publisher).AddEvent(stream.SourceEventClose, nil) + s.handle.(*Publisher).AddEvent(stream.SourceEventClose, nil) } } else { sink := s.handle.(stream.ISink) diff --git a/rtmp/rtmp_stream.go b/rtmp/rtmp_stream.go index d7b365d..c2e947f 100644 --- a/rtmp/rtmp_stream.go +++ b/rtmp/rtmp_stream.go @@ -25,6 +25,10 @@ func NewTransStream(chunkSize int) stream.ITransStream { return transStream } +func TransStreamFactory(source stream.ISource, protocol stream.Protocol, streams []utils.AVStream) (stream.ITransStream, error) { + return NewTransStream(librtmp.ChunkSize), nil +} + func (t *TransStream) Input(packet utils.AVPacket) error { utils.Assert(t.TransStreamImpl.Completed) @@ -38,6 +42,7 @@ func (t *TransStream) Input(packet utils.AVPacket) error { var chunkPayloadOffset int var dts int64 var pts int64 + chunkHeaderSize := 12 if utils.AVCodecIdAAC == packet.CodecId() { dts = packet.ConvertDts(1024) @@ -47,6 +52,10 @@ func (t *TransStream) Input(packet utils.AVPacket) error { pts = packet.ConvertPts(1000) } + if dts >= 0xFFFFFF { + chunkHeaderSize += 4 + } + ct := pts - dts if utils.AVMediaTypeAudio == packet.MediaType() { @@ -76,19 +85,18 @@ func (t *TransStream) Input(packet utils.AVPacket) error { } //分配内存 - allocate := t.StreamBuffers[0].Allocate(12 + payloadSize + ((payloadSize - 1) / t.chunkSize)) + allocate := t.StreamBuffers[0].Allocate(chunkHeaderSize + payloadSize + ((payloadSize - 1) / t.chunkSize)) //写rtmp chunk header chunk.Length = payloadSize chunk.Timestamp = uint32(dts) n := chunk.ToBytes(allocate) - utils.Assert(n == 12) //写flv if videoPkt { - n += t.muxer.WriteVideoData(allocate[12:], uint32(ct), packet.KeyFrame(), false) + n += t.muxer.WriteVideoData(allocate[chunkHeaderSize:], uint32(ct), packet.KeyFrame(), false) } else { - n += t.muxer.WriteAudioData(allocate[12:], false) + n += t.muxer.WriteAudioData(allocate[chunkHeaderSize:], false) } n += chunk.WriteData(allocate[n:], data, t.chunkSize, chunkPayloadOffset) @@ -183,7 +191,7 @@ func (t *TransStream) WriteHeader() error { if videoStream != nil { tmp := n n += t.muxer.WriteVideoData(t.header[n+12:], 0, false, true) - extra := videoStream.Extra() + extra := videoStream.CodecParameters().DecoderConfRecord().ToMP4VC() copy(t.header[n+12:], extra) n += len(extra) diff --git a/rtsp/rtsp_sink.go b/rtsp/rtsp_sink.go index a00b12d..9aa9f01 100644 --- a/rtsp/rtsp_sink.go +++ b/rtsp/rtsp_sink.go @@ -11,6 +11,10 @@ import ( "time" ) +var ( + TransportManger stream.TransportManager +) + // 对于UDP而言, 每个sink维护一对UDPTransport // TCP直接单端口传输 type sink struct { @@ -39,13 +43,13 @@ func (s *sink) setTrackCount(count int) { s.tracks = make([]*rtspTrack, count) } -func (s *sink) addTrack(index int, tcp bool, ssrc uint32) (int, int, error) { +func (s *sink) addTrack(index int, tcp bool, ssrc uint32) (uint16, uint16, error) { utils.Assert(index < cap(s.tracks)) utils.Assert(s.tracks[index] == nil) var err error - var rtpPort int - var rtcpPort int + var rtpPort uint16 + var rtcpPort uint16 track := rtspTrack{ ssrc: ssrc, @@ -53,7 +57,7 @@ func (s *sink) addTrack(index int, tcp bool, ssrc uint32) (int, int, error) { if tcp { s.tcp = true } else { - err = rtspTransportManger.AllocPairTransport(func(port int) { + err = TransportManger.AllocPairTransport(func(port uint16) error { //rtp port var addr *net.UDPAddr addr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", "0.0.0.0", port)) @@ -64,7 +68,8 @@ func (s *sink) addTrack(index int, tcp bool, ssrc uint32) (int, int, error) { } rtpPort = port - }, func(port int) { + return nil + }, func(port uint16) error { //rtcp port var addr *net.UDPAddr addr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", "0.0.0.0", port)) @@ -78,6 +83,8 @@ func (s *sink) addTrack(index int, tcp bool, ssrc uint32) (int, int, error) { } rtcpPort = port + + return nil }) } diff --git a/rtsp/rtsp_stream.go b/rtsp/rtsp_stream.go index 95d8833..1fef4ec 100644 --- a/rtsp/rtsp_stream.go +++ b/rtsp/rtsp_stream.go @@ -45,6 +45,14 @@ func NewTransStream(addr net.IPAddr, urlFormat string) stream.ITransStream { return t } +func TransStreamFactory(source stream.ISource, protocol stream.Protocol, streams []utils.AVStream) (stream.ITransStream, error) { + trackFormat := source.Id() + "?track=%d" + return NewTransStream(net.IPAddr{ + IP: net.IP{}, + Zone: "", + }, trackFormat), nil +} + func (t *tranStream) onAllocBuffer(params interface{}) []byte { return t.rtpTracks[params.(int)].buffer[OverTcpHeaderSize:] } diff --git a/rtsp/transport_manager.go b/rtsp/transport_manager.go deleted file mode 100644 index 7fd313f..0000000 --- a/rtsp/transport_manager.go +++ /dev/null @@ -1,71 +0,0 @@ -package rtsp - -import ( - "fmt" - "github.com/yangjiechina/avformat/libbufio" - "github.com/yangjiechina/avformat/utils" -) - -type TransportManager interface { - init(startPort, endPort int) - - AllocTransport(tcp bool, cb func(port int)) error - - AllocPairTransport(cb func(port int)) error -} - -var rtspTransportManger transportManager - -func init() { - rtspTransportManger = transportManager{} - rtspTransportManger.init(20000, 30000) -} - -type transportManager struct { - startPort int - endPort int - nextPort int -} - -func (t *transportManager) init(startPort, endPort int) { - utils.Assert(endPort > startPort) - t.startPort = startPort - t.endPort = endPort + 1 - t.nextPort = startPort -} - -func (t *transportManager) AllocTransport(tcp bool, cb func(port int)) error { - loop := func(start, end int, tcp bool) int { - for i := start; i < end; i++ { - if used := utils.Used(i, tcp); !used { - cb(i) - return i - } - } - return -1 - } - - port := loop(t.nextPort, t.endPort, tcp) - if port == -1 { - port = loop(t.startPort, t.nextPort, tcp) - } - - if port == -1 { - return fmt.Errorf("no available ports in the [%d-%d] range", t.startPort, t.endPort) - } - - t.nextPort = t.nextPort + 1%t.endPort - t.nextPort = libbufio.MaxInt(t.nextPort, t.startPort) - return nil -} - -func (t *transportManager) AllocPairTransport(cb func(port int), cb2 func(port int)) error { - if err := t.AllocTransport(false, cb); err != nil { - return err - } - - if err := t.AllocTransport(false, cb2); err != nil { - return err - } - return nil -} diff --git a/stream/config.go b/stream/config.go index 03fea7c..ceaf039 100644 --- a/stream/config.go +++ b/stream/config.go @@ -1,5 +1,7 @@ package stream +import "strings" + const ( DefaultMergeWriteLatency = 350 ) @@ -9,6 +11,12 @@ type RtmpConfig struct { Addr string `json:"addr"` } +type RtspConfig struct { + RtmpConfig + Password string + Port [2]uint16 +} + type RecordConfig struct { Enable bool `json:"enable"` Format string `json:"format"` @@ -21,6 +29,38 @@ type HlsConfig struct { PlaylistLength int } +type LogConfig struct { + Level int + Name string + MaxSize int + MaxBackup int + MaxAge int + Compress bool +} + +type HttpConfig struct { + Enable bool + Addr string +} + +type GB28181Config struct { + Addr string + Transport string //"UDP|TCP" + Port [2]uint16 //单端口模式[0]=port/多端口模式[0]=start port, [0]=end port. +} + +func (g GB28181Config) EnableTCP() bool { + return strings.Contains(g.Transport, "TCP") +} + +func (g GB28181Config) EnableUDP() bool { + return strings.Contains(g.Transport, "UDP") +} + +func (g GB28181Config) IsMultiPort() bool { + return g.Port[1] > 0 && g.Port[1] > g.Port[0] +} + // M3U8Path 根据sourceId返回m3u8的磁盘路径 func (c HlsConfig) M3U8Path(sourceId string) string { return c.Dir + "/" + c.M3U8Format(sourceId) @@ -94,8 +134,16 @@ type AppConfig_ struct { //合并写的大小范围,应当大于一帧的时长,不超过一组GOP的时长,在实际发送流的时候也会遵循此条例. MergeWriteLatency int `json:"mw_latency"` Rtmp RtmpConfig - Hook HookConfig + Rtsp RtmpConfig + + Hook HookConfig Record RecordConfig Hls HlsConfig + + Log LogConfig + + Http HttpConfig + + GB28181 GB28181Config } diff --git a/stream/hook.go b/stream/hook.go index e9edb5f..909a532 100644 --- a/stream/hook.go +++ b/stream/hook.go @@ -30,11 +30,11 @@ type eventInfo struct { } func NewPlayHookEventInfo(stream, remoteAddr string, protocol Protocol) eventInfo { - return eventInfo{stream: stream, protocol: streamTypeToStr(protocol), remoteAddr: remoteAddr} + return eventInfo{stream: stream, protocol: protocol.ToString(), remoteAddr: remoteAddr} } func NewPublishHookEventInfo(stream, remoteAddr string, protocol SourceType) eventInfo { - return eventInfo{stream: stream, protocol: sourceTypeToStr(protocol), remoteAddr: remoteAddr} + return eventInfo{stream: stream, protocol: protocol.ToString(), remoteAddr: remoteAddr} } type HookHandler interface { diff --git a/stream/memory_pool.go b/stream/memory_pool.go index 3cb99c3..d0845ce 100644 --- a/stream/memory_pool.go +++ b/stream/memory_pool.go @@ -42,6 +42,10 @@ type MemoryPool interface { Clear() Empty() bool + + Capacity() int + + Size() int } func NewMemoryPool(capacity int) MemoryPool { @@ -221,3 +225,12 @@ func (m *memoryPool) Empty() bool { utils.Assert(!m.mark) return m.blockQueue.Size() < 1 } + +func (m *memoryPool) Capacity() int { + return m.capacity +} + +func (m *memoryPool) Size() int { + head, tail := m.Data() + return len(head) + len(tail) +} diff --git a/stream/sink.go b/stream/sink.go index 434fe55..5d1ada9 100644 --- a/stream/sink.go +++ b/stream/sink.go @@ -25,8 +25,6 @@ type ISink interface { Protocol() Protocol - ProtocolStr() string - // State 获取Sink状态, 调用前外部必须手动加锁 State() SessionState @@ -133,10 +131,6 @@ func (s *SinkImpl) Protocol() Protocol { return s.Protocol_ } -func (s *SinkImpl) ProtocolStr() string { - return streamTypeToStr(s.Protocol_) -} - func (s *SinkImpl) Lock() { s.lock.Lock() } @@ -213,5 +207,5 @@ func (s *SinkImpl) Close() { } func (s *SinkImpl) PrintInfo() string { - return fmt.Sprintf("%s-%v source:%s", s.ProtocolStr(), s.Id_, s.SourceId_) + return fmt.Sprintf("%s-%v source:%s", s.Protocol().ToString(), s.Id_, s.SourceId_) } diff --git a/stream/sink_hook.go b/stream/sink_hook.go index 5ee2870..e27a5fc 100644 --- a/stream/sink_hook.go +++ b/stream/sink_hook.go @@ -10,7 +10,7 @@ func HookPlaying(s ISink, success func(), failure func(state utils.HookState)) { f := func() { source := SourceManager.Find(s.SourceId()) if source == nil { - log.Sugar.Infof("添加sink到等待队列 sink:%s-%v source:%s", s.ProtocolStr(), s.Id(), s.SourceId()) + log.Sugar.Infof("添加sink到等待队列 sink:%s-%v source:%s", s.Protocol().ToString(), s.Id(), s.SourceId()) { s.Lock() @@ -24,7 +24,7 @@ func HookPlaying(s ISink, success func(), failure func(state utils.HookState)) { } } } else { - log.Sugar.Debugf("发送播放事件 sink:%s-%v source:%s", s.ProtocolStr(), s.Id(), s.SourceId()) + log.Sugar.Debugf("发送播放事件 sink:%s-%v source:%s", s.Protocol().ToString(), s.Id(), s.SourceId()) source.AddEvent(SourceEventPlay, s) } @@ -46,7 +46,7 @@ func HookPlaying(s ISink, success func(), failure func(state utils.HookState)) { success() } }, func(response *http.Response, err error) { - log.Sugar.Errorf("Hook播放事件响应失败 err:%s sink:%s-%v source:%s", err.Error(), s.ProtocolStr(), s.Id(), s.SourceId()) + log.Sugar.Errorf("Hook播放事件响应失败 err:%s sink:%s-%v source:%s", err.Error(), s.Protocol().ToString(), s.Id(), s.SourceId()) if failure != nil { failure(utils.HookStateFailure) @@ -54,7 +54,7 @@ func HookPlaying(s ISink, success func(), failure func(state utils.HookState)) { }) if err != nil { - log.Sugar.Errorf("Hook播放事件发送失败 err:%s sink:%s-%v source:%s", err.Error(), s.ProtocolStr(), s.Id(), s.SourceId()) + log.Sugar.Errorf("Hook播放事件发送失败 err:%s sink:%s-%v source:%s", err.Error(), s.Protocol().ToString(), s.Id(), s.SourceId()) if failure != nil { failure(utils.HookStateFailure) @@ -76,7 +76,7 @@ func HookPlayingDone(s ISink, success func(), failure func(state utils.HookState success() } }, func(response *http.Response, err error) { - log.Sugar.Errorf("Hook播放结束事件响应失败 err:%s sink:%s-%v source:%s", err.Error(), s.ProtocolStr(), s.Id(), s.SourceId()) + log.Sugar.Errorf("Hook播放结束事件响应失败 err:%s sink:%s-%v source:%s", err.Error(), s.Protocol().ToString(), s.Id(), s.SourceId()) if failure != nil { failure(utils.HookStateFailure) @@ -84,7 +84,7 @@ func HookPlayingDone(s ISink, success func(), failure func(state utils.HookState }) if err != nil { - log.Sugar.Errorf("Hook播放结束事件发送失败 err:%s sink:%s-%v source:%s", err.Error(), s.ProtocolStr(), s.Id(), s.SourceId()) + log.Sugar.Errorf("Hook播放结束事件发送失败 err:%s sink:%s-%v source:%s", err.Error(), s.Protocol().ToString(), s.Id(), s.SourceId()) if failure != nil { failure(utils.HookStateFailure) diff --git a/stream/source.go b/stream/source.go index 42519ba..635fae1 100644 --- a/stream/source.go +++ b/stream/source.go @@ -5,7 +5,6 @@ import ( "github.com/yangjiechina/live-server/log" "net" "net/http" - "sync" "time" "github.com/yangjiechina/avformat/stream" @@ -36,17 +35,11 @@ const ( ProtocolHls = Protocol(4) ProtocolRtc = Protocol(5) - ProtocolRtmpStr = "rtmp" - - SourceEventPlay = SourceEvent(1) - SourceEventPlayDone = SourceEvent(2) - SourceEventInput = SourceEvent(3) - SourceEventClose = SourceEvent(4) - - // TransMuxerHeaderMaxSize 传输流协议头的最大长度 - // 在解析流分配AVPacket的Data时, 如果没有开启合并写, 提前预留指定长度的字节数量. - // 在封装传输流时, 直接在预留头中添加对应传输流的协议头,减少或免内存拷贝. 在传输flv以及转换AVCC和AnnexB格式时有显著提升. - TransMuxerHeaderMaxSize = 30 + SourceEventPlay = SourceEvent(1) + SourceEventPlayDone = SourceEvent(2) + SourceEventInput = SourceEvent(3) + SourceEventClose = SourceEvent(4) + SourceEventProbeTimeout = SourceEvent(5) ) const ( @@ -59,40 +52,42 @@ const ( SessionStateClose = SessionState(7) //关闭状态 ) -func sourceTypeToStr(sourceType SourceType) string { - if SourceTypeRtmp == sourceType { +func (s SourceType) ToString() string { + if SourceTypeRtmp == s { return "rtmp" - } else if SourceType28181 == sourceType { + } else if SourceType28181 == s { return "28181" - } else if SourceType1078 == sourceType { - return "1078" + } else if SourceType1078 == s { + return "jt1078" } - return "" + panic(fmt.Sprintf("unknown source type %d", s)) } -func streamTypeToStr(protocol Protocol) string { - if ProtocolRtmp == protocol { +func (p Protocol) ToString() string { + if ProtocolRtmp == p { return "rtmp" - } else if ProtocolFlv == protocol { + } else if ProtocolFlv == p { return "flv" - } else if ProtocolRtsp == protocol { + } else if ProtocolRtsp == p { return "rtsp" - } else if ProtocolHls == protocol { + } else if ProtocolHls == p { return "hls" - } else if ProtocolRtc == protocol { + } else if ProtocolRtc == p { return "rtc" } - return "" + panic(fmt.Sprintf("unknown stream protocol %d", p)) } +// ISource 父类Source负责, 除解析流以外的所有事情 type ISource interface { // Id Source的唯一ID/** Id() string // Input 输入推流数据 - Input(data []byte) + //@Return bool fatal error.释放Source + Input(data []byte) error // OriginStreams 返回推流的原始Streams OriginStreams() []utils.AVStream @@ -100,8 +95,8 @@ type ISource interface { // TranscodeStreams 返回转码的Streams TranscodeStreams() []utils.AVStream - // AddSink 添加Sink, 在此之前请确保Sink已经握手、授权通过. 如果Source还未WriteHeader,将Sink添加到等待队列. - // 匹配拉流的编码器, 创建TransMuxer或向存在TransMuxer添加Sink + // AddSink 添加Sink, 在此之前请确保Sink已经握手、授权通过. 如果Source还未WriteHeader,先将Sink添加到等待队列. + // 匹配拉流的编码器, 创建TransStream或向存在TransStream添加Sink AddSink(sink ISink) bool // RemoveSink 删除Sink/** @@ -116,10 +111,34 @@ type ISource interface { // 将Sink添加到等待队列 Close() + // Type 推流类型 Type() SourceType -} -type CreateSource func(id string, type_ SourceType, handler stream.OnDeMuxerHandler) + // FindOrCreatePacketBuffer 查找或者创建AVPacket的内存池 + FindOrCreatePacketBuffer(index int, mediaType utils.AVMediaType) MemoryPool + + // OnDiscardPacket GOP缓存溢出回调, 释放AVPacket + OnDiscardPacket(pkt interface{}) + + // OnDeMuxStream 解析出AVStream回调 + OnDeMuxStream(stream utils.AVStream) + + // IsCompleted 是否已经WireHeader + IsCompleted() bool + + // OnDeMuxStreamDone 所有track解析完毕, 后续的OnDeMuxStream回调不再处理 + OnDeMuxStreamDone() + + // OnDeMuxPacket 解析出AvPacket回调 + OnDeMuxPacket(packet utils.AVPacket) + + // OnDeMuxDone 所有流解析完毕回调 + OnDeMuxDone() + + LoopEvent() + + Init(input func(data []byte) error) +} var TranscoderFactory func(src utils.AVStream, dst utils.AVStream) transcode.ITranscoder @@ -132,18 +151,18 @@ type SourceImpl struct { Conn net.Conn TransDeMuxer stream.DeMuxer //负责从推流协议中解析出AVStream和AVPacket - recordSink ISink //每个Source唯一的一个录制流 - hlsStream ITransStream //hls不等拉流,创建时直接生成 + recordSink ISink //每个Source的录制流 + hlsStream ITransStream //如果开开启HLS传输流, 不等拉流时, 创建直接生成 audioTranscoders []transcode.ITranscoder //音频解码器 videoTranscoders []transcode.ITranscoder //视频解码器 originStreams StreamManager //推流的音视频Streams - allStreams StreamManager //推流Streams+转码器获得的Streams - buffers []StreamBuffer + allStreams StreamManager //推流Streams+转码器获得的Stream + gopBuffers []StreamBuffer //推流每路的GOP缓存 + pktBuffers [8]MemoryPool //推流每路的AVPacket缓存, AVPacket的data从该内存池中分配. 在GOP缓存溢出时,释放池中内存. - Input_ func(data []byte) //解决多态无法传递给子类的问题 + Input_ func(data []byte) error //解决多态无法传递给子类的问题 completed bool - mutex sync.Mutex //只用作AddStream期间 probeTimer *time.Timer //所有的输出协议, 持有Sink @@ -156,21 +175,26 @@ type SourceImpl struct { closeEvent chan byte playingEventQueue chan ISink playingDoneEventQueue chan ISink + probeTimoutEvent chan bool } func (s *SourceImpl) Id() string { return s.Id_ } -func (s *SourceImpl) Init() { +func (s *SourceImpl) Init(input func(data []byte) error) { + s.Input_ = input + //初始化事件接收缓冲区 s.SetState(SessionStateTransferring) + //收流和网络断开的chan都阻塞执行 s.inputEvent = make(chan []byte) s.responseEvent = make(chan byte) s.closeEvent = make(chan byte) s.playingEventQueue = make(chan ISink, 128) s.playingDoneEventQueue = make(chan ISink, 128) + s.probeTimoutEvent = make(chan bool) if s.transStreams == nil { s.transStreams = make(map[TransStreamId]ITransStream, 10) @@ -183,20 +207,58 @@ func (s *SourceImpl) Init() { //创建HLS输出流 if AppConfig.Hls.Enable { - s.hlsStream = TransStreamFactory(s, ProtocolHls, nil) + hlsStream, err := CreateTransStream(s, ProtocolHls, nil) + if err != nil { + panic(err) + } + + s.hlsStream = hlsStream s.transStreams[0x100] = s.hlsStream } } +// FindOrCreatePacketBuffer 查找或者创建AVPacket的内存池 +func (s *SourceImpl) FindOrCreatePacketBuffer(index int, mediaType utils.AVMediaType) MemoryPool { + if index >= cap(s.pktBuffers) { + panic("流路数过多...") + } + + if s.pktBuffers[index] == nil { + if utils.AVMediaTypeAudio == mediaType { + s.pktBuffers[index] = NewMemoryPool(48000 * 64) + } else if AppConfig.GOPCache { + //开启GOP缓存 + //以每秒钟4M码率大小创建视频内存池 + s.pktBuffers[index] = NewMemoryPool(4096 * 1024) + } else { + //未开启GOP缓存 + //以每秒钟4M的1/8码率大小创建视频内存池 + s.pktBuffers[index] = NewMemoryPool(4096 * 1024 / 8) + } + } + + return s.pktBuffers[index] +} + func (s *SourceImpl) LoopEvent() { for { select { case data := <-s.inputEvent: - s.Input_(data) + if err := s.Input_(data); err != nil { + log.Sugar.Errorf("处理输入流失败 释放source:%s err:%s", s.Id_, err.Error()) + s.Close() + } + s.responseEvent <- 0 break case sink := <-s.playingEventQueue: - s.AddSink(sink) + if !s.completed { + AddSinkToWaitingQueue(sink.SourceId(), sink) + } else { + if !s.AddSink(sink) { + sink.Close() + } + } break case sink := <-s.playingDoneEventQueue: s.RemoveSink(sink) @@ -204,12 +266,15 @@ func (s *SourceImpl) LoopEvent() { case _ = <-s.closeEvent: s.Close() return + case _ = <-s.probeTimoutEvent: + s.writeHeader() + break } } } -func (s *SourceImpl) Input(data []byte) { - +func (s *SourceImpl) Input(data []byte) error { + return nil } func (s *SourceImpl) OriginStreams() []utils.AVStream { @@ -228,7 +293,7 @@ func IsSupportMux(protocol Protocol, audioCodecId, videoCodecId utils.AVCodecID) return true } -// 分发每路StreamBuffer给传输流 +// 将GOP缓存发送给TransStream // 按照时间戳升序发送 func (s *SourceImpl) dispatchStreamBuffer(transStream ITransStream, streams []utils.AVStream) { size := len(streams) @@ -238,12 +303,12 @@ func (s *SourceImpl) dispatchStreamBuffer(transStream ITransStream, streams []ut min := int64(0xFFFFFFFF) //找出最小的时间戳 - for index, stream := range streams[:size] { - if s.buffers[stream.Index()].Size() == indexs[index] { + for index, stream_ := range streams[:size] { + if s.gopBuffers[stream_.Index()].Size() == indexs[index] { continue } - pkt := s.buffers[stream.Index()].Peek(indexs[index]).(utils.AVPacket) + pkt := s.gopBuffers[stream_.Index()].Peek(indexs[index]).(utils.AVPacket) v := pkt.Dts() if min == 0xFFFFFFFF { min = v @@ -256,8 +321,8 @@ func (s *SourceImpl) dispatchStreamBuffer(transStream ITransStream, streams []ut break } - for index, stream := range streams[:size] { - buffer := s.buffers[stream.Index()] + for index, stream_ := range streams[:size] { + buffer := s.gopBuffers[stream_.Index()] if buffer.Size() == indexs[index] { continue } @@ -313,12 +378,12 @@ func (s *SourceImpl) AddSink(sink ISink) bool { var streams [5]utils.AVStream var size int - for _, stream := range s.originStreams.All() { - if disableVideo && stream.Type() == utils.AVMediaTypeVideo { + for _, stream_ := range s.originStreams.All() { + if disableVideo && stream_.Type() == utils.AVMediaTypeVideo { continue } - streams[size] = stream + streams[size] = stream_ size++ } @@ -329,9 +394,15 @@ func (s *SourceImpl) AddSink(sink ISink) bool { s.transStreams = make(map[TransStreamId]ITransStream, 10) } //创建一个新的传输流 - log.Sugar.Debugf("创建%s-stream", sink.ProtocolStr()) + log.Sugar.Debugf("创建%s-stream", sink.Protocol().ToString()) + + var err error + transStream, err = CreateTransStream(s, sink.Protocol(), streams[:size]) + if err != nil { + log.Sugar.Errorf("创建传输流失败 err:%s source:%s", err.Error(), s.Id_) + return false + } - transStream = TransStreamFactory(s, sink.Protocol(), streams[:size]) s.transStreams[transStreamId] = transStream for i := 0; i < size; i++ { @@ -420,53 +491,59 @@ func (s *SourceImpl) Close() { s.transStreams = nil } -func (s *SourceImpl) OnDeMuxStream(stream utils.AVStream) (bool, StreamBuffer) { - //整块都受保护 确保Add的Stream 都能WriteHeader - s.mutex.Lock() - defer s.mutex.Unlock() +func (s *SourceImpl) OnDiscardPacket(pkt interface{}) { + packet := pkt.(utils.AVPacket) + s.FindOrCreatePacketBuffer(packet.Index(), packet.MediaType()).FreeHead() +} +func (s *SourceImpl) OnDeMuxStream(stream utils.AVStream) { if s.completed { - fmt.Printf("添加Stream失败 Source: %s已经WriteHeader", s.Id_) - return false, nil + log.Sugar.Warnf("添加Stream失败 Source: %s已经WriteHeader", s.Id_) + return } s.originStreams.Add(stream) s.allStreams.Add(stream) //启动探测超时计时器 - if len(s.originStreams.All()) == 1 && AppConfig.ProbeTimeout > 100 { - s.probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, s.writeHeader) + if len(s.originStreams.All()) == 1 { + if AppConfig.ProbeTimeout == 0 { + AppConfig.ProbeTimeout = 2000 + } + + s.probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, func() { + s.probeTimoutEvent <- true + }) } //为每个Stream创建对应的Buffer if AppConfig.GOPCache { buffer := NewStreamBuffer(200) //OnDeMuxStream的调用顺序,就是AVStream和AVPacket的Index的递增顺序 - s.buffers = append(s.buffers, buffer) - return true, buffer + s.gopBuffers = append(s.gopBuffers, buffer) + //设置GOP缓存溢出回调 + buffer.SetDiscardHandler(s.OnDiscardPacket) } - - return true, nil } // 从DeMuxer解析完Stream后, 处理等待Sinks func (s *SourceImpl) writeHeader() { - { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.completed { - return - } - s.completed = true + if s.completed { + fmt.Printf("添加Stream失败 Source: %s已经WriteHeader", s.Id_) + return } + s.completed = true + if s.probeTimer != nil { s.probeTimer.Stop() } sinks := PopWaitingSinks(s.Id_) for _, sink := range sinks { - s.AddSink(sink) + if !s.AddSink(sink) { + sink.Close() + } } if s.hlsStream != nil { @@ -478,20 +555,31 @@ func (s *SourceImpl) writeHeader() { } } +func (s *SourceImpl) IsCompleted() bool { + return s.completed +} + func (s *SourceImpl) OnDeMuxStreamDone() { s.writeHeader() } func (s *SourceImpl) OnDeMuxPacket(packet utils.AVPacket) { if AppConfig.GOPCache { - buffer := s.buffers[packet.Index()] + buffer := s.gopBuffers[packet.Index()] buffer.AddPacket(packet, packet.KeyFrame(), packet.Dts()) } //分发给各个传输流 - for _, stream := range s.transStreams { - stream.Input(packet) + for _, stream_ := range s.transStreams { + stream_.Input(packet) } + + if AppConfig.GOPCache { + return + } + + //未开启GOP缓存,释放掉内存 + s.FindOrCreatePacketBuffer(packet.Index(), packet.MediaType()).FreeTail() } func (s *SourceImpl) OnDeMuxDone() { diff --git a/stream/trans_stream.go b/stream/trans_stream.go index c8ce2b8..a5ba311 100644 --- a/stream/trans_stream.go +++ b/stream/trans_stream.go @@ -1,6 +1,7 @@ package stream import ( + "fmt" "github.com/yangjiechina/avformat/stream" "github.com/yangjiechina/avformat/utils" ) @@ -8,8 +9,13 @@ import ( // TransStreamId 每个传输流的唯一Id,由协议+流Id组成 type TransStreamId uint64 -// AVCodecID转为byte的对应关系 -var narrowCodecIds map[int]byte +type TransStreamFactory func(source ISource, protocol Protocol, streams []utils.AVStream) (ITransStream, error) + +var ( + // AVCodecID转为byte的对应关系 + narrowCodecIds map[int]byte + transStreamFactories map[Protocol]TransStreamFactory +) func init() { narrowCodecIds = map[int]byte{ @@ -24,6 +30,35 @@ func init() { int(utils.AVCodecIdMP3): 102, int(utils.AVCodecIdOPUS): 103, } + + transStreamFactories = make(map[Protocol]TransStreamFactory, 8) +} + +func RegisterTransStreamFactory(protocol Protocol, streamFunc TransStreamFactory) { + _, ok := transStreamFactories[protocol] + if ok { + panic(fmt.Sprintf("%s has been registered", protocol.ToString())) + } + + transStreamFactories[protocol] = streamFunc +} + +func FindTransStreamFactory(protocol Protocol) (TransStreamFactory, error) { + f, ok := transStreamFactories[protocol] + if !ok { + return nil, fmt.Errorf("unknown protocol %s", protocol.ToString()) + } + + return f, nil +} + +func CreateTransStream(source ISource, protocol Protocol, streams []utils.AVStream) (ITransStream, error) { + factory, err := FindTransStreamFactory(protocol) + if err != nil { + return nil, err + } + + return factory(source, protocol, streams) } // GenerateTransStreamId 根据传入的推拉流协议和编码器ID生成StreamId @@ -62,8 +97,6 @@ func GenerateTransStreamId(protocol Protocol, ids ...utils.AVStream) TransStream return TransStreamId(streamId) } -var TransStreamFactory func(source ISource, protocol Protocol, streams []utils.AVStream) ITransStream - // ITransStream 讲AVPacket封装成传输流,转发给各个Sink type ITransStream interface { Init() diff --git a/stream/transport_manager.go b/stream/transport_manager.go new file mode 100644 index 0000000..41198b4 --- /dev/null +++ b/stream/transport_manager.go @@ -0,0 +1,72 @@ +package stream + +import ( + "fmt" + "github.com/yangjiechina/avformat/libbufio" + "github.com/yangjiechina/avformat/utils" + "sync" +) + +type TransportManager interface { + AllocTransport(tcp bool, cb func(port uint16) error) error + + AllocPairTransport(cb, c2 func(port uint16) error) error +} + +func NewTransportManager(start, end uint16) TransportManager { + utils.Assert(end > start) + + return &transportManager{ + startPort: start, + endPort: end, + nextPort: start, + } +} + +type transportManager struct { + startPort uint16 + endPort uint16 + nextPort uint16 + lock sync.Mutex +} + +func (t *transportManager) AllocTransport(tcp bool, cb func(port uint16) error) error { + loop := func(start, end uint16, tcp bool) (uint16, error) { + for i := start; i < end; i++ { + if used := utils.Used(int(i), tcp); !used { + return i, cb(i) + } + } + + return 0, nil + } + + t.lock.Lock() + defer t.lock.Unlock() + + port, err := loop(t.nextPort, t.endPort, tcp) + if port == 0 { + port, err = loop(t.startPort, t.nextPort, tcp) + } + + if port == 0 { + return fmt.Errorf("no available ports in the [%d-%d] range", t.startPort, t.endPort) + } else if err != nil { + return err + } + + t.nextPort = t.nextPort + 1%t.endPort + t.nextPort = uint16(libbufio.MaxInt(int(t.nextPort), int(t.startPort))) + return nil +} + +func (t *transportManager) AllocPairTransport(cb func(port uint16) error, cb2 func(port uint16) error) error { + if err := t.AllocTransport(false, cb); err != nil { + return err + } + + if err := t.AllocTransport(false, cb2); err != nil { + return err + } + return nil +}