mirror of
https://github.com/langhuihui/monibuca.git
synced 2025-12-24 13:48:04 +08:00
299 lines
7.8 KiB
Go
299 lines
7.8 KiB
Go
package plugin_rtsp
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"m7s.live/v5/pkg/util"
|
||
. "m7s.live/v5/plugin/rtsp/pkg"
|
||
)
|
||
|
||
type RTSPServer struct {
|
||
*NetConnection
|
||
conf *RTSPPlugin
|
||
}
|
||
|
||
func (task *RTSPServer) Go() (err error) {
|
||
var receiver *Receiver
|
||
var sender *Sender
|
||
var req *util.Request
|
||
var sendMode bool
|
||
|
||
// 添加延迟函数在方法结束时清理资源
|
||
defer func() {
|
||
if receiver != nil {
|
||
receiver.Dispose()
|
||
}
|
||
if sender != nil {
|
||
sender.Dispose()
|
||
}
|
||
// 确保任何残留资源被清理
|
||
if task.NetConnection != nil {
|
||
task.NetConnection.Dispose()
|
||
}
|
||
task.Info("RTSP connection closed and resources cleaned up")
|
||
}()
|
||
|
||
for {
|
||
req, err = task.ReadRequest()
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
if task.URL == nil {
|
||
task.URL = req.URL
|
||
task.Logger = task.With("url", task.URL.String())
|
||
task.UserAgent = req.Header.Get("User-Agent")
|
||
task.Info("connect", "userAgent", task.UserAgent)
|
||
}
|
||
|
||
//if !c.auth.Validate(req) {
|
||
// res := &tcp.Response{
|
||
// Status: "401 Unauthorized",
|
||
// Header: map[string][]string{"Www-Authenticate": {`Basic realm="go2rtc"`}},
|
||
// Request: req,
|
||
// }
|
||
// if err = c.WriteResponse(res); err != nil {
|
||
// return err
|
||
// }
|
||
// continue
|
||
//}
|
||
|
||
// Receiver: OPTIONS > DESCRIBE > SETUP... > PLAY > TEARDOWN
|
||
// Sender: OPTIONS > ANNOUNCE > SETUP... > RECORD > TEARDOWN
|
||
switch req.Method {
|
||
case MethodOptions:
|
||
res := &util.Response{
|
||
Header: map[string][]string{
|
||
"Public": {"OPTIONS, SETUP, TEARDOWN, DESCRIBE, PLAY, PAUSE, ANNOUNCE, RECORD"},
|
||
},
|
||
Request: req,
|
||
}
|
||
if err = task.WriteResponse(res); err != nil {
|
||
return
|
||
}
|
||
|
||
case MethodAnnounce:
|
||
if req.Header.Get("Content-Type") != "application/sdp" {
|
||
err = errors.New("wrong content type")
|
||
return
|
||
}
|
||
|
||
task.SDP = string(req.Body) // for info
|
||
var medias []*Media
|
||
if medias, err = UnmarshalSDP(req.Body); err != nil {
|
||
return
|
||
}
|
||
|
||
receiver = &Receiver{}
|
||
receiver.NetConnection = task.NetConnection
|
||
if receiver.Publisher, err = task.conf.Publish(task, strings.TrimPrefix(task.URL.Path, "/")); err != nil {
|
||
receiver = nil
|
||
err = task.WriteResponse(&util.Response{
|
||
StatusCode: 500, Status: err.Error(),
|
||
})
|
||
return
|
||
}
|
||
receiver.Publisher.RemoteAddr = task.Conn.RemoteAddr().String()
|
||
if err = receiver.SetMedia(medias); err != nil {
|
||
return
|
||
}
|
||
res := &util.Response{Request: req}
|
||
if err = task.WriteResponse(res); err != nil {
|
||
return
|
||
}
|
||
task.Depend(receiver.Publisher)
|
||
case MethodDescribe:
|
||
sendMode = true
|
||
sender = &Sender{}
|
||
sender.NetConnection = task.NetConnection
|
||
rawQuery := req.URL.RawQuery
|
||
streamPath := strings.TrimPrefix(task.URL.Path, "/")
|
||
if rawQuery != "" {
|
||
streamPath += "?" + rawQuery
|
||
}
|
||
sender.Subscriber, err = task.conf.Subscribe(task, streamPath)
|
||
if err != nil {
|
||
res := &util.Response{
|
||
StatusCode: http.StatusBadRequest,
|
||
Status: err.Error(),
|
||
Request: req,
|
||
}
|
||
_ = task.WriteResponse(res)
|
||
return
|
||
}
|
||
sender.Subscriber.RemoteAddr = task.Conn.RemoteAddr().String()
|
||
res := &util.Response{
|
||
Header: map[string][]string{
|
||
"Content-Type": {"application/sdp"},
|
||
},
|
||
Request: req,
|
||
}
|
||
// convert tracks to real output medias
|
||
var medias []*Media
|
||
if medias, err = sender.GetMedia(); err != nil {
|
||
return
|
||
}
|
||
if res.Body, err = MarshalSDP(task.SessionName, medias); err != nil {
|
||
return
|
||
}
|
||
|
||
task.SDP = string(res.Body) // for info
|
||
|
||
if err = task.WriteResponse(res); err != nil {
|
||
return
|
||
}
|
||
|
||
case MethodSetup:
|
||
tr := req.Header.Get("Transport")
|
||
res := &util.Response{
|
||
Header: map[string][]string{},
|
||
Request: req,
|
||
}
|
||
|
||
// TCP传输模式
|
||
const tcpTransport = "RTP/AVP/TCP;unicast;interleaved="
|
||
// UDP传输模式前缀
|
||
const udpTransport = "RTP/AVP"
|
||
|
||
if strings.HasPrefix(tr, tcpTransport) {
|
||
task.Debug("into tcp play")
|
||
// 原有的TCP传输处理逻辑
|
||
task.Session = util.RandomString(10)
|
||
|
||
if sendMode {
|
||
if i := reqTrackID(req); i >= 0 {
|
||
tr = fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", i*2, i*2+1)
|
||
res.Header.Set("Transport", tr)
|
||
} else {
|
||
res.Status = "400 Bad Request"
|
||
}
|
||
} else {
|
||
res.Header.Set("Transport", tr[:len(tcpTransport)+3])
|
||
}
|
||
} else if strings.HasPrefix(tr, udpTransport) && strings.Contains(tr, "unicast") && strings.Contains(tr, "client_port=") {
|
||
task.Debug("into udp play")
|
||
|
||
// UDP传输处理逻辑
|
||
task.Session = util.RandomString(10)
|
||
|
||
if sendMode {
|
||
if i := reqTrackID(req); i >= 0 {
|
||
// 解析客户端请求的端口
|
||
clientPortsRe := regexp.MustCompile(`client_port=(\d+)-(\d+)`)
|
||
matches := clientPortsRe.FindStringSubmatch(tr)
|
||
|
||
if len(matches) == 3 {
|
||
clientRTPPort, _ := strconv.Atoi(matches[1])
|
||
clientRTCPPort, _ := strconv.Atoi(matches[2])
|
||
|
||
// 从端口池获取服务器端口
|
||
serverRTPPort, err := task.conf.GetUDPPort()
|
||
if err != nil {
|
||
task.Error("Failed to get UDP port from pool", "error", err)
|
||
res.Status = "500 Internal Server Error: No available UDP ports"
|
||
break
|
||
}
|
||
serverRTCPPort, err := task.conf.GetUDPPort()
|
||
if err != nil {
|
||
task.Error("Failed to get UDP port from pool", "error", err)
|
||
res.Status = "500 Internal Server Error: No available UDP ports"
|
||
break
|
||
}
|
||
// 在sender中记录这些端口信息
|
||
if sender.UDPPorts == nil {
|
||
sender.UDPPorts = make(map[int][]int)
|
||
}
|
||
// 保存 [clientRTP, clientRTCP, serverRTP, serverRTCP]
|
||
sender.UDPPorts[i] = []int{clientRTPPort, clientRTCPPort, int(serverRTPPort), int(serverRTCPPort)}
|
||
|
||
// 记录分配的端口,用于在连接结束时释放
|
||
if sender.AllocatedUDPPorts == nil {
|
||
sender.AllocatedUDPPorts = make([]uint16, 0)
|
||
}
|
||
sender.AllocatedUDPPorts = append(sender.AllocatedUDPPorts, serverRTPPort)
|
||
|
||
// 设置插件引用,用于在Dispose时释放端口
|
||
|
||
// 设置传输响应
|
||
udpResponse := fmt.Sprintf("RTP/AVP;unicast;client_port=%d-%d;server_port=%d-%d",
|
||
clientRTPPort, clientRTCPPort, serverRTPPort, serverRTCPPort)
|
||
res.Header.Set("Transport", udpResponse)
|
||
|
||
// 标记为UDP传输模式
|
||
task.Transport = "UDP"
|
||
} else {
|
||
res.Status = "400 Bad Request: Invalid client_port format"
|
||
}
|
||
} else {
|
||
res.Status = "400 Bad Request: Invalid track ID"
|
||
}
|
||
} else {
|
||
res.Status = "400 Bad Request: UDP only supported for PLAY mode"
|
||
}
|
||
} else {
|
||
res.Status = "461 Unsupported transport"
|
||
}
|
||
|
||
// 设置Session头
|
||
res.Header.Set("Session", task.Session)
|
||
|
||
if err = task.WriteResponse(res); err != nil {
|
||
return
|
||
}
|
||
|
||
case MethodRecord:
|
||
res := &util.Response{Request: req}
|
||
if err = task.WriteResponse(res); err != nil {
|
||
task.Error("Failed to write response", "error", err)
|
||
return
|
||
}
|
||
task.Info("Starting RTSP record session")
|
||
err = receiver.Receive()
|
||
if err != nil {
|
||
task.Error("RTSP receive error", "error", err)
|
||
}
|
||
return
|
||
case MethodPlay:
|
||
res := &util.Response{Request: req}
|
||
if err = task.WriteResponse(res); err != nil {
|
||
task.Error("Failed to write response", "error", err)
|
||
return
|
||
}
|
||
task.Info("Starting RTSP play session")
|
||
err = sender.Send()
|
||
if err != nil {
|
||
task.Error("RTSP send error", "error", err)
|
||
}
|
||
return
|
||
case MethodTeardown:
|
||
res := &util.Response{Request: req}
|
||
_ = task.WriteResponse(res)
|
||
task.Info("RTSP teardown received")
|
||
return
|
||
|
||
default:
|
||
task.Warn("unsupported method", "method", req.Method)
|
||
}
|
||
}
|
||
}
|
||
|
||
func reqTrackID(req *util.Request) int {
|
||
var s string
|
||
if req.URL.RawQuery != "" {
|
||
s = req.URL.RawQuery
|
||
} else {
|
||
s = req.URL.Path
|
||
}
|
||
if i := strings.LastIndexByte(s, '='); i > 0 {
|
||
if i, err := strconv.Atoi(s[i+1:]); err == nil {
|
||
return i
|
||
}
|
||
}
|
||
return -1
|
||
}
|