add rtsp server

This commit is contained in:
notch
2020-12-21 12:13:59 +08:00
parent 5069121b4e
commit 7094607698
17 changed files with 3338 additions and 0 deletions

1
go.mod
View File

@@ -12,6 +12,7 @@ require (
github.com/gorilla/websocket v1.4.2
github.com/kelindar/process v0.0.0-20170730150328-69a29e249ec3
github.com/kelindar/rate v1.0.0
github.com/kelindar/tcp v1.0.0
github.com/pion/rtp v1.6.1
github.com/pixelbender/go-sdp v1.1.0
github.com/stretchr/testify v1.6.1

2
go.sum
View File

@@ -24,6 +24,8 @@ github.com/kelindar/process v0.0.0-20170730150328-69a29e249ec3 h1:6If+E1dikQbdT7
github.com/kelindar/process v0.0.0-20170730150328-69a29e249ec3/go.mod h1:+lTCLnZFXOkqwD8sLPl6u4erAc0cP8wFegQHfipz7KE=
github.com/kelindar/rate v1.0.0 h1:JNZdufLjtDzr/E/rCtWkqo2OVU4yJSScZngJ8LuZ7kU=
github.com/kelindar/rate v1.0.0/go.mod h1:AjT4G+hTItNwt30lucEGZIz8y7Uk5zPho6vurIZ+1Es=
github.com/kelindar/tcp v1.0.0 h1:585JE7qmc6S5EQPYLAkRqfGo4PqDxalke98AXjxPmrE=
github.com/kelindar/tcp v1.0.0/go.mod h1:JB5hj1cshLU60XrLij2BBxW3JQ4hOye8vqbyvuKb52k=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=

76
service/rtsp/io.go Executable file
View File

@@ -0,0 +1,76 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"bufio"
"strings"
"github.com/cnotch/xlog"
)
// 接收处理器接口
type receiveHandler interface {
onRequest(req *Request) error
onResponse(resp *Response) error
onPack(pack *RTPPack) error
}
// 统一消息接受函数
func receive(logger *xlog.Logger, r *bufio.Reader,
channels []int, handler receiveHandler) error {
// 预取4个字节
sl, err := r.Peek(4)
if err != nil {
return err
}
// 如果是 RTP 流
if sl[0] == rtpPackPrefix {
pack, err := ReadPacket(r, channels)
if err != nil {
if pack != nil { // 通道不匹配
logger.Warn(err.Error())
return nil
}
logger.Errorf("decode rtp pack failed; %v.", err)
return err
}
return handler.onPack(pack)
}
i := 0
for ; i < 4; i++ {
if sl[i] != rtspProto[i] {
break
}
}
if i == 4 { // 比较完成并且相等是Response
resp, err := ReadResponse(r)
if err != nil {
logger.Errorf("decode response failed; %v.", err)
return err
}
if logger.LevelEnabled(xlog.DebugLevel) {
logger.Debugf("<<<===\r\n%s", strings.TrimSpace(resp.String()))
}
return handler.onResponse(resp)
}
// 是请求
req, err := ReadRequest(r)
if err != nil {
logger.Errorf("decode request failed; %v.", err)
return err
}
if logger.LevelEnabled(xlog.DebugLevel) {
logger.Debugf("<<<===\r\n%s", strings.TrimSpace(req.String()))
}
return handler.onRequest(req)
}

168
service/rtsp/multicast_proxy.go Executable file
View File

@@ -0,0 +1,168 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"fmt"
"io"
"net"
"sync"
"github.com/cnotch/ipchub/media"
"github.com/cnotch/xlog"
"github.com/emitter-io/address"
)
// 组播代理
type multicastProxy struct {
// 创建时设置
logger *xlog.Logger
path string
bufferSize int
multicastIP string
ports [rtpChannelCount]int
ttl int
sourceIP string
closed bool
udpConn *net.UDPConn
destAddr [rtpChannelCount]*net.UDPAddr
cid media.CID
multicastLock sync.Mutex
members []io.Closer
}
func (proxy *multicastProxy) AddMember(m io.Closer) {
proxy.multicastLock.Lock()
defer proxy.multicastLock.Unlock()
if len(proxy.members) == 0 {
stream := media.Get(proxy.path)
if stream == nil {
proxy.logger.Error("start multicast proxy failed.")
return
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
proxy.logger.Errorf("start multicast proxy failed. %s", err.Error())
return
}
proxy.udpConn = udpConn
err = udpConn.SetWriteBuffer(proxy.bufferSize)
for i, port := range proxy.ports {
if port > 0 {
proxy.destAddr[i], _ = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", proxy.multicastIP, proxy.ports[i]))
}
}
proxy.members = append(proxy.members, m)
proxy.cid = stream.StartConsume(proxy, media.RTPPacket,
"net = rtsp-multicast, "+proxy.multicastIP)
proxy.closed = false
proxy.logger.Info("multicast proxy started.")
}
}
func (proxy *multicastProxy) ReleaseMember(m io.Closer) {
proxy.multicastLock.Lock()
defer proxy.multicastLock.Unlock()
for i, m2 := range proxy.members {
if m == m2 {
proxy.members = append(proxy.members[:i], proxy.members[i+1:]...)
break
}
}
if len(proxy.members) == 0 {
// 停止组播代理
proxy.close()
}
}
func (proxy *multicastProxy) MulticastIP() string {
return proxy.multicastIP
}
func (proxy *multicastProxy) Port(index int) int {
if index < 0 || index > len(proxy.ports) {
return 0
}
return proxy.ports[index]
}
func (proxy *multicastProxy) TTL() int {
return proxy.ttl
}
func (proxy *multicastProxy) SourceIP() string {
if len(proxy.sourceIP) == 0 {
addrs, err := address.GetPublic()
if err != nil {
proxy.sourceIP = "Unknown"
} else {
proxy.sourceIP = addrs[0].IP.String()
}
}
return proxy.sourceIP
}
func (proxy *multicastProxy) Consume(p Pack) {
if proxy.closed {
return
}
p2 := p.(*RTPPack)
addr := proxy.destAddr[int(p2.Channel)]
if addr != nil {
_, err := proxy.udpConn.WriteToUDP(p2.Data, addr)
if err != nil {
proxy.logger.Error(err.Error())
return
}
}
}
func (proxy *multicastProxy) Close() error {
proxy.multicastLock.Lock()
defer proxy.multicastLock.Unlock()
proxy.close()
return nil
}
func (proxy *multicastProxy) close() {
if proxy.closed {
return
}
proxy.closed = true
stream := media.Get(proxy.path)
if stream != nil {
stream.StopConsume(proxy.cid)
}
if proxy.udpConn != nil {
proxy.udpConn.Close()
proxy.udpConn = nil
}
// 关闭所有的组播客户端
for _, m := range proxy.members {
m.Close()
}
proxy.members = nil
for i := range proxy.destAddr {
proxy.destAddr[i] = nil
}
proxy.logger.Info("multicast proxy stopped.")
}

653
service/rtsp/pull_client.go Executable file
View File

