diff --git a/api.go b/api.go index c10d302..6d70893 100644 --- a/api.go +++ b/api.go @@ -374,12 +374,12 @@ func (s *Server) api_VideoTrack_SSE(rw http.ResponseWriter, r *http.Request) { } suber, err := s.SubscribeWithConfig(r.Context(), streamPath, config.Subscribe{ SubVideo: true, + SubType: SubscribeTypeAPI, }) if err != nil { http.Error(rw, err.Error(), http.StatusBadRequest) return } - suber.Type = SubscribeTypeAPI sse := util.NewSSE(rw, r.Context()) PlayBlock(suber, (func(frame *pkg.AVFrame) (err error))(nil), func(frame *pkg.AVFrame) (err error) { var snap pb.TrackSnapShot @@ -410,12 +410,12 @@ func (s *Server) api_AudioTrack_SSE(rw http.ResponseWriter, r *http.Request) { } suber, err := s.SubscribeWithConfig(r.Context(), streamPath, config.Subscribe{ SubAudio: true, + SubType: SubscribeTypeAPI, }) if err != nil { http.Error(rw, err.Error(), http.StatusBadRequest) return } - suber.Type = SubscribeTypeAPI sse := util.NewSSE(rw, r.Context()) PlayBlock(suber, func(frame *pkg.AVFrame) (err error) { var snap pb.TrackSnapShot diff --git a/pkg/config/types.go b/pkg/config/types.go index 39b6681..75e3b52 100755 --- a/pkg/config/types.go +++ b/pkg/config/types.go @@ -54,6 +54,7 @@ type ( WaitTimeout time.Duration `default:"10s" desc:"等待流超时时间"` // 等待流超时 WriteBufferSize int `desc:"写缓冲大小"` // 写缓冲大小 Key string `desc:"订阅鉴权key"` // 订阅鉴权key + SubType string `desc:"订阅类型"` // 订阅类型 } HTTPValus map[string][]string Pull struct { diff --git a/plugin.go b/plugin.go index 70774e8..1d04d64 100644 --- a/plugin.go +++ b/plugin.go @@ -581,20 +581,22 @@ func (p *Plugin) Publish(ctx context.Context, streamPath string) (publisher *Pub func (p *Plugin) SubscribeWithConfig(ctx context.Context, streamPath string, conf config.Subscribe) (subscriber *Subscriber, err error) { subscriber = createSubscriber(p, streamPath, conf) - if p.config.EnableAuth { - onAuthSub := p.Meta.OnAuthSub - if onAuthSub == nil { - onAuthSub = p.Server.Meta.OnAuthSub - } - if onAuthSub != nil { - if err = onAuthSub(subscriber).Await(); err != nil { - p.Warn("auth failed", "error", err) - return + if subscriber.Type == SubscribeTypeServer { + if p.config.EnableAuth { + onAuthSub := p.Meta.OnAuthSub + if onAuthSub == nil { + onAuthSub = p.Server.Meta.OnAuthSub } - } else if conf.Key != "" { - if err = p.auth(subscriber.StreamPath, conf.Key, subscriber.Args.Get("secret"), subscriber.Args.Get("expire")); err != nil { - p.Warn("auth failed", "error", err) - return + if onAuthSub != nil { + if err = onAuthSub(subscriber).Await(); err != nil { + p.Warn("auth failed", "error", err) + return + } + } else if conf.Key != "" { + if err = p.auth(subscriber.StreamPath, conf.Key, subscriber.Args.Get("secret"), subscriber.Args.Get("expire")); err != nil { + p.Warn("auth failed", "error", err) + return + } } } } diff --git a/pusher.go b/pusher.go index 8522e83..ec1b4c6 100644 --- a/pusher.go +++ b/pusher.go @@ -44,13 +44,12 @@ func (p *PushJob) Init(pusher IPusher, plugin *Plugin, streamPath string, conf c func (p *PushJob) Subscribe() (err error) { if p.SubConf != nil { + p.SubConf.SubType = SubscribeTypePush p.Subscriber, err = p.Plugin.SubscribeWithConfig(p.pusher.GetTask().Context, p.StreamPath, *p.SubConf) } else { + p.SubConf = &config.Subscribe{SubType: SubscribeTypePush} p.Subscriber, err = p.Plugin.Subscribe(p.pusher.GetTask().Context, p.StreamPath) } - if p.Subscriber != nil { - p.Subscriber.Type = SubscribeTypePush - } return } diff --git a/recoder.go b/recoder.go index 883398a..7a14295 100644 --- a/recoder.go +++ b/recoder.go @@ -79,12 +79,11 @@ func (p *RecordJob) GetKey() string { func (p *RecordJob) Subscribe() (err error) { if p.SubConf != nil { + p.SubConf.SubType = SubscribeTypeVod p.Subscriber, err = p.Plugin.SubscribeWithConfig(p.recorder.GetTask().Context, p.StreamPath, *p.SubConf) } else { - p.Subscriber, err = p.Plugin.Subscribe(p.recorder.GetTask().Context, p.StreamPath) - } - if p.Subscriber != nil { - p.Subscriber.Type = SubscribeTypeVod + p.SubConf = &config.Subscribe{SubType: SubscribeTypeVod} + p.Subscriber, err = p.Plugin.SubscribeWithConfig(p.recorder.GetTask().Context, p.StreamPath, *p.SubConf) } return } diff --git a/subscriber.go b/subscriber.go index 9cceb93..5d98290 100644 --- a/subscriber.go +++ b/subscriber.go @@ -84,7 +84,11 @@ func createSubscriber(p *Plugin, streamPath string, conf config.Subscribe) *Subs subscriber := &Subscriber{Subscribe: conf, waitPublishDone: make(chan struct{})} subscriber.ID = task.GetNextTaskID() subscriber.Plugin = p - subscriber.Type = SubscribeTypeServer + if conf.SubType != "" { + subscriber.Type = conf.SubType + } else { + subscriber.Type = SubscribeTypeServer + } subscriber.Logger = p.Logger.With("streamPath", streamPath, "sId", subscriber.ID) subscriber.Init(streamPath, &subscriber.Subscribe) if subscriber.Subscribe.BufferTime > 0 {