diff --git a/internal/core/api_test.go b/internal/core/api_test.go index 36128ad0..d6c5762e 100644 --- a/internal/core/api_test.go +++ b/internal/core/api_test.go @@ -578,7 +578,7 @@ func TestAPIProtocolList(t *testing.T) { err = conn.InitializeClient(u, true) require.NoError(t, err) - err = conn.WriteTracks(testFormatH264, nil) + _, err = rtmp.NewWriter(conn, testFormatH264, nil) require.NoError(t, err) time.Sleep(500 * time.Millisecond) @@ -833,7 +833,7 @@ func TestAPIProtocolGet(t *testing.T) { err = conn.InitializeClient(u, true) require.NoError(t, err) - err = conn.WriteTracks(testFormatH264, nil) + _, err = rtmp.NewWriter(conn, testFormatH264, nil) require.NoError(t, err) time.Sleep(500 * time.Millisecond) @@ -1155,7 +1155,7 @@ func TestAPIProtocolKick(t *testing.T) { err = conn.InitializeClient(u, true) require.NoError(t, err) - err = conn.WriteTracks(testFormatH264, nil) + _, err = rtmp.NewWriter(conn, testFormatH264, nil) require.NoError(t, err) case "webrtc": diff --git a/internal/core/metrics_test.go b/internal/core/metrics_test.go index 245bc0bd..3a980f6a 100644 --- a/internal/core/metrics_test.go +++ b/internal/core/metrics_test.go @@ -101,7 +101,7 @@ webrtc_sessions_bytes_sent 0 PacketizationMode: 1, } - err = conn.WriteTracks(videoTrack, nil) + _, err = rtmp.NewWriter(conn, videoTrack, nil) require.NoError(t, err) time.Sleep(500 * time.Millisecond) diff --git a/internal/core/rtmp_conn.go b/internal/core/rtmp_conn.go index 16bf8dc2..2ad03352 100644 --- a/internal/core/rtmp_conn.go +++ b/internal/core/rtmp_conn.go @@ -13,20 +13,16 @@ import ( "github.com/bluenviron/gortsplib/v3/pkg/formats" "github.com/bluenviron/gortsplib/v3/pkg/media" "github.com/bluenviron/gortsplib/v3/pkg/ringbuffer" - "github.com/bluenviron/mediacommon/pkg/codecs/av1" "github.com/bluenviron/mediacommon/pkg/codecs/h264" "github.com/bluenviron/mediacommon/pkg/codecs/mpeg2audio" "github.com/bluenviron/mediacommon/pkg/codecs/mpeg4audio" "github.com/google/uuid" - "github.com/notedit/rtmp/format/flv/flvio" "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/externalcmd" "github.com/bluenviron/mediamtx/internal/formatprocessor" "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/rtmp" - "github.com/bluenviron/mediamtx/internal/rtmp/h264conf" - "github.com/bluenviron/mediamtx/internal/rtmp/message" "github.com/bluenviron/mediamtx/internal/stream" ) @@ -42,158 +38,6 @@ func pathNameAndQuery(inURL *url.URL) (string, url.Values, string) { return pathName, ur.Query(), ur.RawQuery } -type rtmpWriteFunc func(msg interface{}) error - -func getRTMPWriteFunc(medi *media.Media, format formats.Format, stream *stream.Stream) rtmpWriteFunc { - switch format.(type) { - case *formats.H264: - return func(msg interface{}) error { - tmsg := msg.(*message.Video) - - switch tmsg.Type { - case message.VideoTypeConfig: - var conf h264conf.Conf - err := conf.Unmarshal(tmsg.Payload) - if err != nil { - return fmt.Errorf("unable to parse H264 config: %v", err) - } - - au := [][]byte{ - conf.SPS, - conf.PPS, - } - - stream.WriteUnit(medi, format, &formatprocessor.UnitH264{ - BaseUnit: formatprocessor.BaseUnit{ - NTP: time.Now(), - }, - PTS: tmsg.DTS + tmsg.PTSDelta, - AU: au, - }) - - case message.VideoTypeAU: - au, err := h264.AVCCUnmarshal(tmsg.Payload) - if err != nil { - return fmt.Errorf("unable to decode AVCC: %v", err) - } - - stream.WriteUnit(medi, format, &formatprocessor.UnitH264{ - BaseUnit: formatprocessor.BaseUnit{ - NTP: time.Now(), - }, - PTS: tmsg.DTS + tmsg.PTSDelta, - AU: au, - }) - } - - return nil - } - - case *formats.H265: - return func(msg interface{}) error { - switch tmsg := msg.(type) { - case *message.Video: - au, err := h264.AVCCUnmarshal(tmsg.Payload) - if err != nil { - return fmt.Errorf("unable to decode AVCC: %v", err) - } - - stream.WriteUnit(medi, format, &formatprocessor.UnitH265{ - BaseUnit: formatprocessor.BaseUnit{ - NTP: time.Now(), - }, - PTS: tmsg.DTS + tmsg.PTSDelta, - AU: au, - }) - - case *message.ExtendedFramesX: - au, err := h264.AVCCUnmarshal(tmsg.Payload) - if err != nil { - return fmt.Errorf("unable to decode AVCC: %v", err) - } - - stream.WriteUnit(medi, format, &formatprocessor.UnitH265{ - BaseUnit: formatprocessor.BaseUnit{ - NTP: time.Now(), - }, - PTS: tmsg.DTS, - AU: au, - }) - - case *message.ExtendedCodedFrames: - au, err := h264.AVCCUnmarshal(tmsg.Payload) - if err != nil { - return fmt.Errorf("unable to decode AVCC: %v", err) - } - - stream.WriteUnit(medi, format, &formatprocessor.UnitH265{ - BaseUnit: formatprocessor.BaseUnit{ - NTP: time.Now(), - }, - PTS: tmsg.DTS + tmsg.PTSDelta, - AU: au, - }) - } - - return nil - } - - case *formats.AV1: - return func(msg interface{}) error { - if tmsg, ok := msg.(*message.ExtendedCodedFrames); ok { - obus, err := av1.BitstreamUnmarshal(tmsg.Payload, true) - if err != nil { - return fmt.Errorf("unable to decode bitstream: %v", err) - } - - stream.WriteUnit(medi, format, &formatprocessor.UnitAV1{ - BaseUnit: formatprocessor.BaseUnit{ - NTP: time.Now(), - }, - PTS: tmsg.DTS, - OBUs: obus, - }) - } - - return nil - } - - case *formats.MPEG2Audio: - return func(msg interface{}) error { - tmsg := msg.(*message.Audio) - - stream.WriteUnit(medi, format, &formatprocessor.UnitMPEG2Audio{ - BaseUnit: formatprocessor.BaseUnit{ - NTP: time.Now(), - }, - PTS: tmsg.DTS, - Frames: [][]byte{tmsg.Payload}, - }) - - return nil - } - - case *formats.MPEG4Audio: - return func(msg interface{}) error { - tmsg := msg.(*message.Audio) - - if tmsg.AACType == message.AudioAACTypeAU { - stream.WriteUnit(medi, format, &formatprocessor.UnitMPEG4AudioGeneric{ - BaseUnit: formatprocessor.BaseUnit{ - NTP: time.Now(), - }, - PTS: tmsg.DTS, - AUs: [][]byte{tmsg.Payload}, - }) - } - - return nil - } - } - - return nil -} - type rtmpConnState int const ( @@ -209,7 +53,7 @@ type rtmpConnPathManager interface { type rtmpConnParent interface { logger.Writer - connClose(*rtmpConn) + closeConn(*rtmpConn) } type rtmpConn struct { @@ -322,36 +166,32 @@ func (c *rtmpConn) run() { }() } - ctx, cancel := context.WithCancel(c.ctx) - runErr := make(chan error) - go func() { - runErr <- c.runInner(ctx) + err := func() error { + readerErr := make(chan error) + go func() { + readerErr <- c.runReader() + }() + + select { + case err := <-readerErr: + c.nconn.Close() + return err + + case <-c.ctx.Done(): + c.nconn.Close() + <-readerErr + return errors.New("terminated") + } }() - var err error - select { - case err = <-runErr: - cancel() - - case <-c.ctx.Done(): - cancel() - <-runErr - err = errors.New("terminated") - } - c.ctxCancel() - c.parent.connClose(c) + c.parent.closeConn(c) c.Log(logger.Info, "closed (%v)", err) } -func (c *rtmpConn) runInner(ctx context.Context) error { - go func() { - <-ctx.Done() - c.nconn.Close() - }() - +func (c *rtmpConn) runReader() error { c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout))) c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) u, publish, err := c.conn.InitializeServer() @@ -360,12 +200,12 @@ func (c *rtmpConn) runInner(ctx context.Context) error { } if !publish { - return c.runRead(ctx, u) + return c.runRead(u) } return c.runPublish(u) } -func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { +func (c *rtmpConn) runRead(u *url.URL) error { pathName, query, rawQuery := pathNameAndQuery(u) res := c.pathManager.readerAdd(pathReaderAddReq{ @@ -399,22 +239,32 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { ringBuffer, _ := ringbuffer.New(uint64(c.readBufferCount)) go func() { - <-ctx.Done() + <-c.ctx.Done() ringBuffer.Close() }() var medias media.Medias videoFirstIDRFound := false var videoStartDTS time.Duration + var w *rtmp.Writer - videoMedia, videoFormat := c.findVideoFormat(res.stream, ringBuffer, - &videoFirstIDRFound, &videoStartDTS) + videoMedia, videoFormat := c.setupVideo( + &w, + res.stream, + ringBuffer, + &videoFirstIDRFound, + &videoStartDTS) if videoMedia != nil { medias = append(medias, videoMedia) } - audioMedia, audioFormat := c.findAudioFormat(res.stream, ringBuffer, - videoFormat, &videoFirstIDRFound, &videoStartDTS) + audioMedia, audioFormat := c.setupAudio( + &w, + res.stream, + ringBuffer, + videoFormat, + &videoFirstIDRFound, + &videoStartDTS) if audioFormat != nil { medias = append(medias, audioMedia) } @@ -447,7 +297,8 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { }() } - err := c.conn.WriteTracks(videoFormat, audioFormat) + var err error + w, err = rtmp.NewWriter(c.conn, videoFormat, audioFormat) if err != nil { return err } @@ -468,15 +319,19 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { } } -func (c *rtmpConn) findVideoFormat(stream *stream.Stream, ringBuffer *ringbuffer.RingBuffer, - videoFirstIDRFound *bool, videoStartDTS *time.Duration, +func (c *rtmpConn) setupVideo( + w **rtmp.Writer, + stream *stream.Stream, + ringBuffer *ringbuffer.RingBuffer, + videoFirstIDRFound *bool, + videoStartDTS *time.Duration, ) (*media.Media, formats.Format) { var videoFormatH264 *formats.H264 videoMedia := stream.Medias().FindFormat(&videoFormatH264) if videoFormatH264 != nil { - videoStartPTSFilled := false - var videoStartPTS time.Duration + startPTSFilled := false + var startPTS time.Duration var videoDTSExtractor *h264.DTSExtractor stream.AddReader(c, videoMedia, videoFormatH264, func(unit formatprocessor.Unit) { @@ -487,11 +342,11 @@ func (c *rtmpConn) findVideoFormat(stream *stream.Stream, ringBuffer *ringbuffer return nil } - if !videoStartPTSFilled { - videoStartPTSFilled = true - videoStartPTS = tunit.PTS + if !startPTSFilled { + startPTSFilled = true + startPTS = tunit.PTS } - pts := tunit.PTS - videoStartPTS + pts := tunit.PTS - startPTS idrPresent := false nonIDRPresent := false @@ -542,27 +397,8 @@ func (c *rtmpConn) findVideoFormat(stream *stream.Stream, ringBuffer *ringbuffer pts -= *videoStartDTS } - avcc, err := h264.AVCCMarshal(tunit.AU) - if err != nil { - return err - } - c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) - err = c.conn.WriteMessage(&message.Video{ - ChunkStreamID: message.VideoChunkStreamID, - MessageStreamID: 0x1000000, - Codec: message.CodecH264, - IsKeyFrame: idrPresent, - Type: message.VideoTypeAU, - Payload: avcc, - DTS: dts, - PTSDelta: pts - dts, - }) - if err != nil { - return err - } - - return nil + return (*w).WriteH264(pts, dts, idrPresent, tunit.AU) }) }) @@ -572,7 +408,8 @@ func (c *rtmpConn) findVideoFormat(stream *stream.Stream, ringBuffer *ringbuffer return nil, nil } -func (c *rtmpConn) findAudioFormat( +func (c *rtmpConn) setupAudio( + w **rtmp.Writer, stream *stream.Stream, ringBuffer *ringbuffer.RingBuffer, videoFormat formats.Format, @@ -583,8 +420,8 @@ func (c *rtmpConn) findAudioFormat( audioMedia := stream.Medias().FindFormat(&audioFormatMPEG4Generic) if audioMedia != nil { - audioStartPTSFilled := false - var audioStartPTS time.Duration + startPTSFilled := false + var startPTS time.Duration stream.AddReader(c, audioMedia, audioFormatMPEG4Generic, func(unit formatprocessor.Unit) { ringBuffer.Push(func() error { @@ -594,11 +431,11 @@ func (c *rtmpConn) findAudioFormat( return nil } - if !audioStartPTSFilled { - audioStartPTSFilled = true - audioStartPTS = tunit.PTS + if !startPTSFilled { + startPTSFilled = true + startPTS = tunit.PTS } - pts := tunit.PTS - audioStartPTS + pts := tunit.PTS - startPTS if videoFormat != nil { if !*videoFirstIDRFound { @@ -613,18 +450,11 @@ func (c *rtmpConn) findAudioFormat( for i, au := range tunit.AUs { c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) - err := c.conn.WriteMessage(&message.Audio{ - ChunkStreamID: message.AudioChunkStreamID, - MessageStreamID: 0x1000000, - Codec: message.CodecMPEG4Audio, - Rate: flvio.SOUND_44Khz, - Depth: flvio.SOUND_16BIT, - Channels: flvio.SOUND_STEREO, - AACType: message.AudioAACTypeAU, - Payload: au, - DTS: pts + time.Duration(i)*mpeg4audio.SamplesPerAccessUnit* + err := (*w).WriteMPEG4Audio( + pts+time.Duration(i)*mpeg4audio.SamplesPerAccessUnit* time.Second/time.Duration(audioFormatMPEG4Generic.ClockRate()), - }) + au, + ) if err != nil { return err } @@ -644,8 +474,8 @@ func (c *rtmpConn) findAudioFormat( audioFormatMPEG4AudioLATM.Config != nil && len(audioFormatMPEG4AudioLATM.Config.Programs) == 1 && len(audioFormatMPEG4AudioLATM.Config.Programs[0].Layers) == 1 { - audioStartPTSFilled := false - var audioStartPTS time.Duration + startPTSFilled := false + var startPTS time.Duration stream.AddReader(c, audioMedia, audioFormatMPEG4AudioLATM, func(unit formatprocessor.Unit) { ringBuffer.Push(func() error { @@ -655,11 +485,11 @@ func (c *rtmpConn) findAudioFormat( return nil } - if !audioStartPTSFilled { - audioStartPTSFilled = true - audioStartPTS = tunit.PTS + if !startPTSFilled { + startPTSFilled = true + startPTS = tunit.PTS } - pts := tunit.PTS - audioStartPTS + pts := tunit.PTS - startPTS if videoFormat != nil { if !*videoFirstIDRFound { @@ -673,22 +503,7 @@ func (c *rtmpConn) findAudioFormat( } c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) - err := c.conn.WriteMessage(&message.Audio{ - ChunkStreamID: message.AudioChunkStreamID, - MessageStreamID: 0x1000000, - Codec: message.CodecMPEG4Audio, - Rate: flvio.SOUND_44Khz, - Depth: flvio.SOUND_16BIT, - Channels: flvio.SOUND_STEREO, - AACType: message.AudioAACTypeAU, - Payload: tunit.AU, - DTS: pts, - }) - if err != nil { - return err - } - - return nil + return (*w).WriteMPEG4Audio(pts, tunit.AU) }) }) @@ -699,18 +514,18 @@ func (c *rtmpConn) findAudioFormat( audioMedia = stream.Medias().FindFormat(&audioFormatMPEG2) if audioMedia != nil { - audioStartPTSFilled := false - var audioStartPTS time.Duration + startPTSFilled := false + var startPTS time.Duration stream.AddReader(c, audioMedia, audioFormatMPEG2, func(unit formatprocessor.Unit) { ringBuffer.Push(func() error { tunit := unit.(*formatprocessor.UnitMPEG2Audio) - if !audioStartPTSFilled { - audioStartPTSFilled = true - audioStartPTS = tunit.PTS + if !startPTSFilled { + startPTSFilled = true + startPTS = tunit.PTS } - pts := tunit.PTS - audioStartPTS + pts := tunit.PTS - startPTS if videoFormat != nil { if !*videoFirstIDRFound { @@ -734,34 +549,8 @@ func (c *rtmpConn) findAudioFormat( return fmt.Errorf("RTMP only supports MPEG-1 layer 3 audio") } - channels := uint8(flvio.SOUND_STEREO) - if h.ChannelMode == mpeg2audio.ChannelModeMono { - channels = flvio.SOUND_MONO - } - - rate := uint8(flvio.SOUND_44Khz) - switch h.SampleRate { - case 5500: - rate = flvio.SOUND_5_5Khz - case 11025: - rate = flvio.SOUND_11Khz - case 22050: - rate = flvio.SOUND_22Khz - } - - msg := &message.Audio{ - ChunkStreamID: message.AudioChunkStreamID, - MessageStreamID: 0x1000000, - Codec: message.CodecMPEG2Audio, - Rate: rate, - Depth: flvio.SOUND_16BIT, - Channels: channels, - Payload: frame, - DTS: pts, - } - c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) - err = c.conn.WriteMessage(msg) + err = (*w).WriteMPEG2Audio(pts, &h, frame) if err != nil { return err } @@ -812,29 +601,88 @@ func (c *rtmpConn) runPublish(u *url.URL) error { c.pathName = pathName c.mutex.Unlock() - videoFormat, audioFormat, err := c.conn.ReadTracks() + r, err := rtmp.NewReader(c.conn) if err != nil { return err } + videoFormat, audioFormat := r.Tracks() var medias media.Medias - var videoMedia *media.Media - var audioMedia *media.Media + var stream *stream.Stream if videoFormat != nil { - videoMedia = &media.Media{ + videoMedia := &media.Media{ Type: media.TypeVideo, Formats: []formats.Format{videoFormat}, } medias = append(medias, videoMedia) + + switch videoFormat.(type) { + case *formats.AV1: + r.OnDataAV1(func(pts time.Duration, obus [][]byte) { + stream.WriteUnit(videoMedia, videoFormat, &formatprocessor.UnitAV1{ + BaseUnit: formatprocessor.BaseUnit{ + NTP: time.Now(), + }, + PTS: pts, + OBUs: obus, + }) + }) + + case *formats.H265: + r.OnDataH265(func(pts time.Duration, au [][]byte) { + stream.WriteUnit(videoMedia, videoFormat, &formatprocessor.UnitH265{ + BaseUnit: formatprocessor.BaseUnit{ + NTP: time.Now(), + }, + PTS: pts, + AU: au, + }) + }) + + case *formats.H264: + r.OnDataH264(func(pts time.Duration, au [][]byte) { + stream.WriteUnit(videoMedia, videoFormat, &formatprocessor.UnitH264{ + BaseUnit: formatprocessor.BaseUnit{ + NTP: time.Now(), + }, + PTS: pts, + AU: au, + }) + }) + } } - if audioFormat != nil { - audioMedia = &media.Media{ + if audioFormat != nil { //nolint:dupl + audioMedia := &media.Media{ Type: media.TypeAudio, Formats: []formats.Format{audioFormat}, } medias = append(medias, audioMedia) + + switch audioFormat.(type) { + case *formats.MPEG4AudioGeneric: + r.OnDataMPEG4Audio(func(pts time.Duration, au []byte) { + stream.WriteUnit(audioMedia, audioFormat, &formatprocessor.UnitMPEG4AudioGeneric{ + BaseUnit: formatprocessor.BaseUnit{ + NTP: time.Now(), + }, + PTS: pts, + AUs: [][]byte{au}, + }) + }) + + case *formats.MPEG2Audio: + r.OnDataMPEG2Audio(func(pts time.Duration, frame []byte) { + stream.WriteUnit(audioMedia, audioFormat, &formatprocessor.UnitMPEG2Audio{ + BaseUnit: formatprocessor.BaseUnit{ + NTP: time.Now(), + }, + PTS: pts, + Frames: [][]byte{frame}, + }) + }) + } } rres := res.path.publisherStart(pathPublisherStartReq{ @@ -850,40 +698,17 @@ func (c *rtmpConn) runPublish(u *url.URL) error { res.path.name, sourceMediaInfo(medias)) + stream = rres.stream + // disable write deadline to allow outgoing acknowledges c.nconn.SetWriteDeadline(time.Time{}) - videoWriteFunc := getRTMPWriteFunc(videoMedia, videoFormat, rres.stream) - audioWriteFunc := getRTMPWriteFunc(audioMedia, audioFormat, rres.stream) - for { c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout))) - msg, err := c.conn.ReadMessage() + err := r.Read() if err != nil { return err } - - switch msg.(type) { - case *message.Video, *message.ExtendedFramesX, *message.ExtendedCodedFrames: - if videoFormat == nil { - return fmt.Errorf("received a video packet, but track is not set up") - } - - err := videoWriteFunc(msg) - if err != nil { - c.Log(logger.Warn, "%v", err) - } - - case *message.Audio: - if audioFormat == nil { - return fmt.Errorf("received an audio packet, but track is not set up") - } - - err := audioWriteFunc(msg) - if err != nil { - c.Log(logger.Warn, "%v", err) - } - } } } diff --git a/internal/core/rtmp_listener.go b/internal/core/rtmp_listener.go new file mode 100644 index 00000000..fe270fa7 --- /dev/null +++ b/internal/core/rtmp_listener.go @@ -0,0 +1,46 @@ +package core + +import ( + "net" + "sync" +) + +type rtmpListener struct { + ln net.Listener + wg *sync.WaitGroup + parent *rtmpServer +} + +func newRTMPListener( + ln net.Listener, + wg *sync.WaitGroup, + parent *rtmpServer, +) *rtmpListener { + l := &rtmpListener{ + ln: ln, + wg: wg, + parent: parent, + } + + l.wg.Add(1) + go l.run() + + return l +} + +func (l *rtmpListener) run() { + defer l.wg.Done() + + err := func() error { + for { + conn, err := l.ln.Accept() + if err != nil { + return err + } + + l.parent.newConn(conn) + } + }() + + l.parent.acceptError(err) +} diff --git a/internal/core/rtmp_server.go b/internal/core/rtmp_server.go index 15ba2dd3..35499c82 100644 --- a/internal/core/rtmp_server.go +++ b/internal/core/rtmp_server.go @@ -67,7 +67,9 @@ type rtmpServer struct { conns map[*rtmpConn]struct{} // in - chConnClose chan *rtmpConn + chNewConn chan net.Conn + chAcceptErr chan error + chCloseConn chan *rtmpConn chAPIConnsList chan rtmpServerAPIConnsListReq chAPIConnsGet chan rtmpServerAPIConnsGetReq chAPIConnsKick chan rtmpServerAPIConnsKickReq @@ -124,7 +126,9 @@ func newRTMPServer( ctxCancel: ctxCancel, ln: ln, conns: make(map[*rtmpConn]struct{}), - chConnClose: make(chan *rtmpConn), + chNewConn: make(chan net.Conn), + chAcceptErr: make(chan error), + chCloseConn: make(chan *rtmpConn), chAPIConnsList: make(chan rtmpServerAPIConnsListReq), chAPIConnsGet: make(chan rtmpServerAPIConnsGetReq), chAPIConnsKick: make(chan rtmpServerAPIConnsKickReq), @@ -136,6 +140,12 @@ func newRTMPServer( s.metrics.rtmpServerSet(s) } + newRTMPListener( + s.ln, + &s.wg, + s, + ) + s.wg.Add(1) go s.run() @@ -161,40 +171,14 @@ func (s *rtmpServer) close() { func (s *rtmpServer) run() { defer s.wg.Done() - s.wg.Add(1) - connNew := make(chan net.Conn) - acceptErr := make(chan error) - go func() { - defer s.wg.Done() - err := func() error { - for { - conn, err := s.ln.Accept() - if err != nil { - return err - } - - select { - case connNew <- conn: - case <-s.ctx.Done(): - conn.Close() - } - } - }() - - select { - case acceptErr <- err: - case <-s.ctx.Done(): - } - }() - outer: for { select { - case err := <-acceptErr: + case err := <-s.chAcceptErr: s.Log(logger.Error, "%s", err) break outer - case nconn := <-connNew: + case nconn := <-s.chNewConn: c := newRTMPConn( s.ctx, s.isTLS, @@ -211,7 +195,7 @@ outer: s) s.conns[c] = struct{}{} - case c := <-s.chConnClose: + case c := <-s.chCloseConn: delete(s.conns, c) case req := <-s.chAPIConnsList: @@ -272,10 +256,27 @@ func (s *rtmpServer) findConnByUUID(uuid uuid.UUID) *rtmpConn { return nil } -// connClose is called by rtmpConn. -func (s *rtmpServer) connClose(c *rtmpConn) { +// newConn is called by rtmpListener. +func (s *rtmpServer) newConn(conn net.Conn) { select { - case s.chConnClose <- c: + case s.chNewConn <- conn: + case <-s.ctx.Done(): + conn.Close() + } +} + +// acceptError is called by rtmpListener. +func (s *rtmpServer) acceptError(err error) { + select { + case s.chAcceptErr <- err: + case <-s.ctx.Done(): + } +} + +// closeConn is called by rtmpConn. +func (s *rtmpServer) closeConn(c *rtmpConn) { + select { + case s.chCloseConn <- c: case <-s.ctx.Done(): } } diff --git a/internal/core/rtmp_server_test.go b/internal/core/rtmp_server_test.go index 65e3c151..d9f846f5 100644 --- a/internal/core/rtmp_server_test.go +++ b/internal/core/rtmp_server_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/require" "github.com/bluenviron/mediamtx/internal/rtmp" - "github.com/bluenviron/mediamtx/internal/rtmp/message" ) func TestRTMPServerRunOnConnect(t *testing.T) { @@ -154,7 +153,7 @@ func TestRTMPServer(t *testing.T) { IndexDeltaLength: 3, } - err = conn1.WriteTracks(videoTrack, audioTrack) + w, err := rtmp.NewWriter(conn1, videoTrack, audioTrack) require.NoError(t, err) time.Sleep(500 * time.Millisecond) @@ -181,43 +180,40 @@ func TestRTMPServer(t *testing.T) { err = conn2.InitializeClient(u2, false) require.NoError(t, err) - videoTrack1, audioTrack2, err := conn2.ReadTracks() + r, err := rtmp.NewReader(conn2) require.NoError(t, err) + videoTrack1, audioTrack2 := r.Tracks() require.Equal(t, videoTrack, videoTrack1) require.Equal(t, audioTrack, audioTrack2) - err = conn1.WriteMessage(&message.Video{ - ChunkStreamID: message.VideoChunkStreamID, - MessageStreamID: 0x1000000, - Codec: message.CodecH264, - IsKeyFrame: true, - Type: message.VideoTypeAU, - Payload: []byte{ - 0x00, 0x00, 0x00, 0x04, 0x05, 0x02, 0x03, 0x04, // IDR 1 - 0x00, 0x00, 0x00, 0x04, 0x05, 0x02, 0x03, 0x04, // IDR 2 - }, + err = w.WriteH264(0, 0, true, [][]byte{ + {0x05, 0x02, 0x03, 0x04}, // IDR 1 + {0x05, 0x02, 0x03, 0x04}, // IDR 2 }) require.NoError(t, err) - msg1, err := conn2.ReadMessage() + r.OnDataH264(func(pts time.Duration, au [][]byte) { + require.Equal(t, [][]byte{ + { // SPS + 0x67, 0x42, 0xc0, 0x28, 0xd9, 0x00, 0x78, 0x02, + 0x27, 0xe5, 0x84, 0x00, 0x00, 0x03, 0x00, 0x04, + 0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, 0x60, 0xc9, + 0x20, + }, + { // PPS + 0x08, 0x06, 0x07, 0x08, + }, + { // IDR 1 + 0x05, 0x02, 0x03, 0x04, + }, + { // IDR 2 + 0x05, 0x02, 0x03, 0x04, + }, + }, au) + }) + + err = r.Read() require.NoError(t, err) - require.Equal(t, &message.Video{ - ChunkStreamID: message.VideoChunkStreamID, - MessageStreamID: 0x1000000, - Codec: message.CodecH264, - IsKeyFrame: true, - Type: message.VideoTypeAU, - Payload: []byte{ - 0x00, 0x00, 0x00, 0x19, // SPS - 0x67, 0x42, 0xc0, 0x28, 0xd9, 0x00, 0x78, 0x02, - 0x27, 0xe5, 0x84, 0x00, 0x00, 0x03, 0x00, 0x04, - 0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, 0x60, 0xc9, - 0x20, - 0x00, 0x00, 0x00, 0x04, 0x08, 0x06, 0x07, 0x08, // PPS - 0x00, 0x00, 0x00, 0x04, 0x05, 0x02, 0x03, 0x04, // IDR 1 - 0x00, 0x00, 0x00, 0x04, 0x05, 0x02, 0x03, 0x04, // IDR 2 - }, - }, msg1) }) } } @@ -259,7 +255,7 @@ func TestRTMPServerAuthFail(t *testing.T) { PacketizationMode: 1, } - err = conn1.WriteTracks(videoTrack, nil) + _, err = rtmp.NewWriter(conn1, videoTrack, nil) require.NoError(t, err) time.Sleep(500 * time.Millisecond) @@ -275,7 +271,7 @@ func TestRTMPServerAuthFail(t *testing.T) { err = conn2.InitializeClient(u2, false) require.NoError(t, err) - _, _, err = conn2.ReadTracks() + _, err = rtmp.NewReader(conn2) require.EqualError(t, err, "EOF") }) @@ -313,7 +309,7 @@ func TestRTMPServerAuthFail(t *testing.T) { PacketizationMode: 1, } - err = conn1.WriteTracks(videoTrack, nil) + _, err = rtmp.NewWriter(conn1, videoTrack, nil) require.NoError(t, err) time.Sleep(500 * time.Millisecond) @@ -329,7 +325,7 @@ func TestRTMPServerAuthFail(t *testing.T) { err = conn2.InitializeClient(u2, false) require.NoError(t, err) - _, _, err = conn2.ReadTracks() + _, err = rtmp.NewReader(conn2) require.EqualError(t, err, "EOF") }) @@ -368,7 +364,7 @@ func TestRTMPServerAuthFail(t *testing.T) { PacketizationMode: 1, } - err = conn1.WriteTracks(videoTrack, nil) + _, err = rtmp.NewWriter(conn1, videoTrack, nil) require.NoError(t, err) time.Sleep(500 * time.Millisecond) @@ -384,7 +380,7 @@ func TestRTMPServerAuthFail(t *testing.T) { err = conn2.InitializeClient(u2, false) require.NoError(t, err) - _, _, err = conn2.ReadTracks() + _, err = rtmp.NewReader(conn2) require.EqualError(t, err, "EOF") }) } diff --git a/internal/core/rtmp_source.go b/internal/core/rtmp_source.go index d6a9bf12..b7a63da1 100644 --- a/internal/core/rtmp_source.go +++ b/internal/core/rtmp_source.go @@ -12,9 +12,10 @@ import ( "github.com/bluenviron/gortsplib/v3/pkg/media" "github.com/bluenviron/mediamtx/internal/conf" + "github.com/bluenviron/mediamtx/internal/formatprocessor" "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/rtmp" - "github.com/bluenviron/mediamtx/internal/rtmp/message" + "github.com/bluenviron/mediamtx/internal/stream" ) type rtmpSourceParent interface { @@ -60,10 +61,10 @@ func (s *rtmpSource) run(ctx context.Context, cnf *conf.PathConf, reloadConf cha u.Host = net.JoinHostPort(u.Host, "1935") } - ctx2, cancel2 := context.WithTimeout(ctx, time.Duration(s.readTimeout)) - defer cancel2() - nconn, err := func() (net.Conn, error) { + ctx2, cancel2 := context.WithTimeout(ctx, time.Duration(s.readTimeout)) + defer cancel2() + if u.Scheme == "rtmp" { return (&net.Dialer{}).DialContext(ctx2, "tcp", u.Host) } @@ -76,98 +77,9 @@ func (s *rtmpSource) run(ctx context.Context, cnf *conf.PathConf, reloadConf cha return err } - conn := rtmp.NewConn(nconn) - readDone := make(chan error) go func() { - readDone <- func() error { - nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) - nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout))) - err = conn.InitializeClient(u, false) - if err != nil { - return err - } - - nconn.SetWriteDeadline(time.Time{}) - nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) - videoFormat, audioFormat, err := conn.ReadTracks() - if err != nil { - return err - } - - switch videoFormat.(type) { - case *formats.H265, *formats.AV1: - return fmt.Errorf("proxying H265 or AV1 tracks with RTMP is not supported") - } - - var medias media.Medias - var videoMedia *media.Media - var audioMedia *media.Media - - if videoFormat != nil { - videoMedia = &media.Media{ - Type: media.TypeVideo, - Formats: []formats.Format{videoFormat}, - } - medias = append(medias, videoMedia) - } - - if audioFormat != nil { - audioMedia = &media.Media{ - Type: media.TypeAudio, - Formats: []formats.Format{audioFormat}, - } - medias = append(medias, audioMedia) - } - - res := s.parent.sourceStaticImplSetReady(pathSourceStaticSetReadyReq{ - medias: medias, - generateRTPPackets: true, - }) - if res.err != nil { - return res.err - } - - s.Log(logger.Info, "ready: %s", sourceMediaInfo(medias)) - - defer s.parent.sourceStaticImplSetNotReady(pathSourceStaticSetNotReadyReq{}) - - videoWriteFunc := getRTMPWriteFunc(videoMedia, videoFormat, res.stream) - audioWriteFunc := getRTMPWriteFunc(audioMedia, audioFormat, res.stream) - - // disable write deadline to allow outgoing acknowledges - nconn.SetWriteDeadline(time.Time{}) - - for { - nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) - msg, err := conn.ReadMessage() - if err != nil { - return err - } - - switch tmsg := msg.(type) { - case *message.Video: - if videoFormat == nil { - return fmt.Errorf("received an H264 packet, but track is not set up") - } - - err := videoWriteFunc(tmsg) - if err != nil { - s.Log(logger.Warn, "%v", err) - } - - case *message.Audio: - if audioFormat == nil { - return fmt.Errorf("received an AAC packet, but track is not set up") - } - - err := audioWriteFunc(tmsg) - if err != nil { - s.Log(logger.Warn, "%v", err) - } - } - } - }() + readDone <- s.runReader(u, nconn) }() for { @@ -186,6 +98,109 @@ func (s *rtmpSource) run(ctx context.Context, cnf *conf.PathConf, reloadConf cha } } +func (s *rtmpSource) runReader(u *url.URL, nconn net.Conn) error { + conn := rtmp.NewConn(nconn) + + nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) + nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout))) + err := conn.InitializeClient(u, false) + if err != nil { + return err + } + + mc, err := rtmp.NewReader(conn) + if err != nil { + return err + } + + videoFormat, audioFormat := mc.Tracks() + + switch videoFormat.(type) { + case *formats.H265, *formats.AV1: + return fmt.Errorf("proxying H265 or AV1 tracks with RTMP is not supported") + } + + var medias media.Medias + var stream *stream.Stream + + if videoFormat != nil { + videoMedia := &media.Media{ + Type: media.TypeVideo, + Formats: []formats.Format{videoFormat}, + } + medias = append(medias, videoMedia) + + if _, ok := videoFormat.(*formats.H264); ok { + mc.OnDataH264(func(pts time.Duration, au [][]byte) { + stream.WriteUnit(videoMedia, videoFormat, &formatprocessor.UnitH264{ + BaseUnit: formatprocessor.BaseUnit{ + NTP: time.Now(), + }, + PTS: pts, + AU: au, + }) + }) + } + } + + if audioFormat != nil { //nolint:dupl + audioMedia := &media.Media{ + Type: media.TypeAudio, + Formats: []formats.Format{audioFormat}, + } + medias = append(medias, audioMedia) + + switch audioFormat.(type) { + case *formats.MPEG4AudioGeneric: + mc.OnDataMPEG4Audio(func(pts time.Duration, au []byte) { + stream.WriteUnit(audioMedia, audioFormat, &formatprocessor.UnitMPEG4AudioGeneric{ + BaseUnit: formatprocessor.BaseUnit{ + NTP: time.Now(), + }, + PTS: pts, + AUs: [][]byte{au}, + }) + }) + + case *formats.MPEG2Audio: + mc.OnDataMPEG2Audio(func(pts time.Duration, frame []byte) { + stream.WriteUnit(audioMedia, audioFormat, &formatprocessor.UnitMPEG2Audio{ + BaseUnit: formatprocessor.BaseUnit{ + NTP: time.Now(), + }, + PTS: pts, + Frames: [][]byte{frame}, + }) + }) + } + } + + res := s.parent.sourceStaticImplSetReady(pathSourceStaticSetReadyReq{ + medias: medias, + generateRTPPackets: true, + }) + if res.err != nil { + return res.err + } + + defer s.parent.sourceStaticImplSetNotReady(pathSourceStaticSetNotReadyReq{}) + + s.Log(logger.Info, "ready: %s", sourceMediaInfo(medias)) + + stream = res.stream + + // disable write deadline to allow outgoing acknowledges + nconn.SetWriteDeadline(time.Time{}) + + for { + nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) + err := mc.Read() + if err != nil { + return err + } + } +} + // apiSourceDescribe implements sourceStaticImpl. func (*rtmpSource) apiSourceDescribe() pathAPISourceOrReader { return pathAPISourceOrReader{ diff --git a/internal/core/rtmp_source_test.go b/internal/core/rtmp_source_test.go index 55608c8b..a308c03d 100644 --- a/internal/core/rtmp_source_test.go +++ b/internal/core/rtmp_source_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/require" "github.com/bluenviron/mediamtx/internal/rtmp" - "github.com/bluenviron/mediamtx/internal/rtmp/message" ) func TestRTMPSource(t *testing.T) { @@ -81,19 +80,12 @@ func TestRTMPSource(t *testing.T) { IndexDeltaLength: 3, } - err = conn.WriteTracks(videoTrack, audioTrack) + w, err := rtmp.NewWriter(conn, videoTrack, audioTrack) require.NoError(t, err) <-connected - err = conn.WriteMessage(&message.Video{ - ChunkStreamID: message.VideoChunkStreamID, - MessageStreamID: 0x1000000, - Codec: message.CodecH264, - IsKeyFrame: true, - Type: message.VideoTypeAU, - Payload: []byte{0x00, 0x00, 0x00, 0x04, 0x05, 0x02, 0x03, 0x04}, - }) + err = w.WriteH264(0, 0, true, [][]byte{{0x05, 0x02, 0x03, 0x04}}) require.NoError(t, err) <-done diff --git a/internal/rtmp/tracks/boxes_av1.go b/internal/rtmp/boxes_av1.go similarity index 98% rename from internal/rtmp/tracks/boxes_av1.go rename to internal/rtmp/boxes_av1.go index f75b07e8..524b793a 100644 --- a/internal/rtmp/tracks/boxes_av1.go +++ b/internal/rtmp/boxes_av1.go @@ -1,4 +1,4 @@ -package tracks +package rtmp import ( gomp4 "github.com/abema/go-mp4" diff --git a/internal/rtmp/conn.go b/internal/rtmp/conn.go index ee2b7ba3..2a27a4df 100644 --- a/internal/rtmp/conn.go +++ b/internal/rtmp/conn.go @@ -7,13 +7,11 @@ import ( "net/url" "strings" - "github.com/bluenviron/gortsplib/v3/pkg/formats" "github.com/notedit/rtmp/format/flv/flvio" "github.com/bluenviron/mediamtx/internal/rtmp/bytecounter" "github.com/bluenviron/mediamtx/internal/rtmp/handshake" "github.com/bluenviron/mediamtx/internal/rtmp/message" - "github.com/bluenviron/mediamtx/internal/rtmp/tracks" ) func resultIsOK1(res *message.CommandAMF0) bool { @@ -98,6 +96,43 @@ func createURL(tcURL string, app string, play string) (*url.URL, error) { return u, nil } +func readCommand(mrw *message.ReadWriter) (*message.CommandAMF0, error) { + for { + msg, err := mrw.Read() + if err != nil { + return nil, err + } + + if cmd, ok := msg.(*message.CommandAMF0); ok { + return cmd, nil + } + } +} + +func readCommandResult( + mrw *message.ReadWriter, + commandID int, + commandName string, + isValid func(*message.CommandAMF0) bool, +) error { + for { + msg, err := mrw.Read() + if err != nil { + return err + } + + if cmd, ok := msg.(*message.CommandAMF0); ok { + if cmd.CommandID == commandID && cmd.Name == commandName { + if !isValid(cmd) { + return fmt.Errorf("server refused connect request") + } + + return nil + } + } + } +} + // Conn is a RTMP connection. type Conn struct { bc *bytecounter.ReadWriter @@ -121,36 +156,8 @@ func (c *Conn) BytesSent() uint64 { return c.bc.Writer.Count() } -func (c *Conn) readCommand() (*message.CommandAMF0, error) { - for { - msg, err := c.mrw.Read() - if err != nil { - return nil, err - } - - if cmd, ok := msg.(*message.CommandAMF0); ok { - return cmd, nil - } - } -} - -func (c *Conn) readCommandResult(commandID int, commandName string, isValid func(*message.CommandAMF0) bool) error { - for { - msg, err := c.mrw.Read() - if err != nil { - return err - } - - if cmd, ok := msg.(*message.CommandAMF0); ok { - if cmd.CommandID == commandID && cmd.Name == commandName { - if !isValid(cmd) { - return fmt.Errorf("server refused connect request") - } - - return nil - } - } - } +func (c *Conn) skipInitialization() { + c.mrw = message.NewReadWriter(c.bc, false) } // InitializeClient performs the initialization of a client-side connection. @@ -207,7 +214,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - err = c.readCommandResult(1, "_result", resultIsOK1) + err = readCommandResult(c.mrw, 1, "_result", resultIsOK1) if err != nil { return err } @@ -225,7 +232,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - err = c.readCommandResult(2, "_result", resultIsOK2) + err = readCommandResult(c.mrw, 2, "_result", resultIsOK2) if err != nil { return err } @@ -251,7 +258,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - return c.readCommandResult(3, "onStatus", resultIsOK1) + return readCommandResult(c.mrw, 3, "onStatus", resultIsOK1) } err = c.mrw.Write(&message.CommandAMF0{ @@ -292,7 +299,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - err = c.readCommandResult(4, "_result", resultIsOK2) + err = readCommandResult(c.mrw, 4, "_result", resultIsOK2) if err != nil { return err } @@ -312,7 +319,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - return c.readCommandResult(5, "onStatus", resultIsOK1) + return readCommandResult(c.mrw, 5, "onStatus", resultIsOK1) } // InitializeServer performs the initialization of a server-side connection. @@ -324,7 +331,7 @@ func (c *Conn) InitializeServer() (*url.URL, bool, error) { c.mrw = message.NewReadWriter(c.bc, false) - cmd, err := c.readCommand() + cmd, err := readCommand(c.mrw) if err != nil { return nil, false, err } @@ -403,7 +410,7 @@ func (c *Conn) InitializeServer() (*url.URL, bool, error) { } for { - cmd, err := c.readCommand() + cmd, err := readCommand(c.mrw) if err != nil { return nil, false, err } @@ -564,23 +571,12 @@ func (c *Conn) InitializeServer() (*url.URL, bool, error) { } } -// ReadMessage reads a message. -func (c *Conn) ReadMessage() (message.Message, error) { +// Read reads a message. +func (c *Conn) Read() (message.Message, error) { return c.mrw.Read() } -// WriteMessage writes a message. -func (c *Conn) WriteMessage(msg message.Message) error { +// Write writes a message. +func (c *Conn) Write(msg message.Message) error { return c.mrw.Write(msg) } - -// ReadTracks reads track informations. -// It returns the video track and the audio track. -func (c *Conn) ReadTracks() (formats.Format, formats.Format, error) { - return tracks.Read(c.mrw) -} - -// WriteTracks writes track informations. -func (c *Conn) WriteTracks(videoTrack formats.Format, audioTrack formats.Format) error { - return tracks.Write(c.mrw, videoTrack, audioTrack) -} diff --git a/internal/rtmp/conn_test.go b/internal/rtmp/conn_test.go index f50df0c0..db583c8f 100644 --- a/internal/rtmp/conn_test.go +++ b/internal/rtmp/conn_test.go @@ -491,6 +491,6 @@ func BenchmarkRead(b *testing.B) { conn := NewConn(&buf) for n := 0; n < b.N; n++ { - conn.ReadMessage() + conn.Read() } } diff --git a/internal/rtmp/tracks/read.go b/internal/rtmp/reader.go similarity index 63% rename from internal/rtmp/tracks/read.go rename to internal/rtmp/reader.go index 88b4c4b3..00af2491 100644 --- a/internal/rtmp/tracks/read.go +++ b/internal/rtmp/reader.go @@ -1,5 +1,4 @@ -// Package tracks contains functions to read and write track metadata. -package tracks +package rtmp import ( "bytes" @@ -19,6 +18,18 @@ import ( "github.com/bluenviron/mediamtx/internal/rtmp/message" ) +// OnDataAV1Func is the prototype of the callback passed to OnDataAV1(). +type OnDataAV1Func func(pts time.Duration, obus [][]byte) + +// OnDataH26xFunc is the prototype of the callback passed to OnDataH26x(). +type OnDataH26xFunc func(pts time.Duration, au [][]byte) + +// OnDataMPEG4AudioFunc is the prototype of the callback passed to OnDataMPEG4Audio(). +type OnDataMPEG4AudioFunc func(pts time.Duration, au []byte) + +// OnDataMPEG2AudioFunc is the prototype of the callback passed to OnDataMPEG2Audio(). +type OnDataMPEG2AudioFunc func(pts time.Duration, frame []byte) + func h265FindNALU(array []gomp4.HEVCNaluArray, typ h265.NALUType) []byte { for _, entry := range array { if entry.NaluType == byte(typ) && entry.NumNalus == 1 && @@ -62,7 +73,7 @@ func trackFromAACDecoderConfig(data []byte) (formats.Format, error) { var errEmptyMetadata = errors.New("metadata is empty") -func readTracksFromMetadata(r *message.ReadWriter, payload []interface{}) (formats.Format, formats.Format, error) { +func tracksFromMetadata(conn *Conn, payload []interface{}) (formats.Format, formats.Format, error) { if len(payload) != 1 { return nil, nil, fmt.Errorf("invalid metadata") } @@ -145,7 +156,7 @@ func readTracksFromMetadata(r *message.ReadWriter, payload []interface{}) (forma return videoTrack, audioTrack, nil } - msg, err := r.Read() + msg, err := conn.Read() if err != nil { return nil, nil, err } @@ -261,7 +272,7 @@ func readTracksFromMetadata(r *message.ReadWriter, payload []interface{}) (forma } } -func readTracksFromMessages(r *message.ReadWriter, msg message.Message) (formats.Format, formats.Format, error) { +func tracksFromMessages(conn *Conn, msg message.Message) (formats.Format, formats.Format, error) { var startTime *time.Duration var videoTrack formats.Format var audioTrack formats.Format @@ -322,7 +333,7 @@ outer: } var err error - msg, err = r.Read() + msg, err = conn.Read() if err != nil { return nil, nil, err } @@ -335,12 +346,34 @@ outer: return videoTrack, audioTrack, nil } -// Read reads track informations. -// It returns the video track and the audio track. -func Read(r *message.ReadWriter) (formats.Format, formats.Format, error) { +// Reader is a wrapper around Conn that provides utilities to demux incoming data. +type Reader struct { + conn *Conn + videoTrack formats.Format + audioTrack formats.Format + onDataVideo func(message.Message) error + onDataAudio func(*message.Audio) error +} + +// NewReader allocates a Reader. +func NewReader(conn *Conn) (*Reader, error) { + r := &Reader{ + conn: conn, + } + + var err error + r.videoTrack, r.audioTrack, err = r.readTracks() + if err != nil { + return nil, err + } + + return r, nil +} + +func (r *Reader) readTracks() (formats.Format, formats.Format, error) { msg, err := func() (message.Message, error) { for { - msg, err := r.Read() + msg, err := r.conn.Read() if err != nil { return nil, err } @@ -373,15 +406,15 @@ func Read(r *message.ReadWriter) (formats.Format, formats.Format, error) { if len(payload) >= 1 { if s, ok := payload[0].(string); ok && s == "onMetaData" { - videoTrack, audioTrack, err := readTracksFromMetadata(r, payload[1:]) + videoTrack, audioTrack, err := tracksFromMetadata(r.conn, payload[1:]) if err != nil { if err == errEmptyMetadata { - msg, err := r.Read() + msg, err := r.conn.Read() if err != nil { return nil, nil, err } - return readTracksFromMessages(r, msg) + return tracksFromMessages(r.conn, msg) } return nil, nil, err @@ -392,5 +425,135 @@ func Read(r *message.ReadWriter) (formats.Format, formats.Format, error) { } } - return readTracksFromMessages(r, msg) + return tracksFromMessages(r.conn, msg) +} + +// Tracks returns detected tracks +func (r *Reader) Tracks() (formats.Format, formats.Format) { + return r.videoTrack, r.audioTrack +} + +// OnDataAV1 sets a callback that is called when AV1 data is received. +func (r *Reader) OnDataAV1(cb OnDataAV1Func) { + r.onDataVideo = func(msg message.Message) error { + if msg, ok := msg.(*message.ExtendedCodedFrames); ok { + obus, err := av1.BitstreamUnmarshal(msg.Payload, true) + if err != nil { + return fmt.Errorf("unable to decode bitstream: %v", err) + } + + cb(msg.DTS, obus) + } + return nil + } +} + +// OnDataH265 sets a callback that is called when H265 data is received. +func (r *Reader) OnDataH265(cb OnDataH26xFunc) { + r.onDataVideo = func(msg message.Message) error { + switch msg := msg.(type) { + case *message.Video: + au, err := h264.AVCCUnmarshal(msg.Payload) + if err != nil { + return fmt.Errorf("unable to decode AVCC: %v", err) + } + + cb(msg.DTS+msg.PTSDelta, au) + + case *message.ExtendedFramesX: + au, err := h264.AVCCUnmarshal(msg.Payload) + if err != nil { + return fmt.Errorf("unable to decode AVCC: %v", err) + } + + cb(msg.DTS, au) + + case *message.ExtendedCodedFrames: + au, err := h264.AVCCUnmarshal(msg.Payload) + if err != nil { + return fmt.Errorf("unable to decode AVCC: %v", err) + } + + cb(msg.DTS+msg.PTSDelta, au) + } + + return nil + } +} + +// OnDataH264 sets a callback that is called when H264 data is received. +func (r *Reader) OnDataH264(cb OnDataH26xFunc) { + r.onDataVideo = func(msg message.Message) error { + if msg, ok := msg.(*message.Video); ok { + switch msg.Type { + case message.VideoTypeConfig: + var conf h264conf.Conf + err := conf.Unmarshal(msg.Payload) + if err != nil { + return fmt.Errorf("unable to parse H264 config: %v", err) + } + + au := [][]byte{ + conf.SPS, + conf.PPS, + } + + cb(msg.DTS+msg.PTSDelta, au) + + case message.VideoTypeAU: + au, err := h264.AVCCUnmarshal(msg.Payload) + if err != nil { + return fmt.Errorf("unable to decode AVCC: %v", err) + } + + cb(msg.DTS+msg.PTSDelta, au) + } + } + + return nil + } +} + +// OnDataMPEG4Audio sets a callback that is called when MPEG-4 Audio data is received. +func (r *Reader) OnDataMPEG4Audio(cb OnDataMPEG4AudioFunc) { + r.onDataAudio = func(msg *message.Audio) error { + if msg.AACType == message.AudioAACTypeAU { + cb(msg.DTS, msg.Payload) + } + return nil + } +} + +// OnDataMPEG2Audio sets a callback that is called when MPEG-2 Audio data is received. +func (r *Reader) OnDataMPEG2Audio(cb OnDataMPEG2AudioFunc) { + r.onDataAudio = func(msg *message.Audio) error { + cb(msg.DTS, msg.Payload) + return nil + } +} + +// Read reads data. +func (r *Reader) Read() error { + msg, err := r.conn.Read() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *message.Video, *message.ExtendedFramesX, *message.ExtendedCodedFrames: + if r.onDataVideo == nil { + return fmt.Errorf("received a video packet, but track is not set up") + } + + return r.onDataVideo(msg) + + case *message.Audio: + if r.onDataAudio == nil { + return fmt.Errorf("received an audio packet, but track is not set up") + } + + return r.onDataAudio(msg) + } + + return nil } diff --git a/internal/rtmp/tracks/read_test.go b/internal/rtmp/reader_test.go similarity index 98% rename from internal/rtmp/tracks/read_test.go rename to internal/rtmp/reader_test.go index 5a4e6847..0cc4a5de 100644 --- a/internal/rtmp/tracks/read_test.go +++ b/internal/rtmp/reader_test.go @@ -1,4 +1,4 @@ -package tracks +package rtmp import ( "bytes" @@ -16,7 +16,7 @@ import ( "github.com/bluenviron/mediamtx/internal/rtmp/message" ) -func TestRead(t *testing.T) { +func TestReadTracks(t *testing.T) { sps := []byte{ 0x67, 0x64, 0x00, 0x0c, 0xac, 0x3b, 0x50, 0xb0, 0x4b, 0x42, 0x00, 0x00, 0x03, 0x00, 0x02, 0x00, @@ -536,8 +536,12 @@ func TestRead(t *testing.T) { require.NoError(t, err) } - videoTrack, audioTrack, err := Read(mrw) + c := NewConn(&buf) + c.skipInitialization() + + r, err := NewReader(c) require.NoError(t, err) + videoTrack, audioTrack := r.Tracks() require.Equal(t, ca.videoTrack, videoTrack) require.Equal(t, ca.audioTrack, audioTrack) }) diff --git a/internal/rtmp/tracks/write.go b/internal/rtmp/tracks/write.go deleted file mode 100644 index 343303c2..00000000 --- a/internal/rtmp/tracks/write.go +++ /dev/null @@ -1,118 +0,0 @@ -package tracks - -import ( - "github.com/bluenviron/gortsplib/v3/pkg/formats" - "github.com/bluenviron/mediacommon/pkg/codecs/mpeg4audio" - "github.com/notedit/rtmp/format/flv/flvio" - - "github.com/bluenviron/mediamtx/internal/rtmp/h264conf" - "github.com/bluenviron/mediamtx/internal/rtmp/message" -) - -// Write writes track informations. -func Write(w *message.ReadWriter, videoTrack formats.Format, audioTrack formats.Format) error { - err := w.Write(&message.DataAMF0{ - ChunkStreamID: 4, - MessageStreamID: 0x1000000, - Payload: []interface{}{ - "@setDataFrame", - "onMetaData", - flvio.AMFMap{ - { - K: "videodatarate", - V: float64(0), - }, - { - K: "videocodecid", - V: func() float64 { - switch videoTrack.(type) { - case *formats.H264: - return message.CodecH264 - - default: - return 0 - } - }(), - }, - { - K: "audiodatarate", - V: float64(0), - }, - { - K: "audiocodecid", - V: func() float64 { - switch audioTrack.(type) { - case *formats.MPEG2Audio: - return message.CodecMPEG2Audio - - case *formats.MPEG4AudioGeneric, *formats.MPEG4AudioLATM: - return message.CodecMPEG4Audio - - default: - return 0 - } - }(), - }, - }, - }, - }) - if err != nil { - return err - } - - if videoTrack, ok := videoTrack.(*formats.H264); ok { - // write decoder config only if SPS and PPS are available. - // if they're not available yet, they're sent later. - if sps, pps := videoTrack.SafeParams(); sps != nil && pps != nil { - buf, _ := h264conf.Conf{ - SPS: sps, - PPS: pps, - }.Marshal() - - err = w.Write(&message.Video{ - ChunkStreamID: message.VideoChunkStreamID, - MessageStreamID: 0x1000000, - Codec: message.CodecH264, - IsKeyFrame: true, - Type: message.VideoTypeConfig, - Payload: buf, - }) - if err != nil { - return err - } - } - } - - var audioConfig *mpeg4audio.AudioSpecificConfig - - switch track := audioTrack.(type) { - case *formats.MPEG4Audio: - audioConfig = track.Config - - case *formats.MPEG4AudioLATM: - audioConfig = track.Config.Programs[0].Layers[0].AudioSpecificConfig - } - - if audioConfig != nil { - enc, err := audioConfig.Marshal() - if err != nil { - return err - } - - err = w.Write(&message.Audio{ - ChunkStreamID: message.AudioChunkStreamID, - MessageStreamID: 0x1000000, - Codec: message.CodecMPEG4Audio, - Rate: flvio.SOUND_44Khz, - Depth: flvio.SOUND_16BIT, - Channels: flvio.SOUND_STEREO, - AACType: message.AudioAACTypeConfig, - Payload: enc, - }) - if err != nil { - return err - } - } - - return nil -} diff --git a/internal/rtmp/writer.go b/internal/rtmp/writer.go new file mode 100644 index 00000000..31841dc3 --- /dev/null +++ b/internal/rtmp/writer.go @@ -0,0 +1,208 @@ +package rtmp + +import ( + "time" + + "github.com/bluenviron/gortsplib/v3/pkg/formats" + "github.com/bluenviron/mediacommon/pkg/codecs/h264" + "github.com/bluenviron/mediacommon/pkg/codecs/mpeg2audio" + "github.com/bluenviron/mediacommon/pkg/codecs/mpeg4audio" + "github.com/notedit/rtmp/format/flv/flvio" + + "github.com/bluenviron/mediamtx/internal/rtmp/h264conf" + "github.com/bluenviron/mediamtx/internal/rtmp/message" +) + +func mpeg2AudioRate(sr int) uint8 { + switch sr { + case 5500: + return flvio.SOUND_5_5Khz + case 11025: + return flvio.SOUND_11Khz + case 22050: + return flvio.SOUND_22Khz + default: + return flvio.SOUND_44Khz + } +} + +func mpeg2AudioChannels(m mpeg2audio.ChannelMode) uint8 { + if m == mpeg2audio.ChannelModeMono { + return flvio.SOUND_MONO + } + return flvio.SOUND_STEREO +} + +// Writer is a wrapper around Conn that provides utilities to mux outgoing data. +type Writer struct { + conn *Conn +} + +// NewWriter allocates a Writer. +func NewWriter(conn *Conn, videoTrack formats.Format, audioTrack formats.Format) (*Writer, error) { + w := &Writer{ + conn: conn, + } + + err := w.writeTracks(videoTrack, audioTrack) + if err != nil { + return nil, err + } + + return w, nil +} + +func (w *Writer) writeTracks(videoTrack formats.Format, audioTrack formats.Format) error { + err := w.conn.Write(&message.DataAMF0{ + ChunkStreamID: 4, + MessageStreamID: 0x1000000, + Payload: []interface{}{ + "@setDataFrame", + "onMetaData", + flvio.AMFMap{ + { + K: "videodatarate", + V: float64(0), + }, + { + K: "videocodecid", + V: func() float64 { + switch videoTrack.(type) { + case *formats.H264: + return message.CodecH264 + + default: + return 0 + } + }(), + }, + { + K: "audiodatarate", + V: float64(0), + }, + { + K: "audiocodecid", + V: func() float64 { + switch audioTrack.(type) { + case *formats.MPEG2Audio: + return message.CodecMPEG2Audio + + case *formats.MPEG4AudioGeneric, *formats.MPEG4AudioLATM: + return message.CodecMPEG4Audio + + default: + return 0 + } + }(), + }, + }, + }, + }) + if err != nil { + return err + } + + if videoTrack, ok := videoTrack.(*formats.H264); ok { + // write decoder config only if SPS and PPS are available. + // if they're not available yet, they're sent later. + if sps, pps := videoTrack.SafeParams(); sps != nil && pps != nil { + buf, _ := h264conf.Conf{ + SPS: sps, + PPS: pps, + }.Marshal() + + err = w.conn.Write(&message.Video{ + ChunkStreamID: message.VideoChunkStreamID, + MessageStreamID: 0x1000000, + Codec: message.CodecH264, + IsKeyFrame: true, + Type: message.VideoTypeConfig, + Payload: buf, + }) + if err != nil { + return err + } + } + } + + var audioConfig *mpeg4audio.AudioSpecificConfig + + switch track := audioTrack.(type) { + case *formats.MPEG4Audio: + audioConfig = track.Config + + case *formats.MPEG4AudioLATM: + audioConfig = track.Config.Programs[0].Layers[0].AudioSpecificConfig + } + + if audioConfig != nil { + enc, err := audioConfig.Marshal() + if err != nil { + return err + } + + err = w.conn.Write(&message.Audio{ + ChunkStreamID: message.AudioChunkStreamID, + MessageStreamID: 0x1000000, + Codec: message.CodecMPEG4Audio, + Rate: flvio.SOUND_44Khz, + Depth: flvio.SOUND_16BIT, + Channels: flvio.SOUND_STEREO, + AACType: message.AudioAACTypeConfig, + Payload: enc, + }) + if err != nil { + return err + } + } + + return nil +} + +// WriteH264 writes H264 data. +func (w *Writer) WriteH264(pts time.Duration, dts time.Duration, idrPresent bool, au [][]byte) error { + avcc, err := h264.AVCCMarshal(au) + if err != nil { + return err + } + + return w.conn.Write(&message.Video{ + ChunkStreamID: message.VideoChunkStreamID, + MessageStreamID: 0x1000000, + Codec: message.CodecH264, + IsKeyFrame: idrPresent, + Type: message.VideoTypeAU, + Payload: avcc, + DTS: dts, + PTSDelta: pts - dts, + }) +} + +// WriteMPEG4Audio writes MPEG-4 Audio data. +func (w *Writer) WriteMPEG4Audio(pts time.Duration, au []byte) error { + return w.conn.Write(&message.Audio{ + ChunkStreamID: message.AudioChunkStreamID, + MessageStreamID: 0x1000000, + Codec: message.CodecMPEG4Audio, + Rate: flvio.SOUND_44Khz, + Depth: flvio.SOUND_16BIT, + Channels: flvio.SOUND_STEREO, + AACType: message.AudioAACTypeAU, + Payload: au, + DTS: pts, + }) +} + +// WriteMPEG2Audio writes MPEG-2 Audio data. +func (w *Writer) WriteMPEG2Audio(pts time.Duration, h *mpeg2audio.FrameHeader, frame []byte) error { + return w.conn.Write(&message.Audio{ + ChunkStreamID: message.AudioChunkStreamID, + MessageStreamID: 0x1000000, + Codec: message.CodecMPEG2Audio, + Rate: mpeg2AudioRate(h.SampleRate), + Depth: flvio.SOUND_16BIT, + Channels: mpeg2AudioChannels(h.ChannelMode), + Payload: frame, + DTS: pts, + }) +} diff --git a/internal/rtmp/tracks/write_test.go b/internal/rtmp/writer_test.go similarity index 94% rename from internal/rtmp/tracks/write_test.go rename to internal/rtmp/writer_test.go index f08042fd..ed8bb945 100644 --- a/internal/rtmp/tracks/write_test.go +++ b/internal/rtmp/writer_test.go @@ -1,4 +1,4 @@ -package tracks +package rtmp import ( "bytes" @@ -13,11 +13,7 @@ import ( "github.com/bluenviron/mediamtx/internal/rtmp/message" ) -func TestWrite(t *testing.T) { - var buf bytes.Buffer - bc := bytecounter.NewReadWriter(&buf) - mrw := message.NewReadWriter(bc, true) - +func TestWriteTracks(t *testing.T) { videoTrack := &formats.H264{ PayloadTyp: 96, SPS: []byte{ @@ -43,9 +39,16 @@ func TestWrite(t *testing.T) { IndexDeltaLength: 3, } - err := Write(mrw, videoTrack, audioTrack) + var buf bytes.Buffer + c := NewConn(&buf) + c.skipInitialization() + + _, err := NewWriter(c, videoTrack, audioTrack) require.NoError(t, err) + bc := bytecounter.NewReadWriter(&buf) + mrw := message.NewReadWriter(bc, true) + msg, err := mrw.Read() require.NoError(t, err) require.Equal(t, &message.DataAMF0{