@@ -0,0 +1,653 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/url"
"runtime/debug"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cnotch/ipchub/config"
"github.com/cnotch/ipchub/media"
"github.com/cnotch/ipchub/network/socket/buffered"
"github.com/cnotch/ipchub/stats"
"github.com/cnotch/ipchub/utils"
"github.com/cnotch/xlog"
"github.com/pixelbender/go-sdp/sdp"
)
const (
defaultUserAgent = config.Name + "-rstp-client/1.0"
)
// PullClient 负责拉流到服务器
type PullClient struct {
// 打开前设置
closed bool
url *url.URL
userName string
password string
md5password string
path string
rtpChannels [rtpChannelCount]int
logger *xlog.Logger
// 添加到流媒体中心后设置
stream *media.Stream
// 打开连接后设置
conn *buffered.Conn
lockW sync.Mutex
realm string
nonce string
rsession string
seq int64
rawSdp string
sdp *sdp.Session
aControl string
vControl string
aCodec string
vCodec string
}
// NewPullClient 创建拉流客户端
func NewPullClient(localPath, remoteURL string) (*PullClient, error) {
// 检查远端路径
url, err := url.Parse(remoteURL)
if err != nil {
return nil, err
}
if strings.ToLower(url.Scheme) != "rtsp" {
return nil, fmt.Errorf("RemoteURL '%s' is not RTSP url", remoteURL)
}
if strings.ToLower(url.Hostname()) == "" {
return nil, fmt.Errorf("RemoteURL '%s' is not RTSP url", remoteURL)
}
// 如果没有 port补上默认端口
port := url.Port()
if len(port) == 0 {
url.Host = url.Hostname() + ":554"
}
// 提取用户名和密码
var userName, password string
if url.User != nil {
userName = url.User.Username()
password, _ = url.User.Password()
url.User = nil
}
// 检查发布路径
path := utils.CanonicalPath(localPath)
if path == "" {
path = utils.CanonicalPath(url.Path)
} else {
_, err := url.Parse("rtsp://localhost" + path)
if err != nil {
return nil, fmt.Errorf("Path '%s' 不合法", localPath)
}
}
client := &PullClient{
closed: true,
url: url,
userName: userName,
password: password,
path: path,
}
for i := rtpChannelMin; i < rtpChannelCount; i++ {
client.rtpChannels[i] = int(i)
}
client.logger = xlog.L().With(xlog.Fields(
xlog.F("path", client.path),
xlog.F("rurl", client.url.String()),
xlog.F("type", "pull")))
return client, nil
}
// Ping 测试网络和服务器
func (c *PullClient) Ping() error {
if !c.closed {
return nil
}
defer func() {
c.disconnect()
c.conn = nil
c.stream = nil
}()
err := c.connect()
if err != nil {
return err
}
// OPTIONS 尝试握手
err = c.requestHandshake()
if err != nil {
return err
}
// DESCRIBE 获取 sdp看是否存在指定媒体
return c.requestSDP()
}
// Open 打开拉流客户端
// 依次发生请求OPTIONS、DESCRIBE、SETUP、PLAY
// 全部成功,启动接收 RTP流 go routine
func (c *PullClient) Open() (err error) {
if !c.closed {
return nil
}
defer func() {
if err != nil { // 出现任何错误执行断链操作
c.disconnect()
c.conn = nil
c.stream = nil
}
}()
// 连接
err = c.connect()
if err != nil {
return err
}
// 请求握手
err = c.requestHandshake()
if err != nil {
return err
}
// 获取流信息
err = c.requestSDP()
if err != nil {
return err
}
// 设置通讯通道
err = c.requestSetup()
if err != nil {
return err
}
// 请求播放
err = c.requestPlay()
if err != nil {
return err
}
return err
}
// Close 关闭客户端
func (c *PullClient) Close() error {
c.disconnect()
return nil
}
func (c *PullClient) requestHandshake() (err error) {
// 使用 OPTIONS 尝试握手
r := c.newRequest(MethodOptions, c.url)
r.Header.Set(FieldRequire, "implicit-play")
_, err = c.requestWithResponse(r)
return err
}
func (c *PullClient) requestSDP() (err error) {
// DESCRIBE 获取 sdp
r := c.newRequest(MethodDescribe, c.url)
r.Header.Set(FieldAccept, "application/sdp")
resp, err := c.requestWithResponse(r)
if err != nil {
return err
}
// 解析
c.rawSdp = resp.Body
c.sdp, err = sdp.ParseString(c.rawSdp)
if err != nil {
return err
}
for _, media := range c.sdp.Media {
switch media.Type {
case "video":
c.vControl = media.Attributes.Get("control")
c.vCodec = media.Format[0].Name
case "audio":
c.aControl = media.Attributes.Get("control")
c.aCodec = media.Format[0].Name
}
}
return err
}
func (c *PullClient) requestSetup() (err error) {
var respVS, respAS *Response
// 视频通道设置
if len(c.vControl) > 0 {
var setupURL *url.URL
setupURL, err = c.getSetupURL(c.vControl)
r := c.newRequest(MethodSetup, setupURL)
r.Header.Set(FieldTransport,
fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", c.rtpChannels[ChannelVideo], c.rtpChannels[ChannelVideoControl]))
respVS, err = c.requestWithResponse(r)
if err != nil {
return err
}
}
// 音频通道设置
if len(c.aControl) > 0 {
var setupURL *url.URL
setupURL, err = c.getSetupURL(c.aControl)
r := c.newRequest(MethodSetup, setupURL)
r.Header.Set(FieldTransport,
fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", c.rtpChannels[ChannelAudio], c.rtpChannels[ChannelAudioControl]))
respAS, err = c.requestWithResponse(r)
if err != nil {
return err
}
}
_ = respVS
_ = respAS
return
}
func (c *PullClient) requestPlay() (err error) {
r := c.newRequest(MethodPlay, c.url)
resp, err := c.requestWithResponse(r)
if err != nil {
return err
}
_ = resp
mproxy := &multicastProxy{
path: c.path,
bufferSize: config.NetBufferSize(),
multicastIP: utils.Multicast.NextIP(), // 设置组播IP
ttl: config.MulticastTTL(),
logger: c.logger,
}
for i := rtpChannelMin; i < rtpChannelCount; i++ {
mproxy.ports[i] = utils.Multicast.NextPort()
}
c.stream = media.NewStream(c.path, c.rawSdp,
media.Attr("addr", c.url.String()),
media.Multicast(mproxy))
go c.playStream()
return nil
}
func (c *PullClient) playStream() {
defer func() {
if r := recover(); r != nil {
c.logger.Errorf("pull stream panic; %v \n %s", r, debug.Stack())
}
stats.RtspConns.Release() // 减少RTSP连接计数
media.Unregist(c.stream) // 从媒体中心取消注册
c.disconnect() // 确保网络关闭
c.conn = nil // 通知GC尽早释放资源
c.stream = nil
c.logger.Infof("close pull stream")
}()
c.logger.Infof("open pull stream")
media.Regist(c.stream) // 向媒体中心注册流
stats.RtspConns.Add() // 增加一个 RTSP 连接计数
lastHeartbeat := time.Now()
reader := c.conn.Reader()
heartbeatInterval := config.NetHeartbeatInterval()
timeout := config.NetTimeout()
for !c.closed {
deadLine := time.Time{}
if timeout > 0 {
deadLine = time.Now().Add(timeout)
}
if err := c.conn.SetReadDeadline(deadLine); err != nil {
c.logger.Error(err.Error())
break
}
err := receive(c.logger, reader, c.rtpChannels[:], c)
if err != nil {
if err == io.EOF { // 如果对方断开
c.logger.Warn("The remote RTSP server is actively disconnected.")
} else if !c.closed { // 如果非主动关闭
c.logger.Error(err.Error())
}
break
}
if heartbeatInterval > 0 && time.Now().Sub(lastHeartbeat) > heartbeatInterval {
lastHeartbeat = time.Now()
// 心跳包
r := c.newRequest(MethodOptions, c.url)
err := c.request(r)
if err != nil {
c.logger.Error(err.Error())
break
}
}
}
reader = nil
}
func (c *PullClient) onPack(p *RTPPack) error {
return c.stream.WritePacket(p)
}
func (c *PullClient) onRequest(r *Request) (err error) {
// 只处理 Options 方法
switch r.Method {
case MethodOptions:
resp := &Response{
StatusCode: 200,
Header: r.Header,
}
resp.Header.Del(FieldUserAgent)
resp.Header.Set(FieldPublic, MethodOptions)
err = c.response(resp)
if err != nil {
return err
}
default:
resp := &Response{
StatusCode: StatusMethodNotAllowed,
Header: r.Header,
}
resp.Header.Del(FieldUserAgent)
err = c.response(resp)
if err != nil {
return err
}
}
return nil
}
func (c *PullClient) onResponse(resp *Response) (err error) {
// 忽略
return
}
func (c *PullClient) getSetupURL(ctrl string) (setupURL *url.URL, err error) {
if len(ctrl) >= len(rtspURLPrefix) && strings.EqualFold(ctrl[:len(rtspURLPrefix)], rtspURLPrefix) {
return url.Parse(ctrl)
}
setupURL = new(url.URL)
*setupURL = *c.url
if setupURL.Path[len(setupURL.Path)-1] == '/' {
setupURL.Path = setupURL.Path + ctrl
} else {
setupURL.Path = setupURL.Path + "/" + ctrl
}
return
}
func (c *PullClient) newRequest(method string, url *url.URL) *Request {
r := &Request{
Method: method,
Header: make(Header),
}
r.URL = url
if url == nil {
r.URL = c.url
}
r.Header.Set(FieldUserAgent, defaultUserAgent)
r.Header.Set(FieldCSeq, strconv.FormatInt(atomic.AddInt64(&c.seq, 1), 10))
if len(c.rsession) > 0 {
r.Header.Set(FieldSession, c.rsession)
}
// 和安全相关,已经收到安全作用域信息
if len(c.realm) > 0 {
pw := c.password
if len(c.md5password) > 0 {
pw = c.md5password
}
if len(c.nonce) > 0 {
// Digest 认证
r.SetDigestAuth(r.URL, c.realm, c.nonce, c.userName, pw)
} else {
// Basic 认证
r.SetBasicAuth(c.userName, pw)
}
}
return r
}
func (c *PullClient) receiveResponse() (resp *Response, err error) {
resp, err = ReadResponse(c.conn.Reader())
if err != nil {
return nil, err
}
if c.logger.LevelEnabled(xlog.DebugLevel) {
c.logger.Debugf("<<<===\r\n%s", strings.TrimSpace(resp.String()))
}
return
}
func (c *PullClient) requestWithResponse(r *Request) (*Response, error) {
err := c.request(r)
if err != nil {
return nil, err
}
resp, err := c.receiveResponse()
if err != nil {
return nil, err
}
// 保存 session
c.rsession = resp.Header.Get(FieldSession)
// 如果需要安全信息,增加安全信息并再次请求
if resp.StatusCode == StatusUnauthorized {
if len(c.userName) == 0 {
return resp, errors.New("require username and password")
}
pw := c.password
auth := resp.Header.Get(FieldWWWAuthenticate)
if len(auth) > len(digestAuthPrefix) && strings.EqualFold(auth[:len(digestAuthPrefix)], digestAuthPrefix) {
ok := false
c.realm, c.nonce, ok = resp.DigestAuth()
if !ok {
return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
}
r.SetDigestAuth(r.URL, c.realm, c.nonce, c.userName, pw)
} else if len(auth) > len(basicAuthPrefix) && strings.EqualFold(auth[:len(basicAuthPrefix)], basicAuthPrefix) {
ok := false
c.realm, ok = resp.BasicAuth()
if !ok {
return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
}
r.SetBasicAuth(c.userName, pw)
} else {
return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
}
// 修改请求序号
r.Header.Set(FieldCSeq, strconv.FormatInt(atomic.AddInt64(&c.seq, 1), 10))
err := c.request(r)
if err != nil {
return nil, err
}
resp, err = c.receiveResponse()
if err != nil {
return nil, err
}
// 保存 session
c.rsession = resp.Header.Get(FieldSession)
// TODO: 代码臃肿,需要优化
// 再试一次 password md5的情况
if resp.StatusCode == StatusUnauthorized {
md5Digest := md5.Sum([]byte(c.password))
c.md5password = hex.EncodeToString(md5Digest[:])
pw := c.md5password
auth := resp.Header.Get(FieldWWWAuthenticate)
if len(auth) > len(digestAuthPrefix) && strings.EqualFold(auth[:len(digestAuthPrefix)], digestAuthPrefix) {
ok := false
c.realm, c.nonce, ok = resp.DigestAuth()
if !ok {
return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
}
r.SetDigestAuth(r.URL, c.realm, c.nonce, c.userName, pw)
} else if len(auth) > len(basicAuthPrefix) && strings.EqualFold(auth[:len(basicAuthPrefix)], basicAuthPrefix) {
ok := false
c.realm, ok = resp.BasicAuth()
if !ok {
return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
}
r.SetBasicAuth(c.userName, pw)
} else {
return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
}
// 修改请求序号
r.Header.Set(FieldCSeq, strconv.FormatInt(atomic.AddInt64(&c.seq, 1), 10))
err := c.request(r)
if err != nil {
return nil, err
}
resp, err = c.receiveResponse()
if err != nil {
return nil, err
}
// 保存 session
c.rsession = resp.Header.Get(FieldSession)
}
}
if !(resp.StatusCode >= 200 && resp.StatusCode <= 300) {
return resp, errors.New(resp.Status)
}
return resp, nil
}
func (c *PullClient) request(req *Request) error {
c.lockW.Lock()
err := req.Write(c.conn)
if err == nil {
_, err = c.conn.Flush()
}
c.lockW.Unlock()
if err != nil {
c.logger.Errorf("send request error = %v", err)
return err
}
if c.logger.LevelEnabled(xlog.DebugLevel) {
c.logger.Debugf("===>>>\r\n%s", strings.TrimSpace(req.String()))
}
return err
}
func (c *PullClient) response(resp *Response) error {
c.lockW.Lock()
err := resp.Write(c.conn)
if err == nil {
_, err = c.conn.Flush()
}
c.lockW.Unlock()
if err != nil {
c.logger.Errorf("send response error = %v", err)
return err
}
if c.logger.LevelEnabled(xlog.DebugLevel) {
c.logger.Debugf("===>>>\r\n%s", strings.TrimSpace(resp.String()))
}
return nil
}
func (c *PullClient) connect() error {
// 连接超时要更短
timeout := time.Duration(int64(config.NetTimeout()) / 3)
conn, err := net.DialTimeout("tcp", c.url.Host, timeout)
if err != nil {
c.logger.Errorf("connet remote server fail,err = %v", err)
return err
}
c.closed = false // 已经连接
c.conn = buffered.NewConn(conn,
buffered.FlushRate(config.NetFlushRate()),
buffered.BufferSize(config.NetBufferSize()))
c.logger.Infof("connect remote server success")
return nil
}
func (c *PullClient) disconnect() {
if c.closed {
return
}
c.closed = true
c.logger.Info("disconnec from remote server")
if c.conn != nil {
c.conn.Close()
}
c.rsession = ""
atomic.StoreInt64(&c.seq, 0)
c.realm = ""
c.sdp = nil
c.aControl = ""
c.vControl = ""
c.aCodec = ""
c.vCodec = ""
}

View File

@@ -0,0 +1,44 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"strings"
"github.com/cnotch/ipchub/media"
)
func init() {
// 注册拉流工厂
media.RegistPullStreamFactory(NewPullStreamFacotry())
}
type pullStreamFactory struct {
}
// NewPullStreamFacotry 创建拉流工厂
func NewPullStreamFacotry() media.PullStreamFactory {
return &pullStreamFactory{}
}
func (f *pullStreamFactory) Can(remoteURL string) bool {
if len(remoteURL) >= len(rtspURLPrefix) && strings.EqualFold(remoteURL[:len(rtspURLPrefix)], rtspURLPrefix) {
return true
}
return false
}
func (f *pullStreamFactory) Create(localPath, remoteURL string) (*media.Stream, error) {
client, err := NewPullClient(localPath, remoteURL)
if err != nil {
return nil, err
}
err = client.Open()
if err != nil {
return nil, err
}
return client.stream, nil
}

164
service/rtsp/rtptransport.go Executable file
View File

@@ -0,0 +1,164 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"errors"
"strconv"
"strings"
"github.com/cnotch/ipchub/utils/scan"
)
// RTPTransportType RTP 传输模式
type RTPTransportType int
// SessionMode 会话模式
type SessionMode int
// 通讯类型
const (
RTPUnknownTrans RTPTransportType = iota
RTPTCPUnicast // TCP
RTPUDPUnicast // UDP
RTPMulticast // 组播
)
// 会话类型
const (
UnknownSession SessionMode = iota
PlaySession // 播放
RecordSession // 录像
)
// RTPTransport RTP传输设置
type RTPTransport struct {
Mode SessionMode
Append bool
Type RTPTransportType
Channels [rtpChannelCount]int
ClientPorts [rtpChannelCount]int
ServerPorts [rtpChannelCount]int
// 组播相关设置
Ports [rtpChannelCount]int // 组播端口
MulticastIP string // 组播地址 224.0.1.0238.255.255.255
TTL int
Source string // 组播源地址
}
func parseRange(p string) (begin int, end int) {
begin = -1
end = -1
var s1, s2 string
index := strings.IndexByte(p, '-')
if index < 0 {
s1 = p
} else {
s1 = strings.TrimSpace(p[:index])
s2 = strings.TrimSpace(p[index+1:])
}
var err error
if len(s1) > 0 {
begin, err = strconv.Atoi(s1)
if err != nil {
begin = -1
}
}
if len(s2) > 0 {
end, err = strconv.Atoi(s2)
if err != nil {
end = -1
}
}
return
}
// ParseTransport 解析Setup中的传输配置
func (t *RTPTransport) ParseTransport(rtpType int, ts string) (err error) {
if t.Mode == UnknownSession {
t.Mode = PlaySession
}
// 确定传输类型
index := strings.IndexByte(ts, ';')
if index < 0 {
return errors.New("malformed trannsport")
}
transportSpec := strings.TrimSpace(ts[:index])
ts = ts[index+1:]
if transportSpec == "RTP/AVP/TCP" {
t.Type = RTPTCPUnicast
} else if transportSpec == "RTP/AVP" || transportSpec == "RTP/AVP/UDP" {
t.Type = RTPMulticast // 默认组播
} else {
return errors.New("malformed trannsport")
}
// 扫描参数
tailing := ts
substr := ""
continueScan := true
for continueScan {
tailing, substr, continueScan = scan.Semicolon.Scan(tailing)
if substr == "unicast" && t.Type == RTPMulticast {
t.Type = RTPUDPUnicast
continue
}
if substr == "multicast" && t.Type == RTPTCPUnicast {
err = errors.New("malformed trannsport")
continue
}
if substr == "append" {
t.Append = true
continue
}
k, v, _ := scan.EqualPair.Scan(substr)
switch k {
case "mode":
if v == "record" {
t.Mode = RecordSession
} else {
t.Mode = PlaySession
}
case "interleaved":
begin, end := parseRange(v)
if begin >= 0 {
t.Channels[rtpType] = begin
}
if end >= 0 {
t.Channels[rtpType+1] = end
}
if begin < 0 {
err = errors.New("malformed trannsport")
}
case "client_port":
t.ClientPorts[rtpType], t.ClientPorts[rtpType+1] = parseRange(v)
if t.ClientPorts[rtpType] < 0 {
err = errors.New("malformed trannsport")
}
case "server_port":
t.ServerPorts[rtpType], t.ServerPorts[rtpType+1] = parseRange(v)
if t.ServerPorts[rtpType] < 0 {
err = errors.New("malformed trannsport")
}
case "port":
t.Ports[rtpType], t.Ports[rtpType+1] = parseRange(v)
if t.Ports[rtpType] < 0 {
err = errors.New("malformed trannsport")
}
case "destination":
t.MulticastIP = v
case "source":
t.Source = v
case "ttl":
t.TTL, _ = strconv.Atoi(v)
}
}
return
}

