diff --git a/api.go b/api.go new file mode 100644 index 0000000..6af643f --- /dev/null +++ b/api.go @@ -0,0 +1,63 @@ +package main + +import ( + "github.com/gorilla/mux" + "github.com/yangjiechina/avformat/utils" + "github.com/yangjiechina/live-server/hls" + "github.com/yangjiechina/live-server/stream" + "net/http" + "time" +) + +func startApiServer(addr string) { + r := mux.NewRouter() + r.HandleFunc("/live/hls/{id}", onHLS) + http.Handle("/", r) + + srv := &http.Server{ + Handler: r, + Addr: addr, + // Good practice: enforce timeouts for servers you create! + WriteTimeout: 15 * time.Second, + ReadTimeout: 15 * time.Second, + } + + err := srv.ListenAndServe() + + if err != nil { + panic(err) + } +} + +func onHLS(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + sourceId := vars["id"] + + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) + return + } + + conn, _, err := hj.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/vnd.apple.mpegurl") + sinkId := stream.GenerateSinkId(conn) + + /* requestTS := strings.HasSuffix(r.URL.Path, ".ts") + if requestTS { + stream.sink + }*/ + + sink := hls.NewSink(sinkId, sourceId, w) + sink.(*stream.SinkImpl).Play(sink, func() { + + }, func(state utils.HookState) { + w.WriteHeader(http.StatusForbidden) + }) + +} diff --git a/config.json b/config.json index 2d911fb..d359a62 100644 --- a/config.json +++ b/config.json @@ -3,11 +3,22 @@ "probe_timeout": 2000, "mw_latency": 350, + "http": { + "addr": "0.0.0.0:8080" + }, + "rtmp": { "enable": true, "addr": "0.0.0.0:1935" }, + "hls": { + "enable": false, + "segment_duration": 2, + "playlist_length": 10, + "path": "../tmp" + }, + "rtsp": { "enable": true, "addr": "0.0.0.0:554", @@ -30,6 +41,7 @@ }, "hook": { + "timeout": 10, "on_publish": "http://localhost:8080/api/v1/live/publish/auth", "on_publish_done": "http://localhost:8080/api/v1/live/publishdone", "on_play" : "http://localhost:8080/api/v1/live/play/auth", diff --git a/go.mod b/go.mod index c6ff28c..98bbc7c 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/yangjiechina/live-server require github.com/yangjiechina/avformat v0.0.0 -require golang.org/x/sys v0.15.0 // indirect +require github.com/gorilla/mux v1.8.1 replace github.com/yangjiechina/avformat => ../avformat diff --git a/hls/hls_sink.go b/hls/hls_sink.go new file mode 100644 index 0000000..1a8afd1 --- /dev/null +++ b/hls/hls_sink.go @@ -0,0 +1,25 @@ +package hls + +import ( + "github.com/yangjiechina/live-server/stream" + "net/http" +) + +type sink struct { + stream.SinkImpl + conn http.ResponseWriter +} + +func NewSink(id stream.SinkId, sourceId string, w http.ResponseWriter) stream.ISink { + return &sink{stream.SinkImpl{Id_: id, SourceId_: sourceId}, w} +} + +func (s *sink) Input(data []byte) error { + if s.conn != nil { + _, err := s.conn.Write(data) + + return err + } + + return nil +} diff --git a/hls/hls_stream.go b/hls/hls_stream.go new file mode 100644 index 0000000..bd6dd80 --- /dev/null +++ b/hls/hls_stream.go @@ -0,0 +1,37 @@ +package hls + +import ( + "github.com/yangjiechina/avformat/libmpeg" + "github.com/yangjiechina/avformat/utils" + "github.com/yangjiechina/live-server/stream" +) + +type Stream struct { + stream.TransStreamImpl + muxer libmpeg.TSMuxer +} + +func NewTransStream(segmentDuration, playlistLength int) stream.ITransStream { + return &Stream{muxer: libmpeg.NewTSMuxer()} +} + +func (t *Stream) Input(packet utils.AVPacket) { + if utils.AVMediaTypeVideo == packet.MediaType() { + if packet.KeyFrame() { + t.Tracks[packet.Index()].AnnexBExtraData() + t.muxer.Input() + } + } + +} + +func (t *Stream) AddTrack(stream utils.AVStream) { + t.TransStreamImpl.AddTrack(stream) + + t.muxer.AddTrack(stream.Type(), stream.CodecId()) +} + +func (t *Stream) WriteHeader() error { + t.muxer.WriteHeader() + return nil +} diff --git a/hls/m3u8.go b/hls/m3u8.go new file mode 100644 index 0000000..e4777d0 --- /dev/null +++ b/hls/m3u8.go @@ -0,0 +1,117 @@ +package hls + +import ( + "bytes" + "github.com/yangjiechina/live-server/stream" + "math" + "strconv" +) + +const ( + ExtM3u = "EXTM3U" + ExtXVersion = "EXT-X-VERSION" //在文件中唯一 + + ExtINF = "EXTINF" //(浮点类型, 版本小于3用整型),[] + ExXByteRange = "EXT-X-BYTERANGE" //版本4及以上,分片位置 + ExtXDiscontinuity = "EXT-X-DISCONTINUITY" //后面的切片不连续, 文件格式、时间戳发生变化 + ExtXKey = "EXT-X-KEY" //加密使用 + ExtXMap = "EXT-X-MAP" //音视频的元数据 + ExtXProgramDateTime = "EXT-X-PROGRAM-DATE-TIME" + ExtXDateRange = "EXT-X-DATERANGE" + + ExtXTargetDuration = "EXT-X-TARGETDURATION" //切片最大时长, 整型单位秒 + ExtXMediaSequence = "EXT-X-MEDIA-SEQUENCE" //第一个切片序号 + ExtXDiscontinuitySequence = "EXT-X-DISCONTINUITY-SEQUENCE" + ExtXEndList = "EXT-X-ENDLIST" + ExtXPlaylistType = "EXT-X-PLAYLIST-TYPE" + ExtXIFramesOnly = "EXT-X-I-FRAMES-ONLY" + + ExtXMedia = "EXT-X-MEDIA" + ExtXStreamINF = "EXT-X-STREAM-INF" + ExtXIFrameStreamINF = "EXT-X-I-FRAME-STREAM-INF" + ExtXSessionData = "EXT-X-SESSION-DATA" + ExtXSessionKey = "EXT-X-SESSION-KEY" + + ExtXIndependentSegments = "EXT-X-INDEPENDENT-SEGMENTS" + ExtXStart = "EXT-X-START" +) + +//HttpContent-Type头必须是"application/vnd.apple.mpegurl"或"audio/mpegurl" +//无BOM + +type M3U8Writer interface { + AddSegment(duration float32, url string, sequence int) + + ToString() string +} + +func NewM3U8Writer(len int) M3U8Writer { + return &m3u8Writer{ + stringBuffer: bytes.NewBuffer(make([]byte, 1024*10)), + playlist: stream.NewQueue(len), + } +} + +type Segment struct { + duration float32 + url string + sequence int +} + +type m3u8Writer struct { + stringBuffer *bytes.Buffer + targetDuration int + playlist *stream.Queue +} + +func (m *m3u8Writer) AddSegment(duration float32 /*title string,*/, url string, sequence int) { + //影响播放器缓存. + round := int(math.Ceil(float64(duration))) + if round > m.targetDuration { + m.targetDuration = round + } + + if m.playlist.IsFull() { + m.playlist.Pop() + } + + m.playlist.Push(Segment{duration: duration, url: url, sequence: sequence}) +} + +func (m *m3u8Writer) ToString() string { + //暂时只实现简单的播放列表 + head, tail := m.playlist.Data() + if head == nil { + return "" + } + + m.stringBuffer.WriteString("#EXTM3U\r\n") + //暂时只实现第三个版本 + m.stringBuffer.WriteString("#EXT-X-VERSION:3\r\n") + m.stringBuffer.WriteString("#EXT-X-TARGETDURATION:") + m.stringBuffer.WriteString(strconv.Itoa(m.targetDuration)) + m.stringBuffer.WriteString("\r\n") + m.stringBuffer.WriteString("#ExtXMediaSequence:") + m.stringBuffer.WriteString(strconv.Itoa(head[0].(Segment).sequence)) + m.stringBuffer.WriteString("\r\n") + + appendSegments := func(playlist []interface{}) { + for _, segment := range playlist { + m.stringBuffer.WriteString("#EXTINF:") + m.stringBuffer.WriteString(strconv.FormatFloat(float64(segment.(Segment).duration), 'f', -1, 32)) + m.stringBuffer.WriteString(",\r\n") + m.stringBuffer.WriteString(segment.(Segment).url) + m.stringBuffer.WriteString("\r\n") + } + } + + if head != nil { + appendSegments(head) + } + + if tail != nil { + appendSegments(tail) + } + + return m.stringBuffer.String() +} diff --git a/main.go b/main.go index c908ee3..1da1010 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,10 @@ func CreateTransStream(protocol stream.Protocol, streams []utils.AVStream) strea return nil } +func requestStream(sourceId string) { + +} + func init() { stream.TransStreamFactory = CreateTransStream } @@ -41,6 +45,9 @@ func main() { println("启动rtmp服务成功:" + addr) + apiAddr := "0.0.0.0:8080" + go startApiServer(apiAddr) + loadConfigError := http.ListenAndServe(":19999", nil) if loadConfigError != nil { panic(loadConfigError) diff --git a/rtmp/rtmp_publisher.go b/rtmp/rtmp_publisher.go index 7485fc7..bc2c7e9 100644 --- a/rtmp/rtmp_publisher.go +++ b/rtmp/rtmp_publisher.go @@ -5,9 +5,28 @@ import ( "github.com/yangjiechina/avformat/librtmp" "github.com/yangjiechina/avformat/utils" "github.com/yangjiechina/live-server/stream" + "net" ) -type Publisher struct { +type Publisher interface { + + // Init 初始化内存池 + Init() + + // OnDiscardPacket GOP缓存溢出的包 + OnDiscardPacket(pkt interface{}) + + // OnVideo 从rtmp chunk中解析出来的整个视频包, 还需要进步封装成AVPacket + OnVideo(data []byte, ts uint32) + + // OnAudio 从rtmp chunk中解析出来的整个音频包 + OnAudio(data []byte, ts uint32) + + // OnPartPacket 从rtmp chunk中解析出来的一部分音视频包 + OnPartPacket(index int, data []byte, first bool) +} + +type publisher struct { stream.SourceImpl stack *librtmp.Stack @@ -18,17 +37,17 @@ type Publisher struct { videoMark bool } -func NewPublisher(sourceId string, stack *librtmp.Stack) *Publisher { +func NewPublisher(sourceId string, stack *librtmp.Stack, conn net.Conn) Publisher { deMuxer := libflv.NewDeMuxer() - publisher := &Publisher{SourceImpl: stream.SourceImpl{Id_: sourceId, Type_: stream.SourceTypeRtmp, TransDeMuxer: deMuxer}, stack: stack, audioMark: false, videoMark: false} + publisher_ := &publisher{SourceImpl: stream.SourceImpl{Id_: sourceId, Type_: stream.SourceTypeRtmp, TransDeMuxer: deMuxer, Conn: conn}, stack: stack, audioMark: false, videoMark: false} //设置回调,从flv解析出来的Stream和AVPacket都将统一回调到stream.SourceImpl - deMuxer.SetHandler(publisher) - publisher.Input_ = publisher.Input + deMuxer.SetHandler(publisher_) + publisher_.Input_ = publisher_.Input - return publisher + return publisher_ } -func (p *Publisher) Init() { +func (p *publisher) Init() { //创建内存池 p.audioMemoryPool = stream.NewMemoryPool(48000 * 1) if stream.AppConfig.GOPCache { @@ -42,11 +61,11 @@ func (p *Publisher) Init() { go p.SourceImpl.LoopEvent() } -func (p *Publisher) Input(data []byte) { +func (p *publisher) Input(data []byte) { p.stack.Input(nil, data) } -func (p *Publisher) OnDiscardPacket(pkt interface{}) { +func (p *publisher) OnDiscardPacket(pkt interface{}) { packet := pkt.(utils.AVPacket) if utils.AVMediaTypeAudio == packet.MediaType() { p.audioMemoryPool.FreeHead() @@ -55,7 +74,7 @@ func (p *Publisher) OnDiscardPacket(pkt interface{}) { } } -func (p *Publisher) OnDeMuxStream(stream_ utils.AVStream) { +func (p *publisher) OnDeMuxStream(stream_ utils.AVStream) { //AVStream的Data单独拷贝出来 //释放掉内存池中最新分配的内存 tmp := stream_.Extra() @@ -74,7 +93,7 @@ func (p *Publisher) OnDeMuxStream(stream_ utils.AVStream) { } } -func (p *Publisher) OnDeMuxPacket(packet utils.AVPacket) { +func (p *publisher) OnDeMuxPacket(packet utils.AVPacket) { p.SourceImpl.OnDeMuxPacket(packet) if stream.AppConfig.GOPCache { @@ -88,8 +107,7 @@ func (p *Publisher) OnDeMuxPacket(packet utils.AVPacket) { } } -// OnVideo 从rtm chunk解析过来的视频包 -func (p *Publisher) OnVideo(data []byte, ts uint32) { +func (p *publisher) OnVideo(data []byte, ts uint32) { if data == nil { data = p.videoMemoryPool.Fetch() p.videoMark = false @@ -98,7 +116,7 @@ func (p *Publisher) OnVideo(data []byte, ts uint32) { p.SourceImpl.TransDeMuxer.(*libflv.DeMuxer).InputVideo(data, ts) } -func (p *Publisher) OnAudio(data []byte, ts uint32) { +func (p *publisher) OnAudio(data []byte, ts uint32) { if data == nil { data = p.audioMemoryPool.Fetch() p.audioMark = false @@ -107,8 +125,7 @@ func (p *Publisher) OnAudio(data []byte, ts uint32) { _ = p.SourceImpl.TransDeMuxer.(*libflv.DeMuxer).InputAudio(data, ts) } -// OnPartPacket 从rtmp解析过来的部分音视频包 -func (p *Publisher) OnPartPacket(index int, data []byte, first bool) { +func (p *publisher) OnPartPacket(index int, data []byte, first bool) { //audio if index == 0 { if !p.audioMark { diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index c37741b..b7e4ca1 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -16,8 +16,6 @@ type Session interface { func NewSession(conn net.Conn) Session { impl := &sessionImpl{} - impl.Protocol = stream.ProtocolRtmpStr - impl.RemoteAddr = conn.RemoteAddr().String() stack := librtmp.NewStack(impl) impl.stack = stack @@ -26,27 +24,26 @@ func NewSession(conn net.Conn) Session { } type sessionImpl struct { - stream.SessionImpl //解析rtmp协议栈 stack *librtmp.Stack //publisher/sink handle interface{} - isPublish bool - conn net.Conn + isPublisher bool + conn net.Conn } func (s *sessionImpl) OnPublish(app, stream_ string, response chan utils.HookState) { - s.SessionImpl.Stream = app + "/" + stream_ - publisher := NewPublisher(s.SessionImpl.Stream, s.stack) - s.stack.SetOnPublishHandler(publisher) - s.stack.SetOnTransDeMuxerHandler(publisher) + sourceId := app + "_" + stream_ + source := NewPublisher(sourceId, s.stack, s.conn) + s.stack.SetOnPublishHandler(source) + s.stack.SetOnTransDeMuxerHandler(source) - //stream.SessionImpl统一处理, Source是否已经存在, Hook回调.... - s.SessionImpl.OnPublish(publisher, nil, func() { - s.handle = publisher - s.isPublish = true - publisher.Init() + //推流事件Source统一处理, 是否已经存在, Hook回调.... + source.(*publisher).Publish(source.(*publisher), func() { + s.handle = source + s.isPublisher = true + source.Init() response <- utils.HookStateOK }, func(state utils.HookState) { @@ -55,10 +52,11 @@ func (s *sessionImpl) OnPublish(app, stream_ string, response chan utils.HookSta } func (s *sessionImpl) OnPlay(app, stream_ string, response chan utils.HookState) { - s.SessionImpl.Stream = app + "/" + stream_ + sourceId := app + "_" + stream_ - sink := NewSink(stream.GenerateSinkId(s.conn), s.SessionImpl.Stream, s.conn) - s.SessionImpl.OnPlay(sink, nil, func() { + //拉流事件Sink统一处理 + sink := NewSink(stream.GenerateSinkId(s.conn), sourceId, s.conn) + sink.(*stream.SinkImpl).Play(sink, func() { s.handle = sink response <- utils.HookStateOK }, func(state utils.HookState) { @@ -68,8 +66,8 @@ func (s *sessionImpl) OnPlay(app, stream_ string, response chan utils.HookState) func (s *sessionImpl) Input(conn net.Conn, data []byte) error { //如果是推流,并且握手成功,后续收到的包,都将发送给LoopEvent处理 - if s.isPublish { - s.handle.(*Publisher).AddEvent(stream.SourceEventInput, data) + if s.isPublisher { + s.handle.(*publisher).AddEvent(stream.SourceEventInput, data) return nil } else { return s.stack.Input(conn, data) @@ -83,8 +81,8 @@ func (s *sessionImpl) Close() { _, ok := s.handle.(*Publisher) if ok { - if s.isPublish { - s.handle.(*Publisher).AddEvent(stream.SourceEventClose, nil) + if s.isPublisher { + s.handle.(*publisher).AddEvent(stream.SourceEventClose, nil) } } else { sink := s.handle.(stream.ISink) diff --git a/rtmp/rtmp_transtream.go b/rtmp/rtmp_stream.go similarity index 100% rename from rtmp/rtmp_transtream.go rename to rtmp/rtmp_stream.go index cf688ab..6edb1d1 100644 --- a/rtmp/rtmp_transtream.go +++ b/rtmp/rtmp_stream.go @@ -41,6 +41,11 @@ type TransStream struct { incompleteSinks []stream.ISink } +func NewTransStream(chunkSize int) stream.ITransStream { + transStream := &TransStream{chunkSize: chunkSize, TransStreamImpl: stream.TransStreamImpl{Sinks: make(map[stream.SinkId]stream.ISink, 64)}} + return transStream +} + func (t *TransStream) Input(packet utils.AVPacket) { utils.Assert(t.TransStreamImpl.Completed) @@ -286,8 +291,3 @@ func (t *TransStream) WriteHeader() error { t.headerSize = n return nil } - -func NewTransStream(chunkSize int) stream.ITransStream { - transStream := &TransStream{chunkSize: chunkSize, TransStreamImpl: stream.TransStreamImpl{Sinks: make(map[stream.SinkId]stream.ISink, 64)}} - return transStream -} diff --git a/stream/config.go b/stream/config.go index 0821fad..5619e3b 100644 --- a/stream/config.go +++ b/stream/config.go @@ -10,6 +10,7 @@ type RtmpConfig struct { } type HookConfig struct { + Time int Enable bool `json:"enable"` OnPublish string `json:"on_publish"` //推流回调 OnPublishDone string `json:"on_publish_done"` //推流结束回调 diff --git a/stream/hook.go b/stream/hook.go index e3d568a..140cc46 100644 --- a/stream/hook.go +++ b/stream/hook.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "net/http" + "time" ) type HookFunc func(m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error @@ -24,8 +25,17 @@ type Hook interface { DoRecvTimeout(m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error } -type hookImpl struct { -} +type HookEvent int + +const ( + HookEventPublish = HookEvent(0x1) + HookEventPublishDone = HookEvent(0x2) + HookEventPlay = HookEvent(0x3) + HookEventPlayDone = HookEvent(0x4) + HookEventRecord = HookEvent(0x5) + HookEventIdleTimeout = HookEvent(0x6) + HookEventRecvTimeout = HookEvent(0x6) +) // 每个通知的时间都需要携带的字段 type eventInfo struct { @@ -34,56 +44,64 @@ type eventInfo struct { remoteAddr string //peer地址 } -func (hook *hookImpl) send(url string, m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { - marshal, err := json.Marshal(m) +func NewHookEventInfo(stream, protocol, remoteAddr string) eventInfo { + return eventInfo{stream: stream, protocol: protocol, remoteAddr: remoteAddr} +} + +type HookSession interface { + send(url string, body interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error + + Hook(event HookEvent, body interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error +} + +var hookUrls map[HookEvent]string + +func init() { + hookUrls = map[HookEvent]string{ + HookEventPublish: "", + HookEventPublishDone: "", + HookEventPlay: "", + HookEventPlayDone: "", + HookEventRecord: "", + HookEventIdleTimeout: "", + HookEventRecvTimeout: "", + } +} + +type hookSessionImpl struct { +} + +func (h *hookSessionImpl) send(url string, body interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { + marshal, err := json.Marshal(body) if err != nil { return err } - client := &http.Client{} + client := &http.Client{ + Timeout: time.Second * time.Duration(AppConfig.Hook.Time), + } request, err := http.NewRequest("post", url, bytes.NewBuffer(marshal)) if err != nil { return err } request.Header.Set("Content-Type", "application/json") - go func() { - response, err := client.Do(request) - if err != nil || response.StatusCode != http.StatusOK { - failure(response, err) - return - } - + response, err := client.Do(request) + if err != nil || response.StatusCode != http.StatusOK { + failure(response, err) + } else { success(response) - }() + } return nil } -func (hook *hookImpl) DoPublish(m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { - return hook.send(AppConfig.Hook.OnPublish, m, success, failure) -} +func (h *hookSessionImpl) Hook(event HookEvent, body interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { + url := hookUrls[event] + if url == "" { + success(nil) + return nil + } -func (hook *hookImpl) DoPublishDone(m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { - return hook.send(AppConfig.Hook.OnPublishDone, m, success, failure) -} - -func (hook *hookImpl) DoPlay(m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { - return hook.send(AppConfig.Hook.OnPlay, m, success, failure) -} - -func (hook *hookImpl) DoPlayDone(m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { - return hook.send(AppConfig.Hook.OnPlayDone, m, success, failure) -} - -func (hook *hookImpl) DoRecord(m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { - return hook.send(AppConfig.Hook.OnRecord, m, success, failure) -} - -func (hook *hookImpl) DoIdleTimeout(m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { - return hook.send(AppConfig.Hook.OnIdleTimeout, m, success, failure) -} - -func (hook *hookImpl) DoRecvTimeout(m map[string]interface{}, success func(response *http.Response), failure func(response *http.Response, err error)) error { - return hook.send(AppConfig.Hook.OnRecvTimeout, m, success, failure) + return h.send(url, body, success, failure) } diff --git a/stream/session.go b/stream/session.go index 3aca7c4..0f33965 100644 --- a/stream/session.go +++ b/stream/session.go @@ -1,118 +1,17 @@ package stream import ( - "fmt" "github.com/yangjiechina/avformat/utils" - "net/http" ) -// Session 封装推拉流Session 统一管理,统一 hook回调 -type Session interface { - OnPublish(source ISource, pra map[string]interface{}, success func(), failure func(state utils.HookState)) +type SourceHook interface { + Publish(source ISource, success func(), failure func(state utils.HookState)) - OnPublishDone() - - OnPlay(sink ISink, pra map[string]interface{}, success func(), failure func(state utils.HookState)) - - OnPlayDone(pra map[string]interface{}, success func(), failure func(state utils.HookState)) + PublishDone(source ISource, success func(), failure func(state utils.HookState)) } -type SessionImpl struct { - hookImpl - Stream string //stream id - Protocol string //推拉流协议 - RemoteAddr string //peer地址 -} - -// AddInfoParams 为每个需要通知的时间添加必要的信息 -func (s *SessionImpl) AddInfoParams(data map[string]interface{}) { - data["stream"] = s.Stream - data["protocol"] = s.Protocol - data["remoteAddr"] = s.RemoteAddr -} - -func (s *SessionImpl) OnPublish(source_ ISource, pra map[string]interface{}, success func(), failure func(state utils.HookState)) { - //streamId 已经被占用 - source := SourceManager.Find(s.Stream) - if source != nil { - fmt.Printf("推流已经占用 Source:%s", source_.Id()) - failure(utils.HookStateOccupy) - return - } - - if !AppConfig.Hook.EnableOnPublish() { - if err := SourceManager.Add(source_); err == nil { - success() - } else { - fmt.Printf("添加失败 Source:%s", source_.Id()) - failure(utils.HookStateOccupy) - } - - return - } - - if pra == nil { - pra = make(map[string]interface{}, 5) - } - - s.AddInfoParams(pra) - err := s.DoPublish(pra, func(response *http.Response) { - if err := SourceManager.Add(source_); err == nil { - success() - } else { - failure(utils.HookStateOccupy) - } - }, func(response *http.Response, err error) { - failure(utils.HookStateFailure) - }) - - //hook地址连接失败 - if err != nil { - failure(utils.HookStateFailure) - return - } -} - -func (s *SessionImpl) OnPublishDone() { - -} - -func (s *SessionImpl) OnPlay(sink ISink, pra map[string]interface{}, success func(), failure func(state utils.HookState)) { - f := func() { - source := SourceManager.Find(s.Stream) - if source == nil { - fmt.Printf("添加到等待队列 sink:%s", sink.Id()) - sink.SetState(SessionStateWait) - AddSinkToWaitingQueue(s.Stream, sink) - } else { - source.AddEvent(SourceEventPlay, sink) - } - } - - if !AppConfig.Hook.EnableOnPlay() { - f() - success() - return - } - - if pra == nil { - pra = make(map[string]interface{}, 5) - } - - s.AddInfoParams(pra) - err := s.DoPlay(pra, func(response *http.Response) { - f() - success() - }, func(response *http.Response, err error) { - failure(utils.HookStateFailure) - }) - - if err != nil { - failure(utils.HookStateFailure) - return - } -} - -func (s *SessionImpl) OnPlayDone(pra map[string]interface{}, success func(), failure func(state utils.HookState)) { - +type SinkHook interface { + Play(sink ISink, success func(), failure func(state utils.HookState)) + + PlayDone(source ISink, success func(), failure func(state utils.HookState)) } diff --git a/stream/sink.go b/stream/sink.go index 2b35cbb..fe71271 100644 --- a/stream/sink.go +++ b/stream/sink.go @@ -1,8 +1,10 @@ package stream import ( + "fmt" "github.com/yangjiechina/avformat/utils" "net" + "net/http" "sync/atomic" ) @@ -60,6 +62,8 @@ func GenerateSinkId(conn net.Conn) SinkId { } type SinkImpl struct { + hookSessionImpl + Id_ SinkId SourceId_ string Protocol_ Protocol @@ -165,3 +169,38 @@ func (s *SinkImpl) Close() { s.closed.Store(true) } } + +func (s *SinkImpl) Play(sink ISink, success func(), failure func(state utils.HookState)) { + f := func() { + source := SourceManager.Find(sink.SourceId()) + if source == nil { + fmt.Printf("添加到等待队列 sink:%s", sink.Id()) + sink.SetState(SessionStateWait) + AddSinkToWaitingQueue(sink.SourceId(), sink) + } else { + source.AddEvent(SourceEventPlay, sink) + } + } + + if !AppConfig.Hook.EnableOnPlay() { + f() + success() + return + } + + err := s.Hook(HookEventPlay, NewHookEventInfo(sink.SourceId(), streamTypeToStr(sink.Protocol()), ""), func(response *http.Response) { + f() + success() + }, func(response *http.Response, err error) { + failure(utils.HookStateFailure) + }) + + if err != nil { + failure(utils.HookStateFailure) + return + } +} + +func (s *SinkImpl) PlayDone(source ISink, success func(), failure func(state utils.HookState)) { + +} diff --git a/stream/source.go b/stream/source.go index 82815d4..67b7d34 100644 --- a/stream/source.go +++ b/stream/source.go @@ -2,6 +2,8 @@ package stream import ( "fmt" + "net" + "net/http" "sync" "time" @@ -56,6 +58,34 @@ const ( SessionStateClose = SessionState(7) ) +func sourceTypeToStr(sourceType SourceType) string { + if SourceTypeRtmp == sourceType { + return "rtmp" + } else if SourceType28181 == sourceType { + return "28181" + } else if SourceType1078 == sourceType { + return "1078" + } + + return "" +} + +func streamTypeToStr(protocol Protocol) string { + if ProtocolRtmp == protocol { + return "rtmp" + } else if ProtocolFlv == protocol { + return "flv" + } else if ProtocolRtsp == protocol { + return "rtsp" + } else if ProtocolHls == protocol { + return "hls" + } else if ProtocolRtc == protocol { + return "rtc" + } + + return "" +} + type ISource interface { // Id Source的唯一ID/** Id() string @@ -84,6 +114,8 @@ type ISource interface { // 停止一切封装和转发流以及转码工作 // 将Sink添加到等待队列 Close() + + Type() SourceType } type CreateSource func(id string, type_ SourceType, handler stream.OnDeMuxerHandler) @@ -91,9 +123,12 @@ type CreateSource func(id string, type_ SourceType, handler stream.OnDeMuxerHand var TranscoderFactory func(src utils.AVStream, dst utils.AVStream) transcode.ITranscoder type SourceImpl struct { + hookSessionImpl + Id_ string Type_ SourceType state SessionState + Conn net.Conn TransDeMuxer stream.DeMuxer //负责从推流协议中解析出AVStream和AVPacket recordSink ISink //每个Source唯一的一个录制流 @@ -338,7 +373,7 @@ func (s *SourceImpl) Close() { //释放每路转协议流, 将所有sink添加到等待队列 _, _ = SourceManager.Remove(s.Id_) for _, transStream := range s.transStreams { - transStream.PopAllSinks(func(sink ISink) { + transStream.PopAllSink(func(sink ISink) { sink.SetTransStreamId(0) state := sink.SetState(SessionStateWait) if state { @@ -418,3 +453,47 @@ func (s *SourceImpl) OnDeMuxPacket(packet utils.AVPacket) { func (s *SourceImpl) OnDeMuxDone() { } + +func (s *SourceImpl) Publish(source ISource, success func(), failure func(state utils.HookState)) { + //streamId 已经被占用 + if source_ := SourceManager.Find(source.Id()); source_ != nil { + fmt.Printf("推流已经占用 Source:%s", source.Id()) + failure(utils.HookStateOccupy) + } + + if !AppConfig.Hook.EnableOnPublish() { + if err := SourceManager.Add(source); err == nil { + success() + } else { + fmt.Printf("添加失败 Source:%s", source.Id()) + failure(utils.HookStateOccupy) + } + + return + } + + err := s.Hook(HookEventPublish, NewHookEventInfo(source.Id(), sourceTypeToStr(source.Type()), ""), + func(response *http.Response) { + if err := SourceManager.Add(source); err == nil { + success() + } else { + failure(utils.HookStateOccupy) + } + }, func(response *http.Response, err error) { + failure(utils.HookStateFailure) + }) + + //hook地址连接失败 + if err != nil { + failure(utils.HookStateFailure) + return + } +} + +func (s *SourceImpl) PublishDone(source ISource, success func(), failure func(state utils.HookState)) { + +} + +func (s *SourceImpl) Type() SourceType { + return s.Type_ +} diff --git a/stream/trans_stream.go b/stream/trans_stream.go index 937e3c6..1c28e90 100644 --- a/stream/trans_stream.go +++ b/stream/trans_stream.go @@ -64,6 +64,7 @@ func GenerateTransStreamId(protocol Protocol, ids ...utils.AVStream) TransStream var TransStreamFactory func(protocol Protocol, streams []utils.AVStream) ITransStream +// ITransStream 讲AVPacket封装成传输流,转发给各个Sink type ITransStream interface { Input(packet utils.AVPacket) @@ -75,7 +76,7 @@ type ITransStream interface { RemoveSink(id SinkId) (ISink, bool) - PopAllSinks(handler func(sink ISink)) + PopAllSink(handler func(sink ISink)) AllSink() []ISink } @@ -109,7 +110,7 @@ func (t *TransStreamImpl) RemoveSink(id SinkId) (ISink, bool) { return sink, ok } -func (t *TransStreamImpl) PopAllSinks(handler func(sink ISink)) { +func (t *TransStreamImpl) PopAllSink(handler func(sink ISink)) { for _, sink := range t.Sinks { handler(sink) }