From 4d8e2ca5d26ef98c47e67a2bd4d381aede6fc512 Mon Sep 17 00:00:00 2001 From: dexter <178529795@qq.com> Date: Sun, 6 Feb 2022 08:50:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E4=B8=8Ertmp=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E8=B0=83=E9=80=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- codec/mp4.go | 6 +- common/frame.go | 42 +++++++---- common/index.go | 5 +- common/stream.go | 2 +- events.go | 3 +- main.go | 136 ++++++++++++++++++++++------------- publisher.go | 37 +++++++++- stream.go | 175 ++++++++++++++++++++++++++++----------------- subscriber.go | 94 +++++++++++++++--------- track/aac.go | 10 +-- track/audio.go | 31 ++++---- track/base.go | 10 +++ track/g711.go | 6 +- track/h264.go | 34 +++++---- track/h265.go | 27 ++++--- track/video.go | 80 ++++++++++++++++----- tracks.go | 17 +++-- util/big_endian.go | 9 +++ util/buffer.go | 98 ++++++++++++++++++++++++- util/index.go | 93 ++++++++++++++++++++++++ util/socket.go | 28 ++++---- 21 files changed, 682 insertions(+), 261 deletions(-) diff --git a/codec/mp4.go b/codec/mp4.go index 27d590f..8ac326a 100644 --- a/codec/mp4.go +++ b/codec/mp4.go @@ -80,7 +80,7 @@ type FileTypeBox struct { func NewFileTypeBox() (box *FileTypeBox) { box = new(FileTypeBox) - box.MP4BoxHeader.BoxType = util.ReadBE[uint32]([]byte("ftyp")) + util.GetBE([]byte("ftyp"), &box.MP4BoxHeader.BoxType) return } @@ -121,9 +121,7 @@ type MovieBox struct { func NewMovieBox() (box *MovieBox) { box = new(MovieBox) - - box.MP4BoxHeader.BoxType = util.ReadBE[uint32]([]byte("moov")) - + util.GetBE([]byte("moov"), &box.MP4BoxHeader.BoxType) return } diff --git a/common/frame.go b/common/frame.go index 5caca1e..cead579 100644 --- a/common/frame.go +++ b/common/frame.go @@ -11,16 +11,6 @@ import ( type NALUSlice net.Buffers type H264Slice NALUSlice type H265Slice NALUSlice -type BuffersType interface { - NALUSlice | net.Buffers -} - -func SizeOfBuffers[T BuffersType](buf T) (size int) { - for _, b := range buf { - size += len(b) - } - return -} type H264NALU []NALUSlice type H265NALU []NALUSlice @@ -34,14 +24,11 @@ type RawSlice interface { NALUSlice | AudioSlice } -func (nalu H264NALU) IFrame() bool { - return H264Slice(nalu[0]).Type() == codec.NALU_IDR_Picture -} func (nalu *H264NALU) Append(slice ...NALUSlice) { *nalu = append(*nalu, slice...) } func (nalu H264Slice) Type() byte { - return nalu[0][0] & 0b0001_1111 + return nalu[0][0] & 0x1F } func (nalu H265Slice) Type() byte { return nalu[0][0] & 0x7E >> 1 @@ -124,3 +111,30 @@ func (avcc AVCCFrame) VideoCodecID() byte { func (avcc AVCCFrame) AudioCodecID() byte { return avcc[0] >> 4 } + +// func (annexb AnnexBFrame) ToSlices() (ret []NALUSlice) { +// for len(annexb) > 0 { +// before, after, found := bytes.Cut(annexb, codec.NALU_Delimiter1) +// if !found { +// return append(ret, NALUSlice{annexb}) +// } +// if len(before) > 0 { +// ret = append(ret, NALUSlice{before}) +// } +// annexb = after +// } +// return +// } +// func (annexb AnnexBFrame) ToNALUs() (ret [][]NALUSlice) { +// for len(annexb) > 0 { +// before, after, found := bytes.Cut(annexb, codec.NALU_Delimiter1) +// if !found { +// return append(ret, annexb.ToSlices()) +// } +// if len(before) > 0 { +// ret = append(ret, AnnexBFrame(before).ToSlices()) +// } +// annexb = after +// } +// return +// } diff --git a/common/index.go b/common/index.go index 1887e79..c310afc 100644 --- a/common/index.go +++ b/common/index.go @@ -3,14 +3,13 @@ package common import "time" type Track interface { - Get(size int) (result []byte) - Put(b []byte) + GetName() string } type AVTrack interface { Track WriteAVCC(ts uint32, frame AVCCFrame) //写入AVCC格式的数据 - Flush() + Flush() } type BPS struct { diff --git a/common/stream.go b/common/stream.go index 183c4a9..4236cac 100644 --- a/common/stream.go +++ b/common/stream.go @@ -5,5 +5,5 @@ import "context" type IStream interface { context.Context Update() uint32 - AddTrack(string, Track) + AddTrack(Track) } diff --git a/events.go b/events.go index dbee6fd..1713062 100644 --- a/events.go +++ b/events.go @@ -10,6 +10,7 @@ type TransCodeReq struct { } const ( + Event_REQUEST_PUBLISH = "RequestPublish" //当前流丢失发布者,或者订阅者订阅了空流时触发 Event_SUBSCRIBE = "Subscribe" Event_UNSUBSCRIBE = "UnSubscibe" Event_STREAMCLOSE = "StreamClose" @@ -17,4 +18,4 @@ const ( Event_REQUEST_TRANSAUDIO = "RequestTransAudio" ) -var Bus = EventBus.New() \ No newline at end of file +var Bus = EventBus.New() diff --git a/main.go b/main.go index 14d5c24..5b79aac 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "path/filepath" + "reflect" "runtime" "strings" "time" // colorable @@ -30,50 +31,75 @@ func (s Second) Duration() time.Duration { } // StreamConfig 流的三级覆盖配置(全局,插件,流) -type StreamConfig struct { + +type PublishConfig struct { EnableAudio bool EnableVideo bool + KillExit bool // 是否踢掉已经存在的发布者 AutoReconnect bool // 自动重连 PullOnStart bool // 启动时拉流 PullOnSubscribe bool // 订阅时自动拉流 PublishTimeout Second // 发布无数据超时 - WaitTimeout Second // 等待流超时 WaitCloseTimeout Second // 延迟自动关闭(无订阅时) } +type SubscribeConfig struct { + EnableAudio bool + EnableVideo bool + IFrameOnly bool // 只要关键帧 + WaitTimeout Second // 等待流超时 +} + var ( + DefaultPublishConfig = PublishConfig{ + true, true, false, true, true, true, 10, 10, + } + DefaultSubscribeConfig = SubscribeConfig{ + true, true, false, 10, + } config = &struct { - StreamConfig + Publish PublishConfig + Subscribe SubscribeConfig RTPReorder bool - }{StreamConfig{true, true, true, true, true, 10, 10, 0}, false} + }{DefaultPublishConfig, DefaultSubscribeConfig, false} // ConfigRaw 配置信息的原始数据 ConfigRaw []byte - StartTime time.Time //启动时间 - Plugins = make(map[string]*PluginConfig) // Plugins 所有的插件配置 + StartTime time.Time //启动时间 + Plugins = make(map[string]*Plugin) // Plugins 所有的插件配置 Ctx context.Context settingDir string ) -//PluginConfig 插件配置定义 -type PluginConfig struct { - Name string //插件名称 - Config interface{} //插件配置 - Version string //插件版本 - Dir string //插件代码路径 - Run func() //插件启动函数 - HotConfig map[string]func(interface{}) //热修改配置 +type PluginConfig interface { + Update(map[string]any) } -// InstallPlugin 安装插件 -func (opt *PluginConfig) Install(run func()) { - opt.Run = run - _, pluginFilePath, _, _ := runtime.Caller(1) - opt.Dir = filepath.Dir(pluginFilePath) - if parts := strings.Split(opt.Dir, "@"); len(parts) > 1 { - opt.Version = parts[len(parts)-1] +func InstallPlugin(config PluginConfig) *Plugin { + name := strings.TrimSuffix(reflect.TypeOf(config).Elem().Name(), "Config") + plugin := &Plugin{ + Name: name, + Config: config, + Modified: make(map[string]any), } - Plugins[opt.Name] = opt - util.Print(Green("install plugin"), BrightCyan(opt.Name), BrightBlue(opt.Version)) + _, pluginFilePath, _, _ := runtime.Caller(1) + configDir := filepath.Dir(pluginFilePath) + if parts := strings.Split(configDir, "@"); len(parts) > 1 { + plugin.Version = parts[len(parts)-1] + } + if _, ok := Plugins[name]; ok { + return nil + } + Plugins[name] = plugin + log.Print(Green("install plugin"), BrightCyan(name), BrightBlue(plugin.Version)) + return plugin +} + +// Plugin 插件配置定义 +type Plugin struct { + Name string //插件名称 + Config PluginConfig //插件配置 + Version string //插件版本 + Modified map[string]any //修改过的配置项 } func init() { @@ -86,20 +112,20 @@ func init() { func Run(ctx context.Context, configFile string) (err error) { Ctx = ctx if err := util.CreateShutdownScript(); err != nil { - util.Print(Red("create shutdown script error:"), err) + log.Print(Red("create shutdown script error:"), err) } StartTime = time.Now() if ConfigRaw, err = ioutil.ReadFile(configFile); err != nil { - util.Print(Red("read config file error:"), err) + log.Print(Red("read config file error:"), err) return } settingDir = filepath.Join(filepath.Dir(configFile), ".m7s") if err = os.MkdirAll(settingDir, 0755); err != nil { - util.Print(Red("create dir .m7s error:"), err) + log.Print(Red("create dir .m7s error:"), err) return } util.Print(BgGreen(Black("Ⓜ starting m7s ")), BrightBlue(Version)) - var cg map[string]interface{} + var cg map[string]any if _, err = toml.Decode(string(ConfigRaw), &cg); err == nil { if cfg, ok := cg["Engine"]; ok { b, _ := json.Marshal(cfg) @@ -107,23 +133,13 @@ func Run(ctx context.Context, configFile string) (err error) { log.Println(err) } } - for name, config := range Plugins { - if cfg, ok := cg[name]; ok { - config.updateSettings(cfg.(map[string]interface{})) - b, _ := json.Marshal(cfg) - if err = json.Unmarshal(b, config.Config); err != nil { - log.Println(err) - continue - } - } else if config.Config != nil { - continue - } - if config.Run != nil { - go config.Run() - } + } + for name, config := range Plugins { + var cfg map[string]any + if v, ok := cg[name]; ok { + cfg = v.(map[string]any) } - } else { - util.Print(Red("decode config file error:"), err) + config.Update(cfg) } UUID := uuid.NewString() reportTimer := time.NewTimer(time.Minute) @@ -142,36 +158,54 @@ func Run(ctx context.Context, configFile string) (err error) { } } } -func objectAssign(target, source map[string]interface{}) { +func objectAssign(target, source map[string]any) { for k, v := range source { if _, ok := target[k]; !ok { target[k] = v } else { switch v := v.(type) { - case map[string]interface{}: - objectAssign(target[k].(map[string]interface{}), v) + case map[string]any: + objectAssign(target[k].(map[string]any), v) default: target[k] = v } } } } -func (opt *PluginConfig) updateSettings(cfg map[string]interface{}) { + +// Update 更新配置 +func (opt *Plugin) Update(cfg map[string]any) { if setting, err := ioutil.ReadFile(opt.settingPath()); err == nil { var cg map[string]interface{} if _, err = toml.Decode(string(setting), &cg); err == nil { - objectAssign(cfg, cg) + if cfg == nil { + cfg = cg + } else { + objectAssign(cfg, cg) + } } } + // TODO: map转成struct优化 + if cfg != nil { + b, _ := json.Marshal(cfg) + for k, v := range cfg { + opt.Modified[k] = v + } + if err := json.Unmarshal(b, opt.Config); err != nil { + log.Println(err) + } + } + go opt.Config.Update(cfg) } -func (opt *PluginConfig) settingPath() string { +func (opt *Plugin) settingPath() string { return filepath.Join(settingDir, opt.Name+".toml") } -func (opt *PluginConfig) Save() error { + +func (opt *Plugin) Save() error { file, err := os.OpenFile(opt.settingPath(), os.O_CREATE|os.O_WRONLY, 0644) if err == nil { defer file.Close() - err = toml.NewEncoder(file).Encode(opt.Config) + err = toml.NewEncoder(file).Encode(opt.Modified) } return err } diff --git a/publisher.go b/publisher.go index 6e5cba5..4aa468d 100644 --- a/publisher.go +++ b/publisher.go @@ -1,5 +1,40 @@ package engine -type Publisher interface { +import ( + "time" +) + +type IPublisher interface { + Close() // 流关闭时或者被踢时触发 OnStateChange(oldState StreamState, newState StreamState) bool } + +type Publisher struct { + Stream *Stream + Config PublishConfig +} + +func (pub *Publisher) Publish(streamPath string, realPub IPublisher) bool { + Streams.Lock() + defer Streams.Unlock() + s, created := findOrCreateStream(streamPath, time.Second) + if s.IsClosed() { + return false + } + if s.Publisher != nil && pub.Config.KillExit { + s.Publisher.Close() + } + pub.Stream = s + s.Publisher = realPub + if created { + s.PublishTimeout = pub.Config.PublishTimeout.Duration() + s.WaitCloseTimeout = pub.Config.WaitCloseTimeout.Duration() + go s.run() + } + s.actionChan <- PublishAction{} + return true +} + +func (pub *Publisher) OnStateChange(oldState StreamState, newState StreamState) bool { + return true +} diff --git a/stream.go b/stream.go index e12740c..b5b0158 100644 --- a/stream.go +++ b/stream.go @@ -2,6 +2,8 @@ package engine import ( "context" + "net/url" + "strings" "sync/atomic" "time" @@ -18,7 +20,8 @@ const ( STATE_WAITTRACK // 等待Track STATE_PUBLISHING // 正在发布流状态 STATE_WAITCLOSE // 等待关闭状态(自动关闭延时开启) - STATE_CLOSED + STATE_CLOSED // 流已关闭,不可使用 + STATE_DESTROYED // 资源已释放 ) const ( @@ -30,7 +33,7 @@ const ( ACTION_FIRSTENTER // 第一个订阅者进入 ) -var StreamFSM = [STATE_CLOSED + 1]map[StreamAction]StreamState{ +var StreamFSM = [STATE_DESTROYED + 1]map[StreamAction]StreamState{ { ACTION_PUBLISH: STATE_WAITTRACK, ACTION_LASTLEAVE: STATE_CLOSED, @@ -53,125 +56,174 @@ var StreamFSM = [STATE_CLOSED + 1]map[StreamAction]StreamState{ ACTION_FIRSTENTER: STATE_PUBLISHING, ACTION_CLOSE: STATE_CLOSED, }, - {}, + { + ACTION_TIMEOUT: STATE_DESTROYED, + }, } // Streams 所有的流集合 var Streams = util.Map[string, *Stream]{Map: make(map[string]*Stream)} -type SubscribeAction *Subscriber type UnSubscibeAction *Subscriber +type PublishAction struct{} +type UnPublishAction struct{} +type StreamTimeoutConfig struct { + WaitTimeout time.Duration + PublishTimeout time.Duration + WaitCloseTimeout time.Duration +} // Stream 流定义 type Stream struct { context.Context cancel context.CancelFunc - Publisher + StreamTimeoutConfig + *url.URL + Publisher IPublisher State StreamState timeout *time.Timer //当前状态的超时定时器 actionChan chan any - Config StreamConfig - URL string //远程地址,仅远程拉流有值 - StreamPath string + RemoteURL string //远程地址,仅远程拉流有值 StartTime time.Time //流的创建时间 Subscribers util.Slice[*Subscriber] // 订阅者 Tracks FrameCount uint32 //帧总数 + AppName string + StreamName string } -func (r *Stream) Register(streamPath string) (result bool) { - if r == nil { - r = &Stream{ - Config: config.StreamConfig, - } +func (s *Stream) UnPublish() { + if !s.IsClosed() { + s.actionChan <- UnPublishAction{} } - r.StreamPath = streamPath - if result = Streams.Add(streamPath, r); result { - r.actionChan = make(chan any, 1) - r.StartTime = time.Now() - r.timeout = time.NewTimer(r.Config.WaitTimeout.Duration()) - r.Context, r.cancel = context.WithCancel(Ctx) - r.Init(r) - go r.run() - } - return } -// ForceRegister 强制注册流,会将已有的流踢掉 -func (r *Stream) ForceRegister(streamPath string) { - if ok := r.Register(streamPath); !ok { - if s := Streams.Get(streamPath); s != nil { - s.Close() - <-s.Done() - } - r.ForceRegister(streamPath) +func findOrCreateStream(streamPath string, waitTimeout time.Duration) (s *Stream, created bool) { + streamPath = strings.Trim(streamPath, "/") + u, err := url.Parse(streamPath) + if err != nil { + return nil, false + } + p := strings.Split(u.Path, "/") + if len(p) < 2 { + util.Println(Red("Stream Path Format Error:"), streamPath) + return nil, false + } + if s, ok := Streams.Map[u.Path]; ok { + util.Println(Green("Stream Found:"), u.Path) + return s, false } else { - return + util.Println(Green("Stream Created:"), u.Path) + p := strings.Split(u.Path, "/") + s = &Stream{ + URL: u, + AppName: p[0], + StreamName: p[len(p)-1], + } + s.WaitTimeout = waitTimeout + Streams.Map[u.Path] = s + s.actionChan = make(chan any, 1) + s.StartTime = time.Now() + s.timeout = time.NewTimer(waitTimeout) + s.Context, s.cancel = context.WithCancel(Ctx) + s.Init(s) + return s, true } } -func (r *Stream) action(action StreamAction) { +func (r *Stream) action(action StreamAction) bool { if next, ok := StreamFSM[r.State][action]; ok { - if r.Publisher == nil || r.OnStateChange(r.State, next) { - util.Print(Yellow("Stream "), BrightCyan(r.StreamPath), " state changed :", r.State, "->", next) + if r.Publisher == nil || r.Publisher.OnStateChange(r.State, next) { + util.Print(Yellow("Stream "), BrightCyan(r.Path), action, " :", r.State, "->", next) r.State = next switch next { case STATE_WAITPUBLISH: - r.timeout.Reset(r.Config.WaitTimeout.Duration()) + r.Publisher = nil + Bus.Publish(Event_REQUEST_PUBLISH, r) + r.timeout.Reset(r.WaitTimeout) case STATE_WAITTRACK: r.timeout.Reset(time.Second * 5) case STATE_PUBLISHING: r.WaitDone() - r.timeout.Reset(r.Config.PublishTimeout.Duration()) + r.timeout.Reset(r.PublishTimeout) + Bus.Publish(Event_PUBLISH, r) case STATE_WAITCLOSE: - r.timeout.Reset(r.Config.WaitCloseTimeout.Duration()) + r.timeout.Reset(r.WaitCloseTimeout) case STATE_CLOSED: r.cancel() + if r.Publisher != nil { + r.Publisher.Close() + } r.WaitDone() + Bus.Publish(Event_STREAMCLOSE, r) + Streams.Delete(r.Path) + r.timeout.Reset(time.Second) // 延迟1秒钟销毁,防止访问到已关闭的channel + case STATE_DESTROYED: close(r.actionChan) - Streams.Delete(r.StreamPath) fallthrough default: r.timeout.Stop() } } + return true } + return false +} +func (r *Stream) IsClosed() bool { + if r == nil { + return true + } + return r.State == STATE_CLOSED } func (r *Stream) Close() { - r.actionChan <- ACTION_CLOSE + if !r.IsClosed() { + r.actionChan <- ACTION_CLOSE + } } + func (r *Stream) UnSubscribe(sub *Subscriber) { - r.actionChan <- UnSubscibeAction(sub) + if !r.IsClosed() { + r.actionChan <- UnSubscibeAction(sub) + } } func (r *Stream) Subscribe(sub *Subscriber) { - r.actionChan <- SubscribeAction(sub) + if !r.IsClosed() { + sub.Stream = r + sub.Context, sub.cancel = context.WithCancel(r) + r.actionChan <- sub + } } func (r *Stream) run() { for { select { case <-r.timeout.C: - util.Print(Yellow("Stream "), BrightCyan(r.StreamPath), "timeout:", r.State) + util.Print(Yellow("Stream "), BrightCyan(r.Path), " timeout:", r.State) r.action(ACTION_TIMEOUT) case <-r.Done(): r.action(ACTION_CLOSE) + case action, ok := <-r.actionChan: if ok { switch v := action.(type) { + case PublishAction: + r.action(ACTION_PUBLISH) + case UnPublishAction: + r.action(ACTION_PUBLISHLOST) case StreamAction: r.action(v) - case SubscribeAction: - v.Stream = r - v.Context, v.cancel = context.WithCancel(r) + case *Subscriber: r.Subscribers.Add(v) - util.Print(Sprintf(Yellow("%s subscriber %s added remains:%d"), BrightCyan(r.StreamPath), Cyan(v.ID), Blue(len(r.Subscribers)))) + Bus.Publish(Event_SUBSCRIBE, v) + util.Print(Sprintf(Yellow("%s subscriber %s added remains:%d"), BrightCyan(r.Path), Cyan(v.ID), Blue(len(r.Subscribers)))) if r.Subscribers.Len() == 1 { r.action(ACTION_FIRSTENTER) } case UnSubscibeAction: if r.Subscribers.Delete(v) { - util.Print(Sprintf(Yellow("%s subscriber %s removed remains:%d"), BrightCyan(r.StreamPath), Cyan(v.ID), Blue(len(r.Subscribers)))) - if r.Subscribers.Len() == 0 && r.Config.WaitCloseTimeout > 0 { + Bus.Publish(Event_UNSUBSCRIBE, v) + util.Print(Sprintf(Yellow("%s subscriber %s removed remains:%d"), BrightCyan(r.Path), Cyan(v.ID), Blue(len(r.Subscribers)))) + if r.Subscribers.Len() == 0 && r.WaitCloseTimeout > 0 { r.action(ACTION_LASTLEAVE) } } @@ -186,7 +238,7 @@ func (r *Stream) run() { // Update 更新数据重置超时定时器 func (r *Stream) Update() uint32 { if r.State == STATE_PUBLISHING { - r.timeout.Reset(r.Config.PublishTimeout.Duration()) + r.timeout.Reset(r.PublishTimeout) } return atomic.AddUint32(&r.FrameCount, 1) } @@ -198,7 +250,12 @@ func (r *Stream) NewVideoTrack() (vt *track.UnknowVideo) { } return } - +func (r *Stream) NewAudioTrack() (at *track.UnknowAudio) { + at = &track.UnknowAudio{ + Stream: r, + } + return +} func (r *Stream) NewH264Track() (vt *track.H264) { return track.NewH264(r) } @@ -211,19 +268,3 @@ func (r *Stream) NewH265Track() (vt *track.H265) { // t := <-r.WaitTrack(names...) // return t.(DataTrack) // } - -func (r *Stream) WaitVideoTrack(names ...string) track.Video { - if !r.Config.EnableVideo { - return nil - } - t := <-r.WaitTrack(names...) - return t.(track.Video) -} - -func (r *Stream) WaitAudioTrack(names ...string) track.Audio { - if !r.Config.EnableAudio { - return nil - } - t := <-r.WaitTrack(names...) - return t.(track.Audio) -} diff --git a/subscriber.go b/subscriber.go index d99962f..18cba14 100644 --- a/subscriber.go +++ b/subscriber.go @@ -3,12 +3,10 @@ package engine import ( "context" "net/url" - "sync" "time" . "github.com/Monibuca/engine/v4/common" "github.com/Monibuca/engine/v4/track" - "github.com/pkg/errors" ) type AudioFrame AVFrame[AudioSlice] @@ -18,7 +16,8 @@ type VideoFrame AVFrame[NALUSlice] type Subscriber struct { context.Context `json:"-"` cancel context.CancelFunc - *Stream `json:"-"` + Config SubscribeConfig + Stream *Stream `json:"-"` ID string TotalDrop int //总丢帧 TotalPacket int @@ -29,44 +28,36 @@ type Subscriber struct { SubscribeArgs url.Values OnAudio func(*AudioFrame) bool `json:"-"` OnVideo func(*VideoFrame) bool `json:"-"` - closeOnce sync.Once } -func (s *Subscriber) close() { - if s.Stream != nil { - s.UnSubscribe(s) - } +// Close 关闭订阅者 +func (s *Subscriber) Close() { + s.Stream.UnSubscribe(s) if s.cancel != nil { s.cancel() } } -// Close 关闭订阅者 -func (s *Subscriber) Close() { - s.closeOnce.Do(s.close) -} - //Subscribe 开始订阅 将Subscriber与Stream关联 -func (s *Subscriber) Subscribe(streamPath string) error { - if u, err := url.Parse(streamPath); err != nil { - return err - } else if s.SubscribeArgs, err = url.ParseQuery(u.RawQuery); err != nil { - return err - } else { - streamPath = u.Path +func (sub *Subscriber) Subscribe(streamPath string, config SubscribeConfig) bool { + Streams.Lock() + defer Streams.Unlock() + s, created := findOrCreateStream(streamPath, config.WaitTimeout.Duration()) + if s.IsClosed() { + return false } - if stream := Streams.Get(streamPath); stream == nil { - return errors.Errorf("subscribe %s faild :stream not found", streamPath) - } else { - if stream.Subscribe(s); s.Context == nil { - return errors.Errorf("subscribe %s faild :stream closed", streamPath) - } + if created { + Bus.Publish(Event_REQUEST_PUBLISH, s) + go s.run() } - return nil + if s.Subscribe(sub); sub.Stream != nil { + sub.Config = config + } + return true } //Play 开始播放 -func (s *Subscriber) Play(at track.Audio, vt track.Video) { +func (s *Subscriber) Play(at *track.Audio, vt *track.Video) { defer s.Close() if vt == nil && at == nil { return @@ -83,14 +74,18 @@ func (s *Subscriber) Play(at track.Audio, vt track.Video) { vp := vr.Read() ap := ar.TryRead() // chase := true - for { + for s.Err() == nil { if ap == nil && vp == nil { time.Sleep(time.Millisecond * 10) } else if ap != nil && (vp == nil || vp.SeqInStream > ap.SeqInStream) { - s.onAudio(ap) + if !s.onAudio(ap) { + return + } ar.MoveNext() } else if vp != nil && (ap == nil || ap.SeqInStream > vp.SeqInStream) { - s.onVideo(vp) + if !s.onVideo(vp) { + return + } // if chase { // if add10 := vst.Add(time.Millisecond * 10); realSt.After(add10) { // vst = add10 @@ -111,9 +106,42 @@ func (s *Subscriber) onAudio(af *AVFrame[AudioSlice]) bool { func (s *Subscriber) onVideo(vf *AVFrame[NALUSlice]) bool { return s.OnVideo((*VideoFrame)(vf)) } -func (s *Subscriber) PlayAudio(vt track.Audio) { +func (s *Subscriber) PlayAudio(vt *track.Audio) { vt.Play(s.onAudio) } -func (s *Subscriber) PlayVideo(vt track.Video) { +func (s *Subscriber) PlayVideo(vt *track.Video) { vt.Play(s.onVideo) } +func (r *Subscriber) WaitVideoTrack(names ...string) *track.Video { + if !r.Config.EnableVideo { + return nil + } + if t := <-r.Stream.WaitTrack(names...); t == nil { + return nil + } else { + switch vt := t.(type) { + case *track.H264: + return (*track.Video)(vt) + case *track.H265: + return (*track.Video)(vt) + } + return nil + } +} + +func (r *Subscriber) WaitAudioTrack(names ...string) *track.Audio { + if !r.Config.EnableAudio { + return nil + } + if t := <-r.Stream.WaitTrack(names...); t == nil { + return nil + } else { + switch at := t.(type) { + case *track.AAC: + return (*track.Audio)(at) + case *track.G711: + return (*track.Audio)(at) + } + return nil + } +} diff --git a/track/aac.go b/track/aac.go index b92f494..ddc5647 100644 --- a/track/aac.go +++ b/track/aac.go @@ -1,13 +1,15 @@ package track import ( + "time" + "github.com/Monibuca/engine/v4/codec" . "github.com/Monibuca/engine/v4/common" - "time" ) func NewAAC(stream IStream) (aac *AAC) { aac = &AAC{} + aac.Name = "aac" aac.Stream = stream aac.CodecID = codec.CodecID_AAC aac.Init(stream, 32) @@ -15,9 +17,7 @@ func NewAAC(stream IStream) (aac *AAC) { return } -type AAC struct { - BaseAudio -} +type AAC Audio func (aac *AAC) WriteAVCC(ts uint32, frame AVCCFrame) { if frame.IsSequence() { @@ -33,6 +33,6 @@ func (aac *AAC) WriteAVCC(ts uint32, frame AVCCFrame) { aac.SampleRate = HZ(codec.SamplingFrequencies[((config1&0x7)<<1)|(config2>>7)]) aac.DecoderConfiguration.AppendRaw(AudioSlice(frame[2:])) } else { - aac.BaseAudio.WriteAVCC(ts, frame) + (*Audio)(aac).WriteAVCC(ts, frame) } } diff --git a/track/audio.go b/track/audio.go index 7122fc6..dc87fd8 100644 --- a/track/audio.go +++ b/track/audio.go @@ -8,22 +8,25 @@ import ( "github.com/Monibuca/engine/v4/util" ) -type Audio interface { - AVTrack - ReadRing() *AVRing[AudioSlice] - Play(onAudio func(*AVFrame[AudioSlice]) bool) -} - -type BaseAudio struct { +type Audio struct { Media[AudioSlice] Channels byte avccHead []byte } -func (at *BaseAudio) ReadRing() *AVRing[AudioSlice] { +func (av *Audio) GetName() string { + if av.Name == "" { + return strings.ToLower(codec.SoundFormat[av.CodecID]) + } + return av.Name +} +func (at *Audio) GetInfo() *Audio { + return at +} +func (at *Audio) ReadRing() *AVRing[AudioSlice] { return util.Clone(at.AVRing) } -func (at *BaseAudio) Play(onAudio func(*AVFrame[AudioSlice]) bool) { +func (at *Audio) Play(onAudio func(*AVFrame[AudioSlice]) bool) { ar := at.ReadRing() for ap := ar.Read(); at.Stream.Err() == nil; ap = ar.Read() { if !onAudio(ap) { @@ -33,12 +36,12 @@ func (at *BaseAudio) Play(onAudio func(*AVFrame[AudioSlice]) bool) { } } -func (at *BaseAudio) WriteAVCC(ts uint32, frame AVCCFrame) { +func (at *Audio) WriteAVCC(ts uint32, frame AVCCFrame) { at.Media.WriteAVCC(ts, frame) at.Flush() } -func (at *BaseAudio) Flush() { +func (at *Audio) Flush() { if at.Value.AVCC == nil { at.Value.AppendAVCC(at.avccHead) for _, raw := range at.Value.Raw { @@ -51,7 +54,7 @@ func (at *BaseAudio) Flush() { type UnknowAudio struct { Name string Stream IStream - Know Audio + Know AVTrack } func (at *UnknowAudio) WriteAVCC(ts uint32, frame AVCCFrame) { @@ -69,7 +72,7 @@ func (at *UnknowAudio) WriteAVCC(ts uint32, frame AVCCFrame) { at.Know = a a.avccHead = []byte{frame[0], 1} a.WriteAVCC(0, frame) - a.Stream.AddTrack(a.Name, a) + a.Stream.AddTrack(a) case codec.CodecID_PCMA, codec.CodecID_PCMU: alaw := true @@ -81,7 +84,7 @@ func (at *UnknowAudio) WriteAVCC(ts uint32, frame AVCCFrame) { a.SampleRate = HZ(codec.SoundRate[(frame[0]&0x0c)>>2]) a.Channels = frame[0]&0x01 + 1 a.avccHead = frame[:1] - a.Stream.AddTrack(a.Name, a) + a.Stream.AddTrack(a) } } else { at.Know.WriteAVCC(ts, frame) diff --git a/track/base.go b/track/base.go index 1c154df..3522a69 100644 --- a/track/base.go +++ b/track/base.go @@ -13,6 +13,10 @@ type Base struct { BPS } +func (bt *Base) GetName() string { + return bt.Name +} + func (bt *Base) Flush(bf *BaseFrame) { bt.ComputeBPS(bf.BytesIn) bf.SeqInStream = bt.Stream.Update() @@ -26,6 +30,7 @@ type Media[T RawSlice] struct { SampleRate HZ DecoderConfiguration AVFrame[T] `json:"-"` //H264(SPS、PPS) H265(VPS、SPS、PPS) AAC(config) util.BytesPool //无锁内存池,用于发布者(在同一个协程中)复用小块的内存,通常是解包时需要临时使用 + lastAvccTS uint32 //上一个avcc帧的时间戳 } func (av *Media[T]) WriteRTP(raw []byte) { @@ -44,6 +49,11 @@ func (av *Media[T]) WriteSlice(slice T) { av.Value.AppendRaw(slice) } func (av *Media[T]) WriteAVCC(ts uint32, frame AVCCFrame) { + if av.lastAvccTS == 0 { + av.lastAvccTS = ts + } else { + av.Value.DeltaTime = ts - av.lastAvccTS + } av.Value.BytesIn = len(frame) av.Value.AppendAVCC(frame) av.Value.DTS = av.SampleRate.ToNTS(ts) diff --git a/track/g711.go b/track/g711.go index 7b88125..bfdc0d5 100644 --- a/track/g711.go +++ b/track/g711.go @@ -20,11 +20,9 @@ func NewG711(stream IStream, alaw bool) (g711 *G711) { return } -type G711 struct { - BaseAudio -} +type G711 Audio func (g711 *G711) WriteAVCC(ts uint32, frame AVCCFrame) { g711.Value.AppendRaw(AudioSlice(frame[1:])) - g711.BaseAudio.WriteAVCC(ts, frame) + (*Audio)(g711).WriteAVCC(ts, frame) } diff --git a/track/h264.go b/track/h264.go index a60d29d..da712a0 100644 --- a/track/h264.go +++ b/track/h264.go @@ -1,6 +1,7 @@ package track import ( + "net" "time" "github.com/Monibuca/engine/v4/codec" @@ -8,12 +9,11 @@ import ( "github.com/Monibuca/engine/v4/util" ) -type H264 struct { - H264H265 -} +type H264 Video func NewH264(stream IStream) (vt *H264) { vt = &H264{} + vt.Name = "h264" vt.CodecID = codec.CodecID_H264 vt.SampleRate = 90000 vt.Stream = stream @@ -21,7 +21,10 @@ func NewH264(stream IStream) (vt *H264) { vt.Poll = time.Millisecond * 20 return } - +func (vt *H264) WriteAnnexB(pts uint32, dts uint32, frame AnnexBFrame) { + (*Video)(vt).WriteAnnexB(pts, dts, frame) + vt.Flush() +} func (vt *H264) WriteSlice(slice NALUSlice) { switch H264Slice(slice).Type() { case codec.NALU_SPS: @@ -30,8 +33,8 @@ func (vt *H264) WriteSlice(slice NALUSlice) { case codec.NALU_PPS: vt.DecoderConfiguration.AppendRaw(slice) vt.SPSInfo, _ = codec.ParseSPS(slice[0]) - lenSPS := SizeOfBuffers(vt.DecoderConfiguration.Raw[0]) - lenPPS := SizeOfBuffers(vt.DecoderConfiguration.Raw[1]) + lenSPS := util.SizeOfBuffers(net.Buffers(vt.DecoderConfiguration.Raw[0])) + lenPPS := util.SizeOfBuffers(net.Buffers(vt.DecoderConfiguration.Raw[1])) if lenSPS > 3 { vt.DecoderConfiguration.AppendAVCC(codec.RTMP_AVC_HEAD[:6], vt.DecoderConfiguration.Raw[0][0][1:4]) } else { @@ -40,8 +43,10 @@ func (vt *H264) WriteSlice(slice NALUSlice) { tmp := []byte{0xE1, 0, 0, 0x01, 0, 0} vt.DecoderConfiguration.AppendAVCC(tmp[:1], util.PutBE(tmp[1:3], lenSPS), vt.DecoderConfiguration.Raw[0][0], tmp[3:4], util.PutBE(tmp[3:6], lenPPS), vt.DecoderConfiguration.Raw[1][0]) case codec.NALU_IDR_Picture: - case codec.NALU_Non_IDR_Picture: - case codec.NALU_SEI: + vt.Value.IFrame = true + fallthrough + case codec.NALU_Non_IDR_Picture, + codec.NALU_SEI: vt.Media.WriteSlice(slice) } } @@ -58,21 +63,22 @@ func (vt *H264) WriteAVCC(ts uint32, frame AVCCFrame) { vt.DecoderConfiguration.AppendRaw(NALUSlice{info.SequenceParameterSetNALUnit}, NALUSlice{info.PictureParameterSetNALUnit}) } } else { - vt.H264H265.WriteAVCC(ts, frame) + (*Video)(vt).WriteAVCC(ts, frame) + vt.Value.IFrame = frame.IsIDR() + vt.Flush() } } func (vt *H264) Flush() { - if H264NALU(vt.Value.Raw).IFrame() { - vt.Value.IFrame = true + if vt.Value.IFrame { if vt.IDRing == nil { - defer vt.Stream.AddTrack(vt.Name, vt) + defer vt.Stream.AddTrack(vt) } - vt.ComputeGOP() + (*Video)(vt).ComputeGOP() } // RTP格式补完 if vt.Value.RTP == nil { } - vt.H264H265.Flush() + (*Video)(vt).Flush() } diff --git a/track/h265.go b/track/h265.go index 28dab6a..903697e 100644 --- a/track/h265.go +++ b/track/h265.go @@ -7,12 +7,11 @@ import ( . "github.com/Monibuca/engine/v4/common" ) -type H265 struct { - H264H265 -} +type H265 Video func NewH265(stream IStream) (vt *H265) { vt = &H265{} + vt.Name = "h265" vt.CodecID = codec.CodecID_H265 vt.SampleRate = 90000 vt.Stream = stream @@ -20,6 +19,10 @@ func NewH265(stream IStream) (vt *H265) { vt.Poll = time.Millisecond * 20 return } +func (vt *H265) WriteAnnexB(pts uint32, dts uint32, frame AnnexBFrame) { + (*Video)(vt).WriteAnnexB(pts, dts, frame) + vt.Flush() +} func (vt *H265) WriteSlice(slice NALUSlice) { switch H265Slice(slice).Type() { case codec.NAL_UNIT_VPS: @@ -34,13 +37,16 @@ func (vt *H265) WriteSlice(slice NALUSlice) { if err == nil { vt.DecoderConfiguration.AppendAVCC(extraData) } - case 0, 1, 2, 3, 4, 5, 6, 7, 9, + case codec.NAL_UNIT_CODED_SLICE_BLA, codec.NAL_UNIT_CODED_SLICE_BLANT, codec.NAL_UNIT_CODED_SLICE_BLA_N_LP, codec.NAL_UNIT_CODED_SLICE_IDR, codec.NAL_UNIT_CODED_SLICE_IDR_N_LP, codec.NAL_UNIT_CODED_SLICE_CRA: + vt.Value.IFrame = true + fallthrough + case 0, 1, 2, 3, 4, 5, 6, 7, 9: vt.Media.WriteSlice(slice) } } @@ -55,21 +61,22 @@ func (vt *H265) WriteAVCC(ts uint32, frame AVCCFrame) { vt.DecoderConfiguration.AppendRaw(NALUSlice{vps}, NALUSlice{sps}, NALUSlice{pps}) } } else { - vt.H264H265.WriteAVCC(ts, frame) + (*Video)(vt).WriteAVCC(ts, frame) + vt.Value.IFrame = frame.IsIDR() + vt.Flush() } } func (vt *H265) Flush() { - if H265NALU(vt.Value.Raw).IFrame() { - vt.Value.IFrame = true + if vt.Value.IFrame { if vt.IDRing == nil { - defer vt.Stream.AddTrack(vt.Name, vt) + defer vt.Stream.AddTrack(vt) } - vt.ComputeGOP() + (*Video)(vt).ComputeGOP() } // RTP格式补完 if vt.Value.RTP == nil { } - vt.H264H265.Flush() + (*Video)(vt).Flush() } diff --git a/track/video.go b/track/video.go index 13e0814..392b70f 100644 --- a/track/video.go +++ b/track/video.go @@ -1,6 +1,8 @@ package track import ( + "bytes" + "net" "strings" "github.com/Monibuca/engine/v4/codec" @@ -8,13 +10,7 @@ import ( "github.com/Monibuca/engine/v4/util" ) -type Video interface { - AVTrack - ReadRing() *AVRing[NALUSlice] - Play(onVideo func(*AVFrame[NALUSlice]) bool) -} - -type H264H265 struct { +type Video struct { Media[NALUSlice] IDRing *util.Ring[AVFrame[NALUSlice]] `json:"-"` //最近的关键帧位置,首屏渲染 SPSInfo codec.SPSInfo @@ -23,7 +19,14 @@ type H264H265 struct { idrCount int //缓存中包含的idr数量 } -func (t *H264H265) ComputeGOP() { +func (t *Video) GetName() string { + if t.Name == "" { + return strings.ToLower(codec.CodecID[t.CodecID]) + } + return t.Name +} + +func (t *Video) ComputeGOP() { t.idrCount++ if t.IDRing != nil { t.GOP = int(t.Value.SeqInTrack - t.IDRing.Value.SeqInTrack) @@ -41,7 +44,38 @@ func (t *H264H265) ComputeGOP() { t.IDRing = t.Ring } -func (vt *H264H265) WriteAVCC(ts uint32, frame AVCCFrame) { +func (vt *Video) writeAnnexBSlice(annexb AnnexBFrame) { + for len(annexb) > 0 { + before, after, found := bytes.Cut(annexb, codec.NALU_Delimiter1) + if !found { + vt.WriteSlice(NALUSlice{annexb}) + return + } + if len(before) > 0 { + vt.WriteSlice(NALUSlice{before}) + } + annexb = after + } +} + +func (vt *Video) WriteAnnexB(pts uint32, dts uint32, frame AnnexBFrame) { + for len(frame) > 0 { + before, after, found := bytes.Cut(frame, codec.NALU_Delimiter2) + if !found { + vt.writeAnnexBSlice(frame) + if len(vt.Value.Raw) > 0 { + vt.Value.PTS = pts + vt.Value.DTS = dts + } + return + } + if len(before) > 0 { + vt.writeAnnexBSlice(AnnexBFrame(before)) + } + frame = after + } +} +func (vt *Video) WriteAVCC(ts uint32, frame AVCCFrame) { vt.Media.WriteAVCC(ts, frame) for nalus := frame[5:]; len(nalus) > vt.nalulenSize; { nalulen := util.ReadBE[int](nalus[:vt.nalulenSize]) @@ -53,10 +87,9 @@ func (vt *H264H265) WriteAVCC(ts uint32, frame AVCCFrame) { break } } - vt.Flush() } -func (vt *H264H265) Flush() { +func (vt *Video) Flush() { // AVCC格式补完 if vt.Value.AVCC == nil { b := []byte{vt.CodecID, 1, 0, 0, 0} @@ -69,7 +102,7 @@ func (vt *H264H265) Flush() { util.PutBE(b[2:5], vt.SampleRate.ToMini(vt.Value.PTS-vt.Value.DTS)) vt.Value.AppendAVCC(b) for _, nalu := range vt.Value.Raw { - vt.Value.AppendAVCC(util.PutBE(make([]byte, 4), SizeOfBuffers(nalu))) + vt.Value.AppendAVCC(util.PutBE(make([]byte, 4), util.SizeOfBuffers(net.Buffers(nalu)))) vt.Value.AppendAVCC(nalu...) } } @@ -86,12 +119,12 @@ func (vt *H264H265) Flush() { } vt.Media.Flush() } -func (vt *H264H265) ReadRing() *AVRing[NALUSlice] { +func (vt *Video) ReadRing() *AVRing[NALUSlice] { vr := util.Clone(vt.AVRing) vr.Ring = vt.IDRing return vr } -func (vt *H264H265) Play(onVideo func(*AVFrame[NALUSlice]) bool) { +func (vt *Video) Play(onVideo func(*AVFrame[NALUSlice]) bool) { vr := vt.ReadRing() for vp := vr.Read(); vt.Stream.Err() == nil; vp = vr.Read() { if !onVideo(vp) { @@ -104,9 +137,24 @@ func (vt *H264H265) Play(onVideo func(*AVFrame[NALUSlice]) bool) { type UnknowVideo struct { Name string Stream IStream - Know Video + Know AVTrack } +/* +Access Unit的首个nalu是4字节起始码。 +这里举个例子说明,用JM可以生成这样一段码流(不要使用JM8.6,它在这部分与标准不符),这个码流可以见本楼附件: + SPS (4字节头) + PPS (4字节头) + SEI (4字节头) + I0(slice0) (4字节头) + I0(slice1) (3字节头) + P1(slice0) (4字节头) + P1(slice1) (3字节头) + P2(slice0) (4字节头) + P2(slice1) (3字节头) +I0(slice0)是序列第一帧(I帧)的第一个slice,是当前Access Unit的首个nalu,所以是4字节头。而I0(slice1)表示第一帧的第二个slice,所以是3字节头。P1(slice0) 、P1(slice1)同理。 + +*/ func (vt *UnknowVideo) WriteAnnexB(pts uint32, dts uint32, frame AnnexBFrame) { } @@ -123,12 +171,10 @@ func (vt *UnknowVideo) WriteAVCC(ts uint32, frame AVCCFrame) { v := NewH264(vt.Stream) vt.Know = v v.WriteAVCC(0, frame) - v.Stream.AddTrack(v.Name, v) case codec.CodecID_H265: v := NewH265(vt.Stream) vt.Know = v v.WriteAVCC(0, frame) - v.Stream.AddTrack(v.Name, v) } } } else { diff --git a/tracks.go b/tracks.go index 25c9bbe..adfdba2 100644 --- a/tracks.go +++ b/tracks.go @@ -6,6 +6,7 @@ import ( "sync" . "github.com/Monibuca/engine/v4/common" + "github.com/Monibuca/engine/v4/util" ) type Tracks struct { @@ -27,16 +28,18 @@ func (ts *Tracks) Init(ctx context.Context) { ts.Context = ctx } -func (ts *Tracks) AddTrack(name string, t Track) { +func (ts *Tracks) AddTrack(t Track) { ts.Lock() defer ts.Unlock() + name := t.GetName() if _, ok := ts.m[name]; !ok { + util.Println("Track", name, "added") if ts.m[name] = t; ts.Err() == nil { - for i, ch := range ts.waiters[name] { - if ch != nil { + for _, ch := range ts.waiters[name] { + if *ch != nil { *ch <- t close(*ch) - ts.waiters[name][i] = nil //通过设置为nil,防止重复通知 + *ch = nil //通过设置为nil,防止重复通知 } } } @@ -54,10 +57,10 @@ func (ts *Tracks) WaitDone() { ts.Lock() defer ts.Unlock() for _, chs := range ts.waiters { - for i, ch := range chs { - if ch != nil { + for _, ch := range chs { + if *ch != nil { close(*ch) - chs[i] = nil //通过设置为nil,防止重复关闭 + *ch = nil //通过设置为nil,防止重复关闭 } } } diff --git a/util/big_endian.go b/util/big_endian.go index 598dc3d..77ca3eb 100644 --- a/util/big_endian.go +++ b/util/big_endian.go @@ -10,8 +10,17 @@ func PutBE[T constraints.Integer](b []byte, num T) []byte { } func ReadBE[T constraints.Integer](b []byte) (num T) { + num = 0 for i, n := 0, len(b); i < n; i++ { num += T(b[i]) << ((n - i - 1) << 3) } return } + +func GetBE[T constraints.Integer](b []byte, num *T) T { + *num = 0 + for i, n := 0, len(b); i < n; i++ { + *num += T(b[i]) << ((n - i - 1) << 3) + } + return *num +} diff --git a/util/buffer.go b/util/buffer.go index d82f225..b9a698e 100644 --- a/util/buffer.go +++ b/util/buffer.go @@ -1,7 +1,55 @@ package util +import ( + "encoding/binary" + "math" + "net" +) + type Buffer []byte +func (b *Buffer) ReadN(n int) Buffer { + l := b.Len() + r := (*b)[:n] + *b = (*b)[n:l] + return r +} +func (b *Buffer) ReadFloat64() float64 { + return math.Float64frombits(b.ReadUint64()) +} +func (b *Buffer) ReadUint64() uint64 { + return binary.BigEndian.Uint64(b.ReadN(8)) +} +func (b *Buffer) ReadUint32() uint32 { + return binary.BigEndian.Uint32(b.ReadN(4)) +} +func (b *Buffer) ReadUint24() uint32 { + return ReadBE[uint32](b.ReadN(3)) +} +func (b *Buffer) ReadUint16() uint16 { + return binary.BigEndian.Uint16(b.ReadN(2)) +} +func (b *Buffer) ReadByte() byte { + return b.ReadN(1)[0] +} +func (b *Buffer) WriteFloat64(v float64) { + PutBE(b.Malloc(8), math.Float64bits(v)) +} +func (b *Buffer) WriteUint32(v uint32) { + binary.BigEndian.PutUint32(b.Malloc(4), v) +} +func (b *Buffer) WriteUint24(v uint32) { + PutBE(b.Malloc(3), v) +} +func (b *Buffer) WriteUint16(v uint16) { + binary.BigEndian.PutUint16(b.Malloc(2), v) +} +func (b *Buffer) WriteUint8(v byte) { + b.Malloc(1)[0] = v +} +func (b *Buffer) WriteString(a string) { + *b = append(*b, a...) +} func (b *Buffer) Write(a []byte) (n int, err error) { *b = append(*b, a...) return len(a), nil @@ -18,13 +66,59 @@ func (b Buffer) SubBuf(start int, length int) Buffer { func (b *Buffer) Malloc(count int) Buffer { l := b.Len() - if l+count > b.Cap() { - n := make(Buffer, l+count) + newL := l + count + if newL > b.Cap() { + n := make(Buffer, newL) copy(n, *b) *b = n + } else { + *b = b.SubBuf(0, newL) } return b.SubBuf(l, count) } func (b *Buffer) Reset() { *b = b.SubBuf(0, 0) } +func (b *Buffer) Glow(n int) { + l := b.Len() + b.Malloc(n) + *b = b.SubBuf(0, l) +} + +// SizeOfBuffers 计算Buffers的内容长度 +func SizeOfBuffers(buf net.Buffers) (size int) { + for _, b := range buf { + size += len(b) + } + return +} +func CutBuffers(buf net.Buffers, size int) { + +} +// SplitBuffers 按照一定大小分割 Buffers +func SplitBuffers(buf net.Buffers, size int) (result []net.Buffers) { + for total := SizeOfBuffers(buf); total > 0; { + if total <= size { + return append(result, buf) + } else { + var before net.Buffers + sizeOfBefore := 0 + for _, b := range buf { + need := size - sizeOfBefore + if lenOfB := len(b); lenOfB > need { + before = append(before, b[:need]) + result = append(result, before) + total -= need + buf[0] = b[need:] + break + } else { + sizeOfBefore += lenOfB + before = append(before, b) + total -= lenOfB + buf = buf[1:] + } + } + } + } + return +} diff --git a/util/index.go b/util/index.go index 9e461db..1ae918e 100644 --- a/util/index.go +++ b/util/index.go @@ -1,5 +1,98 @@ package util +import ( + "constraints" + "os" + "path/filepath" + "runtime" +) + func Clone[T any](x T) *T { return &x +} + +func CurrentDir(path ...string) string { + if _, currentFilePath, _, _ := runtime.Caller(1); len(path) == 0 { + return filepath.Dir(currentFilePath) + } else { + return filepath.Join(filepath.Dir(currentFilePath), filepath.Join(path...)) + } +} + +// 检查文件或目录是否存在 +// 如果由 filename 指定的文件或目录存在则返回 true,否则返回 false +func Exist(filename string) bool { + _, err := os.Stat(filename) + return err == nil || os.IsExist(err) +} + +func ConvertNum[F constraints.Integer, T constraints.Integer](from F, to T) T { + return T(from) +} + +func ToFloat64(num any) float64 { + switch v := num.(type) { + case uint: + return float64(v) + case int: + return float64(v) + case uint8: + return float64(v) + case uint16: + return float64(v) + case uint32: + return float64(v) + case uint64: + return float64(v) + case int8: + return float64(v) + case int16: + return float64(v) + case int32: + return float64(v) + case int64: + return float64(v) + case float64: + return v + case float32: + return float64(v) + } + return 0 +} +func GetPtsDts(v uint64) uint64 { + // 4 + 3 + 1 + 15 + 1 + 15 + 1 + // 0011 + // 0010 + PTS[30-32] + marker_bit + PTS[29-15] + marker_bit + PTS[14-0] + marker_bit + pts1 := ((v >> 33) & 0x7) << 30 + pts2 := ((v >> 17) & 0x7fff) << 15 + pts3 := ((v >> 1) & 0x7fff) + + return pts1 | pts2 | pts3 +} + +func PutPtsDts(v uint64) uint64 { + // 4 + 3 + 1 + 15 + 1 + 15 + 1 + // 0011 + // 0010 + PTS[30-32] + marker_bit + PTS[29-15] + marker_bit + PTS[14-0] + marker_bit + // 0x100010001 + // 0001 0000 0000 0000 0001 0000 0000 0000 0001 + // 3个 market_it + pts1 := (v >> 30) & 0x7 << 33 + pts2 := (v >> 15) & 0x7fff << 17 + pts3 := (v & 0x7fff) << 1 + + return pts1 | pts2 | pts3 | 0x100010001 +} + +func GetPCR(v uint64) uint64 { + // program_clock_reference_base(33) + Reserved(6) + program_clock_reference_extension(9) + base := v >> 15 + ext := v & 0x1ff + return base*300 + ext +} + +func PutPCR(pcr uint64) uint64 { + base := pcr / 300 + ext := pcr % 300 + return base<<15 | 0x3f<<9 | ext } \ No newline at end of file diff --git a/util/socket.go b/util/socket.go index 8ec1dcf..808b25a 100644 --- a/util/socket.go +++ b/util/socket.go @@ -1,15 +1,20 @@ package util import ( + "context" "log" "net" "net/http" - "os" "time" "golang.org/x/sync/errgroup" ) +type TCPListener interface { + context.Context + Process(*net.TCPConn) +} + // ListenAddrs Listen http and https func ListenAddrs(addr, addTLS, cert, key string, handler http.Handler) { var g errgroup.Group @@ -26,15 +31,18 @@ func ListenAddrs(addr, addTLS, cert, key string, handler http.Handler) { } } -func ListenTCP(addr string, process func(net.Conn)) error { - listener, err := net.Listen("tcp", addr) +func ListenTCP(addr string, process TCPListener) error { + l, err := net.Listen("tcp", addr) if err != nil { return err } + go func() { + <-process.Done() + l.Close() + }() var tempDelay time.Duration for { - conn, err := listener.Accept() - conn.(*net.TCPConn).SetNoDelay(false) + conn, err := l.Accept() if err != nil { if ne, ok := err.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { @@ -51,8 +59,9 @@ func ListenTCP(addr string, process func(net.Conn)) error { } return err } + conn.(*net.TCPConn).SetNoDelay(false) tempDelay = 0 - go process(conn) + go process.Process(conn.(*net.TCPConn)) } } @@ -83,10 +92,3 @@ func CORS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", origin[0]) } } - -// 检查文件或目录是否存在 -// 如果由 filename 指定的文件或目录存在则返回 true,否则返回 false -func Exist(filename string) bool { - _, err := os.Stat(filename) - return err == nil || os.IsExist(err) -}