View File

@@ -0,0 +1,82 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRTPTransport_parseTransport(t *testing.T) {
tests := []struct {
name string
rtpType int
ts string
wantErr bool
}{
{
"test1",
int(ChannelVideo),
"RTP/AVP;multicast;client_port=18888-18889",
false,
},
{
"test2",
int(ChannelAudio),
"RTP/AVP;multicast;destination=232.248.88.236;source=192.168.1.154;port=16666-0;ttl=255",
false,
},
}
var ts RTPTransport
for _, tt := range tests {
if err := ts.ParseTransport(tt.rtpType, tt.ts); (err != nil) != tt.wantErr {
t.Errorf("RTPTransport.parseTransport() error = %v, wantErr %v", err, tt.wantErr)
}
}
assert.Equal(t, 18888, ts.ClientPorts[int(ChannelVideo)])
assert.Equal(t, 18889, ts.ClientPorts[int(ChannelVideo)+1])
assert.Equal(t, PlaySession, ts.Mode, "play")
assert.Equal(t, RTPMulticast, ts.Type, "multicast")
assert.Equal(t, "232.248.88.236", ts.MulticastIP)
assert.Equal(t, "192.168.1.154", ts.Source)
assert.Equal(t, 255, ts.TTL)
assert.Equal(t, 16666, ts.Ports[int(ChannelAudio)])
assert.Equal(t, 0, ts.Ports[int(ChannelAudio)+1])
}
func TestRTPTransport_parseTransport_error(t *testing.T) {
tests := []struct {
name string
rtpType int
ts string
wantErr bool
}{
{
"error",
int(ChannelVideo),
"RTP/AVP/TCP;multicast;client_port=18888-18889",
true,
},
}
var ts RTPTransport
for _, tt := range tests {
if err := ts.ParseTransport(tt.rtpType, tt.ts); (err != nil) != tt.wantErr {
t.Errorf("RTPTransport.parseTransport() error = %v, wantErr %v", err, tt.wantErr)
}
}
}
func Benchmark_ParseTransport(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var ts RTPTransport
ts.ParseTransport(int(ChannelVideo), "RTP/AVP;multicast;destination=232.248.88.236;source=192.168.1.154;port=16666-0;ttl=255")
}
})
}

46
service/rtsp/rtsp.go Executable file
View File

@@ -0,0 +1,46 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"net"
"sync"
"github.com/cnotch/ipchub/network/socket/listener"
"github.com/kelindar/tcp"
"github.com/cnotch/xlog"
)
// MatchRTSP 仅匹配 RTSP 请求方法
// 注意由于RTSP 和 HTTP 都有 OPTIONS 方法,因此对 RTSP 的 OPTIONS 做了进一步细化
func MatchRTSP() listener.Matcher {
return listener.MatchPrefix("OPTIONS * RTSP", "OPTIONS * rtsp",
"OPTIONS rtsp://", "OPTIONS RTSP://",
MethodDescribe, MethodAnnounce, MethodSetup,
MethodPlay, MethodPause, MethodTeardown,
MethodGetParameter, MethodSetParameter,
MethodRecord, MethodRedirect)
}
// Server rtsp 服务器
type Server struct {
logger *xlog.Logger
sessions sync.Map
}
// CreateAcceptHandler 创建连接接入处理器
func CreateAcceptHandler() tcp.OnAccept {
svr := &Server{
logger: xlog.L(),
}
return svr.onAcceptConn
}
// onAcceptConn 当新连接接入时触发
func (svr *Server) onAcceptConn(c net.Conn) {
s := newSession(svr, c)
go s.process()
}

127
service/rtsp/sdp_test.go Executable file
View File

@@ -0,0 +1,127 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"testing"
"github.com/pixelbender/go-sdp/sdp"
)
const sdpRaw = `v=0
o=- 0 0 IN IP4 127.0.0.1
s=No Name
c=IN IP4 127.0.0.1
t=0 0
a=tool:libavformat 58.20.100
m=video 0 RTP/AVP 96
b=AS:2500
a=rtpmap:96 H264/90000
a=fmtp:96 packetization-mode=1; sprop-parameter-sets=Z2QAH6zZQFAFuhAAAAMAEAAAAwPI8YMZYA==,aO+8sA==; profile-level-id=64001F
a=control:streamid=0
m=audio 0 RTP/AVP 97
b=AS:160
a=rtpmap:97 MPEG4-GENERIC/44100/2
a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=121056E500
a=control:streamid=1
`
const sdpRaw2 = `v=0
o=- 946684871882903 1 IN IP4 192.168.1.154
s=RTSP/RTP stream from IPNC
i=h264
t=0 0
a=tool:LIVE555 Streaming Media v2008.04.02
a=type:broadcast
a=control:*
a=source-filter: incl IN IP4 * 192.168.1.154
a=rtcp-unicast: reflection
a=range:npt=0-
a=x-qt-text-nam:RTSP/RTP stream from IPNC
a=x-qt-text-inf:h264
m=audio 18888 RTP/AVP 0
c=IN IP4 232.190.161.0/255
a=control:track1
m=video 16666 RTP/AVP 96
c=IN IP4 232.248.88.236/255
a=rtpmap:96 H264/90000
a=fmtp:96 packetization-mode=1;profile-level-id=EE3CB0;sprop-parameter-sets=H264
a=control:track2
`
const sdpRaw3 = `v=0
o=- 0 0 IN IP6 ::1
s=No Name
c=IN IP6 ::1
t=0 0
a=tool:libavformat 58.20.100
m=video 0 RTP/AVP 96
a=rtpmap:96 H264/90000
a=fmtp:96 packetization-mode=1; sprop-parameter-sets=Z3oAH7y0AoAt0IAAAAMAgAAAHkeMGVA=,aO8Pyw==; profile-level-id=7A001F
a=control:streamid=0
m=audio 0 RTP/AVP 97
b=AS:128
a=rtpmap:97 MPEG4-GENERIC/44100/2
a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=121056E500`
// 4k mp4
const sdpRaw4 = `v=0
o=- 0 0 IN IP6 ::1
s=No Name
c=IN IP6 ::1
t=0 0
a=tool:libavformat 58.20.100
m=video 0 RTP/AVP 96
b=AS:31998
a=rtpmap:96 H264/90000
a=fmtp:96 packetization-mode=1; sprop-parameter-sets=Z2QAM6wspADwAQ+wFSAgICgAAB9IAAdTBO0LFok=,aOtzUlA=; profile-level-id=640033
a=control:streamid=0
m=audio 0 RTP/AVP 97
b=AS:317
a=rtpmap:97 MPEG4-GENERIC/48000/2
a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=1190
a=control:streamid=1`
const sdpH265Raw = `v=0
o=- 0 0 IN IP6 ::1
s=No Name
c=IN IP6 ::1
t=0 0
a=tool:libavformat 58.20.100
m=video 0 RTP/AVP 96
a=rtpmap:96 H265/90000
a=fmtp:96 sprop-vps=QAEMAf//BAgAAAMAnQgAAAMAAF26AkA=; sprop-sps=QgEBBAgAAAMAnQgAAAMAAF2wAoCALRZbqSTK4BAAAAMAEAAAAwHggA==; sprop-pps=RAHBcrRiQA==
a=control:streamid=0
m=audio 0 RTP/AVP 97
b=AS:128
a=rtpmap:97 MPEG4-GENERIC/44100/2
a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=121056E500
a=control:streamid=1
`
func Benchmark_ThirdSdpParse(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
sdp.ParseString(sdpRaw2)
}
})
}
func Test_SDPParse(t *testing.T) {
t.Run("Test_SDPParse", func(t *testing.T) {
s1, err := sdp.ParseString(sdpRaw)
if err != nil {
t.Errorf("sdp.ParseString() error = %v", err)
}
_ = s1
s2, err := sdp.ParseString(sdpRaw2)
if err != nil {
t.Errorf("sdp.ParseString() error = %v", err)
}
_ = s2
})
}

