封装rtsp udp拉流

This commit is contained in:
yangjiechina
2024-03-17 14:18:21 +08:00
parent 14b297827e
commit d325daeb8b
12 changed files with 921 additions and 6 deletions

View File

@@ -22,7 +22,9 @@
"rtsp": {
"enable": true,
"addr": "0.0.0.0:554",
"password": "123456"
"port": [20000,30000],
"password": "123456",
"transport": "UDP|TCP"
},
"webrtc": {

31
main.go
View File

@@ -3,6 +3,7 @@ package main
import (
"github.com/yangjiechina/live-server/flv"
"github.com/yangjiechina/live-server/hls"
"github.com/yangjiechina/live-server/rtsp"
"net"
"net/http"
@@ -14,6 +15,8 @@ import (
"github.com/yangjiechina/live-server/stream"
)
var rtspAddr *net.TCPAddr
func CreateTransStream(source stream.ISource, protocol stream.Protocol, streams []utils.AVStream) stream.ITransStream {
if stream.ProtocolRtmp == protocol {
return rtmp.NewTransStream(librtmp.ChunkSize)
@@ -30,6 +33,13 @@ func CreateTransStream(source stream.ISource, protocol stream.Protocol, streams
return transStream
} else if stream.ProtocolFlv == protocol {
return flv.NewHttpTransStream()
} else if stream.ProtocolRtsp == protocol {
trackFormat := source.Id() + "?track=%d"
return rtsp.NewTransStream(net.IPAddr{
IP: rtspAddr.IP,
Zone: rtspAddr.Zone,
}, trackFormat)
}
return nil
@@ -42,19 +52,32 @@ func init() {
func main() {
stream.AppConfig.GOPCache = true
stream.AppConfig.MergeWriteLatency = 350
rtmpAddr, err := net.ResolveTCPAddr("tcp", "0.0.0.0:1935")
if err != nil {
panic(err)
}
impl := rtmp.NewServer()
addr := "0.0.0.0:1935"
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
err = impl.Start(rtmpAddr)
if err != nil {
panic(err)
}
err = impl.Start(tcpAddr)
println("启动rtmp服务成功:" + rtmpAddr.String())
rtspAddr, err = net.ResolveTCPAddr("tcp", "0.0.0.0:554")
if err != nil {
panic(rtspAddr)
}
rtspServer := rtsp.NewServer()
err = rtspServer.Start(rtspAddr)
if err != nil {
panic(err)
}
println("启动rtmp服务成功:" + addr)
println("启动rtsp服务成功:" + rtspAddr.String())
apiAddr := "0.0.0.0:8080"
go startApiServer(apiAddr)

View File

@@ -26,7 +26,7 @@ func (s *serverImpl) Start(addr net.Addr) error {
server := &transport.TCPServer{}
server.SetHandler(s)
err := server.Bind(addr.String())
err := server.Bind(addr)
if err != nil {
return err

31
rtsp/rtp_track.go Normal file
View File

@@ -0,0 +1,31 @@
package rtsp
import (
"github.com/yangjiechina/avformat/librtp"
"github.com/yangjiechina/avformat/utils"
)
type rtpTrack struct {
pt byte
rate int
mediaType utils.AVMediaType
//目前用于缓存带有SPS和PPS的RTP包
buffer []byte
muxer librtp.Muxer
cache bool
header [][]byte
tmp [][]byte
}
func NewRTPTrack(muxer librtp.Muxer, pt byte, rate int) *rtpTrack {
stream := &rtpTrack{
pt: pt,
rate: rate,
muxer: muxer,
buffer: make([]byte, 1500),
}
return stream
}

87
rtsp/rtsp_server.go Normal file
View File

@@ -0,0 +1,87 @@
package rtsp
import (
"fmt"
"github.com/yangjiechina/avformat/transport"
"github.com/yangjiechina/avformat/utils"
"net"
"net/textproto"
)
type IServer interface {
Start(addr net.Addr) error
Close()
}
func NewServer() IServer {
return &serverImpl{
publicHeader: "OPTIONS, DESCRIBE, SETUP, PLAY, TEARDOWN, PAUSE, GET_PARAMETER, SET_PARAMETER, REDIRECT, RECORD",
}
}
type serverImpl struct {
tcp *transport.TCPServer
handlers map[string]func(source string, headers textproto.MIMEHeader)
publicHeader string
}
func (s *serverImpl) Start(addr net.Addr) error {
utils.Assert(s.tcp == nil)
server := &transport.TCPServer{}
server.SetHandler(s)
err := server.Bind(addr)
if err != nil {
return err
}
s.tcp = server
for key, _ := range s.handlers {
s.publicHeader += key + ", "
}
s.publicHeader = s.publicHeader[:len(s.publicHeader)-2]
return nil
}
func (s *serverImpl) closeSession(conn net.Conn) {
t := conn.(*transport.Conn)
if t.Data != nil {
t.Data.(*session).close()
t.Data = nil
}
}
func (s *serverImpl) Close() {
}
func (s *serverImpl) OnConnected(conn net.Conn) {
t := conn.(*transport.Conn)
t.Data = NewSession(conn)
}
func (s *serverImpl) OnPacket(conn net.Conn, data []byte) {
t := conn.(*transport.Conn)
message, url, header, err := parseMessage(data)
if err != nil {
println(fmt.Sprintf("failed to prase message:%s. err:%s peer:%s", string(data), err.Error(), conn.RemoteAddr().String()))
_ = conn.Close()
s.closeSession(conn)
return
}
err = t.Data.(*session).Input(message, url, header)
if err != nil {
println(fmt.Sprintf("failed to process message of RTSP. err:%s peer:%s msg:%s", err.Error(), conn.RemoteAddr().String(), string(data)))
_ = conn.Close()
}
}
func (s *serverImpl) OnDisConnected(conn net.Conn, err error) {
s.closeSession(conn)
}

278
rtsp/rtsp_session.go Normal file
View File

@@ -0,0 +1,278 @@
package rtsp
import (
"bufio"
"bytes"
"fmt"
"github.com/yangjiechina/avformat/utils"
"github.com/yangjiechina/live-server/stream"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
"strings"
"time"
)
const (
MethodOptions = "OPTIONS"
MethodDescribe = "DESCRIBE"
MethodSetup = "SETUP"
MethodPlay = "PLAY"
MethodTeardown = "TEARDOWN"
MethodPause = "PAUSE"
MethodGetParameter = "GET_PARAMETER"
MethodSetParameter = "SET_PARAMETER"
MethodRedirect = "REDIRECT"
MethodRecord = "RECORD"
Version = "RTSP/1.0"
)
type requestHandler interface {
onOptions(sourceId string, headers textproto.MIMEHeader)
onDescribe(sourceId string, headers textproto.MIMEHeader)
onSetup(sourceId string, index int, headers textproto.MIMEHeader)
onPlay(sourceId string)
onTeardown()
onPause()
}
type session struct {
conn net.Conn
sink_ *sink
sessionId string
writeBuffer *bytes.Buffer
}
func NewSession(conn net.Conn) *session {
milli := int(time.Now().UnixMilli() & 0xFFFFFFFF)
return &session{
conn: conn,
sessionId: strconv.Itoa(milli),
writeBuffer: bytes.NewBuffer(make([]byte, 0, 1024*10)),
}
}
func NewOKResponse(cseq string) http.Response {
rep := http.Response{
Proto: Version,
StatusCode: http.StatusOK,
Status: http.StatusText(http.StatusOK),
Header: make(http.Header),
}
if cseq == "" {
cseq = "1"
}
rep.Header.Set("Cseq", cseq)
return rep
}
func parseMessage(data []byte) (string, *url.URL, textproto.MIMEHeader, error) {
reader := bufio.NewReader(bytes.NewReader(data))
tp := textproto.NewReader(reader)
line, err := tp.ReadLine()
split := strings.Split(line, " ")
if len(split) < 3 {
panic(fmt.Errorf("unknow response line of response:%s", line))
}
method := strings.ToUpper(split[0])
//version
_ = split[2]
url_, err := url.Parse(split[1])
if err != nil {
return "", nil, nil, err
}
header, err := tp.ReadMIMEHeader()
if err != nil {
return "", nil, nil, err
}
return method, url_, header, nil
}
func (s *session) response(response http.Response, body []byte) error {
//添加Content-Length
if body != nil {
response.Header.Set("Content-Length", strconv.Itoa(len(body)))
}
// 将响应头和正文封装成字符串
s.writeBuffer.Reset()
_, err := fmt.Fprintf(s.writeBuffer, "%s %d %s\r\n", response.Proto, response.StatusCode, response.Status)
if err != nil {
return err
}
for k, v := range response.Header {
for _, hv := range v {
s.writeBuffer.WriteString(fmt.Sprintf("%s: %s\r\n", k, hv))
}
}
//分隔头部与主体
s.writeBuffer.WriteString("\r\n")
if body != nil {
s.writeBuffer.Write(body)
if body[len(body)-2] != 0x0D || body[len(body)-1] != 0x0A {
s.writeBuffer.WriteString("\r\n")
}
}
data := s.writeBuffer.Bytes()
_, err = s.conn.Write(data)
return err
}
func (s *session) onOptions(sourceId string, headers textproto.MIMEHeader) error {
rep := NewOKResponse(headers.Get("Cseq"))
rep.Header.Set("Public", "OPTIONS, DESCRIBE, SETUP, PLAY, TEARDOWN, PAUSE, GET_PARAMETER, SET_PARAMETER, REDIRECT, RECORD")
return s.response(rep, nil)
}
func (s *session) onDescribe(source string, headers textproto.MIMEHeader) error {
var err error
sinkId := stream.GenerateSinkId(s.conn.RemoteAddr())
sink_ := NewSink(sinkId, source, s.conn, func(sdp string) {
response := NewOKResponse(headers.Get("Cseq"))
response.Header.Set("Content-Type", "application/sdp")
err = s.response(response, []byte(sdp))
})
code := utils.HookStateOK
s.sink_ = sink_.(*sink)
sink_.(*sink).Play(sink_, func() {
}, func(state utils.HookState) {
code = state
})
if utils.HookStateOK != code {
return fmt.Errorf("hook failed. code:%d", code)
}
return err
}
func (s *session) onSetup(sourceId string, index int, headers textproto.MIMEHeader) error {
transportHeader := headers.Get("Transport")
if transportHeader == "" {
return fmt.Errorf("not find transport header")
}
split := strings.Split(transportHeader, ";")
if len(split) < 3 {
return fmt.Errorf("failed to parsing TRANSPORT header:%s", split)
}
var clientRtpPort int
var clientRtcpPort int
tcp := "RTP/AVP" != split[0] && "RTP/AVP/UDP" != split[0]
for _, value := range split {
if !strings.HasPrefix(value, "client_port=") {
continue
}
pairPort := strings.Split(value[len("client_port="):], "-")
if len(pairPort) != 2 {
return fmt.Errorf("failed to parsing client_port:%s", value)
}
port, err := strconv.Atoi(pairPort[0])
if err != nil {
return err
}
clientRtpPort = port
port, err = strconv.Atoi(pairPort[1])
if err != nil {
return err
}
clientRtcpPort = port
}
rtpPort, rtcpPort, err := s.sink_.addTrack(index, tcp)
if err != nil {
return err
}
println(clientRtpPort)
println(clientRtcpPort)
responseHeader := transportHeader + ";server_port=" + fmt.Sprintf("%d-%d", rtpPort, rtcpPort) + ";ssrc=FFFFFFFF"
response := NewOKResponse(headers.Get("Cseq"))
response.Header.Set("Transport", responseHeader)
response.Header.Set("Session", s.sessionId)
return s.response(response, nil)
}
func (s *session) onPlay(sourceId string, headers textproto.MIMEHeader) error {
response := NewOKResponse(headers.Get("Cseq"))
sessionHeader := headers.Get("Session")
if sessionHeader != "" {
response.Header.Set("Session", sessionHeader)
}
return s.response(response, nil)
}
func (s *session) onTeardown() {
}
func (s *session) onPause() {
}
func (s *session) Input(method string, url_ *url.URL, headers textproto.MIMEHeader) error {
//_ = url_.User.Username()
//_, _ = url_.User.Password()
var err error
split := strings.Split(url_.Path, "/")
source := split[len(split)-1]
if MethodOptions == method {
err = s.onOptions(source, headers)
} else if MethodDescribe == method {
err = s.onDescribe(source, headers)
} else if MethodSetup == method {
query, err := url.ParseQuery(url_.RawQuery)
if err != nil {
return err
}
track := query.Get("track")
index, err := strconv.Atoi(track)
if err != nil {
return err
}
if err = s.onSetup(source, index, headers); err != nil {
return err
}
} else if MethodPlay == method {
err = s.onPlay(source, headers)
} else if MethodTeardown == method {
s.onTeardown()
} else if MethodPause == method {
s.onPause()
}
return err
}
func (s *session) close() {
}

133
rtsp/rtsp_sink.go Normal file
View File

@@ -0,0 +1,133 @@
package rtsp
import (
"fmt"
"github.com/yangjiechina/avformat/transport"
"github.com/yangjiechina/avformat/utils"
"github.com/yangjiechina/live-server/stream"
"net"
)
// 对于UDP而言, 每个sink维护一对UDPTransport
// TCP直接单端口传输
type sink struct {
stream.SinkImpl
//一个rtsp源可能存在多个流, 每个流都需要拉取拉取
tracks []*rtspTrack
sdpCB func(sdp string)
}
func NewSink(id stream.SinkId, sourceId string, conn net.Conn, cb func(sdp string)) stream.ISink {
return &sink{
stream.SinkImpl{Id_: id, SourceId_: sourceId, Protocol_: stream.ProtocolRtsp, Conn: conn},
nil,
cb,
}
}
func (s *sink) setTrackCount(count int) {
s.tracks = make([]*rtspTrack, count)
}
func (s *sink) addTrack(index int, tcp bool) (int, int, error) {
utils.Assert(index < cap(s.tracks))
utils.Assert(s.tracks[index] == nil)
var err error
var rtpPort int
var rtcpPort int
track := rtspTrack{}
if tcp {
err = rtspTransportManger.AllocTransport(true, func(port int) {
var addr *net.TCPAddr
addr, err = net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", "0.0.0.0", port))
if err == nil {
track.rtp = &transport.TCPServer{}
track.rtp.SetHandler2(track.onTCPConnected, nil, track.onTCPDisconnected)
err = track.rtp.Bind(addr)
}
rtpPort = port
})
} else {
err = rtspTransportManger.AllocPairTransport(func(port int) {
//rtp port
var addr *net.UDPAddr
addr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", "0.0.0.0", port))
if err == nil {
track.rtp = &transport.UDPTransport{}
track.rtp.SetHandler2(nil, track.onRTPPacket, nil)
err = track.rtp.Bind(addr)
}
rtpPort = port
}, func(port int) {
//rtcp port
var addr *net.UDPAddr
addr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", "0.0.0.0", port))
if err == nil {
track.rtcp = &transport.UDPTransport{}
track.rtcp.SetHandler2(nil, track.onRTCPPacket, nil)
err = track.rtcp.Bind(addr)
} else {
track.rtp.Close()
track.rtp = nil
}
rtcpPort = port
})
}
if err != nil {
return 0, 0, err
}
s.tracks[index] = &track
return rtpPort, rtcpPort, err
}
func (s *sink) input(index int, data []byte) error {
utils.Assert(index < cap(s.tracks))
//拉流方还没有连上来
s.tracks[index].pktCount++
s.tracks[index].rtpConn.Write(data)
return nil
}
func (s *sink) isConnected(index int) bool {
return s.tracks[index] != nil && s.tracks[index].rtpConn != nil
}
func (s *sink) pktCount(index int) int {
return s.tracks[index].pktCount
}
// SendHeader 回调rtsp流的sdp信息
func (s *sink) SendHeader(data []byte) error {
s.sdpCB(string(data))
return nil
}
func (s *sink) TrackConnected(index int) bool {
utils.Assert(index < cap(s.tracks))
utils.Assert(s.tracks[index].rtp != nil)
return s.tracks[index].rtcpConn != nil
}
func (s *sink) Close() {
for _, track := range s.tracks {
if track.rtp != nil {
track.rtp.Close()
}
if track.rtcp != nil {
track.rtcp.Close()
}
}
}

236
rtsp/rtsp_stream.go Normal file
View File

@@ -0,0 +1,236 @@
package rtsp
import (
"encoding/binary"
"fmt"
"github.com/yangjiechina/avformat/librtp"
"github.com/yangjiechina/avformat/librtsp/sdp"
"github.com/yangjiechina/avformat/utils"
"github.com/yangjiechina/live-server/stream"
"net"
)
// 低延迟是rtsp特性, 不考虑实现GOP缓存
type tranStream struct {
stream.TransStreamImpl
addr net.IPAddr
addrType string
urlFormat string
rtpTracks []*rtpTrack
sdp string
}
func NewTransStream(addr net.IPAddr, urlFormat string) stream.ITransStream {
t := &tranStream{
addr: addr,
urlFormat: urlFormat,
}
if addr.IP.To4() != nil {
t.addrType = "IP4"
} else {
t.addrType = "IP6"
}
t.Init()
return t
}
func (t *tranStream) onAllocBuffer(params interface{}) []byte {
return t.rtpTracks[params.(int)].buffer
}
func (t *tranStream) onRtpPacket(data []byte, timestamp uint32, params interface{}) {
index := params.(int)
if t.rtpTracks[index].cache && t.rtpTracks[index].header == nil {
bytes := make([]byte, len(data))
copy(bytes, data)
t.rtpTracks[index].tmp = append(t.rtpTracks[index].tmp, bytes)
return
}
for _, iSink := range t.Sinks {
if !iSink.(*sink).isConnected(index) {
continue
}
if iSink.(*sink).pktCount(index) < 1 && utils.AVMediaTypeVideo == t.rtpTracks[index].mediaType {
seq := binary.BigEndian.Uint16(data[2:])
count := len(t.rtpTracks[index].header)
for i, rtp := range t.rtpTracks[index].header {
librtp.RollbackSeq(rtp, int(seq)-(count-i-1))
iSink.(*sink).input(index, rtp)
}
}
iSink.(*sink).input(index, data)
}
}
func (t *tranStream) Input(packet utils.AVPacket) error {
stream_ := t.rtpTracks[packet.Index()]
if utils.AVMediaTypeAudio == packet.MediaType() {
stream_.muxer.Input(packet.Data(), uint32(packet.ConvertPts(stream_.rate)))
} else if utils.AVMediaTypeVideo == packet.MediaType() {
//将sps和pps按照单一模式打包
if stream_.header == nil {
if !packet.KeyFrame() {
return nil
}
extra, err := t.TransStreamImpl.Tracks[packet.Index()].AnnexBExtraData()
if err != nil {
return err
}
var count int
stream_.cache = true
utils.SplitNalU(extra, func(nalu []byte) {
data := utils.RemoveStartCode(nalu)
stream_.muxer.Input(data, uint32(packet.ConvertPts(stream_.rate)))
count++
})
stream_.header = stream_.tmp
}
data := utils.RemoveStartCode(packet.AnnexBPacketData())
stream_.muxer.Input(data, uint32(packet.ConvertPts(stream_.rate)))
}
return nil
}
func (t *tranStream) AddSink(sink_ stream.ISink) error {
sink_.(*sink).setTrackCount(len(t.TransStreamImpl.Tracks))
if err := sink_.SendHeader([]byte(t.sdp)); err != nil {
return err
}
return t.TransStreamImpl.AddSink(sink_)
}
func (t *tranStream) AddTrack(stream utils.AVStream) error {
if err := t.TransStreamImpl.AddTrack(stream); err != nil {
return err
}
payloadType, ok := librtp.CodecIdPayloads[stream.CodecId()]
if !ok {
return fmt.Errorf("no payload type was found for codecid:%d", stream.CodecId())
}
//创建RTP封装
var muxer librtp.Muxer
if utils.AVCodecIdH264 == stream.CodecId() {
muxer = librtp.NewH264Muxer(payloadType.Pt, 0, 0xFFFFFFFF)
} else if utils.AVCodecIdAAC == stream.CodecId() {
muxer = librtp.NewAACMuxer(payloadType.Pt, 0, 0xFFFFFFFF)
}
muxer.SetAllocHandler(t.onAllocBuffer)
muxer.SetWriteHandler(t.onRtpPacket)
t.rtpTracks = append(t.rtpTracks, NewRTPTrack(muxer, byte(payloadType.Pt), payloadType.ClockRate))
muxer.SetParams(len(t.rtpTracks) - 1)
return nil
}
func (t *tranStream) WriteHeader() error {
description := sdp.SessionDescription{
Version: 0,
Origin: sdp.Origin{
Username: "-",
SessionID: 0,
SessionVersion: 0,
NetworkType: "IN",
AddressType: t.addrType,
UnicastAddress: t.addr.IP.String(),
},
SessionName: "Stream",
TimeDescriptions: []sdp.TimeDescription{{
Timing: sdp.Timing{
StartTime: 0,
StopTime: 0,
},
RepeatTimes: nil,
},
},
MediaDescriptions: []*sdp.MediaDescription{
{
MediaName: sdp.MediaName{
Media: "video",
Protos: []string{"RTP", "AVP"},
Formats: []string{"108"},
},
ConnectionInformation: &sdp.ConnectionInformation{
NetworkType: "IN",
AddressType: t.addrType,
Address: &sdp.Address{Address: t.addr.IP.String()},
},
Attributes: []sdp.Attribute{
sdp.NewAttribute("recvonly", ""),
sdp.NewAttribute("control:"+fmt.Sprintf(t.urlFormat, 0), ""),
sdp.NewAttribute("rtpmap:108 H264/90000", ""),
},
},
{
MediaName: sdp.MediaName{
Media: "audio",
Protos: []string{"RTP", "AVP"},
Formats: []string{"97"},
},
ConnectionInformation: &sdp.ConnectionInformation{
NetworkType: "IN",
AddressType: t.addrType,
Address: &sdp.Address{Address: t.addr.IP.String()},
},
Attributes: []sdp.Attribute{
sdp.NewAttribute("recvonly", ""),
sdp.NewAttribute("control:"+fmt.Sprintf(t.urlFormat, 1), ""),
//用MP4A-LATM更准确一点
sdp.NewAttribute("rtpmap:97 mpeg4-generic/48000", ""),
//[14496-3], [RFC6416] profile-level-id:
//1 : Main Audio Profile Level 1
//9 : Speech Audio Profile Level 1
//15: High Quality Audio Profile Level 2
//30: Natural Audio Profile Level 1
//44: High Efficiency AAC Profile Level 2
//48: High Efficiency AAC v2 Profile Level 2
//55: Baseline MPEG Surround Profile (see ISO/IEC 23003-1) Level 3
//[RFC5619]
//a=fmtp:96 streamType=5; profile-level-id=44; mode=AAC-hbr; config=131
// 056E598; sizeLength=13; indexLength=3; indexDeltaLength=3; constant
// Duration=2048; MPS-profile-level-id=55; MPS-config=F1B4CF920442029B
// 501185B6DA00;
//低比特率用sizelength=6;indexlength=2;indexdeltalength=2
//[RFC3640]
//mode=AAC-hbr
sdp.NewAttribute("fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3;", ""),
},
},
},
}
marshal, err := description.Marshal()
if err != nil {
return err
}
println(marshal)
t.sdp = string(marshal)
return nil
}

41
rtsp/rtsp_track.go Normal file
View File

@@ -0,0 +1,41 @@
package rtsp
import (
"github.com/yangjiechina/avformat/transport"
"net"
)
type rtspTrack struct {
rtp transport.ITransport
rtcp transport.ITransport
rtpConn net.Conn
rtcpConn net.Conn
//rtcp
pktCount int
}
func (s *rtspTrack) onRTPPacket(conn net.Conn, data []byte) {
if s.rtpConn == nil {
s.rtpConn = conn
}
}
func (s *rtspTrack) onRTCPPacket(conn net.Conn, data []byte) {
if s.rtcpConn == nil {
s.rtcpConn = conn
}
}
// tcp链接成功回调
func (s *rtspTrack) onTCPConnected(conn net.Conn) {
if s.rtcpConn != nil {
s.rtcpConn = conn
}
}
// tcp断开链接回调
func (s *rtspTrack) onTCPDisconnected(conn net.Conn, err error) {
}

70
rtsp/transport_manager.go Normal file
View File

@@ -0,0 +1,70 @@
package rtsp
import (
"fmt"
"github.com/yangjiechina/avformat/utils"
)
type TransportManager interface {
init(startPort, endPort int)
AllocTransport(tcp bool, cb func(port int)) error
AllocPairTransport(cb func(port int)) error
}
var rtspTransportManger transportManager
func init() {
rtspTransportManger = transportManager{}
rtspTransportManger.init(20000, 30000)
}
type transportManager struct {
startPort int
endPort int
nextPort int
}
func (t *transportManager) init(startPort, endPort int) {
utils.Assert(endPort > startPort)
t.startPort = startPort
t.endPort = endPort + 1
t.nextPort = startPort
}
func (t *transportManager) AllocTransport(tcp bool, cb func(port int)) error {
loop := func(start, end int, tcp bool) int {
for i := start; i < end; i++ {
if used := utils.Used(i, tcp); !used {
cb(i)
return i
}
}
return -1
}
port := loop(t.nextPort, t.endPort, tcp)
if port == -1 {
port = loop(t.startPort, t.nextPort, tcp)
}
if port == -1 {
return fmt.Errorf("no available ports in the [%d-%d] range", t.startPort, t.endPort)
}
t.nextPort = t.nextPort + 1%t.endPort
t.nextPort = utils.MaxInt(t.nextPort, t.startPort)
return nil
}
func (t *transportManager) AllocPairTransport(cb func(port int), cb2 func(port int)) error {
if err := t.AllocTransport(false, cb); err != nil {
return err
}
if err := t.AllocTransport(false, cb2); err != nil {
return err
}
return nil
}

View File

@@ -14,6 +14,8 @@ type ISink interface {
Input(data []byte) error
SendHeader(data []byte) error
SourceId() string
TransStreamId() TransStreamId
@@ -99,6 +101,10 @@ func (s *SinkImpl) Input(data []byte) error {
return nil
}
func (s *SinkImpl) SendHeader(data []byte) error {
return s.Input(data)
}
func (s *SinkImpl) SourceId() string {
return s.SourceId_
}

View File

@@ -76,6 +76,8 @@ type ITransStream interface {
AddSink(sink ISink) error
ExistSink(id SinkId) bool
RemoveSink(id SinkId) (ISink, bool)
PopAllSink(handler func(sink ISink))
@@ -117,6 +119,11 @@ func (t *TransStreamImpl) AddSink(sink ISink) error {
return nil
}
func (t *TransStreamImpl) ExistSink(id SinkId) bool {
_, ok := t.Sinks[id]
return ok
}
func (t *TransStreamImpl) RemoveSink(id SinkId) (ISink, bool) {
sink, ok := t.Sinks[id]
if ok {
@@ -142,6 +149,7 @@ func (t *TransStreamImpl) AllSink() []ISink {
func (t *TransStreamImpl) Close() error {
return nil
}
func (t *TransStreamImpl) SendPacket(data []byte) error {
for _, sink := range t.Sinks {
sink.Input(data)