diff --git a/config/config.go b/config/config.go index 7aa3c72..f0ba6dc 100644 --- a/config/config.go +++ b/config/config.go @@ -80,8 +80,14 @@ func (config Config) Unmarshal(s any) { //字段映射,小写对应的大写 nameMap := make(map[string]string) for i, j := 0, t.NumField(); i < j; i++ { - name := t.Field(i).Name - nameMap[strings.ToLower(name)] = name + field := t.Field(i) + name := field.Name + if tag := field.Tag.Get("yaml"); tag != "" { + name, _, _ = strings.Cut(tag, ",") + } else { + name = strings.ToLower(name) + } + nameMap[name] = field.Name } for k, v := range config { name, ok := nameMap[k] @@ -190,6 +196,12 @@ func Struct2Config(s any, prefix ...string) (config Config) { continue } name := strings.ToLower(ft.Name) + if tag := ft.Tag.Get("yaml"); tag != "" { + if tag == "-" { + continue + } + name, _, _ = strings.Cut(tag, ",") + } var envPath []string if len(prefix) > 0 { envPath = append(prefix, strings.ToUpper(ft.Name)) @@ -201,9 +213,6 @@ func Struct2Config(s any, prefix ...string) (config Config) { return } } - if ft.Tag.Get("json") == "-" { - continue - } switch ft.Type.Kind() { case reflect.Struct: config[name] = Struct2Config(fv, envPath...) diff --git a/config/types.go b/config/types.go index dcb82d2..67bfcf3 100755 --- a/config/types.go +++ b/config/types.go @@ -109,7 +109,7 @@ func (p *Push) AddPush(url string, streamPath string) { if p.PushList == nil { p.PushList = make(map[string]string) } - p.PushList[url] = streamPath + p.PushList[streamPath] = url } type Console struct { diff --git a/http.go b/http.go index d7c5167..8ccda43 100644 --- a/http.go +++ b/http.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "os" + "strconv" "strings" "time" @@ -233,14 +234,18 @@ func (conf *GlobalConfig) API_replay_rtpdump(w http.ResponseWriter, r *http.Requ } cv := q.Get("vcodec") ca := q.Get("acodec") + cvp := q.Get("vpayload") + cap := q.Get("apayload") var pub RTPDumpPublisher + i, _ := strconv.ParseInt(cvp, 10, 64) + pub.VPayloadType = byte(i) + i, _ = strconv.ParseInt(cap, 10, 64) + pub.APayloadType = byte(i) switch cv { case "h264": pub.VCodec = codec.CodecID_H264 case "h265": pub.VCodec = codec.CodecID_H265 - default: - pub.VCodec = codec.CodecID_H264 } switch ca { case "aac": @@ -249,8 +254,6 @@ func (conf *GlobalConfig) API_replay_rtpdump(w http.ResponseWriter, r *http.Requ pub.ACodec = codec.CodecID_PCMA case "pcmu": pub.ACodec = codec.CodecID_PCMU - default: - pub.ACodec = codec.CodecID_AAC } ss := strings.Split(dumpFile, ",") if len(ss) > 1 { @@ -331,4 +334,4 @@ func (conf *GlobalConfig) API_replay_mp4(w http.ResponseWriter, r *http.Request) w.Write([]byte("ok")) go pub.ReadMP4Data(f) } -} \ No newline at end of file +} diff --git a/publisher-rtpdump.go b/publisher-rtpdump.go index 83be4f8..a9772fd 100644 --- a/publisher-rtpdump.go +++ b/publisher-rtpdump.go @@ -16,14 +16,16 @@ import ( type RTPDumpPublisher struct { Publisher - VCodec codec.VideoCodecID - ACodec codec.AudioCodecID - other *rtpdump.Packet + VCodec codec.VideoCodecID + ACodec codec.AudioCodecID + VPayloadType uint8 + APayloadType uint8 + other *rtpdump.Packet sync.Mutex } func (t *RTPDumpPublisher) Feed(file *os.File) { - + r, h, err := rtpdump.NewReader(file) if err != nil { t.Stream.Error("RTPDumpPublisher open file error", zap.Error(err)) @@ -38,7 +40,9 @@ func (t *RTPDumpPublisher) Feed(file *os.File) { case codec.CodecID_H265: t.VideoTrack = track.NewH265(t.Publisher.Stream) } - t.VideoTrack.SetSpeedLimit(500 * time.Millisecond) + if t.VideoTrack != nil { + t.VideoTrack.SetSpeedLimit(500 * time.Millisecond) + } } if t.AudioTrack == nil { switch t.ACodec { @@ -55,7 +59,9 @@ func (t *RTPDumpPublisher) Feed(file *os.File) { case codec.CodecID_PCMU: t.AudioTrack = track.NewG711(t.Publisher.Stream, false) } - t.AudioTrack.SetSpeedLimit(500 * time.Millisecond) + if t.AudioTrack != nil { + t.AudioTrack.SetSpeedLimit(500 * time.Millisecond) + } } t.Unlock() needLock := true @@ -92,9 +98,9 @@ func (t *RTPDumpPublisher) WriteRTP(raw []byte) { var frame common.RTPFrame frame.Unmarshal(raw) switch frame.PayloadType { - case 96: + case t.VPayloadType: t.VideoTrack.WriteRTP(&util.ListItem[common.RTPFrame]{Value: frame}) - case 97, 0, 8: + case t.APayloadType: t.AudioTrack.WriteRTP(&util.ListItem[common.RTPFrame]{Value: frame}) default: t.Stream.Warn("RTPDumpPublisher unknown payload type", zap.Uint8("payloadType", frame.PayloadType)) diff --git a/stream.go b/stream.go index b027b38..98d8912 100644 --- a/stream.go +++ b/stream.go @@ -4,6 +4,7 @@ import ( "encoding/json" "sort" "strings" + "sync" "time" "unsafe" @@ -127,10 +128,17 @@ type StreamTimeoutConfig struct { IdleTimeout time.Duration //无订阅者后超时,不需要订阅即可激活 } type Tracks struct { - util.Map[string, Track] + sync.Map MainVideo *track.Video } +func (tracks *Tracks) Range(f func(name string, t Track)) { + tracks.Map.Range(func(k, v any) bool { + f(k.(string), v.(Track)) + return true + }) +} + func (tracks *Tracks) Add(name string, t Track) bool { switch v := t.(type) { case *track.Video: @@ -143,12 +151,13 @@ func (tracks *Tracks) Add(name string, t Track) bool { v.Narrow() } } - return tracks.Map.Add(name, t) + _, loaded := tracks.LoadOrStore(name, t) + return !loaded } func (tracks *Tracks) SetIDR(video Track) { if video == tracks.MainVideo { - tracks.Map.Range(func(_ string, t Track) { + tracks.Range(func(_ string, t Track) { if v, ok := t.(*track.Audio); ok { v.Narrow() } @@ -157,10 +166,12 @@ func (tracks *Tracks) SetIDR(video Track) { } func (tracks *Tracks) MarshalJSON() ([]byte, error) { - return json.Marshal(util.MapList(&tracks.Map, func(_ string, t Track) Track { + var trackList []Track + tracks.Range(func(_ string, t Track) { t.SnapForJson() - return t - })) + trackList = append(trackList, t) + }) + return json.Marshal(trackList) } // Stream 流定义 @@ -209,10 +220,10 @@ func (s *Stream) Summary() (r StreamSummay) { if s.Publisher != nil { r.Type = s.Publisher.GetPublisher().Type } - r.Tracks = util.MapList(&s.Tracks.Map, func(name string, t Track) string { + s.Tracks.Range(func(name string, t Track) { b := t.GetBase() r.BPS += b.BPS - return name + r.Tracks = append(r.Tracks, name) }) r.Path = s.Path r.State = s.State @@ -250,7 +261,6 @@ func findOrCreateStream(streamPath string, waitTimeout time.Duration) (s *Stream s.Info("created") Streams.Map[streamPath] = s s.actionChan.Init(1) - s.Tracks.Init() go s.run() return s, true } @@ -411,7 +421,9 @@ func (s *Stream) run() { } } hasTrackTimeout := false + trackCount := 0 s.Tracks.Range(func(name string, t Track) { + trackCount++ if _, ok := t.(track.Custom); ok { return } @@ -421,7 +433,7 @@ func (s *Stream) run() { hasTrackTimeout = true } }) - if hasTrackTimeout || (s.Publisher != nil && s.Publisher.IsClosed()) { + if trackCount == 0 || hasTrackTimeout || (s.Publisher != nil && s.Publisher.IsClosed()) { s.action(ACTION_PUBLISHLOST) } else { s.timeout.Reset(time.Second * 5) @@ -444,11 +456,12 @@ func (s *Stream) run() { if s.IsClosed() { v.Reject(ErrStreamIsClosed) } - republish := s.Publisher == v.Value // 重复发布 + republish := s.Publisher == v.Value // 重复发布 + kicked := !republish && s.Publisher != nil && s.Publisher.IsClosed() // 被踢下线 if !republish { s.Publisher = v.Value } - if s.action(ACTION_PUBLISH) || republish { + if s.action(ACTION_PUBLISH) || republish || kicked { v.Resolve() } else { v.Reject(ErrBadStreamName) @@ -507,12 +520,9 @@ func (s *Stream) run() { case TrackRemoved: timeOutInfo = zap.String("action", "TrackRemoved") name := v.GetBase().Name - if t, ok := s.Tracks.Delete(name); ok { + if t, ok := s.Tracks.LoadAndDelete(name); ok { s.Info("track -1", zap.String("name", name)) s.Subscribers.Broadcast(t) - if s.Tracks.Len() == 0 { - s.action(ACTION_PUBLISHLOST) - } if dt, ok := t.(track.Custom); ok { dt.Dispose() } @@ -526,11 +536,11 @@ func (s *Stream) run() { name := v.Value.GetBase().Name if _, ok := v.Value.(*track.Video); ok && !pubConfig.PubVideo { v.Reject(ErrTrackMute) - return + continue } if _, ok := v.Value.(*track.Audio); ok && !pubConfig.PubAudio { v.Reject(ErrTrackMute) - return + continue } if s.Tracks.Add(name, v.Value) { v.Resolve() diff --git a/track/audio.go b/track/audio.go index 8b3730a..d929bec 100644 --- a/track/audio.go +++ b/track/audio.go @@ -23,6 +23,7 @@ func (a *Audio) Attach() { if a.Attached.CompareAndSwap(false, true) { if err := a.Stream.AddTrack(a).Await(); err != nil { a.Error("attach audio track failed", zap.Error(err)) + a.Attached.Store(false) } else { a.Info("audio track attached", zap.Uint32("sample rate", a.SampleRate)) } diff --git a/track/base.go b/track/base.go index 246c9d0..10ca67f 100644 --- a/track/base.go +++ b/track/base.go @@ -121,6 +121,8 @@ func (av *Media) SnapForJson() { v := av.LastValue if av.RawPart != nil { av.RawPart = av.RawPart[:0] + } else { + av.RawPart = make([]int, 0, 10) } if av.RawSize = v.AUList.ByteLength; av.RawSize > 0 { r := v.AUList.NewReader() diff --git a/track/video.go b/track/video.go index f1afbdd..5de4667 100644 --- a/track/video.go +++ b/track/video.go @@ -32,6 +32,7 @@ func (v *Video) Attach() { if v.Attached.CompareAndSwap(false, true) { if err := v.Stream.AddTrack(v).Await(); err != nil { v.Error("attach video track failed", zap.Error(err)) + v.Attached.Store(false) } else { v.Info("video track attached", zap.Uint("width", v.Width), zap.Uint("height", v.Height)) } diff --git a/util/strutct.go b/util/strutct.go new file mode 100644 index 0000000..0bb1d4e --- /dev/null +++ b/util/strutct.go @@ -0,0 +1,101 @@ +package util + +import ( + "errors" + "reflect" +) + +// 构造器 +type Builder struct { // 用于存储属性字段 + fileId []reflect.StructField +} + +func NewBuilder() *Builder { + return &Builder{} +} + +// 添加字段 +func (b *Builder) AddField(field string, typ reflect.Type) *Builder { + b.fileId = append(b.fileId, reflect.StructField{Name: field, Type: typ}) + return b +} + +// 根据预先添加的字段构建出结构体 +func (b *Builder) Build() *Struct { + stu := reflect.StructOf(b.fileId) + index := make(map[string]int) + for i := 0; i < stu.NumField(); i++ { + index[stu.Field(i).Name] = i + } + return &Struct{stu, index} +} +func (b *Builder) AddString(name string) *Builder { + return b.AddField(name, reflect.TypeOf("")) +} +func (b *Builder) AddBool(name string) *Builder { + return b.AddField(name, reflect.TypeOf(true)) +} +func (b *Builder) AddInt64(name string) *Builder { + return b.AddField(name, reflect.TypeOf(int64(0))) +} +func (b *Builder) AddFloat64(name string) *Builder { + return b.AddField(name, reflect.TypeOf(float64(1.2))) +} + +// 实际生成的结构体,基类 +// 结构体的类型 +type Struct struct { + typ reflect.Type + // + // 用于通过字段名称,从Builder的[]reflect.StructField中获取reflect.StructField + index map[string]int +} + +func (s Struct) New() *Instance { + return &Instance{reflect.New(s.typ).Elem(), s.index} +} + +// 结构体的值 +type Instance struct { + instance reflect.Value + // + index map[string]int +} + +var ( + FieldNoExist error = errors.New("field no exist") +) + +func (in Instance) Field(name string) (reflect.Value, error) { + if i, ok := in.index[name]; ok { + return in.instance.Field(i), nil + } else { + return reflect.Value{}, FieldNoExist + } +} +func (in *Instance) SetString(name, value string) { + if i, ok := in.index[name]; ok { + in.instance.Field(i).SetString(value) + } +} +func (in *Instance) SetBool(name string, value bool) { + if i, ok := in.index[name]; ok { + in.instance.Field(i).SetBool(value) + } +} +func (in *Instance) SetInt64(name string, value int64) { + if i, ok := in.index[name]; ok { + in.instance.Field(i).SetInt(value) + } +} +func (in *Instance) SetFloat64(name string, value float64) { + if i, ok := in.index[name]; ok { + in.instance.Field(i).SetFloat(value) + } +} +func (i *Instance) Interface() interface{} { + return i.instance.Interface() +} +func (i *Instance) Addr() interface{} { + return i.instance.Addr().Interface() +}