659
service/rtsp/session.go Executable file
View File

@@ -0,0 +1,659 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/url"
"runtime/debug"
"strings"
"sync"
"time"
"github.com/cnotch/ipchub/config"
"github.com/cnotch/ipchub/media"
"github.com/cnotch/ipchub/network/socket/buffered"
"github.com/cnotch/ipchub/network/websocket"
"github.com/cnotch/ipchub/provider/auth"
"github.com/cnotch/ipchub/provider/security"
"github.com/cnotch/ipchub/stats"
"github.com/cnotch/ipchub/utils"
"github.com/cnotch/xlog"
"github.com/emitter-io/address"
"github.com/pixelbender/go-sdp/sdp"
)
const (
realm = config.Name
)
const (
statusInit = iota
statusReady
statusPlaying
statusRecording
)
var buffers = sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, 1024*2))
},
}
// Session RTSP 会话
type Session struct {
// 创建时设置
svr *Server
logger *xlog.Logger
closed bool
lsession string // 本地会话标识
timeout time.Duration
conn *buffered.Conn
lockW sync.Mutex
wsconn websocket.Conn
authMode auth.Mode
nonce string
user *auth.User
// DESCRIBE或 ANNOUNCE 后设置
url *url.URL
path string
rawSdp string
sdp *sdp.Session
aControl string
vControl string
aCodec string
vCodec string
mode SessionMode
// Setup 后设置
transport RTPTransport
// 启动流媒体传输后设置
status int // session状态
stream mediaStream // 媒体流
consumer media.Consumer // 消费者
}
func newSession(svr *Server, conn net.Conn) *Session {
session := &Session{
svr: svr,
lsession: security.NewID().Base64(),
timeout: config.NetTimeout(),
conn: buffered.NewConn(conn,
buffered.FlushRate(config.NetFlushRate()),
buffered.BufferSize(config.NetBufferSize())),
mode: UnknownSession,
transport: RTPTransport{
Mode: PlaySession, // 默认为播放
Type: RTPUnknownTrans,
},
authMode: config.RtspAuthMode(),
nonce: security.NewID().MD5(),
status: statusInit,
stream: defaultStream,
consumer: defaultConsumer,
}
if wsc, ok := conn.(websocket.Conn); ok { // 如果是WebSocket有http进行验证
session.authMode = auth.NoneAuth
session.wsconn = wsc
session.path = wsc.Path()
session.user = auth.Get(wsc.Username())
}
ipaddr, _ := address.Parse(conn.RemoteAddr().String(), 80)
// 如果是本机IP不验证以便ffmpeg本机rtsp->rtmp
if utils.IsLocalhostIP(ipaddr.IP) {
session.authMode = auth.NoneAuth
}
for i := rtpChannelMin; i < rtpChannelCount; i++ {
session.transport.Channels[i] = -1
session.transport.ClientPorts[i] = -1
}
session.logger = svr.logger.With(xlog.Fields(
xlog.F("session", session.lsession)))
return session
}
// Addr Session地址
func (s *Session) Addr() string {
return s.conn.RemoteAddr().String()
}
// Consume 消费媒体包
func (s *Session) Consume(p Pack) {
s.consumer.Consume(p)
}
// Close 关闭会话
func (s *Session) Close() error {
if s.closed {
return nil
}
s.closed = true
s.conn.Close()
return nil
}
func (s *Session) process() {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("session panic; %v \n %s", r, debug.Stack())
}
stats.RtspConns.Release()
s.Close()
s.consumer.Close()
s.stream.Close()
// 重置到初始状态
s.conn = nil
s.status = statusInit
s.stream = defaultStream
s.consumer = defaultConsumer
s.logger.Infof("close rtsp session")
}()
s.logger.Infof("open rtsp session")
stats.RtspConns.Add() // 增加一个 RTSP 连接计数
reader := s.conn.Reader()
for !s.closed {
deadLine := time.Time{}
if s.timeout > 0 {
deadLine = time.Now().Add(s.timeout)
}
if err := s.conn.SetReadDeadline(deadLine); err != nil {
s.logger.Error(err.Error())
break
}
err := receive(s.logger, reader, s.transport.Channels[:], s)
if err != nil {
if err == io.EOF { // 如果客户端断开提醒
s.logger.Warn("The client actively disconnects")
} else if !s.closed { // 如果主动关闭,不提示
s.logger.Error(err.Error())
}
break
}
}
}
// receiveHandler.onPack
func (s *Session) onPack(pack *RTPPack) (err error) {
return s.stream.WritePacket(pack)
}
// receiveHandler.onResponse
func (s *Session) onResponse(resp *Response) (err error) {
// 忽略,服务器不会主动发起请求
return
}
// receiveHandler.onRequest
func (s *Session) onRequest(req *Request) (err error) {
resp := s.newResponse(StatusOK, req)
// 预处理
continueProcess, err := s.onPreprocess(resp, req)
if !continueProcess {
return err
}
switch req.Method {
case MethodDescribe:
s.onDescribe(resp, req)
case MethodAnnounce:
s.onAnnounce(resp, req)
case MethodSetup:
s.onSetup(resp, req)
case MethodRecord:
s.onRecord(resp, req)
case MethodPlay:
return s.onPlay(resp, req) // play 发送流媒体不在当前 routine需要先回复
default:
// 状态不支持的方法
resp.StatusCode = StatusMethodNotValidInThisState
}
// 发送响应
err = s.response(resp)
return err
}
func (s *Session) onDescribe(resp *Response, req *Request) {
// TODO: 检查 accept 中的类型是否包含 sdp
s.url = req.URL
if s.wsconn == nil { // websocket访问的路径有ws://路径表示
s.path = utils.CanonicalPath(req.URL.Path)
}
stream := media.GetOrCreate(s.path)
if stream == nil {
resp.StatusCode = StatusNotFound
return
}
if !s.checkPermission(auth.PullRight) {
resp.StatusCode = StatusForbidden
return
}
// 从流中取 sdp
sdpRaw := stream.Attr("sdp")
if len(sdpRaw) == 0 {
resp.StatusCode = StatusNotFound
return
}
err := s.parseSdp(sdpRaw)
if err != nil { // TODO需要更好的处理方式
resp.StatusCode = StatusNotFound
return
}
resp.Header.Set(FieldContentType, "application/sdp")
resp.Body = s.rawSdp
s.mode = PlaySession // 标记为播放会话
}
func (s *Session) onAnnounce(resp *Response, req *Request) {
// 检查 Content-Type: application/sdp
if req.Header.Get(FieldContentType) != "application/sdp" {
resp.StatusCode = StatusBadRequest // TODO:更合适的代码
return
}
s.url = req.URL
s.path = utils.CanonicalPath(req.URL.Path)
if !s.checkPermission(auth.PushRight) {
resp.StatusCode = StatusForbidden
return
}
// 从流中取 sdp
err := s.parseSdp(req.Body)
if err != nil {
resp.StatusCode = StatusBadRequest
return
}
s.mode = RecordSession // 标记为录像会话
}
func (s *Session) onSetup(resp *Response, req *Request) {
// a=control:streamid=1
// a=control:rtsp://192.168.1.165/trackID=1
// a=control:?ctype=video
setupURL := &url.URL{}
*setupURL = *req.URL
if setupURL.Port() == "" {
setupURL.Host = fmt.Sprintf("%s:554", setupURL.Host)
}
setupPath := setupURL.String()
//setupPath = setupPath[strings.LastIndex(setupPath, "/")+1:]
vPath, err := getControlPath(s.vControl)
if err != nil {
resp.StatusCode = StatusInternalServerError
resp.Status = "Invalid VControl"
return
}
aPath, err := getControlPath(s.aControl)
if err != nil {
resp.StatusCode = StatusInternalServerError
resp.Status = "Invalid AControl"
return
}
ts := req.Header.Get(FieldTransport)
resp.Header.Set(FieldTransport, ts) // 先回写transport
// 检查控制路径
chindex := -1
if setupPath == aPath || (aPath != "" && strings.LastIndex(setupPath, aPath) == len(setupPath)-len(aPath)) {
chindex = int(ChannelAudio)
} else if setupPath == vPath || (vPath != "" && strings.LastIndex(setupPath, vPath) == len(setupPath)-len(vPath)) {
chindex = int(ChannelVideo)
} else { // 找不到被 Setup 的资源
resp.StatusCode = StatusInternalServerError
resp.Status = fmt.Sprintf("SETUP Unkown control:%s", setupPath)
return
}
err = s.transport.ParseTransport(chindex, ts)
if err != nil {
resp.StatusCode = StatusInvalidParameter
resp.Status = err.Error()
return
}
// 检查和以前的命令是否一致
if s.mode == UnknownSession {
s.mode = s.transport.Mode
}
if s.mode != s.transport.Mode {
resp.StatusCode = StatusInvalidParameter
if s.mode == PlaySession {
resp.Status = "Current state can't setup as record"
} else {
resp.Status = "Current state can't setup as play"
}
return
}
// record 只支持 TCP 单播
if s.mode == RecordSession {
// 检查用户权限
if !s.checkPermission(auth.PushRight) {
resp.StatusCode = StatusForbidden
return
}
if s.transport.Type != RTPTCPUnicast {
resp.StatusCode = StatusUnsupportedTransport
resp.Status = "when mode = recordonly support tcp unicast"
} else {
if s.status < statusReady { // 初始状态切换到Ready
s.status = statusReady
}
}
return
}
// 检查用户权限,播放
if !s.checkPermission(auth.PullRight) {
resp.StatusCode = StatusForbidden
return
}
if s.transport.Type == RTPMulticast { // 需要修改回复的transport
st := media.GetOrCreate(s.path)
if st == nil { // 没有找到源
resp.StatusCode = StatusNotFound
return
}
ma := st.Multicastable()
if ma == nil { // 不支持组播
resp.StatusCode = StatusUnsupportedTransport
return
}
ts = fmt.Sprintf("%s;destination=%s;port=%d-%d;source=%s;ttl=%d",
ts, ma.MulticastIP(),
ma.Port(chindex), ma.Port(chindex+1),
ma.SourceIP(), ma.TTL())
resp.Header.Set(FieldTransport, ts)
}
if s.status < statusReady { // 初始状态切换到Ready
s.status = statusReady
}
}
func (s *Session) onRecord(resp *Response, req *Request) {
if s.status == statusRecording {
return
}
// 传输模式、会话模式判断
if s.mode != RecordSession || s.transport.Type != RTPTCPUnicast {
resp.StatusCode = StatusMethodNotValidInThisState
return
}
if !s.checkPermission(auth.PushRight) {
resp.StatusCode = StatusForbidden
return
}
s.asTCPPusher()
s.status = statusRecording
}
func (s *Session) onPlay(resp *Response, req *Request) (err error) {
if s.status == statusPlaying {
return
}
// 传输模式、会话模式判断
if s.mode != PlaySession || s.transport.Type == RTPUnknownTrans {
resp.StatusCode = StatusMethodNotValidInThisState
return s.response(resp)
}
stream := media.GetOrCreate( s.path)
if stream == nil {
resp.StatusCode = StatusNotFound
return s.response(resp)
}
if !s.checkPermission(auth.PullRight) {
resp.StatusCode = StatusForbidden
return s.response(resp)
}
resp.Header.Set(FieldRange, req.Header.Get(FieldRange))
switch s.transport.Type {
case RTPTCPUnicast:
err = s.asTCPConsumer(stream, resp)
case RTPUDPUnicast:
err = s.asUDPConsumer(stream, resp)
default:
err = s.asMulticastConsumer(stream, resp)
}
if err == nil {
s.status = statusPlaying
}
return
}
func (s *Session) checkPermission(right auth.AccessRight) bool {
if s.authMode == auth.NoneAuth {
return true
}
if s.user == nil {
return false
}
return s.user.ValidatePermission(s.path, right)
}
func (s *Session) checkAuth(r *Request) (user *auth.User, err error) {
switch s.authMode {
case auth.BasicAuth:
username, password, has := r.BasicAuth()
if !has {
return nil, errors.New("require legal Authorization field")
}
user := auth.Get(username)
if user == nil {
return nil, errors.New("user not exist")
}
err = user.ValidatePassword(password)
if err != nil {
return nil, err
}
return user, nil
case auth.DigestAuth:
username, response, has := r.DigestAuth()
if !has {
return nil, errors.New("require legal Authorization field")
}
user := auth.Get(username)
if user == nil {
return nil, errors.New("user not exist")
}
resp2 := formatDigestAuthResponse(realm, s.nonce, r.Method, r.URL.String(), username, user.Password)
if resp2 == response {
return user, nil
}
resp2 = formatDigestAuthResponse(realm, s.nonce, r.Method, r.URL.String(), username, user.PasswordMD5())
if resp2 == response {
return user, nil
}
s.nonce = security.NewID().MD5()
return nil, errors.New("require legal Authorization field")
default: // 无需验证
return nil, nil
}
}
func (s *Session) onPreprocess(resp *Response, req *Request) (continueProcess bool, err error) {
// Options 方法无需验证,直接回复
if req.Method == MethodOptions {
resp.Header.Set(FieldPublic, "DESCRIBE, SETUP, TEARDOWN, PLAY, OPTIONS, ANNOUNCE, RECORD")
err = s.response(resp)
return false, err
}
// 关闭请求
if req.Method == MethodTeardown {
// 发送响应
err = s.response(resp)
s.Close()
return false, err
}
// 检查状态下的方法
switch s.status {
case statusReady:
continueProcess = req.Method == MethodSetup ||
req.Method == MethodPlay || req.Method == MethodRecord
case statusPlaying:
continueProcess = req.Method == MethodPlay
case statusRecording:
continueProcess = req.Method == MethodRecord
default:
continueProcess = !(req.Method == MethodPlay || req.Method == MethodRecord)
}
if !continueProcess {
resp.StatusCode = StatusMethodNotValidInThisState
err = s.response(resp)
return false, err
}
// 检查认证
user, err2 := s.checkAuth(req)
if err2 != nil {
resp.StatusCode = StatusUnauthorized
if err2 != nil {
resp.Status = err2.Error()
}
err = s.response(resp)
return false, err
}
s.user = user
return true, nil
}
func (s *Session) response(resp *Response) error {
s.lockW.Lock()
var err error
if s.wsconn != nil { // websocket 客户端
buf := buffers.Get().(*bytes.Buffer)
buf.Reset()
defer buffers.Put(buf)
err = resp.Write(buf) // 保证写入包的完整性,简化前端分包
_, err = s.wsconn.Write(buf.Bytes())
} else {
err = resp.Write(s.conn)
if err == nil {
_, err = s.conn.Flush()
}
}
s.lockW.Unlock()
if err != nil {
s.logger.Errorf("send response error = %v", err)
return err
}
if s.logger.LevelEnabled(xlog.DebugLevel) {
s.logger.Debugf("===>>>\r\n%s", strings.TrimSpace(resp.String()))
}
return nil
}
func (s *Session) newResponse(code int, req *Request) *Response {
resp := &Response{
StatusCode: code,
Header: make(Header),
Request: req,
}
resp.Header.Set(FieldCSeq, req.Header.Get(FieldCSeq))
resp.Header.Set(FieldSession, s.lsession)
// 根据认证模式增加认证所需的字段
switch s.authMode {
case auth.BasicAuth:
resp.SetBasicAuth(realm)
case auth.DigestAuth:
resp.SetDigestAuth(realm, s.nonce)
}
return resp
}
func (s *Session) parseSdp(rawSdp string) (err error) {
// 从流中取 sdp
s.rawSdp = rawSdp
// 解析
s.sdp, err = sdp.ParseString(s.rawSdp)
if err != nil {
return
}
for _, media := range s.sdp.Media {
switch media.Type {
case "video":
s.vControl = media.Attributes.Get("control")
s.vCodec = media.Format[0].Name
case "audio":
s.aControl = media.Attributes.Get("control")
s.aCodec = media.Format[0].Name
}
}
return
}
func getControlPath(ctrl string) (path string, err error) {
if len(ctrl) >= len(rtspURLPrefix) && strings.EqualFold(ctrl[:len(rtspURLPrefix)], rtspURLPrefix) {
var ctrlURL *url.URL
ctrlURL, err = url.Parse(ctrl)
if err != nil {
return "", err
}
if ctrlURL.Port() == "" {
ctrlURL.Host = fmt.Sprintf("%s:554", ctrlURL.Hostname())
}
return ctrlURL.String(), nil
}
return ctrl, nil
}

