完善sink断开处理

This commit is contained in:
yangjiechina
2024-04-09 16:00:48 +08:00
parent 0376ccf604
commit 8dc824494e
10 changed files with 287 additions and 106 deletions

206
api.go
View File

@@ -1,17 +1,19 @@
package main
import (
"context"
"encoding/json"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/yangjiechina/avformat/utils"
"github.com/yangjiechina/live-server/flv"
"github.com/yangjiechina/live-server/hls"
"github.com/yangjiechina/live-server/log"
"github.com/yangjiechina/live-server/rtc"
"github.com/yangjiechina/live-server/stream"
"io"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
@@ -29,10 +31,51 @@ func init() {
func startApiServer(addr string) {
r := mux.NewRouter()
r.HandleFunc("/live/flv/{source}", onFLV)
r.HandleFunc("/live/hls/{source}", onHLS)
r.HandleFunc("/live/rtc/{source}", onRtc)
r.HandleFunc("/live/flv/ws/{source}", onWSFlv)
/**
http://host:port/xxx.flv
http://host:port/xxx.rtc
http://host:port/xxx.m3u8
http://host:port/xxx_0.ts
ws://host:port/xxx.flv
*/
r.HandleFunc("/live/{source}", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
source := vars["source"]
index := strings.LastIndex(source, ".")
if index < 0 || index == len(source)-1 {
log.Sugar.Errorf("bad request:%s. stream format must be passed at the end of the URL", r.URL.Path)
w.WriteHeader(http.StatusBadRequest)
return
}
sourceId := source[:index]
format := source[index+1:]
if "flv" == format {
//判断是否是websocket请求
ws := true
if !("upgrade" == strings.ToLower(r.Header.Get("Connection"))) {
ws = false
} else if !("websocket" == strings.ToLower(r.Header.Get("Upgrade"))) {
ws = false
} else if !("13" == r.Header.Get("Sec-Websocket-Version")) {
ws = false
}
if ws {
onWSFlv(sourceId, w, r)
} else {
onFLV(sourceId, w, r)
}
} else if "m3u8" == format {
onHLS(sourceId, w, r)
} else if "ts" == format {
onTS(sourceId, w, r)
} else if "rtc" == format {
onRtc(sourceId, w, r)
}
})
r.HandleFunc("/rtc.html", func(writer http.ResponseWriter, request *http.Request) {
http.ServeFile(writer, request, "./rtc.html")
@@ -54,40 +97,40 @@ func startApiServer(addr string) {
}
}
func onWSFlv(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/x-flv")
w.Header().Set("Connection", "Keep-Alive")
w.Header().Set("Transfer-Encoding", "chunked")
func onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
panic(err)
}
vars := mux.Vars(r)
sourceId := vars["source"]
if index := strings.LastIndex(sourceId, "."); index > -1 {
sourceId = sourceId[:index]
log.Sugar.Errorf("websocket头检查失败 err:%s", err.Error())
w.WriteHeader(http.StatusBadRequest)
return
}
tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
sinkId := stream.GenerateSinkId(tcpAddr)
sink := flv.NewFLVSink(sinkId, sourceId, flv.NewWSConn(conn))
go func() {
log.Sugar.Infof("ws-flv 连接 sink:%s", sink.PrintInfo())
sink.(*stream.SinkImpl).Play(sink, func() {
//sink.(*stream.SinkImpl).PlayDone(sink, nil, nil)
}, func(state utils.HookState) {
w.WriteHeader(http.StatusForbidden)
conn.Close()
})
}()
netConn := conn.NetConn()
bytes := make([]byte, 64)
for {
select {}
if _, err := netConn.Read(bytes); err != nil {
log.Sugar.Infof("ws-flv 断开连接 sink:%s", sink.PrintInfo())
sink.Close()
break
}
}
}
func onFLV(w http.ResponseWriter, r *http.Request) {
func onFLV(sourceId string, w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/x-flv")
w.Header().Set("Connection", "Keep-Alive")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -97,58 +140,112 @@ func onFLV(w http.ResponseWriter, r *http.Request) {
http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
return
}
context_ := r.Context()
w.WriteHeader(http.StatusOK)
w.WriteHeader(http.StatusOK)
conn, _, err := hj.Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
vars := mux.Vars(r)
sourceId := vars["source"]
if index := strings.LastIndex(sourceId, "."); index > -1 {
sourceId = sourceId[:index]
}
tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
sinkId := stream.GenerateSinkId(tcpAddr)
sink := flv.NewFLVSink(sinkId, sourceId, conn)
go func(ctx context.Context) {
log.Sugar.Infof("http-flv 连接 sink:%s", sink.PrintInfo())
sink.(*stream.SinkImpl).Play(sink, func() {
//sink.(*stream.SinkImpl).PlayDone(sink, nil, nil)
}, func(state utils.HookState) {
w.WriteHeader(http.StatusForbidden)
conn.Close()
})
}(context_)
bytes := make([]byte, 64)
for {
if _, err := conn.Read(bytes); err != nil {
log.Sugar.Infof("http-flv 断开连接 sink:%s", sink.PrintInfo())
break
}
}
}
func onHLS(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
source := vars["source"]
func onTS(source string, w http.ResponseWriter, r *http.Request) {
if !stream.AppConfig.Hls.Enable {
log.Sugar.Warnf("处理m3u8请求失败 server未开启hls request:%s", r.URL.Path)
http.Error(w, "hls disable", http.StatusInternalServerError)
return
}
index := strings.LastIndex(source, "_")
if index < 0 || index == len(source)-1 {
w.WriteHeader(http.StatusBadRequest)
return
}
seq := source[index+1:]
sourceId := source[:index]
tsPath := stream.AppConfig.Hls.TSPath(sourceId, seq)
if _, err := os.Stat(tsPath); err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
//链路复用无法获取http断开回调
//Hijack需要自行解析http
http.ServeFile(w, r, tsPath)
}
func onHLS(sourceId string, w http.ResponseWriter, r *http.Request) {
if !stream.AppConfig.Hls.Enable {
log.Sugar.Warnf("处理hls请求失败 server未开启hls request:%s", r.URL.Path)
http.Error(w, "hls disable", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/vnd.apple.mpegurl")
//m3u8和ts会一直刷新, 每个请求只hook一次.
tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
sinkId := stream.GenerateSinkId(tcpAddr)
//删除末尾的.ts/.m3u8, 请确保id中不存在.
//var sourceId string
//if index := strings.LastIndex(source, "."); index > -1 {
// sourceId = source[:index]
//}
//
//tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
//sinkId := stream.GenerateSinkId(tcpAddr)
if strings.HasSuffix(source, ".m3u8") {
//查询是否存在hls流, 不存在-等生成后再响应m3u8文件. 存在-直接响应m3u8文件
http.ServeFile(w, r, "../tmp/"+source)
} else if strings.HasSuffix(source, ".ts") {
http.ServeFile(w, r, "../tmp/"+source)
//hook成功后, 如果还没有m3u8文件等生成m3u8文件
//后续直接返回当前m3u8文件
if stream.SinkManager.Exist(sinkId) {
http.ServeFile(w, r, stream.AppConfig.Hls.M3U8Path(sourceId))
} else {
context := r.Context()
done := make(chan int, 0)
sink := hls.NewM3U8Sink(sinkId, sourceId, func(m3u8 []byte) {
w.Write(m3u8)
done <- 0
})
hookState := utils.HookStateOK
sink.Play(sink, func() {
err := stream.SinkManager.Add(sink)
utils.Assert(err == nil)
}, func(state utils.HookState) {
log.Sugar.Warnf("hook播放事件失败 request:%s", r.URL.Path)
hookState = state
w.WriteHeader(http.StatusForbidden)
})
if utils.HookStateOK != hookState {
return
}
select {
case <-done:
case <-context.Done():
log.Sugar.Infof("http m3u8连接断开")
break
}
}
}
func onRtc(w http.ResponseWriter, r *http.Request) {
func onRtc(sourceId string, w http.ResponseWriter, r *http.Request) {
v := struct {
Type string `json:"type"`
SDP string `json:"sdp"`
@@ -163,12 +260,12 @@ func onRtc(w http.ResponseWriter, r *http.Request) {
panic(err)
}
sinkId := stream.SinkId(123)
split := strings.Split(r.URL.Path, "/")
tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
sinkId := stream.GenerateSinkId(tcpAddr)
group := sync.WaitGroup{}
group.Add(1)
sink := rtc.NewSink(sinkId, split[len(split)-1], v.SDP, func(sdp string) {
sink := rtc.NewSink(sinkId, sourceId, v.SDP, func(sdp string) {
response := struct {
Type string `json:"type"`
SDP string `json:"sdp"`
@@ -191,7 +288,10 @@ func onRtc(w http.ResponseWriter, r *http.Request) {
sink.Play(sink, func() {
}, func(state utils.HookState) {
w.WriteHeader(http.StatusForbidden)
group.Done()
})
group.Wait()
}

View File

@@ -2,37 +2,30 @@ package hls
import (
"github.com/yangjiechina/live-server/stream"
"net/http"
)
type sink struct {
type tsSink struct {
stream.SinkImpl
conn http.ResponseWriter
}
func NewSink(id stream.SinkId, sourceId string, w http.ResponseWriter) stream.ISink {
return &sink{stream.SinkImpl{Id_: id, SourceId_: sourceId, Protocol_: stream.ProtocolHls}, w}
}
func (s *sink) Input(data []byte) error {
if s.conn != nil {
_, err := s.conn.Write(data)
return err
func NewTSSink(id stream.SinkId, sourceId string) stream.ISink {
return &tsSink{stream.SinkImpl{Id_: id, SourceId_: sourceId, Protocol_: stream.ProtocolHls}}
}
func (s *tsSink) Input(data []byte) error {
return nil
}
type m3u8Sink struct {
stream.SinkImpl
cb func(m3u8 []byte)
}
func (s *m3u8Sink) Input(data []byte) error {
s.cb(data)
return nil
}
func NewM3U8Sink(id stream.SinkId, sourceId string, w http.ResponseWriter) stream.ISink {
return &m3u8Sink{stream.SinkImpl{Id_: id, SourceId_: sourceId, Protocol_: stream.ProtocolHls}}
func NewM3U8Sink(id stream.SinkId, sourceId string, cb func(m3u8 []byte)) stream.ISink {
return &m3u8Sink{stream.SinkImpl{Id_: id, SourceId_: sourceId, Protocol_: stream.ProtocolHls}, cb}
}

View File

@@ -32,12 +32,14 @@ type transStream struct {
duration int
m3u8File *os.File
playlistLength int
m3u8Sinks map[stream.SinkId]stream.ISink
}
// NewTransStream 创建HLS传输流
// @url url前缀
// @m3u8Name m3u8文件名
// @tsFormat ts文件格式, 例如: test_%d.ts
// @tsFormat ts文件格式, 例如: %d.ts
// @parentDir 保存切片的绝对路径. mu38和ts切片放在同一目录下, 目录地址使用parentDir+urlPrefix
// @segmentDuration 单个切片时长
// @playlistLength 缓存多少个切片
@@ -47,6 +49,7 @@ func NewTransStream(url, m3u8Name, tsFormat, dir string, segmentDuration, playli
return nil, err
}
//创建m3u8文件
m3u8Path := fmt.Sprintf("%s/%s", dir, m3u8Name)
file, err := os.OpenFile(m3u8Path, os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
@@ -62,6 +65,7 @@ func NewTransStream(url, m3u8Name, tsFormat, dir string, segmentDuration, playli
playlistLength: playlistLength,
}
//创建TS封装器
muxer := libmpeg.NewTSMuxer()
muxer.SetWriteHandler(stream_.onTSWrite)
muxer.SetAllocHandler(stream_.onTSAlloc)
@@ -75,6 +79,8 @@ func NewTransStream(url, m3u8Name, tsFormat, dir string, segmentDuration, playli
stream_.muxer = muxer
stream_.m3u8 = NewM3U8Writer(playlistLength)
stream_.m3u8File = file
stream_.m3u8Sinks = make(map[stream.SinkId]stream.ISink, 24)
return stream_, nil
}
@@ -90,10 +96,12 @@ func (t *transStream) Input(packet utils.AVPacket) error {
}
}
pts := packet.ConvertPts(90000)
dts := packet.ConvertDts(90000)
if utils.AVMediaTypeVideo == packet.MediaType() {
return t.muxer.Input(packet.Index(), packet.AnnexBPacketData(), packet.Pts()*90, packet.Dts()*90, packet.KeyFrame())
return t.muxer.Input(packet.Index(), packet.AnnexBPacketData(), pts, dts, packet.KeyFrame())
} else {
return t.muxer.Input(packet.Index(), packet.Data(), packet.Pts()*90, packet.Dts()*90, packet.KeyFrame())
return t.muxer.Input(packet.Index(), packet.Data(), pts, dts, packet.KeyFrame())
}
}
@@ -117,9 +125,24 @@ func (t *transStream) AddTrack(stream utils.AVStream) error {
}
func (t *transStream) WriteHeader() error {
t.Init()
return t.createSegment()
}
func (t *transStream) AddSink(sink stream.ISink) error {
if sink_, ok := sink.(*m3u8Sink); ok {
if t.m3u8.Size() > 0 {
return sink.Input([]byte(t.m3u8.ToString()))
} else {
t.m3u8Sinks[sink.Id()] = sink_
return nil
}
}
return t.TransStreamImpl.AddSink(sink)
}
func (t *transStream) onTSWrite(data []byte) {
t.context.writeBufferSize += len(data)
}
@@ -166,22 +189,33 @@ func (t *transStream) flushSegment() error {
return err
}
//通知等待m3u8的sink
if len(t.m3u8Sinks) > 0 {
for _, sink := range t.m3u8Sinks {
sink.Input([]byte(m3u8Txt))
}
t.m3u8Sinks = make(map[stream.SinkId]stream.ISink, 0)
}
return nil
}
// 创建一个新的ts切片
func (t *transStream) createSegment() error {
//保存上一个ts切片
if t.context.file != nil {
err := t.flushSegment()
t.context.segmentSeq++
if err != nil {
return err
}
}
tsName := fmt.Sprintf(t.tsFormat, t.context.segmentSeq)
t.context.path = fmt.Sprintf("%s%s", t.dir, tsName)
//ts文件
t.context.path = fmt.Sprintf("%s/%s", t.dir, tsName)
//m3u8中的url
t.context.url = fmt.Sprintf("%s%s", t.url, tsName)
file, err := os.OpenFile(t.context.path, os.O_WRONLY|os.O_CREATE, 0666)
if err != nil {
return err

View File

@@ -115,7 +115,7 @@ func (m *m3u8Writer) ToString() string {
m.stringBuffer.WriteString("#EXT-X-TARGETDURATION:")
m.stringBuffer.WriteString(strconv.Itoa(m.targetDuration()))
m.stringBuffer.WriteString("\r\n")
m.stringBuffer.WriteString("#ExtXMediaSequence:")
m.stringBuffer.WriteString("#Ext-X-MEDIA-SEQUENCE:")
m.stringBuffer.WriteString(strconv.Itoa(head[0].(Segment).sequence))
m.stringBuffer.WriteString("\r\n")

View File

@@ -25,10 +25,8 @@ func CreateTransStream(source stream.ISource, protocol stream.Protocol, streams
return rtmp.NewTransStream(librtmp.ChunkSize)
} else if stream.ProtocolHls == protocol {
id := source.Id()
m3u8Name := id + ".m3u8"
tsFormat := id + "_%d.ts"
transStream, err := hls.NewTransStream("", m3u8Name, tsFormat, "../tmp/", 2, 10)
transStream, err := hls.NewTransStream("", stream.AppConfig.Hls.M3U8Format(id), stream.AppConfig.Hls.TSFormat(id, "%d"), stream.AppConfig.Hls.Dir, stream.AppConfig.Hls.Duration, stream.AppConfig.Hls.PlaylistLength)
if err != nil {
panic(err)
}
@@ -60,6 +58,11 @@ func main() {
stream.AppConfig.GOPCache = true
stream.AppConfig.MergeWriteLatency = 350
stream.AppConfig.Hls.Enable = true
stream.AppConfig.Hls.Dir = "../tmp"
stream.AppConfig.Hls.Duration = 2
stream.AppConfig.Hls.PlaylistLength = 10
rtmpAddr, err := net.ResolveTCPAddr("tcp", "0.0.0.0:1935")
if err != nil {
panic(err)

View File

@@ -63,11 +63,11 @@ func (t *transStream) AddSink(sink_ stream.ISink) error {
}
if _, err := connection.AddTransceiverFromTrack(videoTrack, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendonly}); err != nil {
panic(err)
return err
}
if _, err = connection.AddTrack(videoTrack); err != nil {
panic(err)
return err
}
rtcSink.addTrack(index, videoTrack)
@@ -80,14 +80,17 @@ func (t *transStream) AddSink(sink_ stream.ISink) error {
complete := webrtc.GatheringCompletePromise(connection)
answer, err := connection.CreateAnswer(nil)
if err != nil {
panic(err)
return err
} else if err = connection.SetLocalDescription(answer); err != nil {
panic(err)
return err
}
<-complete
connection.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
rtcSink.state = state
if webrtc.ICEConnectionStateDisconnected > state {
rtcSink.Close()
}
})
rtcSink.peer = connection

View File

@@ -9,6 +9,36 @@ type RtmpConfig struct {
Addr string `json:"addr"`
}
type RecordConfig struct {
Enable bool `json:"enable"`
Format string `json:"format"`
}
type HlsConfig struct {
Enable bool
Dir string
Duration int
PlaylistLength int
}
// M3U8Path 根据sourceId返回m3u8的磁盘路径
func (c HlsConfig) M3U8Path(sourceId string) string {
return c.Dir + "/" + c.M3U8Format(sourceId)
}
func (c HlsConfig) M3U8Format(sourceId string) string {
return sourceId + ".m3u8"
}
// TSPath 根据sourceId和ts文件名返回ts的磁盘路径
func (c HlsConfig) TSPath(sourceId string, tsSeq string) string {
return c.Dir + "/" + c.TSFormat(sourceId, tsSeq)
}
func (c HlsConfig) TSFormat(sourceId string, tsSeq string) string {
return sourceId + "_" + tsSeq + ".ts"
}
type HookConfig struct {
Time int
Enable bool `json:"enable"`
@@ -65,4 +95,7 @@ type AppConfig_ struct {
MergeWriteLatency int `json:"mw_latency"`
Rtmp RtmpConfig
Hook HookConfig
Record RecordConfig
Hls HlsConfig
}

View File

@@ -1,6 +1,7 @@
package stream
import (
"fmt"
"github.com/yangjiechina/avformat/utils"
"github.com/yangjiechina/live-server/log"
"net"
@@ -43,7 +44,10 @@ type ISink interface {
// DesiredVideoCodecId DescribeVideoCodecId 允许客户端拉取指定的视频流
DesiredVideoCodecId() utils.AVCodecID
// Close 关闭释放Sink, 从传输流或等待队列中删除sink
Close()
PrintInfo() string
}
// GenerateSinkId 根据网络地址生成SinkId IPV4使用一个uint64, IPV6使用String
@@ -184,6 +188,10 @@ func (s *SinkImpl) Close() {
}
}
func (s *SinkImpl) PrintInfo() string {
return fmt.Sprintf("%s-%v source:%s", s.ProtocolStr(), s.Id_, s.SourceId_)
}
func (s *SinkImpl) Play(sink ISink, success func(), failure func(state utils.HookState)) {
f := func() {
source := SourceManager.Find(sink.SourceId())

View File

@@ -91,7 +91,7 @@ func ExistSink(sourceId string, sinkId SinkId) bool {
// ISinkManager 添加到TransStream的所有Sink
type ISinkManager interface {
Add(source ISink) error
Add(sink ISink) error
Find(id SinkId) ISink
@@ -110,10 +110,10 @@ type sinkManagerImpl struct {
m sync.Map
}
func (s *sinkManagerImpl) Add(source ISink) error {
_, ok := s.m.LoadOrStore(source.Id(), source)
func (s *sinkManagerImpl) Add(sink ISink) error {
_, ok := s.m.LoadOrStore(sink.Id(), sink)
if ok {
return fmt.Errorf("the source %s has been exist", source.Id())
return fmt.Errorf("the sink %s has been exist", sink.Id())
}
return nil

View File

@@ -132,13 +132,14 @@ type SourceImpl struct {
TransDeMuxer stream.DeMuxer //负责从推流协议中解析出AVStream和AVPacket
recordSink ISink //每个Source唯一的一个录制流
hlsStream ITransStream //hls不等拉流创建时直接生成
audioTranscoders []transcode.ITranscoder //音频解码器
videoTranscoders []transcode.ITranscoder //视频解码器
originStreams StreamManager //推流的音视频Streams
allStreams StreamManager //推流Streams+转码器获得的Streams
buffers []StreamBuffer
Input_ func(data []byte) //解决无法多态传递给子类的问题
Input_ func(data []byte) //解决多态无法传递给子类的问题
completed bool
mutex sync.Mutex //只用作AddStream期间
@@ -154,8 +155,6 @@ type SourceImpl struct {
closeEvent chan byte
playingEventQueue chan ISink
playingDoneEventQueue chan ISink
testTransStream ITransStream
}
func (s *SourceImpl) Id() string {
@@ -175,9 +174,17 @@ func (s *SourceImpl) Init() {
if s.transStreams == nil {
s.transStreams = make(map[TransStreamId]ITransStream, 10)
}
//测试传输流
s.testTransStream = TransStreamFactory(s, ProtocolHls, nil)
s.transStreams[0x100] = s.testTransStream
//创建录制流
if AppConfig.Record.Enable {
}
//创建HLS输出流
if AppConfig.Hls.Enable {
s.hlsStream = TransStreamFactory(s, ProtocolHls, nil)
s.transStreams[0x100] = s.hlsStream
}
}
func (s *SourceImpl) LoopEvent() {
@@ -220,7 +227,7 @@ func IsSupportMux(protocol Protocol, audioCodecId, videoCodecId utils.AVCodecID)
return true
}
// 分发每路StreamBuffer给传输流
// 分发每路StreamBuffer给传输流
// 按照时间戳升序发送
func (s *SourceImpl) dispatchStreamBuffer(transStream ITransStream, streams []utils.AVStream) {
size := len(streams)
@@ -447,12 +454,12 @@ func (s *SourceImpl) writeHeader() {
s.AddSink(sink)
}
if s.testTransStream != nil {
if s.hlsStream != nil {
for _, stream_ := range s.originStreams.All() {
s.testTransStream.AddTrack(stream_)
s.hlsStream.AddTrack(stream_)
}
s.testTransStream.WriteHeader()
s.hlsStream.WriteHeader()
}
}