diff --git a/pkg/codec/vpx/params.go b/pkg/codec/vpx/params.go index 18767dc..8dcecc9 100644 --- a/pkg/codec/vpx/params.go +++ b/pkg/codec/vpx/params.go @@ -16,6 +16,7 @@ type Params struct { RateControlOvershootPercent uint RateControlMinQuantizer uint RateControlMaxQuantizer uint + LagInFrames uint ErrorResilient ErrorResilientMode } diff --git a/pkg/codec/vpx/vpx.go b/pkg/codec/vpx/vpx.go index e21a936..06fe7bb 100644 --- a/pkg/codec/vpx/vpx.go +++ b/pkg/codec/vpx/vpx.go @@ -23,6 +23,9 @@ package vpx // int pktSz(vpx_codec_cx_pkt_t *pkt) { // return pkt->data.frame.sz; // } +// vpx_codec_frame_flags_t pktFrameFlags(vpx_codec_cx_pkt_t *pkt) { +// return pkt->data.frame.flags; +// } // // // Alloc helpers // vpx_codec_ctx_t *newCtx() { @@ -61,15 +64,17 @@ import ( ) type encoder struct { - codec *C.vpx_codec_ctx_t - raw *C.vpx_image_t - cfg *C.vpx_codec_enc_cfg_t - r video.Reader - frameIndex int - tStart int - tLastFrame int - frame []byte - deadline int + codec *C.vpx_codec_ctx_t + raw *C.vpx_image_t + cfg *C.vpx_codec_enc_cfg_t + r video.Reader + frameIndex int + tStart int + tLastFrame int + frame []byte + deadline int + requireKeyFrame bool + isKeyFrame bool mu sync.Mutex closed bool @@ -141,6 +146,7 @@ func newParams(codecIface *C.vpx_codec_iface_t) (Params, error) { RateControlOvershootPercent: uint(cfg.rc_overshoot_pct), RateControlMinQuantizer: uint(cfg.rc_min_quantizer), RateControlMaxQuantizer: uint(cfg.rc_max_quantizer), + LagInFrames: uint(cfg.g_lag_in_frames), ErrorResilient: ErrorResilientMode(cfg.g_error_resilient), }, nil } @@ -171,6 +177,7 @@ func newEncoder(r video.Reader, p prop.Media, params Params, codecIface *C.vpx_c cfg.g_h = C.uint(p.Height) cfg.g_timebase.num = 1 cfg.g_timebase.den = 1000 + cfg.g_lag_in_frames = C.uint(params.LagInFrames) cfg.rc_target_bitrate = C.uint(params.BitRate) / 1000 cfg.kf_max_dist = C.uint(params.KeyFrameInterval) @@ -254,6 +261,9 @@ func (e *encoder) Read() ([]byte, func(), error) { duration = 1 } var flags int + if e.requireKeyFrame { + flags = flags | C.VPX_EFLAG_FORCE_KF + } if ec := C.encode_wrapper( e.codec, e.raw, C.long(t-e.tStart), C.ulong(duration), C.long(flags), C.ulong(e.deadline), @@ -262,6 +272,7 @@ func (e *encoder) Read() ([]byte, func(), error) { return nil, func() {}, fmt.Errorf("vpx_codec_encode failed (%d)", ec) } + e.requireKeyFrame = false e.frameIndex++ e.tLastFrame = t @@ -273,6 +284,7 @@ func (e *encoder) Read() ([]byte, func(), error) { break } if pkt.kind == C.VPX_CODEC_CX_FRAME_PKT { + e.isKeyFrame = C.pktFrameFlags(pkt)&C.VPX_FRAME_IS_KEY == C.VPX_FRAME_IS_KEY encoded := C.GoBytes(unsafe.Pointer(C.pktBuf(pkt)), C.pktSz(pkt)) e.frame = append(e.frame, encoded...) } @@ -288,7 +300,10 @@ func (e *encoder) SetBitRate(b int) error { } func (e *encoder) ForceKeyFrame() error { - panic("ForceKeyFrame is not implemented") + e.mu.Lock() + defer e.mu.Unlock() + e.requireKeyFrame = true + return nil } func (e *encoder) Close() error { diff --git a/pkg/codec/vpx/vpx_test.go b/pkg/codec/vpx/vpx_test.go index 16bca55..0e3cf8e 100644 --- a/pkg/codec/vpx/vpx_test.go +++ b/pkg/codec/vpx/vpx_test.go @@ -20,6 +20,8 @@ func TestImageSizeChange(t *testing.T) { }, "VP9": func() (codec.VideoEncoderBuilder, error) { p, err := NewVP9Params() + // Disable latency to ease test and begin to receive packets for each input frame + p.LagInFrames = 0 return &p, err }, } { @@ -87,3 +89,72 @@ func TestImageSizeChange(t *testing.T) { }) } } + +func TestRequestKeyFrame(t *testing.T) { + for name, factory := range map[string]func() (codec.VideoEncoderBuilder, error){ + "VP8": func() (codec.VideoEncoderBuilder, error) { + p, err := NewVP8Params() + return &p, err + }, + "VP9": func() (codec.VideoEncoderBuilder, error) { + p, err := NewVP9Params() + // Disable latency to ease test and begin to receive packets for each input frame + p.LagInFrames = 0 + return &p, err + }, + } { + factory := factory + t.Run(name, func(t *testing.T) { + param, err := factory() + if err != nil { + t.Fatal(err) + } + + var initialWidth, initialHeight, width, height int = 320, 240, 320, 240 + + var cnt uint32 + r, err := param.BuildVideoEncoder( + video.ReaderFunc(func() (image.Image, func(), error) { + i := atomic.AddUint32(&cnt, 1) + if i == 3 { + return nil, nil, io.EOF + } + return image.NewYCbCr( + image.Rect(0, 0, width, height), + image.YCbCrSubsampleRatio420, + ), func() {}, nil + }), + prop.Media{ + Video: prop.Video{ + Width: initialWidth, + Height: initialHeight, + FrameRate: 1, + FrameFormat: frame.FormatI420, + }, + }, + ) + if err != nil { + t.Fatal(err) + } + _, rel, err := r.Read() + if err != nil { + t.Fatal(err) + } + rel() + r.ForceKeyFrame() + _, rel, err = r.Read() + if err != nil { + t.Fatal(err) + } + if !r.(*encoder).isKeyFrame { + t.Fatal("Not a key frame") + } + rel() + _, _, err = r.Read() + if err != io.EOF { + t.Fatal(err) + } + }) + + } +}