312
service/rtsp/session_roles.go Executable file
View File

@@ -0,0 +1,312 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"bytes"
"errors"
"fmt"
"net"
"github.com/cnotch/ipchub/config"
"github.com/cnotch/ipchub/media"
"github.com/cnotch/ipchub/utils"
"github.com/cnotch/xlog"
)
var (
errModeBehavior = errors.New("Play mode can't send rtp pack")
defaultStream mediaStream = emptyStream{}
defaultConsumer media.Consumer = emptyConsumer{}
)
// 媒体流
type mediaStream interface {
Close() error
WritePacket(pack *RTPPack) error
}
// 占位流,简化判断
type emptyStream struct {
}
func (s emptyStream) Close() error { return nil }
func (s emptyStream) WritePacket(*RTPPack) error { return errModeBehavior }
// 占位消费者,简化判断
type emptyConsumer struct {
}
func (c emptyConsumer) Consume(p Pack) {}
func (c emptyConsumer) Close() error { return nil }
type tcpPushStream struct {
closed bool
stream *media.Stream
}
func (s *tcpPushStream) Close() error {
if s.closed {
return nil
}
s.closed = true
media.Unregist(s.stream)
s.stream = nil
return nil
}
func (s *tcpPushStream) WritePacket(p *RTPPack) error {
return s.stream.WritePacket(p)
}
type tcpConsumer struct {
*Session
closed bool
source *media.Stream
cid media.CID
}
func (c *tcpConsumer) Consume(p Pack) {
if c.closed {
return
}
p2 := p.(*RTPPack)
var err error
if c.wsconn != nil {
buf := buffers.Get().(*bytes.Buffer)
buf.Reset()
defer buffers.Put(buf)
p2.Write(buf, c.transport.Channels[:])
c.lockW.Lock()
_, err = c.wsconn.Write(buf.Bytes())
c.lockW.Unlock()
} else {
c.lockW.Lock()
err = p2.Write(c.conn, c.transport.Channels[:])
c.lockW.Unlock()
}
if err != nil {
c.logger.Errorf("send pack error = %v , close socket", err)
c.Close()
return
}
}
func (c *tcpConsumer) Close() error {
if c.closed {
return nil
}
c.closed = true
c.source.StopConsume(c.cid)
c.source = nil
return nil
}
type udpConsumer struct {
*Session
closed bool
source *media.Stream
cid media.CID
udpConn *net.UDPConn // 用于Player的UDP单播
destAddr [rtpChannelCount]*net.UDPAddr
}
func (c *udpConsumer) Consume(p Pack) {
if c.closed {
return
}
p2 := p.(*RTPPack)
addr := c.destAddr[int(p2.Channel)]
if addr != nil {
_, err := c.udpConn.WriteToUDP(p2.Data, addr)
if err != nil {
c.logger.Warn(err.Error())
return
}
}
}
func (c *udpConsumer) Close() error {
if c.closed {
return nil
}
c.closed = true
c.source.StopConsume(c.cid)
c.udpConn.Close()
c.source = nil
return nil
}
func (c *udpConsumer) prepareUDP(destIP string, destPorts [rtpChannelCount]int) error {
// 如果还没准备 Socket
if c.udpConn == nil {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
return err
}
c.udpConn = udpConn
err = udpConn.SetWriteBuffer(config.NetBufferSize())
}
for i, port := range destPorts {
if port > 0 {
c.destAddr[i], _ = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", destIP, port))
}
}
return nil
}
type multicastConsumer struct {
*Session
closed bool
source *media.Stream
}
func (c *multicastConsumer) Consume(p Pack) {}
func (c *multicastConsumer) Close() error {
if c.closed {
return nil
}
c.closed = true
c.source.Multicastable().ReleaseMember(c.Session)
c.source = nil
c.Session = nil
return nil
}
// 将Session作为Pusher角色
func (s *Session) asTCPPusher() {
pusher := &tcpPushStream{}
mproxy := &multicastProxy{
path: s.path,
bufferSize: config.NetBufferSize(),
multicastIP: utils.Multicast.NextIP(), // 设置组播IP
ttl: config.MulticastTTL(),
logger: s.logger.With(xlog.Fields(
xlog.F("path", s.path),
xlog.F("type", "multicast-proxy"))),
}
s.logger = s.logger.With(xlog.Fields(
xlog.F("path", s.path),
xlog.F("type", "pusher")))
for i := rtpChannelMin; i < rtpChannelCount; i++ {
mproxy.ports[i] = utils.Multicast.NextPort()
}
pusher.stream = media.NewStream(s.path, s.rawSdp,
media.Attr("addr", s.conn.RemoteAddr().String()),
media.Multicast(mproxy))
media.Regist(pusher.stream)
// 设置Session字段
s.stream = pusher
s.logger.Infof("specify session type")
}
func (s *Session) asTCPConsumer(stream *media.Stream, resp *Response) (err error) {
if s.wsconn != nil {
s.logger = s.logger.With(xlog.Fields(
xlog.F("path", s.path),
xlog.F("type", "websocket-player")))
} else {
s.logger = s.logger.With(xlog.Fields(
xlog.F("path", s.path),
xlog.F("type", "tcp-player")))
}
c := &tcpConsumer{
Session: s,
source: stream,
}
err = s.response(resp)
if err != nil {
return err
}
s.timeout = 0 // play 只需发送不用接收,因此设置不超时
s.consumer = c
if s.wsconn != nil {
c.cid = stream.StartConsumeNoGopCache(s, media.RTPPacket, "net=rtsp-websocket")
} else {
c.cid = stream.StartConsume(s, media.RTPPacket, "net=rtsp-tcp")
}
s.logger.Infof("specify session type")
return
}
func (s *Session) asUDPConsumer(stream *media.Stream, resp *Response) (err error) {
c := &udpConsumer{
Session: s,
}
// 创建udp连接
err = c.prepareUDP(utils.GetIP(s.conn.RemoteAddr()), s.transport.ClientPorts)
if err != nil {
resp.StatusCode = StatusInternalServerError
err = s.response(resp)
if err != nil {
return err
}
return nil
}
s.logger = s.logger.With(xlog.Fields(
xlog.F("path", s.path),
xlog.F("type", "udp-player")))
err = s.response(resp)
if err != nil {
return err
}
s.timeout = 0 // play 只需发送不用接收,因此设置不超时
s.consumer = c
c.cid = stream.StartConsume(s, media.RTPPacket, "net=rtsp-udp")
s.logger.Infof("specify session type")
return nil
}
func (s *Session) asMulticastConsumer(stream *media.Stream, resp *Response) (err error) {
c := &multicastConsumer{
Session: s,
source: stream,
}
ma := stream.Multicastable()
if ma == nil { // 不支持组播
resp.StatusCode = StatusUnsupportedTransport
err = s.response(resp)
if err != nil {
return err
}
return nil
}
s.logger = s.logger.With(xlog.Fields(
xlog.F("path", s.path),
xlog.F("type", "multicast-player")))
err = s.response(resp)
if err != nil {
return err
}
c.timeout = 0 // play 只需发送不用接收,因此设置不超时
s.consumer = c
ma.AddMember(s)
s.logger.Infof("specify session type")
return nil
}

