rtsp支持密码校验

This commit is contained in:
yangjiechina
2024-06-03 19:19:20 +08:00
parent fead132ba2
commit 1fa1fc2ff4
12 changed files with 564 additions and 422 deletions

View File

@@ -2,6 +2,7 @@
"gop_cache": true,
"probe_timeout": 2000,
"mw_latency": 350,
"public_ip": "192.168.2.148",
"http": {
"addr": "0.0.0.0:8080"
@@ -29,12 +30,11 @@
"webrtc": {
"port": 8000,
"public_ip": "192.168.31.123",
"transport": "UDP"
},
"gb28181": {
"port": "50000-60000",
"port": [50000,60000],
"transport": "UDP|TCP"
},

10
go.mod
View File

@@ -6,8 +6,13 @@ require (
github.com/gorilla/mux v1.8.1
github.com/gorilla/websocket v1.5.1
github.com/natefinch/lumberjack v2.0.0+incompatible
github.com/pion/interceptor v0.1.25
github.com/pion/logging v0.2.2
github.com/pion/rtcp v1.2.12
github.com/pion/rtp v1.8.3
github.com/pion/webrtc/v3 v3.2.29
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.9.0
github.com/x-cray/logrus-prefixed-formatter v0.5.2
go.uber.org/zap v1.27.0
)
@@ -22,12 +27,8 @@ require (
github.com/pion/datachannel v1.5.5 // indirect
github.com/pion/dtls/v2 v2.2.7 // indirect
github.com/pion/ice/v2 v2.3.13 // indirect
github.com/pion/interceptor v0.1.25 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns v0.0.12 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/rtcp v1.2.12 // indirect
github.com/pion/rtp v1.8.3 // indirect
github.com/pion/sctp v1.8.12 // indirect
github.com/pion/sdp/v3 v3.0.8 // indirect
github.com/pion/srtp/v2 v2.0.18 // indirect
@@ -35,7 +36,6 @@ require (
github.com/pion/transport/v2 v2.2.3 // indirect
github.com/pion/turn/v2 v2.1.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.9.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/crypto v0.18.0 // indirect
golang.org/x/net v0.20.0 // indirect

28
main.go
View File

@@ -22,12 +22,13 @@ func NewDefaultAppConfig() stream.AppConfig_ {
return stream.AppConfig_{
GOPCache: true,
MergeWriteLatency: 350,
PublicIP: "192.168.2.148",
Hls: stream.HlsConfig{
Enable: true,
Dir: "../tmp",
Duration: 2,
PlaylistLength: 10,
PlaylistLength: 0xFFFF,
},
Rtmp: stream.RtmpConfig{
@@ -35,9 +36,14 @@ func NewDefaultAppConfig() stream.AppConfig_ {
Addr: "0.0.0.0:1935",
},
Rtsp: stream.RtmpConfig{
Enable: true,
Addr: "0.0.0.0:554",
Rtsp: stream.RtspConfig{
TransportConfig: stream.TransportConfig{
Transport: "UDP|TCP",
Port: [2]uint16{30000, 40000},
},
Enable: true,
Addr: "0.0.0.0:554",
Password: "123456",
},
Log: stream.LogConfig{
@@ -55,9 +61,11 @@ func NewDefaultAppConfig() stream.AppConfig_ {
},
GB28181: stream.GB28181Config{
Addr: "0.0.0.0",
Transport: "UDP|TCP",
Port: [2]uint16{20000, 30000},
Addr: "0.0.0.0",
TransportConfig: stream.TransportConfig{
Transport: "UDP|TCP",
Port: [2]uint16{20000, 30000},
},
},
}
}
@@ -77,6 +85,10 @@ func init() {
if stream.AppConfig.GB28181.IsMultiPort() {
gb28181.TransportManger = stream.NewTransportManager(stream.AppConfig.GB28181.Port[0], stream.AppConfig.GB28181.Port[1])
}
if stream.AppConfig.Rtsp.IsMultiPort() {
rtsp.TransportManger = stream.NewTransportManager(stream.AppConfig.Rtsp.Port[0], stream.AppConfig.Rtsp.Port[1])
}
}
func main() {
@@ -101,7 +113,7 @@ func main() {
panic(rtspAddr)
}
rtspServer := rtsp.NewServer()
rtspServer := rtsp.NewServer(stream.AppConfig.Rtsp.Password)
err = rtspServer.Start(rtspAddr)
if err != nil {
panic(err)

70
rtsp/http_digst.go Normal file
View File

@@ -0,0 +1,70 @@
package rtsp
import (
"crypto/md5"
"encoding/base64"
"encoding/hex"
"fmt"
"strings"
)
import "math/rand"
func generateNonce() string {
k := make([]byte, 12)
for bytes := 0; bytes < len(k); {
n, err := rand.Read(k[bytes:])
if err != nil {
panic("rand.Read() failed")
}
bytes += n
}
return base64.StdEncoding.EncodeToString(k)
}
func generateAuthHeader(realm string) string {
return fmt.Sprintf(`Digest realm="%s", nonce="%s" algorithm=MD5`,
realm, generateNonce())
}
func h(data string) string {
hash := md5.New()
hash.Write([]byte(data))
return hex.EncodeToString(hash.Sum(nil))
}
func calculateResponse(username, realm, nonce, uri, password string) string {
//H(data) = MD5(data)
//KD(secret, data) = H(concat(secret, ":", data))
//request-digest = <"> < KD ( H(A1), unq(nonce-value) ":" H(A2) ) > <">
A1 := fmt.Sprintf("%s:%s:%s", username, realm, password)
A2 := fmt.Sprintf("%s:%s", "DESCRIBE", uri)
return h(h(A1) + ":" + nonce + ":" + h(A2))
}
func parseAuthParams(value string) (map[string]string, error) {
index := strings.Index(value, "Digest ")
if index == -1 {
return nil, fmt.Errorf("unknow scheme %s", value)
}
pairs := strings.Split(value[len("Digest "):], ",")
m := make(map[string]string, len(pairs))
for _, pair := range pairs {
i := strings.Index(pair, "=")
if i < 0 {
m[pair] = ""
} else if i == len(pair)-1 {
m[pair[:i]] = ""
} else {
m[strings.TrimSpace(pair[:i])] = strings.Trim(pair[i+1:], "\"")
}
}
return m, nil
}
func DoAuthenticatePlainTextPassword(params map[string]string, password string) bool {
response := calculateResponse(params["username"], params["realm"], params["nonce"], params["uri"], password)
return response == params["response"]
}

View File

@@ -1,32 +0,0 @@
package rtsp
import (
"github.com/yangjiechina/avformat/librtp"
"github.com/yangjiechina/avformat/utils"
)
type rtpTrack struct {
pt byte
rate int
mediaType utils.AVMediaType
//目前用于缓存带有SPS和PPS的RTP包
buffer []byte
muxer librtp.Muxer
cache bool
header [][]byte
tmp [][]byte
}
func NewRTPTrack(muxer librtp.Muxer, pt byte, rate int, mediaType utils.AVMediaType) *rtpTrack {
stream := &rtpTrack{
pt: pt,
rate: rate,
muxer: muxer,
buffer: make([]byte, 1500),
mediaType: mediaType,
}
return stream
}

272
rtsp/rtsp_handler.go Normal file
View File

@@ -0,0 +1,272 @@
package rtsp
import (
"fmt"
"github.com/yangjiechina/avformat/utils"
"github.com/yangjiechina/live-server/log"
"github.com/yangjiechina/live-server/stream"
"net/http"
"net/textproto"
"net/url"
"reflect"
"strconv"
"strings"
)
type Request struct {
session *session
sourceId string
method string
url *url.URL
headers textproto.MIMEHeader
}
// Handler 处理RTSP各个请求消息
type Handler interface {
// Process 路由请求给具体的handler, 并发送响应
Process(session *session, method string, url_ *url.URL, headers textproto.MIMEHeader) error
OnOptions(request Request) (*http.Response, []byte, error)
// OnDescribe 获取spd
OnDescribe(request Request) (*http.Response, []byte, error)
// OnSetup 订阅track
OnSetup(request Request) (*http.Response, []byte, error)
// OnPlay 请求播放
OnPlay(request Request) (*http.Response, []byte, error)
// OnTeardown 结束播放
OnTeardown(request Request) (*http.Response, []byte, error)
OnPause(request Request) (*http.Response, []byte, error)
OnGetParameter(request Request) (*http.Response, []byte, error)
OnSetParameter(request Request) (*http.Response, []byte, error)
OnRedirect(request Request) (*http.Response, []byte, error)
// OnRecord 推流
OnRecord(request Request) (*http.Response, []byte, error)
}
type handler struct {
methods map[string]reflect.Value
password string
publicHeader string
}
func (h handler) Process(session *session, method string, url_ *url.URL, headers textproto.MIMEHeader) error {
m, ok := h.methods[method]
if !ok {
return fmt.Errorf("the method %s is not implmented", method)
}
//确保拉流要经过授权
state, ok := method2StateMap[method]
if ok && state > SessionStateSetup && session.sink_ == nil {
return fmt.Errorf("please establish a session first")
}
var err error
split := strings.Split(url_.Path, "/")
source := split[len(split)-1]
//反射调用各个处理函数
results := m.Call([]reflect.Value{
reflect.ValueOf(&h),
reflect.ValueOf(Request{session, source, method, url_, headers}),
})
err, _ = results[2].Interface().(error)
if err != nil {
return err
}
response := results[0].Interface().(*http.Response)
if ok {
session.state = state
}
if response == nil {
return nil
}
body := results[1].Bytes()
err = session.response(response, body)
return err
}
func (h handler) OnOptions(request Request) (*http.Response, []byte, error) {
rep := NewOKResponse(request.headers.Get("Cseq"))
rep.Header.Set("Public", h.publicHeader)
return rep, nil, nil
}
func (h handler) OnDescribe(request Request) (*http.Response, []byte, error) {
var err error
var response *http.Response
var body []byte
//校验密码
if h.password != "" {
var success bool
authorization := request.headers.Get("Authorization")
if authorization != "" {
params, err := parseAuthParams(authorization)
success = err == nil && DoAuthenticatePlainTextPassword(params, h.password)
}
if !success {
response401 := NewResponse(http.StatusUnauthorized, request.headers.Get("Cseq"))
response401.Header.Set("WWW-Authenticate", generateAuthHeader("lkm"))
return response401, nil, nil
}
}
sinkId := stream.GenerateSinkId(request.session.conn.RemoteAddr())
sink_ := NewSink(sinkId, request.sourceId, request.session.conn, func(sdp string) {
response = NewOKResponse(request.headers.Get("Cseq"))
response.Header.Set("Content-Type", "application/sdp")
request.session.response(response, []byte(sdp))
})
code := utils.HookStateOK
stream.HookPlaying(sink_, func() {
}, func(state utils.HookState) {
code = state
})
if utils.HookStateOK != code {
return nil, nil, fmt.Errorf("hook failed. code:%d", code)
}
request.session.sink_ = sink_.(*sink)
return nil, body, err
}
func (h handler) OnSetup(request Request) (*http.Response, []byte, error) {
var response *http.Response
query, err := url.ParseQuery(request.url.RawQuery)
if err != nil {
return nil, nil, err
}
track := query.Get("track")
index, err := strconv.Atoi(track)
if err != nil {
return nil, nil, err
}
transportHeader := request.headers.Get("Transport")
if transportHeader == "" {
return nil, nil, fmt.Errorf("not find transport header")
}
split := strings.Split(transportHeader, ";")
if len(split) < 3 {
return nil, nil, fmt.Errorf("failed to parsing TRANSPORT header:%s", transportHeader)
}
tcp := "RTP/AVP" != split[0] && "RTP/AVP/UDP" != split[0]
if !tcp {
for _, value := range split {
if !strings.HasPrefix(value, "client_port=") {
continue
}
pairPort := strings.Split(value[len("client_port="):], "-")
if len(pairPort) != 2 {
return nil, nil, fmt.Errorf("failed to parsing client_port:%s", value)
}
port, err := strconv.Atoi(pairPort[0])
if err != nil {
return nil, nil, err
}
_ = port
port2, err := strconv.Atoi(pairPort[1])
if err != nil {
return nil, nil, err
}
_ = port2
log.Sugar.Debugf("client port:%d-%d", port, port2)
}
}
ssrc := 0xFFFFFFFF
rtpPort, rtcpPort, err := request.session.sink_.addSender(index, tcp, uint32(ssrc))
if err != nil {
return nil, nil, err
}
responseHeader := transportHeader
if tcp {
//修改interleaved为实际的stream index
responseHeader += ";interleaved=" + fmt.Sprintf("%d-%d", index, index)
} else {
responseHeader += ";server_port=" + fmt.Sprintf("%d-%d", rtpPort, rtcpPort)
}
responseHeader += ";ssrc=" + strconv.FormatInt(int64(ssrc), 16)
response = NewOKResponse(request.headers.Get("Cseq"))
response.Header.Set("Transport", responseHeader)
response.Header.Set("Session", request.session.sessionId)
return response, nil, nil
}
func (h handler) OnPlay(request Request) (*http.Response, []byte, error) {
response := NewOKResponse(request.headers.Get("Cseq"))
sessionHeader := request.headers.Get("Session")
if sessionHeader != "" {
response.Header.Set("Session", sessionHeader)
}
request.session.sink_.playing = true
return response, nil, nil
}
func (h handler) OnTeardown(request Request) (*http.Response, []byte, error) {
response := NewOKResponse(request.headers.Get("Cseq"))
return response, nil, nil
}
func (h handler) OnPause(request Request) (*http.Response, []byte, error) {
response := NewOKResponse(request.headers.Get("Cseq"))
return response, nil, nil
}
func newHandler(password string) *handler {
h := handler{
methods: make(map[string]reflect.Value, 10),
password: password,
}
//反射获取所有成员函数, 映射对应的RTSP请求方法
t := reflect.TypeOf(&h)
numMethod := t.NumMethod()
headers := make([]string, 0, 10)
for i := 0; i < numMethod; i++ {
method := t.Method(i)
if !strings.HasPrefix(method.Name, "On") {
continue
}
//确保函数名和RTSP标准的请求方法保持一致
methodName := strings.ToUpper(method.Name[2:])
h.methods[methodName] = method.Func
headers = append(headers, methodName)
}
h.publicHeader = strings.Join(headers, ",")
return &h
}

View File

@@ -5,7 +5,6 @@ import (
"github.com/yangjiechina/avformat/utils"
"github.com/yangjiechina/live-server/log"
"net"
"net/textproto"
)
type IServer interface {
@@ -14,36 +13,29 @@ type IServer interface {
Close()
}
func NewServer() IServer {
func NewServer(password string) IServer {
return &serverImpl{
publicHeader: "OPTIONS, DESCRIBE, SETUP, PLAY, TEARDOWN, PAUSE, GET_PARAMETER, SET_PARAMETER, REDIRECT, RECORD",
handler: newHandler(password),
}
}
type serverImpl struct {
tcp *transport.TCPServer
handlers map[string]func(source string, headers textproto.MIMEHeader)
publicHeader string
tcp *transport.TCPServer
handler *handler
}
func (s *serverImpl) Start(addr net.Addr) error {
utils.Assert(s.tcp == nil)
//监听TCP端口
server := &transport.TCPServer{}
server.SetHandler(s)
err := server.Bind(addr)
if err != nil {
return err
}
s.tcp = server
for key, _ := range s.handlers {
s.publicHeader += key + ", "
}
s.publicHeader = s.publicHeader[:len(s.publicHeader)-2]
return nil
}
@@ -69,18 +61,16 @@ func (s *serverImpl) OnConnected(conn net.Conn) {
func (s *serverImpl) OnPacket(conn net.Conn, data []byte) {
t := conn.(*transport.Conn)
message, url, header, err := parseMessage(data)
method, url, header, err := parseMessage(data)
if err != nil {
log.Sugar.Errorf("failed to prase message:%s. err:%s peer:%s", string(data), err.Error(), conn.RemoteAddr().String())
log.Sugar.Errorf("failed to prase message:%s. err:%s conn:%s", string(data), err.Error(), conn.RemoteAddr().String())
_ = conn.Close()
s.closeSession(conn)
return
}
err = t.Data.(*session).Input(message, url, header)
err = s.handler.Process(t.Data.(*session), method, url, header)
if err != nil {
log.Sugar.Errorf("failed to process message of RTSP. err:%s peer:%s msg:%s", err.Error(), conn.RemoteAddr().String(), string(data))
log.Sugar.Errorf("failed to process message of RTSP. err:%s conn:%s msg:%s", err.Error(), conn.RemoteAddr().String(), string(data))
_ = conn.Close()
}
}

View File

@@ -4,9 +4,6 @@ import (
"bufio"
"bytes"
"fmt"
"github.com/yangjiechina/avformat/utils"
"github.com/yangjiechina/live-server/log"
"github.com/yangjiechina/live-server/stream"
"net"
"net/http"
"net/textproto"
@@ -19,43 +16,29 @@ import (
type SessionState int
const (
MethodOptions = "OPTIONS"
MethodDescribe = "DESCRIBE"
MethodSetup = "SETUP"
MethodPlay = "PLAY"
MethodTeardown = "TEARDOWN"
MethodPause = "PAUSE"
MethodGetParameter = "GET_PARAMETER"
MethodSetParameter = "SET_PARAMETER"
MethodRedirect = "REDIRECT"
MethodRecord = "RECORD"
Version = "RTSP/1.0"
SessionSateOptions = SessionState(0x1)
SessionSateDescribe = SessionState(0x2)
SessionSateSetup = SessionState(0x3)
SessionSatePlay = SessionState(0x4)
SessionSateTeardown = SessionState(0x5)
SessionSatePause = SessionState(0x6)
SessionStateOptions = SessionState(0x1)
SessionStateDescribe = SessionState(0x2)
SessionStateSetup = SessionState(0x3)
SessionStatePlay = SessionState(0x4)
SessionStateTeardown = SessionState(0x5)
SessionStatePause = SessionState(0x6)
)
type requestHandler interface {
onOptions(sourceId string, headers textproto.MIMEHeader)
var (
method2StateMap map[string]SessionState
)
//获取spd
onDescribe(sourceId string, headers textproto.MIMEHeader)
//订阅track
onSetup(sourceId string, index int, headers textproto.MIMEHeader)
//播放
onPlay(sourceId string)
onTeardown()
onPause()
func init() {
method2StateMap = map[string]SessionState{
"OPTIONS": SessionStateOptions,
"DESCRIBE": SessionStateDescribe,
"SETUP": SessionStateSetup,
"PLAY": SessionStatePlay,
"TEARDOWN": SessionStateTeardown,
"PAUSE": SessionStatePause,
}
}
type session struct {
@@ -63,61 +46,11 @@ type session struct {
sink_ *sink
sessionId string
writeBuffer *bytes.Buffer
writeBuffer *bytes.Buffer //响应体缓冲区
state SessionState
}
func NewSession(conn net.Conn) *session {
milli := int(time.Now().UnixMilli() & 0xFFFFFFFF)
return &session{
conn: conn,
sessionId: strconv.Itoa(milli),
writeBuffer: bytes.NewBuffer(make([]byte, 0, 1024*10)),
}
}
func NewOKResponse(cseq string) http.Response {
rep := http.Response{
Proto: Version,
StatusCode: http.StatusOK,
Status: http.StatusText(http.StatusOK),
Header: make(http.Header),
}
if cseq == "" {
cseq = "1"
}
rep.Header.Set("Cseq", cseq)
return rep
}
func parseMessage(data []byte) (string, *url.URL, textproto.MIMEHeader, error) {
reader := bufio.NewReader(bytes.NewReader(data))
tp := textproto.NewReader(reader)
line, err := tp.ReadLine()
split := strings.Split(line, " ")
if len(split) < 3 {
panic(fmt.Errorf("unknow response line of response:%s", line))
}
method := strings.ToUpper(split[0])
//version
_ = split[2]
url_, err := url.Parse(split[1])
if err != nil {
return "", nil, nil, err
}
header, err := tp.ReadMIMEHeader()
if err != nil {
return "", nil, nil, err
}
return method, url_, header, nil
}
func (s *session) response(response http.Response, body []byte) error {
func (s *session) response(response *http.Response, body []byte) error {
//添加Content-Length
if body != nil {
response.Header.Set("Content-Length", strconv.Itoa(len(body)))
@@ -150,161 +83,65 @@ func (s *session) response(response http.Response, body []byte) error {
return err
}
func (s *session) onOptions(sourceId string, headers textproto.MIMEHeader) error {
rep := NewOKResponse(headers.Get("Cseq"))
rep.Header.Set("Public", "OPTIONS, DESCRIBE, SETUP, PLAY, TEARDOWN, PAUSE, GET_PARAMETER, SET_PARAMETER, REDIRECT, RECORD")
return s.response(rep, nil)
}
func (s *session) onDescribe(source string, headers textproto.MIMEHeader) error {
var err error
sinkId := stream.GenerateSinkId(s.conn.RemoteAddr())
sink_ := NewSink(sinkId, source, s.conn, func(sdp string) {
response := NewOKResponse(headers.Get("Cseq"))
response.Header.Set("Content-Type", "application/sdp")
err = s.response(response, []byte(sdp))
})
code := utils.HookStateOK
s.sink_ = sink_.(*sink)
stream.HookPlaying(sink_, func() {
}, func(state utils.HookState) {
code = state
})
if utils.HookStateOK != code {
return fmt.Errorf("hook failed. code:%d", code)
}
return err
}
func (s *session) onSetup(sourceId string, index int, headers textproto.MIMEHeader) error {
transportHeader := headers.Get("Transport")
if transportHeader == "" {
return fmt.Errorf("not find transport header")
}
split := strings.Split(transportHeader, ";")
if len(split) < 3 {
return fmt.Errorf("failed to parsing TRANSPORT header:%s", split)
}
tcp := "RTP/AVP" != split[0] && "RTP/AVP/UDP" != split[0]
if !tcp {
for _, value := range split {
if !strings.HasPrefix(value, "client_port=") {
continue
}
pairPort := strings.Split(value[len("client_port="):], "-")
if len(pairPort) != 2 {
return fmt.Errorf("failed to parsing client_port:%s", value)
}
port, err := strconv.Atoi(pairPort[0])
if err != nil {
return err
}
_ = port
port2, err := strconv.Atoi(pairPort[1])
if err != nil {
return err
}
_ = port2
log.Sugar.Debugf("client port:%d-%d", port, port2)
}
}
ssrc := 0xFFFFFFFF
rtpPort, rtcpPort, err := s.sink_.addTrack(index, tcp, uint32(ssrc))
if err != nil {
return err
}
responseHeader := transportHeader
if tcp {
//修改interleaved为实际的stream index
responseHeader += ";interleaved=" + fmt.Sprintf("%d-%d", index, index)
} else {
responseHeader += ";server_port=" + fmt.Sprintf("%d-%d", rtpPort, rtcpPort)
}
responseHeader += ";ssrc=" + strconv.FormatInt(int64(ssrc), 16)
response := NewOKResponse(headers.Get("Cseq"))
response.Header.Set("Transport", responseHeader)
response.Header.Set("Session", s.sessionId)
return s.response(response, nil)
}
func (s *session) onPlay(sourceId string, headers textproto.MIMEHeader) error {
response := NewOKResponse(headers.Get("Cseq"))
sessionHeader := headers.Get("Session")
if sessionHeader != "" {
response.Header.Set("Session", sessionHeader)
}
err := s.response(response, nil)
if err == nil {
s.sink_.playing = true
}
return err
}
func (s *session) onTeardown() {
}
func (s *session) onPause() {
}
func (s *session) Input(method string, url_ *url.URL, headers textproto.MIMEHeader) error {
//_ = url_.User.Username()
//_, _ = url_.User.Password()
var err error
split := strings.Split(url_.Path, "/")
source := split[len(split)-1]
if MethodOptions == method {
err = s.onOptions(source, headers)
} else if MethodDescribe == method {
err = s.onDescribe(source, headers)
} else if MethodSetup == method {
query, err := url.ParseQuery(url_.RawQuery)
if err != nil {
return err
}
track := query.Get("track")
index, err := strconv.Atoi(track)
if err != nil {
return err
}
if err = s.onSetup(source, index, headers); err != nil {
return err
}
} else if MethodPlay == method {
err = s.onPlay(source, headers)
} else if MethodTeardown == method {
s.onTeardown()
} else if MethodPause == method {
s.onPause()
}
return err
}
func (s *session) close() {
if s.sink_ != nil {
s.sink_.Close()
s.sink_ = nil
}
}
// 解析rtsp消息
func parseMessage(data []byte) (string, *url.URL, textproto.MIMEHeader, error) {
reader := bufio.NewReader(bytes.NewReader(data))
tp := textproto.NewReader(reader)
line, err := tp.ReadLine()
split := strings.Split(line, " ")
if len(split) < 3 {
panic(fmt.Errorf("unknow response line of response:%s", line))
}
method := strings.ToUpper(split[0])
//version
_ = split[2]
url_, err := url.Parse(split[1])
if err != nil {
return "", nil, nil, err
}
header, err := tp.ReadMIMEHeader()
if err != nil {
return "", nil, nil, err
}
return method, url_, header, nil
}
func NewSession(conn net.Conn) *session {
milli := int(time.Now().UnixMilli() & 0xFFFFFFFF)
return &session{
conn: conn,
sessionId: strconv.Itoa(milli),
writeBuffer: bytes.NewBuffer(make([]byte, 0, 1024*10)),
state: SessionStateOptions,
}
}
func NewResponse(code int, cseq string) *http.Response {
rep := http.Response{
Proto: Version,
StatusCode: code,
Status: http.StatusText(code),
Header: make(http.Header),
}
if cseq == "" {
cseq = "1"
}
rep.Header.Set("Cseq", cseq)
return &rep
}
func NewOKResponse(cseq string) *http.Response {
return NewResponse(http.StatusOK, cseq)
}

View File

@@ -3,6 +3,7 @@ package rtsp
import (
"fmt"
"github.com/pion/rtcp"
"github.com/yangjiechina/avformat/librtp"
"github.com/yangjiechina/avformat/transport"
"github.com/yangjiechina/avformat/utils"
"github.com/yangjiechina/live-server/log"
@@ -15,18 +16,17 @@ var (
TransportManger stream.TransportManager
)
// 对于UDP而言, 每个sink维护一对UDPTransport
// TCP直接单端口传输
// rtsp拉流sink
// 对于udp而言, 每个sink维护多个transport
// tcp直接单端口传输
type sink struct {
stream.SinkImpl
//一个rtsp源可能存在多个流, 每个流都需要拉取拉取
tracks []*rtspTrack
sdpCb func(sdp string)
senders []*librtp.RtpSender //一个rtsp源可能存在多个流, 每个流都需要拉取
sdpCb func(sdp string) //rtsp_stream生成sdp后使用该回调给rtsp_session, 响应describe
//是否是TCP拉流
tcp bool
playing bool
tcp bool //tcp拉流标记
playing bool //是否已经收到play请求
}
func NewSink(id stream.SinkId, sourceId string, conn net.Conn, cb func(sdp string)) stream.ISink {
@@ -39,21 +39,23 @@ func NewSink(id stream.SinkId, sourceId string, conn net.Conn, cb func(sdp strin
}
}
func (s *sink) setTrackCount(count int) {
s.tracks = make([]*rtspTrack, count)
func (s *sink) setSenderCount(count int) {
s.senders = make([]*librtp.RtpSender, count)
}
func (s *sink) addTrack(index int, tcp bool, ssrc uint32) (uint16, uint16, error) {
utils.Assert(index < cap(s.tracks))
utils.Assert(s.tracks[index] == nil)
func (s *sink) addSender(index int, tcp bool, ssrc uint32) (uint16, uint16, error) {
utils.Assert(index < cap(s.senders))
utils.Assert(s.senders[index] == nil)
var err error
var rtpPort uint16
var rtcpPort uint16
track := rtspTrack{
ssrc: ssrc,
sender := librtp.RtpSender{
SSRC: ssrc,
}
//tcp使用信令链路
if tcp {
s.tcp = true
} else {
@@ -61,10 +63,12 @@ func (s *sink) addTrack(index int, tcp bool, ssrc uint32) (uint16, uint16, error
//rtp port
var addr *net.UDPAddr
addr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", "0.0.0.0", port))
if err == nil {
track.rtp = &transport.UDPTransport{}
track.rtp.SetHandler2(nil, track.onRTPPacket, nil)
err = track.rtp.Bind(addr)
//创建rtp udp server
sender.Rtp = &transport.UDPTransport{}
sender.Rtp.SetHandler2(nil, sender.OnRTPPacket, nil)
err = sender.Rtp.Bind(addr)
}
rtpPort = port
@@ -73,17 +77,18 @@ func (s *sink) addTrack(index int, tcp bool, ssrc uint32) (uint16, uint16, error
//rtcp port
var addr *net.UDPAddr
addr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", "0.0.0.0", port))
if err == nil {
track.rtcp = &transport.UDPTransport{}
track.rtcp.SetHandler2(nil, track.onRTCPPacket, nil)
err = track.rtcp.Bind(addr)
//创建rtcp udp server
sender.Rtcp = &transport.UDPTransport{}
sender.Rtcp.SetHandler2(nil, sender.OnRTCPPacket, nil)
err = sender.Rtcp.Bind(addr)
} else {
track.rtp.Close()
track.rtp = nil
sender.Rtp.Close()
sender.Rtp = nil
}
rtcpPort = port
return nil
})
}
@@ -92,34 +97,35 @@ func (s *sink) addTrack(index int, tcp bool, ssrc uint32) (uint16, uint16, error
return 0, 0, err
}
s.tracks[index] = &track
s.senders[index] = &sender
return rtpPort, rtcpPort, err
}
func (s *sink) input(index int, data []byte, rtpTime uint32) error {
//拉流方还没有连上来
utils.Assert(index < cap(s.tracks))
utils.Assert(index < cap(s.senders))
track := s.tracks[index]
track.pktCount++
track.octetCount += len(data)
sender := s.senders[index]
sender.PktCount++
sender.OctetCount += len(data)
if s.tcp {
s.Conn.Write(data)
} else {
track.rtpConn.Write(data)
//发送rtcp sr包
sender.RtpConn.Write(data)
if track.rtcpConn == nil || track.pktCount%100 != 0 {
if sender.RtcpConn == nil || sender.PktCount%100 != 0 {
return nil
}
nano := uint64(time.Now().UnixNano())
ntp := (nano/1000000000 + 2208988800<<32) | (nano % 1000000000)
sr := rtcp.SenderReport{
SSRC: track.ssrc,
SSRC: sender.SSRC,
NTPTime: ntp,
RTPTime: rtpTime,
PacketCount: uint32(track.pktCount),
OctetCount: uint32(track.octetCount),
PacketCount: uint32(sender.PktCount),
OctetCount: uint32(sender.OctetCount),
}
marshal, err := sr.Marshal()
@@ -127,42 +133,39 @@ func (s *sink) input(index int, data []byte, rtpTime uint32) error {
log.Sugar.Errorf("创建rtcp sr消息失败 err:%s msg:%v", err.Error(), sr)
}
track.rtcpConn.Write(marshal)
sender.RtcpConn.Write(marshal)
}
return nil
}
// 拉流链路是否已经连接上
// 拉流测发送了play请求, 并且对于udp而言, 还需要收到nat穿透包
func (s *sink) isConnected(index int) bool {
return s.playing && (s.tcp || (s.tracks[index] != nil && s.tracks[index].rtpConn != nil))
return s.playing && (s.tcp || (s.senders[index] != nil && s.senders[index].RtpConn != nil))
}
// 发送rtp包总数
func (s *sink) pktCount(index int) int {
return s.tracks[index].pktCount
return s.senders[index].PktCount
}
// SendHeader 回调rtsp的sdp信息
// SendHeader 回调rtsp stream的sdp信息
func (s *sink) SendHeader(data []byte) error {
s.sdpCb(string(data))
return nil
}
func (s *sink) TrackConnected(index int) bool {
utils.Assert(index < cap(s.tracks))
utils.Assert(s.tracks[index].rtp != nil)
return s.tracks[index].rtcpConn != nil
}
func (s *sink) Close() {
s.SinkImpl.Close()
for _, track := range s.tracks {
if track.rtp != nil {
track.rtp.Close()
for _, sender := range s.senders {
if sender.Rtp != nil {
sender.Rtp.Close()
}
if track.rtcp != nil {
track.rtcp.Close()
if sender.Rtcp != nil {
sender.Rtcp.Close()
}
}
}

View File

@@ -18,14 +18,15 @@ const (
OverTcpMagic = 0x24
)
// 低延迟是rtsp特性, 不考虑实现GOP缓存
// rtsp传输流封装
// 低延迟是rtsp特性, 所以不考虑实现GOP缓存
type tranStream struct {
stream.TransStreamImpl
addr net.IPAddr
addrType string
urlFormat string
rtpTracks []*rtpTrack
rtpTracks []*rtspTrack
sdp string
}
@@ -48,11 +49,13 @@ func NewTransStream(addr net.IPAddr, urlFormat string) stream.ITransStream {
func TransStreamFactory(source stream.ISource, protocol stream.Protocol, streams []utils.AVStream) (stream.ITransStream, error) {
trackFormat := source.Id() + "?track=%d"
return NewTransStream(net.IPAddr{
IP: net.IP{},
IP: net.ParseIP(stream.AppConfig.PublicIP),
Zone: "",
}, trackFormat), nil
}
// rtpMuxer申请输出流内存的回调
// 无论是tcp/udp拉流, 均使用同一块内存, 并且给tcp预留4字节的包长.
func (t *tranStream) onAllocBuffer(params interface{}) []byte {
return t.rtpTracks[params.(int)].buffer[OverTcpHeaderSize:]
}
@@ -63,27 +66,30 @@ func (t *tranStream) onRtpPacket(data []byte, timestamp uint32, params interface
index := params.(int)
track := t.rtpTracks[index]
//保存sps和ssp等数据
if track.cache && track.header == nil {
//保存带有sps和ssp等编码信息的rtp包, 对所有sink通用
if track.cache && track.extraDataBuffer == nil {
bytes := make([]byte, OverTcpHeaderSize+len(data))
copy(bytes[OverTcpHeaderSize:], data)
track.tmp = append(track.tmp, bytes)
track.tmpExtraDataBuffer = append(track.tmpExtraDataBuffer, bytes)
t.overTCP(bytes, index)
return
}
//将rtp包发送给各个sink
for _, value := range t.Sinks {
sink_ := value.(*sink)
if !sink_.isConnected(index) {
continue
}
//为刚刚连接上的sink, 发送sps和pps等rtp包
if sink_.pktCount(index) < 1 && utils.AVMediaTypeVideo == track.mediaType {
seq := binary.BigEndian.Uint16(data[2:])
count := len(track.header)
count := len(track.extraDataBuffer)
for i, rtp := range track.header {
for i, rtp := range track.extraDataBuffer {
//回滚rtp包的序号
librtp.RollbackSeq(rtp[OverTcpHeaderSize:], int(seq)-(count-i-1))
if sink_.tcp {
sink_.input(index, rtp, 0)
@@ -96,6 +102,7 @@ func (t *tranStream) onRtpPacket(data []byte, timestamp uint32, params interface
end := OverTcpHeaderSize + len(data)
t.overTCP(track.buffer[:end], index)
//发送rtp包
if sink_.tcp {
sink_.input(index, track.buffer[:end], 0)
} else {
@@ -117,7 +124,7 @@ func (t *tranStream) Input(packet utils.AVPacket) error {
} else if utils.AVMediaTypeVideo == packet.MediaType() {
//将sps和pps按照单一模式打包
if stream_.header == nil {
if stream_.extraDataBuffer == nil {
if !packet.KeyFrame() {
return nil
}
@@ -135,7 +142,7 @@ func (t *tranStream) Input(packet utils.AVPacket) error {
stream_.muxer.Input(spsBytes[0], uint32(packet.ConvertPts(stream_.rate)))
stream_.muxer.Input(ppsBytes[0], uint32(packet.ConvertPts(stream_.rate)))
stream_.header = stream_.tmp
stream_.extraDataBuffer = stream_.tmpExtraDataBuffer
}
data := libavc.RemoveStartCode(packet.AnnexBPacketData(t.TransStreamImpl.Tracks[packet.Index()]))
@@ -146,7 +153,7 @@ func (t *tranStream) Input(packet utils.AVPacket) error {
}
func (t *tranStream) AddSink(sink_ stream.ISink) error {
sink_.(*sink).setTrackCount(len(t.TransStreamImpl.Tracks))
sink_.(*sink).setSenderCount(len(t.TransStreamImpl.Tracks))
if err := sink_.(*sink).SendHeader([]byte(t.sdp)); err != nil {
return err
}
@@ -179,7 +186,7 @@ func (t *tranStream) AddTrack(stream utils.AVStream) error {
muxer.SetAllocHandler(t.onAllocBuffer)
muxer.SetWriteHandler(t.onRtpPacket)
t.rtpTracks = append(t.rtpTracks, NewRTPTrack(muxer, byte(payloadType.Pt), payloadType.ClockRate, stream.Type()))
t.rtpTracks = append(t.rtpTracks, NewRTSPTrack(muxer, byte(payloadType.Pt), payloadType.ClockRate, stream.Type()))
muxer.SetParams(len(t.rtpTracks) - 1)
return nil
}

View File

@@ -1,56 +1,32 @@
package rtsp
import (
"github.com/yangjiechina/avformat/transport"
"net"
"github.com/yangjiechina/avformat/librtp"
"github.com/yangjiechina/avformat/utils"
)
// 对rtsp每路输出流的封装
type rtspTrack struct {
rtp transport.ITransport
rtcp transport.ITransport
pt byte
rate int
mediaType utils.AVMediaType
rtpConn net.Conn
rtcpConn net.Conn
buffer []byte //buffer of rtp packet
muxer librtp.Muxer
cache bool
//rtcp
pktCount int
ssrc uint32
octetCount int
extraDataBuffer [][]byte //缓存带有编码信息的rtp包, 对所有sink通用
tmpExtraDataBuffer [][]byte //缓存带有编码信息的rtp包, 整个过程会多次回调(sps->pps->sei...), 先保存到临时区, 最后再缓存到extraDataBuffer
}
func (s *rtspTrack) onRTPPacket(conn net.Conn, data []byte) {
if s.rtpConn == nil {
s.rtpConn = conn
}
}
func (s *rtspTrack) onRTCPPacket(conn net.Conn, data []byte) {
if s.rtcpConn == nil {
s.rtcpConn = conn
func NewRTSPTrack(muxer librtp.Muxer, pt byte, rate int, mediaType utils.AVMediaType) *rtspTrack {
stream := &rtspTrack{
pt: pt,
rate: rate,
muxer: muxer,
buffer: make([]byte, 1500),
mediaType: mediaType,
}
//packs, err := rtcp.Unmarshal(data)
//if err != nil {
// log.Sugar.Warnf("解析rtcp包失败 err:%s conn:%s pkt:%s", err.Error(), conn.RemoteAddr().String(), hex.EncodeToString(data))
// return
//}
//
//for _, pkt := range packs {
// if _, ok := pkt.(*rtcp.ReceiverReport); ok {
// } else if _, ok := pkt.(*rtcp.SourceDescription); ok {
// } else if _, ok := pkt.(*rtcp.Goodbye); ok {
// }
//}
}
// tcp链接成功回调
func (s *rtspTrack) onTCPConnected(conn net.Conn) {
if s.rtcpConn != nil {
s.rtcpConn = conn
}
}
// tcp断开链接回调
func (s *rtspTrack) onTCPDisconnected(conn net.Conn, err error) {
return stream
}

View File

@@ -6,15 +6,22 @@ const (
DefaultMergeWriteLatency = 350
)
type TransportConfig struct {
Transport string //"UDP|TCP"
Port [2]uint16 //单端口模式[0]=port/多端口模式[0]=start port, [0]=end port.
}
type RtmpConfig struct {
Enable bool `json:"enable"`
Addr string `json:"addr"`
}
type RtspConfig struct {
RtmpConfig
TransportConfig
Addr string
Enable bool `json:"enable"`
Password string
Port [2]uint16
}
type RecordConfig struct {
@@ -44,20 +51,19 @@ type HttpConfig struct {
}
type GB28181Config struct {
Addr string
Transport string //"UDP|TCP"
Port [2]uint16 //单端口模式[0]=port/多端口模式[0]=start port, [0]=end port.
TransportConfig
Addr string
}
func (g GB28181Config) EnableTCP() bool {
func (g TransportConfig) EnableTCP() bool {
return strings.Contains(g.Transport, "TCP")
}
func (g GB28181Config) EnableUDP() bool {
func (g TransportConfig) EnableUDP() bool {
return strings.Contains(g.Transport, "UDP")
}
func (g GB28181Config) IsMultiPort() bool {
func (g TransportConfig) IsMultiPort() bool {
return g.Port[1] > 0 && g.Port[1] > g.Port[0]
}
@@ -127,14 +133,15 @@ func init() {
// AppConfig_ GOP缓存和合并写必须保持一致同时开启或关闭. 关闭GOP缓存是为了降低延迟很难理解又另外开启合并写.
type AppConfig_ struct {
GOPCache bool `json:"gop_cache"` //是否开启GOP缓存只缓存一组音视频
ProbeTimeout int `json:"probe_timeout"`
GOPCache bool `json:"gop_cache"` //是否开启GOP缓存只缓存一组音视频
ProbeTimeout int `json:"probe_timeout"`
PublicIP string `json:"public_ip"`
//缓存指定时长的包满了之后才发送给Sink. 可以降低用户态和内核态的交互频率,大幅提升性能.
//合并写的大小范围应当大于一帧的时长不超过一组GOP的时长在实际发送流的时候也会遵循此条例.
MergeWriteLatency int `json:"mw_latency"`
Rtmp RtmpConfig
Rtsp RtmpConfig
Rtsp RtspConfig
Hook HookConfig