diff --git a/internal/mjpeg/init.go b/internal/mjpeg/init.go index f9a6d451..e9d6dbee 100644 --- a/internal/mjpeg/init.go +++ b/internal/mjpeg/init.go @@ -80,8 +80,6 @@ func handlerKeyframe(w http.ResponseWriter, r *http.Request) { } } -const header = "--frame\r\nContent-Type: image/jpeg\r\nContent-Length: " - func handlerStream(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { outputMjpeg(w, r) @@ -98,26 +96,10 @@ func outputMjpeg(w http.ResponseWriter, r *http.Request) { return } - flusher := w.(http.Flusher) - cons := &mjpeg.Consumer{ RemoteAddr: tcp.RemoteAddr(r), UserAgent: r.UserAgent(), } - cons.Listen(func(msg any) { - switch msg := msg.(type) { - case []byte: - data := []byte(header + strconv.Itoa(len(msg))) - data = append(data, '\r', '\n', '\r', '\n') - data = append(data, msg...) - data = append(data, '\r', '\n') - - // Chrome bug: mjpeg image always shows the second to last image - // https://bugs.chromium.org/p/chromium/issues/detail?id=527446 - _, _ = w.Write(data) - flusher.Flush() - } - }) if err := stream.AddConsumer(cons); err != nil { log.Error().Err(err).Msg("[api.mjpeg] add consumer") @@ -130,11 +112,33 @@ func outputMjpeg(w http.ResponseWriter, r *http.Request) { h.Set("Connection", "close") h.Set("Pragma", "no-cache") - <-r.Context().Done() + wr := &writer{wr: w, buf: []byte(header)} + _, _ = cons.WriteTo(wr) stream.RemoveConsumer(cons) +} - //log.Trace().Msg("[api.mjpeg] close") +const header = "--frame\r\nContent-Type: image/jpeg\r\nContent-Length: " + +type writer struct { + wr io.Writer + buf []byte +} + +func (w *writer) Write(p []byte) (n int, err error) { + w.buf = w.buf[:len(header)] + w.buf = append(w.buf, strconv.Itoa(len(p))...) + w.buf = append(w.buf, "\r\n\r\n"...) + w.buf = append(w.buf, p...) + w.buf = append(w.buf, "\r\n"...) + + // Chrome bug: mjpeg image always shows the second to last image + // https://bugs.chromium.org/p/chromium/issues/detail?id=527446 + if n, err = w.wr.Write(w.buf); err == nil { + w.wr.(http.Flusher).Flush() + } + + return } func inputMjpeg(w http.ResponseWriter, r *http.Request) { @@ -168,11 +172,6 @@ func handlerWS(tr *ws.Transport, _ *ws.Message) error { RemoteAddr: tcp.RemoteAddr(tr.Request), UserAgent: tr.Request.UserAgent(), } - cons.Listen(func(msg any) { - if data, ok := msg.([]byte); ok { - tr.Write(data) - } - }) if err := stream.AddConsumer(cons); err != nil { log.Error().Err(err).Caller().Send() @@ -181,9 +180,21 @@ func handlerWS(tr *ws.Transport, _ *ws.Message) error { tr.Write(&ws.Message{Type: "mjpeg"}) + wr := &writer2{tr: tr} // TODO: fixme + go cons.WriteTo(wr) + tr.OnClose(func() { stream.RemoveConsumer(cons) }) return nil } + +type writer2 struct { + tr *ws.Transport +} + +func (w *writer2) Write(p []byte) (n int, err error) { + w.tr.Write(p) + return len(p), nil +} diff --git a/pkg/core/writebuffer.go b/pkg/core/writebuffer.go new file mode 100644 index 00000000..dce7affb --- /dev/null +++ b/pkg/core/writebuffer.go @@ -0,0 +1,83 @@ +package core + +import ( + "bytes" + "io" + "sync" +) + +type WriteBuffer struct { + io.Writer + err error + mu sync.Mutex + wg sync.WaitGroup + state byte +} + +func NewWriteBuffer(wr io.Writer) *WriteBuffer { + if wr == nil { + wr = bytes.NewBuffer(nil) + } + return &WriteBuffer{Writer: wr} +} + +func (w *WriteBuffer) Write(p []byte) (n int, err error) { + w.mu.Lock() + if w.err != nil { + err = w.err + } else if n, err = w.Writer.Write(p); err != nil { + w.err = err + w.done() + } + w.mu.Unlock() + return +} + +func (w *WriteBuffer) WriteTo(wr io.Writer) (n int64, err error) { + w.Reset(wr) + w.wg.Wait() + return 0, w.err // TODO: fix counter +} + +func (w *WriteBuffer) Close() error { + if closer, ok := w.Writer.(io.Closer); ok { + return closer.Close() + } + w.mu.Lock() + w.done() + w.mu.Unlock() + return nil +} + +func (w *WriteBuffer) Reset(wr io.Writer) { + w.mu.Lock() + w.add() + if buf, ok := wr.(*bytes.Buffer); ok { + if _, err := io.Copy(wr, buf); err != nil { + w.err = err + w.done() + } + } + w.Writer = wr + w.mu.Unlock() +} + +const ( + none = iota + start + end +) + +func (w *WriteBuffer) add() { + if w.state == none { + w.state = start + w.wg.Add(1) + } +} + +func (w *WriteBuffer) done() { + if w.state == start { + w.state = end + w.wg.Done() + } +} diff --git a/pkg/mjpeg/consumer.go b/pkg/mjpeg/consumer.go index 88244337..3fa36040 100644 --- a/pkg/mjpeg/consumer.go +++ b/pkg/mjpeg/consumer.go @@ -2,19 +2,21 @@ package mjpeg import ( "encoding/json" + "io" + "github.com/AlexxIT/go2rtc/pkg/core" "github.com/pion/rtp" ) type Consumer struct { - core.Listener - UserAgent string RemoteAddr string medias []*core.Media sender *core.Sender + wr *core.WriteBuffer + send int } @@ -34,11 +36,16 @@ func (c *Consumer) GetMedias() []*core.Media { } func (c *Consumer) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiver) error { + if c.wr == nil { + c.wr = core.NewWriteBuffer(nil) + } + if c.sender == nil { c.sender = core.NewSender(media, track.Codec) c.sender.Handler = func(packet *rtp.Packet) { - c.Fire(packet.Payload) - c.send += len(packet.Payload) + if n, err := c.wr.Write(packet.Payload); err == nil { + c.send += n + } } if track.Codec.IsRTP() { @@ -50,10 +57,17 @@ func (c *Consumer) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiv return nil } +func (c *Consumer) WriteTo(wr io.Writer) (int64, error) { + return c.wr.WriteTo(wr) +} + func (c *Consumer) Stop() error { if c.sender != nil { c.sender.Close() } + if c.wr != nil { + _ = c.wr.Close() + } return nil }