176
service/rtsp/types.go Normal file
View File

@@ -0,0 +1,176 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package rtsp
import (
"github.com/cnotch/ipchub/media/cache"
"github.com/cnotch/ipchub/protos/rtp"
"github.com/cnotch/ipchub/protos/rtsp"
)
// Pack .
type Pack = cache.Pack
// .
var (
ReadPacket = rtp.ReadPacket
ReadResponse = rtsp.ReadResponse
ReadRequest = rtsp.ReadRequest
)
// Request .
type Request = rtsp.Request
//Response .
type Response = rtsp.Response
// Header .
type Header = rtsp.Header
// RTPPack .
type RTPPack = rtp.Packet
const (
rtpPackPrefix = rtp.TransferPrefix
rtspProto = "RTSP/1.0" // RTSP协议版本
rtspURLPrefix = "rtsp://" // RTSP地址前缀
basicAuthPrefix = "Basic " // 用户基础验证前缀
digestAuthPrefix = "Digest " // 摘要认证前缀
)
// 预定义 RTP 包类型
const (
ChannelVideo = rtp.ChannelVideo
ChannelVideoControl = rtp.ChannelVideoControl
ChannelAudio = rtp.ChannelAudio
ChannelAudioControl = rtp.ChannelAudioControl
rtpChannelCount = rtp.ChannelCount
rtpChannelMin = rtp.ChannelMin
)
// 通用的 RTSP 方法。
//
// 除非特别说明,这些定义在 RFC2326 规范的 10 章中。
// 未实现的方法需要返回 "501 Not Implemented"
const (
MethodOptions = rtsp.MethodOptions // 查询命令支持情况(C->S, S->C)
MethodDescribe = rtsp.MethodDescribe // 获取媒体信息(C->S)
MethodAnnounce = rtsp.MethodAnnounce // 声明要push的媒体信息(方向C->S, S->C)
MethodSetup = rtsp.MethodSetup // 构建传输会话,也可以调整传输参数(C->S);如果不允许调整,可以返回 455 错误
MethodPlay = rtsp.MethodPlay // 开始发送媒体数据(C->S)
MethodPause = rtsp.MethodPause // 暂停发送媒体数据(C->S)
MethodTeardown = rtsp.MethodTeardown // 关闭发送通道;关闭后需要重新执行 Setup 方法(C->S)
MethodGetParameter = rtsp.MethodGetParameter // 获取参数空body可作为心跳ping(C->S, S->C)
MethodSetParameter = rtsp.MethodSetParameter // 设置参数,应该每次只设置一个参数(C->S, S->C)
MethodRecord = rtsp.MethodRecord // 启动录像(C->S)
MethodRedirect = rtsp.MethodRedirect // 跳转(S->C)
)
// RTSP 头部域定义
// type:support:methods
// type: "g"通用的请求头部;"R"请求头部;"r"响应头部;"e"实体Body头部域。
// support: opt. 可选; req. 必须
// methods: 头部域应用范围
const (
FieldAccept = rtsp.FieldAccept // (R:opt.:entity)
FieldAcceptEncoding = rtsp.FieldAcceptEncoding // (R:opt.:entity)
FieldAcceptLanguage = rtsp.FieldAcceptLanguage // (R:opt.:all)
FieldAllow = rtsp.FieldAllow // (R:opt.:all)
FieldAuthorization = rtsp.FieldAuthorization // (R:opt.:all)
FieldBandwidth = rtsp.FieldBandwidth // (R:opt.all)
FieldBlocksize = rtsp.FieldBlocksize // (R:opt.:all but OPTIONS, TEARDOWN)
FieldCacheControl = rtsp.FieldCacheControl // (g:opt.:SETUP)
FieldConference = rtsp.FieldConference // (R:opt.:SETUP)
FieldConnection = rtsp.FieldConnection // (g:req.:all)
FieldContentBase = rtsp.FieldContentBase // (e:opt.:entity)
FieldContentEncoding = rtsp.FieldContentEncoding // (e:req.:SET_PARAMETER ; e:req.:DESCRIBE, ANNOUNCE )
FieldContentLanguage = rtsp.FieldContentLanguage // (e:req.:DESCRIBE, ANNOUNCE)
FieldContentLength = rtsp.FieldContentLength // (e:req.:SET_PARAMETER, ANNOUNCE; e:req.:entity)
FieldContentLocation = rtsp.FieldContentLocation // (e:opt.:entity)
FieldContentType = rtsp.FieldContentType // (e:req.:SET_PARAMETER, ANNOUNCE; r:req.:entity )
FieldCSeq = rtsp.FieldCSeq // (g:req.:all)
FieldDate = rtsp.FieldDate // (g:opt.:all)
FieldExpires = rtsp.FieldExpires // (e:opt.:DESCRIBE, ANNOUNCE)
FieldFrom = rtsp.FieldFrom // (R:opt.:all)
FieldIfModifiedSince = rtsp.FieldIfModifiedSince // (R:opt.:DESCRIBE, SETUP)
FieldLastModified = rtsp.FieldLastModified // (e:opt.:entity)
FieldProxyAuthenticate = rtsp.FieldProxyAuthenticate //
FieldProxyRequire = rtsp.FieldProxyRequire // (R:req.:all)
FieldPublic = rtsp.FieldPublic // (r:opt.:all)
FieldRange = rtsp.FieldRange // (R:opt.:PLAY, PAUSE, RECORD; r:opt.:PLAY, PAUSE, RECORD)
FieldReferer = rtsp.FieldReferer // (R:opt.:all)
FieldRequire = rtsp.FieldRequire // (R:req.:all)
FieldRetryAfter = rtsp.FieldRetryAfter // (r:opt.:all)
FieldRTPInfo = rtsp.FieldRTPInfo // (r:req.:PLAY)
FieldScale = rtsp.FieldScale // (Rr:opt.:PLAY, RECORD)
FieldSession = rtsp.FieldSession // (Rr:req.:all but SETUP, OPTIONS)
FieldServer = rtsp.FieldServer // (r:opt.:all)
FieldSpeed = rtsp.FieldSpeed // (Rr:opt.:PLAY)
FieldTransport = rtsp.FieldTransport // (Rr:req.:SETUP)
FieldUnsupported = rtsp.FieldUnsupported // (r:req.:all)
FieldUserAgent = rtsp.FieldUserAgent // (R:opt.:all)
FieldVia = rtsp.FieldVia // (g:opt.:all)
FieldWWWAuthenticate = rtsp.FieldWWWAuthenticate // (r:opt.:all)
)
// RTSP 响应状态码.
// See: https://tools.ietf.org/html/rfc2326#page-19
const (
StatusContinue = rtsp.StatusContinue
//======Success 2xx
StatusOK = rtsp.StatusOK
StatusCreated = rtsp.StatusCreated // only for RECORD
StatusLowOnStorageSpace = rtsp.StatusLowOnStorageSpace //only for RECORD
//======Redirection 3xx
StatusMultipleChoices = rtsp.StatusMultipleChoices
StatusMovedPermanently = rtsp.StatusMovedPermanently
StatusMovedTemporarily = rtsp.StatusMovedTemporarily // 和http不同
StatusSeeOther = rtsp.StatusSeeOther
StatusNotModified = rtsp.StatusNotModified
StatusUseProxy = rtsp.StatusUseProxy
//======Client Error 4xx
StatusBadRequest = rtsp.StatusBadRequest
StatusUnauthorized = rtsp.StatusUnauthorized
StatusPaymentRequired = rtsp.StatusPaymentRequired
StatusForbidden = rtsp.StatusForbidden
StatusNotFound = rtsp.StatusNotFound
StatusMethodNotAllowed = rtsp.StatusMethodNotAllowed
StatusNotAcceptable = rtsp.StatusNotAcceptable
StatusProxyAuthRequired = rtsp.StatusProxyAuthRequired
StatusRequestTimeout = rtsp.StatusRequestTimeout
StatusGone = rtsp.StatusGone
StatusLengthRequired = rtsp.StatusLengthRequired
StatusPreconditionFailed = rtsp.StatusPreconditionFailed // only for DESCRIBE, SETUP
StatusRequestEntityTooLarge = rtsp.StatusRequestEntityTooLarge
StatusRequestURITooLong = rtsp.StatusRequestURITooLong
StatusUnsupportedMediaType = rtsp.StatusUnsupportedMediaType
StatusInvalidParameter = rtsp.StatusInvalidParameter // only for SETUP
StatusConferenceNotFound = rtsp.StatusConferenceNotFound // only for SETUP
StatusNotEnoughBandwidth = rtsp.StatusNotEnoughBandwidth // only for SETUP
StatusSessionNotFound = rtsp.StatusSessionNotFound
StatusMethodNotValidInThisState = rtsp.StatusMethodNotValidInThisState
StatusHeaderFieldNotValid = rtsp.StatusHeaderFieldNotValid
StatusInvalidRange = rtsp.StatusInvalidRange // only for PLAY
StatusParameterIsReadOnly = rtsp.StatusParameterIsReadOnly // only for SET_PARAMETER
StatusAggregateOpNotAllowed = rtsp.StatusAggregateOpNotAllowed
StatusOnlyAggregateOpAllowed = rtsp.StatusOnlyAggregateOpAllowed
StatusUnsupportedTransport = rtsp.StatusUnsupportedTransport
StatusDestinationUnreachable = rtsp.StatusDestinationUnreachable
StatusInternalServerError = rtsp.StatusInternalServerError
StatusNotImplemented = rtsp.StatusNotImplemented
StatusBadGateway = rtsp.StatusBadGateway
StatusServiceUnavailable = rtsp.StatusServiceUnavailable
StatusGatewayTimeout = rtsp.StatusGatewayTimeout
StatusRTSPVersionNotSupported = rtsp.StatusRTSPVersionNotSupported
StatusOptionNotSupported = rtsp.StatusOptionNotSupported // 和 http 不同
)
// StatusText .
var StatusText = rtsp.StatusText
var formatDigestAuthResponse = rtsp.FormatDigestAuthResponse

