fix: change tracks to sync.Map to void dead lock

This commit is contained in:
langhuihui
2023-05-21 14:27:57 +08:00
parent 3ee931ce61
commit e21f8f765a
9 changed files with 170 additions and 37 deletions

View File

@@ -80,8 +80,14 @@ func (config Config) Unmarshal(s any) {
//字段映射,小写对应的大写 //字段映射,小写对应的大写
nameMap := make(map[string]string) nameMap := make(map[string]string)
for i, j := 0, t.NumField(); i < j; i++ { for i, j := 0, t.NumField(); i < j; i++ {
name := t.Field(i).Name field := t.Field(i)
nameMap[strings.ToLower(name)] = name name := field.Name
if tag := field.Tag.Get("yaml"); tag != "" {
name, _, _ = strings.Cut(tag, ",")
} else {
name = strings.ToLower(name)
}
nameMap[name] = field.Name
} }
for k, v := range config { for k, v := range config {
name, ok := nameMap[k] name, ok := nameMap[k]
@@ -190,6 +196,12 @@ func Struct2Config(s any, prefix ...string) (config Config) {
continue continue
} }
name := strings.ToLower(ft.Name) name := strings.ToLower(ft.Name)
if tag := ft.Tag.Get("yaml"); tag != "" {
if tag == "-" {
continue
}
name, _, _ = strings.Cut(tag, ",")
}
var envPath []string var envPath []string
if len(prefix) > 0 { if len(prefix) > 0 {
envPath = append(prefix, strings.ToUpper(ft.Name)) envPath = append(prefix, strings.ToUpper(ft.Name))
@@ -201,9 +213,6 @@ func Struct2Config(s any, prefix ...string) (config Config) {
return return
} }
} }
if ft.Tag.Get("json") == "-" {
continue
}
switch ft.Type.Kind() { switch ft.Type.Kind() {
case reflect.Struct: case reflect.Struct:
config[name] = Struct2Config(fv, envPath...) config[name] = Struct2Config(fv, envPath...)

View File

@@ -109,7 +109,7 @@ func (p *Push) AddPush(url string, streamPath string) {
if p.PushList == nil { if p.PushList == nil {
p.PushList = make(map[string]string) p.PushList = make(map[string]string)
} }
p.PushList[url] = streamPath p.PushList[streamPath] = url
} }
type Console struct { type Console struct {

13
http.go
View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
@@ -233,14 +234,18 @@ func (conf *GlobalConfig) API_replay_rtpdump(w http.ResponseWriter, r *http.Requ
} }
cv := q.Get("vcodec") cv := q.Get("vcodec")
ca := q.Get("acodec") ca := q.Get("acodec")
cvp := q.Get("vpayload")
cap := q.Get("apayload")
var pub RTPDumpPublisher var pub RTPDumpPublisher
i, _ := strconv.ParseInt(cvp, 10, 64)
pub.VPayloadType = byte(i)
i, _ = strconv.ParseInt(cap, 10, 64)
pub.APayloadType = byte(i)
switch cv { switch cv {
case "h264": case "h264":
pub.VCodec = codec.CodecID_H264 pub.VCodec = codec.CodecID_H264
case "h265": case "h265":
pub.VCodec = codec.CodecID_H265 pub.VCodec = codec.CodecID_H265
default:
pub.VCodec = codec.CodecID_H264
} }
switch ca { switch ca {
case "aac": case "aac":
@@ -249,8 +254,6 @@ func (conf *GlobalConfig) API_replay_rtpdump(w http.ResponseWriter, r *http.Requ
pub.ACodec = codec.CodecID_PCMA pub.ACodec = codec.CodecID_PCMA
case "pcmu": case "pcmu":
pub.ACodec = codec.CodecID_PCMU pub.ACodec = codec.CodecID_PCMU
default:
pub.ACodec = codec.CodecID_AAC
} }
ss := strings.Split(dumpFile, ",") ss := strings.Split(dumpFile, ",")
if len(ss) > 1 { if len(ss) > 1 {
@@ -331,4 +334,4 @@ func (conf *GlobalConfig) API_replay_mp4(w http.ResponseWriter, r *http.Request)
w.Write([]byte("ok")) w.Write([]byte("ok"))
go pub.ReadMP4Data(f) go pub.ReadMP4Data(f)
} }
} }

View File

@@ -16,14 +16,16 @@ import (
type RTPDumpPublisher struct { type RTPDumpPublisher struct {
Publisher Publisher
VCodec codec.VideoCodecID VCodec codec.VideoCodecID
ACodec codec.AudioCodecID ACodec codec.AudioCodecID
other *rtpdump.Packet VPayloadType uint8
APayloadType uint8
other *rtpdump.Packet
sync.Mutex sync.Mutex
} }
func (t *RTPDumpPublisher) Feed(file *os.File) { func (t *RTPDumpPublisher) Feed(file *os.File) {
r, h, err := rtpdump.NewReader(file) r, h, err := rtpdump.NewReader(file)
if err != nil { if err != nil {
t.Stream.Error("RTPDumpPublisher open file error", zap.Error(err)) t.Stream.Error("RTPDumpPublisher open file error", zap.Error(err))
@@ -38,7 +40,9 @@ func (t *RTPDumpPublisher) Feed(file *os.File) {
case codec.CodecID_H265: case codec.CodecID_H265:
t.VideoTrack = track.NewH265(t.Publisher.Stream) t.VideoTrack = track.NewH265(t.Publisher.Stream)
} }
t.VideoTrack.SetSpeedLimit(500 * time.Millisecond) if t.VideoTrack != nil {
t.VideoTrack.SetSpeedLimit(500 * time.Millisecond)
}
} }
if t.AudioTrack == nil { if t.AudioTrack == nil {
switch t.ACodec { switch t.ACodec {
@@ -55,7 +59,9 @@ func (t *RTPDumpPublisher) Feed(file *os.File) {
case codec.CodecID_PCMU: case codec.CodecID_PCMU:
t.AudioTrack = track.NewG711(t.Publisher.Stream, false) t.AudioTrack = track.NewG711(t.Publisher.Stream, false)
} }
t.AudioTrack.SetSpeedLimit(500 * time.Millisecond) if t.AudioTrack != nil {
t.AudioTrack.SetSpeedLimit(500 * time.Millisecond)
}
} }
t.Unlock() t.Unlock()
needLock := true needLock := true
@@ -92,9 +98,9 @@ func (t *RTPDumpPublisher) WriteRTP(raw []byte) {
var frame common.RTPFrame var frame common.RTPFrame
frame.Unmarshal(raw) frame.Unmarshal(raw)
switch frame.PayloadType { switch frame.PayloadType {
case 96: case t.VPayloadType:
t.VideoTrack.WriteRTP(&util.ListItem[common.RTPFrame]{Value: frame}) t.VideoTrack.WriteRTP(&util.ListItem[common.RTPFrame]{Value: frame})
case 97, 0, 8: case t.APayloadType:
t.AudioTrack.WriteRTP(&util.ListItem[common.RTPFrame]{Value: frame}) t.AudioTrack.WriteRTP(&util.ListItem[common.RTPFrame]{Value: frame})
default: default:
t.Stream.Warn("RTPDumpPublisher unknown payload type", zap.Uint8("payloadType", frame.PayloadType)) t.Stream.Warn("RTPDumpPublisher unknown payload type", zap.Uint8("payloadType", frame.PayloadType))

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"sort" "sort"
"strings" "strings"
"sync"
"time" "time"
"unsafe" "unsafe"
@@ -127,10 +128,17 @@ type StreamTimeoutConfig struct {
IdleTimeout time.Duration //无订阅者后超时,不需要订阅即可激活 IdleTimeout time.Duration //无订阅者后超时,不需要订阅即可激活
} }
type Tracks struct { type Tracks struct {
util.Map[string, Track] sync.Map
MainVideo *track.Video MainVideo *track.Video
} }
func (tracks *Tracks) Range(f func(name string, t Track)) {
tracks.Map.Range(func(k, v any) bool {
f(k.(string), v.(Track))
return true
})
}
func (tracks *Tracks) Add(name string, t Track) bool { func (tracks *Tracks) Add(name string, t Track) bool {
switch v := t.(type) { switch v := t.(type) {
case *track.Video: case *track.Video:
@@ -143,12 +151,13 @@ func (tracks *Tracks) Add(name string, t Track) bool {
v.Narrow() v.Narrow()
} }
} }
return tracks.Map.Add(name, t) _, loaded := tracks.LoadOrStore(name, t)
return !loaded
} }
func (tracks *Tracks) SetIDR(video Track) { func (tracks *Tracks) SetIDR(video Track) {
if video == tracks.MainVideo { if video == tracks.MainVideo {
tracks.Map.Range(func(_ string, t Track) { tracks.Range(func(_ string, t Track) {
if v, ok := t.(*track.Audio); ok { if v, ok := t.(*track.Audio); ok {
v.Narrow() v.Narrow()
} }
@@ -157,10 +166,12 @@ func (tracks *Tracks) SetIDR(video Track) {
} }
func (tracks *Tracks) MarshalJSON() ([]byte, error) { func (tracks *Tracks) MarshalJSON() ([]byte, error) {
return json.Marshal(util.MapList(&tracks.Map, func(_ string, t Track) Track { var trackList []Track
tracks.Range(func(_ string, t Track) {
t.SnapForJson() t.SnapForJson()
return t trackList = append(trackList, t)
})) })
return json.Marshal(trackList)
} }
// Stream 流定义 // Stream 流定义
@@ -209,10 +220,10 @@ func (s *Stream) Summary() (r StreamSummay) {
if s.Publisher != nil { if s.Publisher != nil {
r.Type = s.Publisher.GetPublisher().Type r.Type = s.Publisher.GetPublisher().Type
} }
r.Tracks = util.MapList(&s.Tracks.Map, func(name string, t Track) string { s.Tracks.Range(func(name string, t Track) {
b := t.GetBase() b := t.GetBase()
r.BPS += b.BPS r.BPS += b.BPS
return name r.Tracks = append(r.Tracks, name)
}) })
r.Path = s.Path r.Path = s.Path
r.State = s.State r.State = s.State
@@ -250,7 +261,6 @@ func findOrCreateStream(streamPath string, waitTimeout time.Duration) (s *Stream
s.Info("created") s.Info("created")
Streams.Map[streamPath] = s Streams.Map[streamPath] = s
s.actionChan.Init(1) s.actionChan.Init(1)
s.Tracks.Init()
go s.run() go s.run()
return s, true return s, true
} }
@@ -411,7 +421,9 @@ func (s *Stream) run() {
} }
} }
hasTrackTimeout := false hasTrackTimeout := false
trackCount := 0
s.Tracks.Range(func(name string, t Track) { s.Tracks.Range(func(name string, t Track) {
trackCount++
if _, ok := t.(track.Custom); ok { if _, ok := t.(track.Custom); ok {
return return
} }
@@ -421,7 +433,7 @@ func (s *Stream) run() {
hasTrackTimeout = true hasTrackTimeout = true
} }
}) })
if hasTrackTimeout || (s.Publisher != nil && s.Publisher.IsClosed()) { if trackCount == 0 || hasTrackTimeout || (s.Publisher != nil && s.Publisher.IsClosed()) {
s.action(ACTION_PUBLISHLOST) s.action(ACTION_PUBLISHLOST)
} else { } else {
s.timeout.Reset(time.Second * 5) s.timeout.Reset(time.Second * 5)
@@ -444,11 +456,12 @@ func (s *Stream) run() {
if s.IsClosed() { if s.IsClosed() {
v.Reject(ErrStreamIsClosed) v.Reject(ErrStreamIsClosed)
} }
republish := s.Publisher == v.Value // 重复发布 republish := s.Publisher == v.Value // 重复发布
kicked := !republish && s.Publisher != nil && s.Publisher.IsClosed() // 被踢下线
if !republish { if !republish {
s.Publisher = v.Value s.Publisher = v.Value
} }
if s.action(ACTION_PUBLISH) || republish { if s.action(ACTION_PUBLISH) || republish || kicked {
v.Resolve() v.Resolve()
} else { } else {
v.Reject(ErrBadStreamName) v.Reject(ErrBadStreamName)
@@ -507,12 +520,9 @@ func (s *Stream) run() {
case TrackRemoved: case TrackRemoved:
timeOutInfo = zap.String("action", "TrackRemoved") timeOutInfo = zap.String("action", "TrackRemoved")
name := v.GetBase().Name name := v.GetBase().Name
if t, ok := s.Tracks.Delete(name); ok { if t, ok := s.Tracks.LoadAndDelete(name); ok {
s.Info("track -1", zap.String("name", name)) s.Info("track -1", zap.String("name", name))
s.Subscribers.Broadcast(t) s.Subscribers.Broadcast(t)
if s.Tracks.Len() == 0 {
s.action(ACTION_PUBLISHLOST)
}
if dt, ok := t.(track.Custom); ok { if dt, ok := t.(track.Custom); ok {
dt.Dispose() dt.Dispose()
} }
@@ -526,11 +536,11 @@ func (s *Stream) run() {
name := v.Value.GetBase().Name name := v.Value.GetBase().Name
if _, ok := v.Value.(*track.Video); ok && !pubConfig.PubVideo { if _, ok := v.Value.(*track.Video); ok && !pubConfig.PubVideo {
v.Reject(ErrTrackMute) v.Reject(ErrTrackMute)
return continue
} }
if _, ok := v.Value.(*track.Audio); ok && !pubConfig.PubAudio { if _, ok := v.Value.(*track.Audio); ok && !pubConfig.PubAudio {
v.Reject(ErrTrackMute) v.Reject(ErrTrackMute)
return continue
} }
if s.Tracks.Add(name, v.Value) { if s.Tracks.Add(name, v.Value) {
v.Resolve() v.Resolve()

View File

@@ -23,6 +23,7 @@ func (a *Audio) Attach() {
if a.Attached.CompareAndSwap(false, true) { if a.Attached.CompareAndSwap(false, true) {
if err := a.Stream.AddTrack(a).Await(); err != nil { if err := a.Stream.AddTrack(a).Await(); err != nil {
a.Error("attach audio track failed", zap.Error(err)) a.Error("attach audio track failed", zap.Error(err))
a.Attached.Store(false)
} else { } else {
a.Info("audio track attached", zap.Uint32("sample rate", a.SampleRate)) a.Info("audio track attached", zap.Uint32("sample rate", a.SampleRate))
} }

View File

@@ -121,6 +121,8 @@ func (av *Media) SnapForJson() {
v := av.LastValue v := av.LastValue
if av.RawPart != nil { if av.RawPart != nil {
av.RawPart = av.RawPart[:0] av.RawPart = av.RawPart[:0]
} else {
av.RawPart = make([]int, 0, 10)
} }
if av.RawSize = v.AUList.ByteLength; av.RawSize > 0 { if av.RawSize = v.AUList.ByteLength; av.RawSize > 0 {
r := v.AUList.NewReader() r := v.AUList.NewReader()

View File

@@ -32,6 +32,7 @@ func (v *Video) Attach() {
if v.Attached.CompareAndSwap(false, true) { if v.Attached.CompareAndSwap(false, true) {
if err := v.Stream.AddTrack(v).Await(); err != nil { if err := v.Stream.AddTrack(v).Await(); err != nil {
v.Error("attach video track failed", zap.Error(err)) v.Error("attach video track failed", zap.Error(err))
v.Attached.Store(false)
} else { } else {
v.Info("video track attached", zap.Uint("width", v.Width), zap.Uint("height", v.Height)) v.Info("video track attached", zap.Uint("width", v.Width), zap.Uint("height", v.Height))
} }

101
util/strutct.go Normal file
View File

@@ -0,0 +1,101 @@
package util
import (
"errors"
"reflect"
)
// 构造器
type Builder struct { // 用于存储属性字段
fileId []reflect.StructField
}
func NewBuilder() *Builder {
return &Builder{}
}
// 添加字段
func (b *Builder) AddField(field string, typ reflect.Type) *Builder {
b.fileId = append(b.fileId, reflect.StructField{Name: field, Type: typ})
return b
}
// 根据预先添加的字段构建出结构体
func (b *Builder) Build() *Struct {
stu := reflect.StructOf(b.fileId)
index := make(map[string]int)
for i := 0; i < stu.NumField(); i++ {
index[stu.Field(i).Name] = i
}
return &Struct{stu, index}
}
func (b *Builder) AddString(name string) *Builder {
return b.AddField(name, reflect.TypeOf(""))
}
func (b *Builder) AddBool(name string) *Builder {
return b.AddField(name, reflect.TypeOf(true))
}
func (b *Builder) AddInt64(name string) *Builder {
return b.AddField(name, reflect.TypeOf(int64(0)))
}
func (b *Builder) AddFloat64(name string) *Builder {
return b.AddField(name, reflect.TypeOf(float64(1.2)))
}
// 实际生成的结构体,基类
// 结构体的类型
type Struct struct {
typ reflect.Type
// <fieldName : 索引>
// 用于通过字段名称从Builder的[]reflect.StructField中获取reflect.StructField
index map[string]int
}
func (s Struct) New() *Instance {
return &Instance{reflect.New(s.typ).Elem(), s.index}
}
// 结构体的值
type Instance struct {
instance reflect.Value
// <fieldName : 索引>
index map[string]int
}
var (
FieldNoExist error = errors.New("field no exist")
)
func (in Instance) Field(name string) (reflect.Value, error) {
if i, ok := in.index[name]; ok {
return in.instance.Field(i), nil
} else {
return reflect.Value{}, FieldNoExist
}
}
func (in *Instance) SetString(name, value string) {
if i, ok := in.index[name]; ok {
in.instance.Field(i).SetString(value)
}
}
func (in *Instance) SetBool(name string, value bool) {
if i, ok := in.index[name]; ok {
in.instance.Field(i).SetBool(value)
}
}
func (in *Instance) SetInt64(name string, value int64) {
if i, ok := in.index[name]; ok {
in.instance.Field(i).SetInt(value)
}
}
func (in *Instance) SetFloat64(name string, value float64) {
if i, ok := in.index[name]; ok {
in.instance.Field(i).SetFloat(value)
}
}
func (i *Instance) Interface() interface{} {
return i.instance.Interface()
}
func (i *Instance) Addr() interface{} {
return i.instance.Addr().Interface()
}