diff --git a/example/multiple/main.go b/example/multiple/main.go index 035f93a..9fa8c57 100644 --- a/example/multiple/main.go +++ b/example/multiple/main.go @@ -16,6 +16,6 @@ import ( func main() { ctx := context.Background() // ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(time.Second*100)) - go m7s.Run(ctx, "config1.yaml") - m7s.NewServer().Run(ctx, "config2.yaml") + m7s.AddRootTaskWithContext(ctx, m7s.NewServer("config2.yaml")) + m7s.Run(ctx, "config1.yaml") } diff --git a/example/rtmp-push/main.go b/example/rtmp-push/main.go index 0706b67..5bf5815 100644 --- a/example/rtmp-push/main.go +++ b/example/rtmp-push/main.go @@ -17,8 +17,8 @@ func main() { flag.BoolVar(&multi, "multi", false, "debug") flag.Parse() if multi { - go m7s.Run(ctx, "config1.yaml") + m7s.AddRootTaskWithContext(ctx, m7s.NewServer("config2.yaml")) } time.Sleep(time.Second) - m7s.NewServer().Run(ctx, "config2.yaml") + m7s.Run(ctx, "config1.yaml") } diff --git a/example/rtsp-pull/main.go b/example/rtsp-pull/main.go index 300d2de..da0a2a5 100644 --- a/example/rtsp-pull/main.go +++ b/example/rtsp-pull/main.go @@ -17,8 +17,8 @@ func main() { flag.BoolVar(&multi, "multi", false, "debug") flag.Parse() if multi { - go m7s.Run(ctx, "config1.yaml") + m7s.AddRootTaskWithContext(ctx, m7s.NewServer("config2.yaml")) } time.Sleep(time.Second) - m7s.NewServer().Run(ctx, "config2.yaml") + m7s.Run(ctx, "config1.yaml") } diff --git a/example/rtsp-push/main.go b/example/rtsp-push/main.go index 300d2de..da0a2a5 100644 --- a/example/rtsp-push/main.go +++ b/example/rtsp-push/main.go @@ -17,8 +17,8 @@ func main() { flag.BoolVar(&multi, "multi", false, "debug") flag.Parse() if multi { - go m7s.Run(ctx, "config1.yaml") + m7s.AddRootTaskWithContext(ctx, m7s.NewServer("config2.yaml")) } time.Sleep(time.Second) - m7s.NewServer().Run(ctx, "config2.yaml") + m7s.Run(ctx, "config1.yaml") } diff --git a/pkg/error.go b/pkg/error.go index 958891d..6bd29e0 100644 --- a/pkg/error.go +++ b/pkg/error.go @@ -11,7 +11,6 @@ var ( ErrPublishTimeout = errors.New("publish timeout") ErrPublishIdleTimeout = errors.New("publish idle timeout") ErrPublishDelayCloseTimeout = errors.New("publish delay close timeout") - ErrPublishWaitCloseTimeout = errors.New("publish wait close timeout") ErrPushRemoteURLExist = errors.New("push remote url exist") ErrSubscribeTimeout = errors.New("subscribe timeout") ErrRestart = errors.New("restart") diff --git a/pkg/util/index.go b/pkg/util/index.go index dd1ed57..7a08abb 100644 --- a/pkg/util/index.go +++ b/pkg/util/index.go @@ -67,3 +67,8 @@ func initFatalLog() *os.File { } return logFile } + +func Exist(filename string) bool { + _, err := os.Stat(filename) + return err == nil || os.IsExist(err) +} diff --git a/pkg/util/task-macro.go b/pkg/util/task-macro.go index a2b8e0d..4073781 100644 --- a/pkg/util/task-macro.go +++ b/pkg/util/task-macro.go @@ -10,15 +10,19 @@ import ( "sync/atomic" ) -var RootTask MarcoLongTask var idG atomic.Uint32 func GetNextTaskID() uint32 { return idG.Add(1) } +var RootTask MarcoLongTask + func init() { RootTask.initTask(context.Background(), &RootTask) + RootTask.Description = map[string]any{ + "title": "RootTask", + } RootTask.Logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) } @@ -158,6 +162,10 @@ func (mt *MarcoTask) AddChan(channel any, callback any) *ChannelTask { func (mt *MarcoTask) run() { cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(mt.addSub)}} defer func() { + err := recover() + if err != nil { + mt.Stop(err.(error)) + } stopReason := mt.StopReason() for _, task := range mt.children { task.Stop(stopReason) diff --git a/pkg/util/task.go b/pkg/util/task.go index dc87ff8..c0c60ad 100644 --- a/pkg/util/task.go +++ b/pkg/util/task.go @@ -122,9 +122,12 @@ func (task *Task) StopReason() error { } func (task *Task) Stop(err error) { + if err == nil { + panic("task stop with nil error") + } if task.CancelCauseFunc != nil { if task.Logger != nil { - task.Debug("task stop", "reason", err.Error(), "elapsed", time.Since(task.StartTime), "taskId", task.ID, "taskType", task.GetTaskType(), "ownerType", task.GetOwnerType()) + task.Debug("task stop", "reason", err, "elapsed", time.Since(task.StartTime), "taskId", task.ID, "taskType", task.GetTaskType(), "ownerType", task.GetOwnerType()) } task.CancelCauseFunc(err) } diff --git a/plugin/flv/api.go b/plugin/flv/api.go new file mode 100644 index 0000000..40c55a1 --- /dev/null +++ b/plugin/flv/api.go @@ -0,0 +1,266 @@ +package plugin_flv + +import ( + "bufio" + "encoding/binary" + "io" + "io/fs" + "m7s.live/m7s/v5/pkg/util" + flv "m7s.live/m7s/v5/plugin/flv/pkg" + rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +func (plugin *FLVPlugin) Download(w http.ResponseWriter, r *http.Request) { + streamPath := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/download/"), ".flv") + singleFile := filepath.Join(plugin.Path, streamPath+".flv") + query := r.URL.Query() + rangeStr := strings.Split(query.Get("range"), "-") + s, err := strconv.Atoi(rangeStr[0]) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + startTime := time.UnixMilli(int64(s)) + e, err := strconv.Atoi(rangeStr[1]) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + endTime := time.UnixMilli(int64(e)) + timeRange := endTime.Sub(startTime) + plugin.Info("download", "stream", streamPath, "start", startTime, "end", endTime) + dir := filepath.Join(plugin.Path, streamPath) + if util.Exist(singleFile) { + + } else if util.Exist(dir) { + var fileList []fs.FileInfo + var found bool + var startOffsetTime time.Duration + err = filepath.Walk(dir, func(path string, info fs.FileInfo, err error) error { + if info.IsDir() || !strings.HasSuffix(info.Name(), ".flv") { + return nil + } + modTime := info.ModTime() + //tmp, _ := strconv.Atoi(strings.TrimSuffix(info.Name(), ".flv")) + //fileStartTime := time.Unix(tmp, 10) + if !found { + if modTime.After(startTime) { + found = true + //fmt.Println(path, modTime, startTime, found) + } else { + fileList = []fs.FileInfo{info} + startOffsetTime = startTime.Sub(modTime) + //fmt.Println(path, modTime, startTime, found) + return nil + } + } + if modTime.After(endTime) { + return fs.ErrInvalid + } + fileList = append(fileList, info) + return nil + }) + if !found { + http.NotFound(w, r) + return + } + + w.Header().Set("Content-Type", "video/x-flv") + w.Header().Set("Content-Disposition", "attachment") + var writer io.Writer = w + flvHead := make([]byte, 9+4) + tagHead := make(util.Buffer, 11) + var contentLength uint64 + + var amf *rtmp.AMF + var metaData rtmp.EcmaArray + var filepositions []uint64 + var times []float64 + for pass := 0; pass < 2; pass++ { + offsetTime := startOffsetTime + var offsetTimestamp, lastTimestamp uint32 + var init, seqAudioWritten, seqVideoWritten bool + if pass == 1 { + metaData["keyframes"] = map[string]any{ + "filepositions": filepositions, + "times": times, + } + amf.Marshals("onMetaData", metaData) + offsetDelta := amf.Len() + 15 + offset := offsetDelta + len(flvHead) + contentLength += uint64(offset) + metaData["duration"] = timeRange.Seconds() + metaData["filesize"] = contentLength + for i := range filepositions { + filepositions[i] += uint64(offset) + } + metaData["keyframes"] = map[string]any{ + "filepositions": filepositions, + "times": times, + } + amf.Reset() + amf.Marshals("onMetaData", metaData) + plugin.Info("start download", "metaData", metaData) + w.Header().Set("Content-Length", strconv.FormatInt(int64(contentLength), 10)) + w.WriteHeader(http.StatusOK) + } + if offsetTime == 0 { + init = true + } else { + offsetTimestamp = -uint32(offsetTime.Milliseconds()) + } + for i, info := range fileList { + if r.Context().Err() != nil { + return + } + filePath := filepath.Join(dir, info.Name()) + plugin.Debug("read", "file", filePath) + file, err := os.Open(filePath) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + reader := bufio.NewReader(file) + if i == 0 { + _, err = io.ReadFull(reader, flvHead) + if pass == 1 { + // 第一次写入头 + _, err = writer.Write(flvHead) + tagHead[0] = flv.FLV_TAG_TYPE_SCRIPT + l := amf.Len() + tagHead[1] = byte(l >> 16) + tagHead[2] = byte(l >> 8) + tagHead[3] = byte(l) + flv.PutFlvTimestamp(tagHead, 0) + writer.Write(tagHead) + writer.Write(amf.Buffer) + l += 11 + binary.BigEndian.PutUint32(tagHead[:4], uint32(l)) + writer.Write(tagHead[:4]) + } + } else { + // 后面的头跳过 + _, err = reader.Discard(13) + if !init { + offsetTime = 0 + offsetTimestamp = 0 + } + } + for err == nil { + _, err = io.ReadFull(reader, tagHead) + if err != nil { + break + } + tmp := tagHead + t := tmp.ReadByte() + dataLen := tmp.ReadUint24() + lastTimestamp = tmp.ReadUint24() | uint32(tmp.ReadByte())<<24 + //fmt.Println(lastTimestamp, tagHead) + if init { + if t == flv.FLV_TAG_TYPE_SCRIPT { + _, err = reader.Discard(int(dataLen) + 4) + } else { + lastTimestamp += offsetTimestamp + if lastTimestamp >= uint32(timeRange.Milliseconds()) { + break + } + if pass == 0 { + data := make([]byte, dataLen+4) + _, err = io.ReadFull(reader, data) + frameType := (data[0] >> 4) & 0b0111 + idr := frameType == 1 || frameType == 4 + if idr { + filepositions = append(filepositions, contentLength) + times = append(times, float64(lastTimestamp)/1000) + } + contentLength += uint64(11 + dataLen + 4) + } else { + //fmt.Println("write", lastTimestamp) + flv.PutFlvTimestamp(tagHead, lastTimestamp) + _, err = writer.Write(tagHead) + _, err = io.CopyN(writer, reader, int64(dataLen+4)) + } + } + continue + } + + switch t { + case flv.FLV_TAG_TYPE_SCRIPT: + if pass == 0 { + data := make([]byte, dataLen+4) + _, err = io.ReadFull(reader, data) + amf = &rtmp.AMF{ + Buffer: util.Buffer(data[1+2+len("onMetaData") : len(data)-4]), + } + var obj any + obj, err = amf.Unmarshal() + metaData = obj.(map[string]any) + } else { + _, err = reader.Discard(int(dataLen) + 4) + } + case flv.FLV_TAG_TYPE_AUDIO: + if !seqAudioWritten { + if pass == 0 { + contentLength += uint64(11 + dataLen + 4) + _, err = reader.Discard(int(dataLen) + 4) + } else { + flv.PutFlvTimestamp(tagHead, 0) + _, err = writer.Write(tagHead) + _, err = io.CopyN(writer, reader, int64(dataLen+4)) + } + seqAudioWritten = true + } else { + _, err = reader.Discard(int(dataLen) + 4) + } + case flv.FLV_TAG_TYPE_VIDEO: + if !seqVideoWritten { + if pass == 0 { + contentLength += uint64(11 + dataLen + 4) + _, err = reader.Discard(int(dataLen) + 4) + } else { + flv.PutFlvTimestamp(tagHead, 0) + _, err = writer.Write(tagHead) + _, err = io.CopyN(writer, reader, int64(dataLen+4)) + } + seqVideoWritten = true + } else { + if lastTimestamp >= uint32(offsetTime.Milliseconds()) { + data := make([]byte, dataLen+4) + _, err = io.ReadFull(reader, data) + frameType := (data[0] >> 4) & 0b0111 + idr := frameType == 1 || frameType == 4 + if idr { + init = true + plugin.Debug("init", "lastTimestamp", lastTimestamp) + if pass == 0 { + filepositions = append(filepositions, contentLength) + times = append(times, float64(lastTimestamp)/1000) + contentLength += uint64(11 + dataLen + 4) + } else { + flv.PutFlvTimestamp(tagHead, 0) + _, err = writer.Write(tagHead) + _, err = writer.Write(data) + } + } + } else { + _, err = reader.Discard(int(dataLen) + 4) + } + } + } + } + offsetTimestamp = lastTimestamp + err = file.Close() + } + } + plugin.Info("end download") + } else { + http.NotFound(w, r) + return + } +} diff --git a/plugin/flv/index.go b/plugin/flv/index.go index 8a60c2b..c0fcf2c 100644 --- a/plugin/flv/index.go +++ b/plugin/flv/index.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "net" "net/http" + "path/filepath" + "strconv" "strings" "time" @@ -15,21 +17,22 @@ import ( type FLVPlugin struct { m7s.Plugin + Path string } const defaultConfig m7s.DefaultYaml = `publish: speed: 1` -func (p *FLVPlugin) OnInit() error { - for streamPath, url := range p.GetCommonConf().PullOnStart { - p.Pull(streamPath, url) +func (plugin *FLVPlugin) OnInit() error { + for streamPath, url := range plugin.GetCommonConf().PullOnStart { + plugin.Pull(streamPath, url) } return nil } var _ = m7s.InstallPlugin[FLVPlugin](defaultConfig, PullFLV, RecordFlv) -func (p *FLVPlugin) WriteFlvHeader(sub *m7s.Subscriber) (flv net.Buffers) { +func (plugin *FLVPlugin) WriteFlvHeader(sub *m7s.Subscriber) (flv net.Buffers) { at, vt := &sub.Publisher.AudioTrack, &sub.Publisher.VideoTrack hasAudio, hasVideo := at.AVTrack != nil && sub.SubAudio, vt.AVTrack != nil && sub.SubVideo var amf rtmp.AMF @@ -70,13 +73,33 @@ func (p *FLVPlugin) WriteFlvHeader(sub *m7s.Subscriber) (flv net.Buffers) { return } -func (p *FLVPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (plugin *FLVPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { streamPath := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/"), ".flv") if r.URL.RawQuery != "" { streamPath += "?" + r.URL.RawQuery } - - sub, err := p.Subscribe(r.Context(), streamPath) + query := r.URL.Query() + startTimeStr := query.Get("start") + speedStr := query.Get("speed") + speed, err := strconv.ParseFloat(speedStr, 64) + if err != nil { + speed = 1 + } + s, err := strconv.Atoi(startTimeStr) + if err == nil { + startTime := time.UnixMilli(int64(s)) + var vod Vod + vod.Context = r.Context() + vod.Logger = plugin.Logger.With("streamPath", streamPath) + if err = vod.Init(startTime, filepath.Join(plugin.Path, streamPath)); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + vod.SetSpeed(speed) + err = vod.Run() + return + } + sub, err := plugin.Subscribe(r.Context(), streamPath) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -84,7 +107,7 @@ func (p *FLVPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "video/x-flv") w.Header().Set("Transfer-Encoding", "identity") w.WriteHeader(http.StatusOK) - wto := p.GetCommonConf().WriteTimeout + wto := plugin.GetCommonConf().WriteTimeout var gotFlvTag func(net.Buffers) error var b [15]byte @@ -103,7 +126,7 @@ func (p *FLVPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.(http.Flusher).Flush() } - flv := p.WriteFlvHeader(sub) + flv := plugin.WriteFlvHeader(sub) copy(b[:4], flv[3]) gotFlvTag(flv[:3]) rtmpData2FlvTag := func(t byte, data *rtmp.RTMPData) error { diff --git a/plugin/flv/pkg/flv.go b/plugin/flv/pkg/flv.go index c2e2a70..b0c20bc 100644 --- a/plugin/flv/pkg/flv.go +++ b/plugin/flv/pkg/flv.go @@ -52,10 +52,17 @@ func (w *FlvWriter) WriteTag(t byte, ts, dataSize uint32, payload ...[]byte) (er // return append(append(append(flv, b), avcc...), util.PutBE(b.Malloc(4), dataSize+11)) //} +func PutFlvTimestamp(header []byte, timestamp uint32) { + header[4] = byte(timestamp >> 16) + header[5] = byte(timestamp >> 8) + header[6] = byte(timestamp) + header[7] = byte(timestamp >> 24) +} + func WriteFLVTagHead(t uint8, ts, dataSize uint32, b []byte) { b[0] = t b[1], b[2], b[3] = byte(dataSize>>16), byte(dataSize>>8), byte(dataSize) - b[4], b[5], b[6], b[7] = byte(ts>>16), byte(ts>>8), byte(ts), byte(ts>>24) + PutFlvTimestamp(b, ts) } //func WriteFLVTag(w io.Writer, t byte, timestamp uint32, payload ...[]byte) (n int64, err error) { diff --git a/plugin/flv/pkg/record.go b/plugin/flv/pkg/record.go index d787f89..0b9c3a4 100644 --- a/plugin/flv/pkg/record.go +++ b/plugin/flv/pkg/record.go @@ -13,10 +13,14 @@ import ( "time" ) -var writeMetaTagQueueTask util.MarcoLongTask +type WriteFlvMetaTagQueueTask struct { + util.MarcoLongTask +} + +var writeMetaTagQueueTask WriteFlvMetaTagQueueTask func init() { - util.RootTask.AddTask(&writeMetaTagQueueTask) + m7s.AddRootTask(&writeMetaTagQueueTask) } type writeMetaTagTask struct { diff --git a/plugin/flv/pkg/vod.go b/plugin/flv/pkg/vod.go index 7c6558a..65dfa0a 100644 --- a/plugin/flv/pkg/vod.go +++ b/plugin/flv/pkg/vod.go @@ -1,11 +1,170 @@ package flv -import "m7s.live/m7s/v5/pkg/util" +import ( + "bufio" + "io" + "io/fs" + "m7s.live/m7s/v5/pkg/util" + "os" + "path/filepath" + "strings" + "time" +) type Vod struct { util.Task + io.WriteCloser + Dir string + lastTimestamp uint32 + speed float64 + singleFile bool + offsetTime time.Duration + offsetTimestamp uint32 + fileList []fs.FileInfo } -func (v *Vod) Start() error { - return nil +func (v *Vod) SetSpeed(speed float64) { + v.speed = speed +} + +func (v *Vod) speedControl() { + targetTime := time.Duration(float64(time.Since(v.StartTime)) * v.speed) + sleepTime := time.Duration(v.lastTimestamp)*time.Millisecond - targetTime + //fmt.Println("sleepTime", sleepTime, time.Since(start).Milliseconds(), lastTimestamp) + if sleepTime > 0 { + time.Sleep(sleepTime) + } +} + +func (v *Vod) Init(startTime time.Time, dir string) (err error) { + v.Dir = dir + singleFile := filepath.Join(dir, ".flv") + if util.Exist(singleFile) { + v.singleFile = true + } else if util.Exist(dir) { + var found bool + err = filepath.Walk(dir, func(path string, info fs.FileInfo, err error) error { + if info.IsDir() || !strings.HasSuffix(info.Name(), ".flv") { + return nil + } + modTime := info.ModTime() + //tmp, _ := strconv.Atoi(strings.TrimSuffix(info.Name(), ".flv")) + //fileStartTime := time.Unix(tmp, 10) + if !found { + if modTime.After(startTime) { + found = true + //fmt.Println(path, modTime, startTime, found) + } else { + v.fileList = []fs.FileInfo{info} + v.offsetTime = startTime.Sub(modTime) + //fmt.Println(path, modTime, startTime, found) + return nil + } + } + v.fileList = append(v.fileList, info) + return nil + }) + if !found { + return os.ErrNotExist + } + } + return +} + +func (v *Vod) Run() (err error) { + flvHead := make([]byte, 9+4) + tagHead := make(util.Buffer, 11) + var file *os.File + var init, seqAudioWritten, seqVideoWritten bool + if v.offsetTime == 0 { + init = true + } else { + v.offsetTimestamp = -uint32(v.offsetTime.Milliseconds()) + } + for i, info := range v.fileList { + if v.IsStopped() { + return + } + filePath := filepath.Join(v.Dir, info.Name()) + v.Debug("read", "file", filePath) + file, err = os.Open(filePath) + if err != nil { + return + } + reader := bufio.NewReader(file) + if i == 0 { + // 第一次写入头 + _, err = io.ReadFull(reader, flvHead) + _, err = v.Write(flvHead) + } else { + // 后面的头跳过 + _, err = reader.Discard(13) + if !init { + v.offsetTime = 0 + v.offsetTimestamp = 0 + } + } + for err == nil { + _, err = io.ReadFull(reader, tagHead) + if err != nil { + break + } + tmp := tagHead + t := tmp.ReadByte() + dataLen := tmp.ReadUint24() + v.lastTimestamp = tmp.ReadUint24() | uint32(tmp.ReadByte())<<24 + //fmt.Println(lastTimestamp, tagHead) + if init { + if t == FLV_TAG_TYPE_SCRIPT { + _, err = reader.Discard(int(dataLen) + 4) + } else { + v.lastTimestamp += v.offsetTimestamp + PutFlvTimestamp(tagHead, v.lastTimestamp) + _, err = v.Write(tagHead) + _, err = io.CopyN(v, reader, int64(dataLen+4)) + v.speedControl() + } + continue + } + switch t { + case FLV_TAG_TYPE_SCRIPT: + _, err = reader.Discard(int(dataLen) + 4) + case FLV_TAG_TYPE_AUDIO: + if !seqAudioWritten { + PutFlvTimestamp(tagHead, 0) + _, err = v.Write(tagHead) + _, err = io.CopyN(v, reader, int64(dataLen+4)) + seqAudioWritten = true + } else { + _, err = reader.Discard(int(dataLen) + 4) + } + case FLV_TAG_TYPE_VIDEO: + if !seqVideoWritten { + PutFlvTimestamp(tagHead, 0) + _, err = v.Write(tagHead) + _, err = io.CopyN(v, reader, int64(dataLen+4)) + seqVideoWritten = true + } else { + if v.lastTimestamp >= uint32(v.offsetTime.Milliseconds()) { + data := make([]byte, dataLen+4) + _, err = io.ReadFull(reader, data) + frameType := (data[0] >> 4) & 0b0111 + idr := frameType == 1 || frameType == 4 + if idr { + init = true + //fmt.Println("init", lastTimestamp) + PutFlvTimestamp(tagHead, 0) + _, err = v.Write(tagHead) + _, err = v.Write(data) + } + } else { + _, err = reader.Discard(int(dataLen) + 4) + } + } + } + } + v.offsetTimestamp = v.lastTimestamp + err = file.Close() + } + return } diff --git a/server.go b/server.go index 7ff1525..3988f70 100644 --- a/server.go +++ b/server.go @@ -91,7 +91,21 @@ func NewServer(conf any) (s *Server) { } func Run(ctx context.Context, conf any) error { - return util.RootTask.AddTaskWithContext(ctx, NewServer(conf)).WaitStopped() + for { + if err := util.RootTask.AddTaskWithContext(ctx, NewServer(conf)).WaitStopped(); err != ErrRestart { + return err + } + } +} + +func AddRootTask[T util.ITask](task T) T { + util.RootTask.AddTask(task) + return task +} + +func AddRootTaskWithContext[T util.ITask](ctx context.Context, task T) T { + util.RootTask.AddTaskWithContext(ctx, task) + return task } type rawconfig = map[string]map[string]any @@ -259,15 +273,11 @@ func (s *Server) Dispose() { _ = s.tcplis.Close() _ = s.grpcClientConn.Close() s.config.HTTP.StopListen() - if err := s.StopReason(); err == ErrRestart { - var server Server - server.ID = s.ID - server.Meta = s.Meta - server.DB = s.DB - *s = server - util.RootTask.AddTask(s) - } else { - s.Info("server stopped", "err", err) + if s.DB != nil { + db, err := s.DB.DB() + if err == nil { + err = db.Close() + } } } diff --git a/test/server_test.go b/test/server_test.go index 098485b..0515527 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -1,7 +1,6 @@ package test import ( - "context" "m7s.live/m7s/v5" "m7s.live/m7s/v5/pkg" "testing" @@ -9,8 +8,8 @@ import ( ) func TestRestart(b *testing.T) { - ctx := context.TODO() - var server = m7s.NewServer() + conf := map[string]map[string]any{"global": {"loglevel": "debug"}} + var server *m7s.Server go func() { time.Sleep(time.Second * 2) server.Stop(pkg.ErrRestart) @@ -22,7 +21,13 @@ func TestRestart(b *testing.T) { server.Stop(pkg.ErrStopFromAPI) b.Log("server stop3") }() - if err := server.Run(ctx, map[string]map[string]any{"global": {"loglevel": "debug"}}); err != pkg.ErrStopFromAPI { - b.Error("server.Run should return ErrStopFromAPI", err) + for { + server = m7s.NewServer(conf) + if err := m7s.AddRootTask(server).WaitStopped(); err != pkg.ErrRestart { + return + } } + //if err := util.RootTask.AddTask(server).WaitStopped(); err != pkg.ErrStopFromAPI { + // b.Error("server.Run should return ErrStopFromAPI", err) + //} }