170
service/wsp/protocol.go Executable file
View File

@@ -0,0 +1,170 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package wsp
import (
"bytes"
"fmt"
"io"
"strings"
"sync"
"unicode"
"github.com/cnotch/ipchub/utils/scan"
"github.com/cnotch/xlog"
)
const (
wspProto = "WSP/1.1" // WSP协议版本
prefixBody = "\r\n\r\n" // Header和Body分割符
)
// WSP 协议命令
const (
CmdInit = "INIT" // 初始化建立通道
CmdJoin = "JOIN" // 数据通道使用
CmdWrap = "WRAP" // 包装其他协议的命令
CmdGetInfo = "GET_INFO" // 获取客户及license信息
)
// WSP 协议字段
const (
FieldProto = "proto" // 初始化的协议 如rtsp
FieldSeq = "seq" // 命令序列
FieldHost = "host" // 需要代理服务访问的远端host
FieldPort = "port" // 需要代理服务访问的远端port
FieldClient = "client" // 客户信息
FieldChannel = "channel" // 数据通道编号相当于一个session
FieldSocket = "socket" // 代替上面的host和port
)
type badStringError struct {
what string
str string
}
func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
// Request WSP 协议请求
type Request struct {
Cmd string
Header map[string]string
Body string
}
var (
spacePair = scan.NewPair(' ',
func(r rune) bool {
return unicode.IsSpace(r)
})
validCmds = map[string]bool{
CmdGetInfo: true,
CmdInit: true,
CmdJoin: true,
CmdWrap: true,
}
bspool = &sync.Pool{
New: func() interface{} {
return make([]byte, 8*1024)
},
}
buffers = sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, 1024*2))
},
}
)
// DecodeStringRequest 解码字串请求
func DecodeStringRequest(input string) (*Request, error) {
index := strings.Index(input, prefixBody)
if index < 0 {
return nil, &badStringError{"malformed WSP request,missing '\\r\\n\\r\\n'", input}
}
req := &Request{
Body: input[index+4:],
Header: make(map[string]string, 4),
}
scanner := scan.Line
// 先取首行
tailing, substr, ok := scanner.Scan(input[:index])
if !ok {
return nil, &badStringError{"malformed WSP request first line", substr}
}
proto, cmd, ok := spacePair.Scan(substr)
if proto != wspProto {
return nil, &badStringError{"malformed WSP request proto ", proto}
}
if _, ok := validCmds[cmd]; !ok {
return nil, &badStringError{"malformed WSP request command ", cmd}
}
req.Cmd = cmd
// 循环取header
for ok {
tailing, substr, ok = scanner.Scan(tailing)
k, v, found := scan.ColonPair.Scan(substr)
if found {
req.Header[k] = v
}
}
return req, nil
}
// DecodeRequest 解码请求
func DecodeRequest(r io.Reader, logger *xlog.Logger) (*Request, error) {
buf := bspool.Get().([]byte)
defer bspool.Put(buf)
n, err := r.Read(buf)
if n == 0 && err == nil { // 上一个报文结束,再读一次
n, err = r.Read(buf)
}
if err != nil {
return nil, err
}
input := string(buf[:n])
logger.Debugf("wsp <<<=== \r\n%s", input)
return DecodeStringRequest(input)
}
// IsWrap 是否是包装协议如果是可以从Body提取被包装的协议
func (req *Request) IsWrap() bool {
return req.Cmd == CmdWrap
}
// ResponseOK 响应请求成功
func (req *Request) ResponseOK(buf *bytes.Buffer, header map[string]string, payload string) {
req.ResponseTo(buf, 200, "OK", header, payload)
}
// ResponseTo 响应请求到buf
func (req *Request) ResponseTo(buf *bytes.Buffer, statusCode int, statusText string, header map[string]string, payload string) {
// 写首行
buf.WriteString(fmt.Sprintf("%s %d %s\r\n", wspProto, statusCode, statusText))
// 写头
header[FieldSeq] = req.Header[FieldSeq]
for k, v := range header {
buf.WriteString(k)
buf.WriteString(": ")
buf.WriteString(v)
buf.WriteString("\r\n")
}
// 写header和body分割
buf.WriteString("\r\n")
// 写payload
if len(payload) > 0 {
buf.WriteString(payload)
}
}

69
service/wsp/protocol_test.go Executable file
View File

@@ -0,0 +1,69 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package wsp
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"github.com/cnotch/xlog"
)
func TestDecodeRequest(t *testing.T) {
var reqStr = "WSP/1.1 GET_INFO\r\nproto: rtsp\r\nhost: 192.168.1.1\r\nport: 554\r\nclient: \r\nseq: 1\r\n\r\n"
t.Run("decode", func(t *testing.T) {
r := bytes.NewBufferString(reqStr)
got, err := DecodeRequest(r, xlog.L())
if err != nil {
t.Errorf("DecodeRequest() error = %v", err)
return
}
assert.Equal(t, CmdGetInfo, got.Cmd)
assert.Equal(t, "1", got.Header[FieldSeq])
})
}
func TestRequest_ResponseOK(t *testing.T) {
respStr1 := "WSP/1.1 200 OK\r\nchannel: 334\r\nseq: 1\r\n\r\n"
respStr2 := "WSP/1.1 404 NOT FOUND\r\nchannel: 334\r\nseq: 1\r\n\r\n123"
t.Run("no payload", func(t *testing.T) {
req := &Request{
Header: make(map[string]string),
}
req.Header[FieldSeq] = "1"
buf := &bytes.Buffer{}
header := make(map[string]string)
header[FieldChannel] = "334"
req.ResponseOK(buf, header, "")
resp := buf.String()
assert.Equal(t, respStr1, resp)
})
t.Run("payload", func(t *testing.T) {
req := &Request{
Header: make(map[string]string),
}
req.Header[FieldSeq] = "1"
buf := &bytes.Buffer{}
header := make(map[string]string)
header[FieldChannel] = "334"
req.ResponseTo(buf, 404, "NOT FOUND", header, "123")
resp := buf.String()
assert.Equal(t, respStr2, resp)
})
}
func Benchmark_DecodeRequest(b *testing.B) {
var reqStr = "WSP/1.1 GET_INFO\r\nproto: rtsp\r\nhost: 192.168.1.1\r\nport: 554\r\nclient: \r\nseq: 1\r\n\r\n"
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
got, _ := DecodeStringRequest(reqStr)
_ = got
}
})
}

465
service/wsp/session.go Executable file
View File

