diff --git a/go.mod b/go.mod index ac84c69..bad26a6 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,11 @@ module mediasource -go 1.23 +go 1.23.3 + +require ( + github.com/harshabose/simple_webrtc_comm/transcode v0.0.0 + github.com/harshabose/tools/buffer v0.0.0 +) require ( github.com/google/uuid v1.6.0 // indirect @@ -24,8 +29,6 @@ require ( golang.org/x/crypto v0.32.0 // indirect golang.org/x/net v0.34.0 // indirect golang.org/x/sys v0.29.0 // indirect - github.com/harshabose/simple_webrtc_comm/transcode v0.0.0 - github.com/harshabose/tools/buffer v0.0.0 ) replace ( diff --git a/pkg/bandwidth.go b/pkg/bandwidth.go new file mode 100644 index 0000000..3331ed8 --- /dev/null +++ b/pkg/bandwidth.go @@ -0,0 +1,79 @@ +package mediasource + +import ( + "context" + "errors" + "github.com/pion/interceptor/pkg/cc" + "sync" + "time" +) + +type consumer struct { + channel chan int64 + track *Track +} + +type bandwidthEstimator struct { + estimator cc.BandwidthEstimator + consumers map[string]*consumer + interval time.Duration + mutex sync.RWMutex + ctx context.Context +} + +func (be *bandwidthEstimator) Start() { + go be.loop() +} + +func (be *bandwidthEstimator) SetConsumer(id string, setChannel func(chan int64), track *Track) error { + be.mutex.Lock() + defer be.mutex.Unlock() + + if _, exits := be.consumers[id]; exits { + return errors.New("consumer already exists") + } + + be.consumers[id] = &consumer{channel: make(chan int64), track: track} + setChannel(be.consumers[id].channel) + + return nil +} + +func (be *bandwidthEstimator) loop() { + // wait here + for { + be.mutex.RLock() + + select { + case <-be.ctx.Done(): + return + default: + be.estimate() + } + + be.mutex.RUnlock() + } +} + +func (be *bandwidthEstimator) estimate() { + var totalPriority Priority + + if len(be.consumers) == 0 { + return + } + + for _, consumer := range be.consumers { + totalPriority += consumer.track.priority + } + + totalBitrate := be.estimator.GetTargetBitrate() + + for _, consumer := range be.consumers { + if consumer.track.priority == Level0 { + continue + } + select { + case consumer.channel <- int64(float64(totalBitrate) * float64(consumer.track.priority) / float64(totalPriority)): + } + } +} diff --git a/pkg/localtrack_options.go b/pkg/localtrack_options.go deleted file mode 100644 index 684eebd..0000000 --- a/pkg/localtrack_options.go +++ /dev/null @@ -1,38 +0,0 @@ -package mediasource - -import "github.com/pion/webrtc/v4" - -type Option = func(*Track) error - -func WithH264(clockrate uint32) Option { - return func(track *Track) error { - var ( - err error = nil - ) - if track.track, err = webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeH264, - ClockRate: clockrate, - SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=420029", - }, "video", "webrtc"); err != nil { - return err - } - return nil - } -} - -func WithOpus(samplerate uint32, channelLayout uint16) Option { - return func(track *Track) error { - var ( - err error = nil - ) - - if track.track, err = webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeOpus, - ClockRate: samplerate, - Channels: channelLayout, - }, "audio", "webrtc"); err != nil { - return err - } - return nil - } -} diff --git a/internal/stream.go b/pkg/stream.go similarity index 60% rename from internal/stream.go rename to pkg/stream.go index 758c6b0..b2fb27e 100644 --- a/internal/stream.go +++ b/pkg/stream.go @@ -1,8 +1,8 @@ -package internal +package mediasource import ( "context" - "fmt" + "mediasource/internal" "time" "github.com/asticode/go-astiav" @@ -11,13 +11,6 @@ import ( "github.com/pion/webrtc/v4/pkg/media" ) -type Options struct { - DemuxerOptions []transcode.DemuxerOption - DecoderOptions []transcode.DecoderOption - FilterOptions []transcode.FilterOption - EncoderOptions []transcode.EncoderOption -} - type Stream struct { demuxer *transcode.Demuxer decoder *transcode.Decoder @@ -27,38 +20,21 @@ type Stream struct { ctx context.Context } -func CreateStream(ctx context.Context, containerAddress string, options *Options) (*Stream, error) { +func CreateStream(ctx context.Context, options ...StreamOption) (*Stream, error) { var ( - demuxer *transcode.Demuxer - decoder *transcode.Decoder - filter *transcode.Filter - encoder *transcode.Encoder - err error + err error + stream *Stream = &Stream{ctx: ctx} ) - if demuxer, err = transcode.CreateDemuxer(ctx, containerAddress, options.DemuxerOptions...); err != nil { - return nil, err - } - if decoder, err = transcode.CreateDecoder(ctx, demuxer, append([]transcode.DecoderOption{demuxer.GetDecoderContextOptions()}, options.DecoderOptions...)...); err != nil { - return nil, err - } - if filter, err = transcode.CreateFilter(ctx, decoder, transcode.VideoFilters, decoder.GetSrcFilterContextOptions(), transcode.WithDefaultVideoFilterContentOptions); err != nil { - return nil, err - } - if encoder, err = transcode.CreateEncoder(ctx, filter, transcode.WithLowLatencyVideoEncoderSetting); err != nil { - return nil, err + for _, option := range options { + if err = option(stream); err != nil { + return nil, err + } } - fmt.Println("started encoder with settings:") + stream.buffer = buffer.CreateChannelBuffer(ctx, stream.encoder.GetFPS()*3, internal.CreateSamplePool()) - return &Stream{ - demuxer: demuxer, - decoder: decoder, - filter: filter, - encoder: encoder, - buffer: buffer.CreateChannelBuffer(ctx, encoder.GetFPS()*3, CreateSamplePool()), - ctx: ctx, - }, nil + return stream, nil } func (stream *Stream) Start() { diff --git a/pkg/stream_options.go b/pkg/stream_options.go new file mode 100644 index 0000000..6d2cd0c --- /dev/null +++ b/pkg/stream_options.go @@ -0,0 +1,48 @@ +package mediasource + +import ( + "github.com/asticode/go-astiav" + transcode "github.com/harshabose/simple_webrtc_comm/transcode/pkg" +) + +type StreamOption = func(*Stream) error + +func WithDemuxer(containerAddress string, options ...transcode.DemuxerOption) StreamOption { + return func(stream *Stream) error { + var err error + if stream.demuxer, err = transcode.CreateDemuxer(stream.ctx, containerAddress, options...); err != nil { + return err + } + return nil + } +} + +func WithDecoder(options ...transcode.DecoderOption) StreamOption { + return func(stream *Stream) error { + var err error + if stream.decoder, err = transcode.CreateDecoder(stream.ctx, stream.demuxer, options...); err != nil { + return err + } + return nil + } +} + +func WithFilter(filterConfig *transcode.FilterConfig, options ...transcode.FilterOption) StreamOption { + return func(stream *Stream) error { + var err error + if stream.filter, err = transcode.CreateFilter(stream.ctx, stream.decoder, filterConfig, options...); err != nil { + return err + } + return nil + } +} + +func WithEncoder(codec astiav.CodecID, options ...transcode.EncoderOption) StreamOption { + return func(stream *Stream) error { + var err error + if stream.encoder, err = transcode.CreateEncoder(stream.ctx, codec, stream.filter, options...); err != nil { + return err + } + return nil + } +} diff --git a/pkg/track.go b/pkg/track.go index 0c24592..efbbd4d 100644 --- a/pkg/track.go +++ b/pkg/track.go @@ -3,24 +3,23 @@ package mediasource import ( "context" "fmt" - "time" - "github.com/pion/webrtc/v4" "github.com/pion/webrtc/v4/pkg/media" - - "mediasource/internal" ) // NO BUFFER IMPLEMENTATION type Track struct { - track *webrtc.TrackLocalStaticSample - stream *internal.Stream - ctx context.Context + track *webrtc.TrackLocalStaticSample + rtpSender *webrtc.RTPSender + stream *Stream + priority Priority + ctx context.Context } -func CreateLocalTrack(ctx context.Context, stream *internal.Stream, options ...Option) (*Track, error) { - track := &Track{stream: stream, ctx: ctx} +func CreateTrack(ctx context.Context, peerConnection *webrtc.PeerConnection, options ...TrackOption) (*Track, error) { + var err error + track := &Track{ctx: ctx} for _, option := range options { if err := option(track); err != nil { @@ -28,6 +27,10 @@ func CreateLocalTrack(ctx context.Context, stream *internal.Stream, options ...O } } + if track.rtpSender, err = peerConnection.AddTrack(track.track); err != nil { + return nil, err + } + return track, nil } @@ -49,12 +52,8 @@ func (track *Track) loop() { var ( sample *media.Sample = nil err error = nil - ticker *time.Ticker = nil ) - ticker = time.NewTicker(time.Second) - defer ticker.Stop() - loop: for { select { diff --git a/pkg/track_options.go b/pkg/track_options.go new file mode 100644 index 0000000..fe9b527 --- /dev/null +++ b/pkg/track_options.go @@ -0,0 +1,76 @@ +package mediasource + +import ( + "github.com/pion/webrtc/v4" +) + +type TrackOption = func(*Track) error + +func WithH264Track(clockrate uint32, id string) TrackOption { + return func(track *Track) error { + var ( + err error = nil + ) + + if track.track, err = webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeH264, + ClockRate: clockrate, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=420029", + }, id, "webrtc"); err != nil { + return err + } + return nil + } +} + +func WithOpusTrack(samplerate uint32, channelLayout uint16, id string) TrackOption { + return func(track *Track) error { + var ( + err error = nil + ) + + if track.track, err = webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: samplerate, + Channels: channelLayout, + }, id, "webrtc"); err != nil { + return err + } + return nil + } +} + +func WithStreamOptions(options ...StreamOption) TrackOption { + return func(track *Track) error { + for _, option := range options { + if err := option(track.stream); err != nil { + return err + } + } + return nil + } +} + +func WithPriority(level Priority) TrackOption { + return func(track *Track) error { + track.priority = level + return nil + } +} + +type Priority uint8 + +const ( + Level0 Priority = 0 + Level1 Priority = 1 + Level2 Priority = 2 + Level3 Priority = 3 + Level4 Priority = 4 + Level5 Priority = 5 +) + +func withBandwidthControl(estimator *bandwidthEstimator) TrackOption { + return func(track *Track) error { + return estimator.SetConsumer(track.track.ID(), track.stream.encoder.SetBitrateChannel, track) + } +} diff --git a/pkg/tracks.go b/pkg/tracks.go new file mode 100644 index 0000000..54f2013 --- /dev/null +++ b/pkg/tracks.go @@ -0,0 +1,63 @@ +package mediasource + +import ( + "context" + "errors" + "github.com/pion/interceptor" + "github.com/pion/webrtc/v4" +) + +type Tracks struct { + bwEstimator *bandwidthEstimator + mediaEngine *webrtc.MediaEngine + interceptorRegistry *interceptor.Registry + tracks map[string]*Track + ctx context.Context +} + +func CreateTracks(ctx context.Context, mediaEngine *webrtc.MediaEngine, interceptorRegistry *interceptor.Registry, options ...TracksOption) (*Tracks, error) { + tracks := &Tracks{ + tracks: make(map[string]*Track), + mediaEngine: mediaEngine, + interceptorRegistry: interceptorRegistry, + ctx: ctx, + } + + for _, option := range options { + if err := option(tracks); err != nil { + return nil, err + } + } + + return tracks, nil +} + +func (tracks *Tracks) CreateTrack(peerConnection *webrtc.PeerConnection, options ...TrackOption) error { + var ( + track *Track + err error + ) + if track, err = CreateTrack(tracks.ctx, peerConnection, append(options, withBandwidthControl(tracks.bwEstimator))...); err != nil { + return err + } + if _, exists := tracks.tracks[track.track.ID()]; exists { + return errors.New("track already exists") + } + tracks.tracks[track.track.ID()] = track + return nil +} + +func (tracks *Tracks) StartTrack(id string) { + if track, ok := tracks.tracks[id]; ok { + track.Start() + } + if tracks.bwEstimator != nil { + tracks.bwEstimator.Start() + } +} + +func (tracks *Tracks) StartAll() { + for _, track := range tracks.tracks { + tracks.StartTrack(track.track.ID()) + } +} diff --git a/pkg/tracks_options.go b/pkg/tracks_options.go new file mode 100644 index 0000000..910883c --- /dev/null +++ b/pkg/tracks_options.go @@ -0,0 +1,256 @@ +package mediasource + +import ( + "fmt" + "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/pkg/flexfec" + "github.com/pion/interceptor/pkg/gcc" + "github.com/pion/interceptor/pkg/jitterbuffer" + "github.com/pion/interceptor/pkg/nack" + "github.com/pion/interceptor/pkg/report" + "github.com/pion/interceptor/pkg/twcc" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + "time" +) + +type TracksOption = func(*Tracks) error + +type PacketisationMode uint8 + +const ( + PacketisationMode0 PacketisationMode = 0 + PacketisationMode1 PacketisationMode = 1 + PacketisationMode2 PacketisationMode = 2 +) + +type ProfileLevel string + +const ( + ProfileLevelBaseline21 ProfileLevel = "420015" // Level 2.1 (480p) + ProfileLevelBaseline31 ProfileLevel = "42001f" // Level 3.1 (720p) + ProfileLevelBaseline41 ProfileLevel = "420029" // Level 4.1 (1080p) + ProfileLevelBaseline42 ProfileLevel = "42002a" // Level 4.2 (2K) + + ProfileLevelMain21 ProfileLevel = "4D0015" // Level 2.1 + ProfileLevelMain31 ProfileLevel = "4D001f" // Level 3.1 + ProfileLevelMain41 ProfileLevel = "4D0029" // Level 4.1 + ProfileLevelMain42 ProfileLevel = "4D002a" // Level 4.2 + + ProfileLevelHigh21 ProfileLevel = "640015" // Level 2.1 + ProfileLevelHigh31 ProfileLevel = "64001f" // Level 3.1 + ProfileLevelHigh41 ProfileLevel = "640029" // Level 4.1 + ProfileLevelHigh42 ProfileLevel = "64002a" // Level 4.2 +) + +func WithH264MediaEngine(clockrate uint32, packetisationMode PacketisationMode, profileLevelID ProfileLevel) TracksOption { + return func(tracks *Tracks) error { + if err := tracks.mediaEngine.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeH264, + ClockRate: clockrate, + Channels: 0, + SDPFmtpLine: fmt.Sprintf("level-asymmetry-allowed=1;packetization-mode=%d;profile-level-id=%s", packetisationMode, profileLevelID), + }, + PayloadType: 96, + }, webrtc.RTPCodecTypeVideo); err != nil { + return err + } + return nil + } +} + +type StereoType uint8 + +const ( + Mono StereoType = 0 + Stereo StereoType = 1 +) + +func WithOpusMediaEngine(samplerate uint32, channelLayout uint16, stereo StereoType) TracksOption { + return func(tracks *Tracks) error { + if err := tracks.mediaEngine.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: samplerate, + Channels: channelLayout, + SDPFmtpLine: fmt.Sprintf("minptime=10;useinbandfec=1;stereo=%d", stereo), + }, + PayloadType: 111, + }, webrtc.RTPCodecTypeAudio); err != nil { + return err + } + return nil + } +} + +type NACKGeneratorOptions []nack.GeneratorOption + +var ( + NACKGeneratorLowLatency NACKGeneratorOptions = []nack.GeneratorOption{nack.GeneratorSize(256), nack.GeneratorSkipLastN(2), nack.GeneratorMaxNacksPerPacket(1), nack.GeneratorInterval(50 * time.Millisecond)} + NACKGeneratorDefault NACKGeneratorOptions = []nack.GeneratorOption{nack.GeneratorSize(512), nack.GeneratorSkipLastN(5), nack.GeneratorMaxNacksPerPacket(2), nack.GeneratorInterval(100 * time.Millisecond)} + NACKGeneratorHighQuality NACKGeneratorOptions = []nack.GeneratorOption{nack.GeneratorSize(2048), nack.GeneratorSkipLastN(10), nack.GeneratorMaxNacksPerPacket(3), nack.GeneratorInterval(200 * time.Millisecond)} + NACKGeneratorLowBandwidth NACKGeneratorOptions = []nack.GeneratorOption{nack.GeneratorSize(4096), nack.GeneratorSkipLastN(15), nack.GeneratorMaxNacksPerPacket(4), nack.GeneratorInterval(150 * time.Millisecond)} +) + +type NACKResponderOptions []nack.ResponderOption + +var ( + NACKResponderLowLatency NACKResponderOptions = []nack.ResponderOption{nack.ResponderSize(256), nack.DisableCopy()} + NACKResponderDefault NACKResponderOptions = []nack.ResponderOption{nack.ResponderSize(1024)} + NACKResponderHighQuality NACKResponderOptions = []nack.ResponderOption{nack.ResponderSize(2048)} + NACKResponderLowBandwidth NACKResponderOptions = []nack.ResponderOption{nack.ResponderSize(4096)} +) + +func WithNACKInterceptor(generatorOptions NACKGeneratorOptions, responderOptions NACKResponderOptions) TracksOption { + return func(tracks *Tracks) error { + var ( + generator *nack.GeneratorInterceptorFactory + responder *nack.ResponderInterceptorFactory + err error + ) + if generator, err = nack.NewGeneratorInterceptor(generatorOptions...); err != nil { + return err + } + if responder, err = nack.NewResponderInterceptor(responderOptions...); err != nil { + return err + } + + tracks.mediaEngine.RegisterFeedback(webrtc.RTCPFeedback{Type: "nack"}, webrtc.RTPCodecTypeVideo) + tracks.interceptorRegistry.Add(responder) + tracks.interceptorRegistry.Add(generator) + + return nil + } +} + +type TWCCSenderInterval time.Duration + +const ( + TWCCIntervalLowLatency = TWCCSenderInterval(50 * time.Millisecond) + TWCCIntervalDefault = TWCCSenderInterval(100 * time.Millisecond) + TWCCIntervalHighQuality = TWCCSenderInterval(200 * time.Millisecond) + TWCCIntervalLowBandwidth = TWCCSenderInterval(500 * time.Millisecond) +) + +func WithTWCCSenderInterceptor(interval TWCCSenderInterval) TracksOption { + return func(tracks *Tracks) error { + var ( + generator *twcc.SenderInterceptorFactory + err error + ) + + tracks.mediaEngine.RegisterFeedback(webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBTransportCC}, webrtc.RTPCodecTypeVideo) + if err := tracks.mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: sdp.TransportCCURI}, webrtc.RTPCodecTypeVideo); err != nil { + return err + } + + tracks.mediaEngine.RegisterFeedback(webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBTransportCC}, webrtc.RTPCodecTypeAudio) + if err := tracks.mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: sdp.TransportCCURI}, webrtc.RTPCodecTypeAudio); err != nil { + return err + } + + if generator, err = twcc.NewSenderInterceptor(twcc.SendInterval(time.Duration(interval))); err != nil { + return err + } + + tracks.interceptorRegistry.Add(generator) + return nil + } +} + +// NOTE: THIS SHOULD BE USED WITH WithTWCCSenderInterceptor and the interval needs to be same + +func WithBandwidthEstimatorInterceptor(initialBitrate int, interval time.Duration) TracksOption { + return func(tracks *Tracks) error { + var ( + congestionController *cc.InterceptorFactory + err error + ) + + tracks.bwEstimator = &bandwidthEstimator{ctx: tracks.ctx, consumers: make(map[string]*consumer), interval: interval} + tracks.bwEstimator.Start() + + if congestionController, err = cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { + return gcc.NewSendSideBWE(gcc.SendSideBWEInitialBitrate(initialBitrate)) + }); err != nil { + return err + } + + congestionController.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) { + tracks.bwEstimator.estimator = estimator + tracks.bwEstimator.Start() + }) + + tracks.interceptorRegistry.Add(congestionController) + if err = webrtc.ConfigureTWCCHeaderExtensionSender(tracks.mediaEngine, tracks.interceptorRegistry); err != nil { + return err + } + + return nil + } +} + +func WithJitterBufferInterceptor() TracksOption { + return func(tracks *Tracks) error { + var ( + jitterBuffer *jitterbuffer.InterceptorFactory + err error + ) + + if jitterBuffer, err = jitterbuffer.NewInterceptor(); err != nil { + return err + } + tracks.interceptorRegistry.Add(jitterBuffer) + return nil + } +} + +type RTCPReportInterval time.Duration + +const ( + RTCPReportIntervalLowLatency = RTCPReportInterval(50 * time.Millisecond) + RTCPReportIntervalDefault = RTCPReportInterval(1 * time.Second) + RTCPReportIntervalHighQuality = RTCPReportInterval(200 * time.Millisecond) + RTCPReportIntervalLowBandwidth = RTCPReportInterval(2 * time.Second) +) + +func WithRTCPReportsInterceptor(interval RTCPReportInterval) TracksOption { + return func(tracks *Tracks) error { + var ( + sender *report.SenderInterceptorFactory + receiver *report.ReceiverInterceptorFactory + err error + ) + + if sender, err = report.NewSenderInterceptor(report.SenderInterval(time.Duration(interval))); err != nil { + return err + } + if receiver, err = report.NewReceiverInterceptor(report.ReceiverInterval(time.Duration(interval))); err != nil { + return err + } + + tracks.interceptorRegistry.Add(receiver) + tracks.interceptorRegistry.Add(sender) + + return nil + } +} + +// WARN: DO NOT USE FLEXFEC YET, AS THE FECOPTION ARE NOT YET IMPLEMENTED +func WithFLEXFECInterceptor() TracksOption { + return func(tracks *Tracks) error { + var ( + fecInterceptor *flexfec.FecInterceptorFactory + err error + ) + + // NOTE: Pion's FLEXFEC does not implement FecOption yet, if needed, someone needs to contribute to the repo + if fecInterceptor, err = flexfec.NewFecInterceptor(); err != nil { + return err + } + + tracks.interceptorRegistry.Add(fecInterceptor) + return nil + } +}