@@ -0,0 +1,465 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package wsp
import (
"bufio"
"bytes"
"fmt"
"io"
"net/url"
"runtime/debug"
"strings"
"sync"
"time"
"github.com/cnotch/ipchub/config"
"github.com/cnotch/ipchub/media"
"github.com/cnotch/ipchub/media/cache"
"github.com/cnotch/ipchub/network/websocket"
"github.com/cnotch/ipchub/provider/security"
"github.com/cnotch/ipchub/service/rtsp"
"github.com/cnotch/ipchub/stats"
"github.com/cnotch/xlog"
"github.com/pixelbender/go-sdp/sdp"
)
// Pack .
type Pack = cache.Pack
const (
statusInit = iota
statusReady
statusPlaying
)
const (
rtspURLPrefix = "rtsp://" // RTSP地址前缀
)
// Session RTSP 会话
type Session struct {
// 创建时设置
svr *Server
channelID string
logger *xlog.Logger
closed bool
paused bool
lsession string // 本地会话标识
timeout time.Duration
conn websocket.Conn
lockW sync.Mutex
dataChannel websocket.Conn
// DESCRIBE 后设置
url *url.URL
path string
rawSdp string
sdp *sdp.Session
aControl string
vControl string
aCodec string
vCodec string
// Setup 后设置
transport rtsp.RTPTransport
// 启动流媒体传输后设置
status int // session状态
source *media.Stream
cid *media.CID
}
func newSession(svr *Server, conn websocket.Conn, channelID string) *Session {
session := &Session{
svr: svr,
channelID: channelID,
lsession: security.NewID().Base64(),
timeout: config.NetTimeout() * time.Duration(2),
conn: conn,
transport: rtsp.RTPTransport{
Mode: rtsp.PlaySession, // 默认为播放
Type: rtsp.RTPUnknownTrans,
},
status: statusInit,
}
for i := 0; i < 4; i++ {
session.transport.Channels[i] = -1
session.transport.ClientPorts[i] = -1
}
session.logger = svr.logger.With(xlog.Fields(
xlog.F("channel", channelID),
xlog.F("path", conn.Path())))
return session
}
// 设置rtp数据通道
func (s *Session) setDataChannel(dc websocket.Conn) {
s.lockW.Lock()
s.dataChannel = dc
s.lockW.Unlock()
}
// Addr Session地址
func (s *Session) Addr() string {
return s.conn.RemoteAddr().String()
}
// Consume 消费媒体包
func (s *Session) Consume(p Pack) {
if s.closed || s.paused {
return
}
buf := buffers.Get().(*bytes.Buffer)
buf.Reset()
defer buffers.Put(buf)
p2 := p.(*rtsp.RTPPack)
p2.Write(buf, s.transport.Channels[:])
var err error
s.lockW.Lock()
if s.dataChannel != nil {
_, err = s.dataChannel.Write(buf.Bytes())
}
s.lockW.Unlock()
if err != nil {
s.logger.Errorf("send pack error = %v , close socket", err)
s.Close()
return
}
}
// Close 关闭会话
func (s *Session) Close() error {
if s.closed {
return nil
}
s.closed = true
s.paused = false
s.conn.Close()
s.lockW.Lock()
if s.dataChannel != nil {
s.dataChannel.Close()
}
s.lockW.Unlock()
return nil
}
func (s *Session) process() {
var err error
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("wsp channel panic, %v \n %s", r, debug.Stack())
}
if err != nil {
if err == io.EOF { // 如果客户端断开提醒
s.logger.Warn("websocket disconnect actively")
} else if !s.closed { // 如果主动关闭,不提示
s.logger.Error(err.Error())
}
}
// 删除通道
s.svr.sessions.Delete(s.channelID)
// 停止消费
if s.cid != nil {
s.source.StopConsume(*s.cid)
s.cid = nil
s.source = nil
}
// 关闭连接
s.Close()
// 重置到初始状态
s.conn = nil
s.dataChannel = nil
s.status = statusInit
stats.WspConns.Release()
s.logger.Info("close wsp channel")
}()
s.logger.Info("open wsp channel")
stats.WspConns.Add() // 增加一个 RTSP 连接计数
for !s.closed {
deadLine := time.Time{}
if s.timeout > 0 {
deadLine = time.Now().Add(s.timeout)
}
if err = s.conn.SetReadDeadline(deadLine); err != nil {
break
}
var req *Request
req, err = DecodeRequest(s.conn, s.logger)
if err != nil {
break
}
if req.Cmd != CmdWrap {
s.logger.Error("must is WRAP command request")
break
}
// 从包装命令中提取 rtsp 请求
var rtspReq *rtsp.Request
rtspReq, err = rtsp.ReadRequest(bufio.NewReader(bytes.NewBufferString(req.Body)))
if err != nil {
break
}
// 处理请求,并获得响应
rtspResp := s.onRequest(rtspReq)
// 发送响应
buf := buffers.Get().(*bytes.Buffer)
buf.Reset()
defer buffers.Put(buf)
req.ResponseOK(buf, map[string]string{FieldChannel: s.channelID}, "")
rtspResp.Write(buf)
_, err = s.conn.Write(buf.Bytes())
if err != nil {
break
}
s.logger.Debugf("wsp ===>>>\r\n%s", buf.String())
// 关闭通道
if rtspReq.Method == rtsp.MethodTeardown {
break
}
}
}
func (s *Session) onRequest(req *rtsp.Request) *rtsp.Response {
resp := s.newResponse(rtsp.StatusOK, req)
// 预处理
continueProcess := s.onPreprocess(resp, req)
if !continueProcess {
return resp
}
switch req.Method {
case rtsp.MethodDescribe:
s.onDescribe(resp, req)
case rtsp.MethodSetup:
s.onSetup(resp, req)
case rtsp.MethodPlay:
s.onPlay(resp, req)
case rtsp.MethodPause:
s.onPause(resp, req)
default:
// 状态不支持的方法
resp.StatusCode = rtsp.StatusMethodNotValidInThisState
}
return resp
}
func (s *Session) onDescribe(resp *rtsp.Response, req *rtsp.Request) {
// TODO: 检查 accept 中的类型是否包含 sdp
s.url = req.URL
s.path = s.conn.Path() // 使用websocket路径
// s.path = utils.CanonicalPath(req.URL.Path)
stream := media.GetOrCreate(s.path)
if stream == nil {
resp.StatusCode = rtsp.StatusNotFound
return
}
// 从流中取 sdp
sdpRaw := stream.Attr("sdp")
if len(sdpRaw) == 0 {
resp.StatusCode = rtsp.StatusNotFound
return
}
err := s.parseSdp(sdpRaw)
if err != nil { // TODO需要更好的处理方式
resp.StatusCode = rtsp.StatusNotFound
return
}
resp.Header.Set(rtsp.FieldContentType, "application/sdp")
resp.Body = s.rawSdp
}
func (s *Session) onSetup(resp *rtsp.Response, req *rtsp.Request) {
// a=control:streamid=1
// a=control:rtsp://192.168.1.165/trackID=1
// a=control:?ctype=video
setupURL := &url.URL{}
*setupURL = *req.URL
if setupURL.Port() == "" {
setupURL.Host = fmt.Sprintf("%s:554", setupURL.Host)
}
setupPath := setupURL.String()
//setupPath = setupPath[strings.LastIndex(setupPath, "/")+1:]
vPath := getControlPath(s.vControl)
if vPath == "" {
resp.StatusCode = rtsp.StatusInternalServerError
resp.Status = "Invalid VControl"
return
}
aPath := getControlPath(s.aControl)
ts := req.Header.Get(rtsp.FieldTransport)
resp.Header.Set(rtsp.FieldTransport, ts) // 先回写transport
// 检查控制路径
chindex := -1
if setupPath == aPath || (aPath != "" && strings.LastIndex(setupPath, aPath) == len(setupPath)-len(aPath)) {
chindex = int(rtsp.ChannelAudio)
} else if setupPath == vPath || (vPath != "" && strings.LastIndex(setupPath, vPath) == len(setupPath)-len(vPath)) {
chindex = int(rtsp.ChannelVideo)
} else { // 找不到被 Setup 的资源
resp.StatusCode = rtsp.StatusInternalServerError
resp.Status = fmt.Sprintf("SETUP Unkown control:%s", setupPath)
return
}
err := s.transport.ParseTransport(chindex, ts)
if err != nil {
resp.StatusCode = rtsp.StatusInvalidParameter
resp.Status = err.Error()
return
}
// 检查必须是play模式
if rtsp.PlaySession != s.transport.Mode {
resp.StatusCode = rtsp.StatusInvalidParameter
resp.Status = "can't setup as record"
return
}
if s.transport.Type != rtsp.RTPTCPUnicast { // 需要修改回复的transport
resp.StatusCode = rtsp.StatusUnsupportedTransport
resp.Status = "websocket only support tcp unicast"
return
}
if s.status < statusReady { // 初始状态切换到Ready
s.status = statusReady
}
}
func (s *Session) onPlay(resp *rtsp.Response, req *rtsp.Request) {
if s.status == statusPlaying {
s.paused = false
return
}
stream := media.GetOrCreate(s.path)
if stream == nil {
resp.StatusCode = rtsp.StatusNotFound
return
}
resp.Header.Set(rtsp.FieldRange, req.Header.Get(rtsp.FieldRange))
if s.cid == nil {
s.source = stream
// cid := stream.StartConsume(s)
cid := stream.StartConsumeNoGopCache(s, media.RTPPacket, "wsp")
s.cid = &cid
}
s.status = statusPlaying
s.paused = false
return
}
func (s *Session) onPause(resp *rtsp.Response, req *rtsp.Request) {
if s.status == statusPlaying {
s.paused = true
}
}
func (s *Session) onPreprocess(resp *rtsp.Response, req *rtsp.Request) (continueProcess bool) {
// Options
if req.Method == rtsp.MethodOptions {
resp.Header.Set(rtsp.FieldPublic, "DESCRIBE, SETUP, TEARDOWN, PLAY, OPTIONS, ANNOUNCE")
return false
}
// 关闭请求
if req.Method == rtsp.MethodTeardown {
return false
}
// 检查状态下的方法
switch s.status {
case statusReady:
continueProcess = req.Method == rtsp.MethodSetup ||
req.Method == rtsp.MethodPlay
case statusPlaying:
continueProcess = (req.Method == rtsp.MethodPlay ||
req.Method == rtsp.MethodPause)
default:
continueProcess = !(req.Method == rtsp.MethodPlay ||
req.Method == rtsp.MethodRecord)
}
if !continueProcess {
resp.StatusCode = rtsp.StatusMethodNotValidInThisState
return false
}
return true
}
func (s *Session) newResponse(code int, req *rtsp.Request) *rtsp.Response {
resp := &rtsp.Response{
StatusCode: code,
Header: make(rtsp.Header),
Request: req,
}
resp.Header.Set(rtsp.FieldCSeq, req.Header.Get(rtsp.FieldCSeq))
resp.Header.Set(rtsp.FieldSession, s.lsession)
return resp
}
func (s *Session) parseSdp(rawSdp string) (err error) {
// 从流中取 sdp
s.rawSdp = rawSdp
// 解析
s.sdp, err = sdp.ParseString(s.rawSdp)
if err != nil {
return
}
for _, media := range s.sdp.Media {
switch media.Type {
case "video":
s.vControl = media.Attributes.Get("control")
s.vCodec = media.Format[0].Name
case "audio":
s.aControl = media.Attributes.Get("control")
s.aCodec = media.Format[0].Name
}
}
return
}
func getControlPath(ctrl string) (path string) {
if len(ctrl) >= len(rtspURLPrefix) && strings.EqualFold(ctrl[:len(rtspURLPrefix)], rtspURLPrefix) {
ctrlURL, err := url.Parse(ctrl)
if err != nil {
return
}
if ctrlURL.Port() == "" {
ctrlURL.Host = fmt.Sprintf("%s:554", ctrlURL.Hostname())
}
return ctrlURL.String()
}
return ctrl
}

124
service/wsp/wsp.go Executable file
View File

@@ -0,0 +1,124 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package wsp
import (
"bytes"
"net"
"sync"
"github.com/cnotch/ipchub/network/websocket"
"github.com/cnotch/ipchub/provider/security"
"github.com/cnotch/xlog"
"github.com/kelindar/tcp"
)
// Server https://github.com/Streamedian/html5_rtsp_player 客户端配套的服务器
type Server struct {
logger *xlog.Logger
sessions sync.Map
}
// CreateAcceptHandler 创建连接接入处理器
func CreateAcceptHandler() tcp.OnAccept {
svr := &Server{
logger: xlog.L(),
}
return svr.onAcceptConn
}
// onAcceptConn 当新连接接入时触发
func (svr *Server) onAcceptConn(c net.Conn) {
wsc := c.(websocket.Conn)
if wsc.Subprotocol() == "control" {
go svr.handshakeControlChannel(wsc)
} else {
go svr.handshakeDataChannel(wsc)
}
}
func (svr *Server) handshakeControlChannel(wsc websocket.Conn) {
svr.logger.Info("wsp control channel handshake.")
wsc = wsc.TextTransport()
for {
req, err := DecodeRequest(wsc, svr.logger)
if err != nil {
svr.logger.Error(err.Error())
wsc.Close()
break
}
if req.Cmd == CmdGetInfo {
continue
}
if req.Cmd != CmdInit {
svr.logger.Errorf("wsp control channel handshake failed, malformed WSP request command: %s.", req.Cmd)
wsc.Close()
break
}
// 初始化
channelID := security.NewID().String()
buf := buffers.Get().(*bytes.Buffer)
buf.Reset()
defer buffers.Put(buf)
req.ResponseOK(buf, map[string]string{FieldChannel: channelID}, "")
_, err = wsc.Write(buf.Bytes())
if err != nil {
svr.logger.Error(err.Error())
wsc.Close()
break
}
session := newSession(svr, wsc, channelID)
svr.sessions.Store(channelID, session)
svr.logger.Debugf("wsp ===>>> \r\n%s", buf.String())
go session.process()
break
}
}
func (svr *Server) handshakeDataChannel(wsc websocket.Conn) {
tc := wsc.TextTransport()
req, err := DecodeRequest(tc, svr.logger)
if err != nil {
svr.logger.Error(err.Error())
tc.Close()
return
}
channelID := req.Header[FieldChannel]
code := 200
text := "OK"
var session *Session
si, ok := svr.sessions.Load(channelID)
if ok {
session = si.(*Session)
} else {
code = 404
text = "NOT FOUND"
}
buf := buffers.Get().(*bytes.Buffer)
buf.Reset()
defer buffers.Put(buf)
req.ResponseTo(buf, code, text, map[string]string{}, "")
_, err = tc.Write(buf.Bytes())
if err != nil {
svr.logger.Error(err.Error())
tc.Close()
return
}
svr.logger.Debugf("wsp ===>>> \r\n%s", buf.String())
if session == nil {
tc.Close()
return
}
// 添加到session
session.setDataChannel(wsc)
}