Files
monibuca/plugin/rtsp/server.go

299 lines
7.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package plugin_rtsp
import (
"errors"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"m7s.live/v5/pkg/util"
. "m7s.live/v5/plugin/rtsp/pkg"
)
type RTSPServer struct {
*NetConnection
conf *RTSPPlugin
}
func (task *RTSPServer) Go() (err error) {
var receiver *Receiver
var sender *Sender
var req *util.Request
var sendMode bool
// 添加延迟函数在方法结束时清理资源
defer func() {
if receiver != nil {
receiver.Dispose()
}
if sender != nil {
sender.Dispose()
}
// 确保任何残留资源被清理
if task.NetConnection != nil {
task.NetConnection.Dispose()
}
task.Info("RTSP connection closed and resources cleaned up")
}()
for {
req, err = task.ReadRequest()
if err != nil {
return
}
if task.URL == nil {
task.URL = req.URL
task.Logger = task.With("url", task.URL.String())
task.UserAgent = req.Header.Get("User-Agent")
task.Info("connect", "userAgent", task.UserAgent)
}
//if !c.auth.Validate(req) {
// res := &tcp.Response{
// Status: "401 Unauthorized",
// Header: map[string][]string{"Www-Authenticate": {`Basic realm="go2rtc"`}},
// Request: req,
// }
// if err = c.WriteResponse(res); err != nil {
// return err
// }
// continue
//}
// Receiver: OPTIONS > DESCRIBE > SETUP... > PLAY > TEARDOWN
// Sender: OPTIONS > ANNOUNCE > SETUP... > RECORD > TEARDOWN
switch req.Method {
case MethodOptions:
res := &util.Response{
Header: map[string][]string{
"Public": {"OPTIONS, SETUP, TEARDOWN, DESCRIBE, PLAY, PAUSE, ANNOUNCE, RECORD"},
},
Request: req,
}
if err = task.WriteResponse(res); err != nil {
return
}
case MethodAnnounce:
if req.Header.Get("Content-Type") != "application/sdp" {
err = errors.New("wrong content type")
return
}
task.SDP = string(req.Body) // for info
var medias []*Media
if medias, err = UnmarshalSDP(req.Body); err != nil {
return
}
receiver = &Receiver{}
receiver.NetConnection = task.NetConnection
if receiver.Publisher, err = task.conf.Publish(task, strings.TrimPrefix(task.URL.Path, "/")); err != nil {
receiver = nil
err = task.WriteResponse(&util.Response{
StatusCode: 500, Status: err.Error(),
})
return
}
receiver.Publisher.RemoteAddr = task.Conn.RemoteAddr().String()
if err = receiver.SetMedia(medias); err != nil {
return
}
res := &util.Response{Request: req}
if err = task.WriteResponse(res); err != nil {
return
}
task.Depend(receiver.Publisher)
case MethodDescribe:
sendMode = true
sender = &Sender{}
sender.NetConnection = task.NetConnection
rawQuery := req.URL.RawQuery
streamPath := strings.TrimPrefix(task.URL.Path, "/")
if rawQuery != "" {
streamPath += "?" + rawQuery
}
sender.Subscriber, err = task.conf.Subscribe(task, streamPath)
if err != nil {
res := &util.Response{
StatusCode: http.StatusBadRequest,
Status: err.Error(),
Request: req,
}
_ = task.WriteResponse(res)
return
}
sender.Subscriber.RemoteAddr = task.Conn.RemoteAddr().String()
res := &util.Response{
Header: map[string][]string{
"Content-Type": {"application/sdp"},
},
Request: req,
}
// convert tracks to real output medias
var medias []*Media
if medias, err = sender.GetMedia(); err != nil {
return
}
if res.Body, err = MarshalSDP(task.SessionName, medias); err != nil {
return
}
task.SDP = string(res.Body) // for info
if err = task.WriteResponse(res); err != nil {
return
}
case MethodSetup:
tr := req.Header.Get("Transport")
res := &util.Response{
Header: map[string][]string{},
Request: req,
}
// TCP传输模式
const tcpTransport = "RTP/AVP/TCP;unicast;interleaved="
// UDP传输模式前缀
const udpTransport = "RTP/AVP"
if strings.HasPrefix(tr, tcpTransport) {
task.Debug("into tcp play")
// 原有的TCP传输处理逻辑
task.Session = util.RandomString(10)
if sendMode {
if i := reqTrackID(req); i >= 0 {
tr = fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", i*2, i*2+1)
res.Header.Set("Transport", tr)
} else {
res.Status = "400 Bad Request"
}
} else {
res.Header.Set("Transport", tr[:len(tcpTransport)+3])
}
} else if strings.HasPrefix(tr, udpTransport) && strings.Contains(tr, "unicast") && strings.Contains(tr, "client_port=") {
task.Debug("into udp play")
// UDP传输处理逻辑
task.Session = util.RandomString(10)
if sendMode {
if i := reqTrackID(req); i >= 0 {
// 解析客户端请求的端口
clientPortsRe := regexp.MustCompile(`client_port=(\d+)-(\d+)`)
matches := clientPortsRe.FindStringSubmatch(tr)
if len(matches) == 3 {
clientRTPPort, _ := strconv.Atoi(matches[1])
clientRTCPPort, _ := strconv.Atoi(matches[2])
// 从端口池获取服务器端口
serverRTPPort, err := task.conf.GetUDPPort()
if err != nil {
task.Error("Failed to get UDP port from pool", "error", err)
res.Status = "500 Internal Server Error: No available UDP ports"
break
}
serverRTCPPort, err := task.conf.GetUDPPort()
if err != nil {
task.Error("Failed to get UDP port from pool", "error", err)
res.Status = "500 Internal Server Error: No available UDP ports"
break
}
// 在sender中记录这些端口信息
if sender.UDPPorts == nil {
sender.UDPPorts = make(map[int][]int)
}
// 保存 [clientRTP, clientRTCP, serverRTP, serverRTCP]
sender.UDPPorts[i] = []int{clientRTPPort, clientRTCPPort, int(serverRTPPort), int(serverRTCPPort)}
// 记录分配的端口,用于在连接结束时释放
if sender.AllocatedUDPPorts == nil {
sender.AllocatedUDPPorts = make([]uint16, 0)
}
sender.AllocatedUDPPorts = append(sender.AllocatedUDPPorts, serverRTPPort)
// 设置插件引用用于在Dispose时释放端口
// 设置传输响应
udpResponse := fmt.Sprintf("RTP/AVP;unicast;client_port=%d-%d;server_port=%d-%d",
clientRTPPort, clientRTCPPort, serverRTPPort, serverRTCPPort)
res.Header.Set("Transport", udpResponse)
// 标记为UDP传输模式
task.Transport = "UDP"
} else {
res.Status = "400 Bad Request: Invalid client_port format"
}
} else {
res.Status = "400 Bad Request: Invalid track ID"
}
} else {
res.Status = "400 Bad Request: UDP only supported for PLAY mode"
}
} else {
res.Status = "461 Unsupported transport"
}
// 设置Session头
res.Header.Set("Session", task.Session)
if err = task.WriteResponse(res); err != nil {
return
}
case MethodRecord:
res := &util.Response{Request: req}
if err = task.WriteResponse(res); err != nil {
task.Error("Failed to write response", "error", err)
return
}
task.Info("Starting RTSP record session")
err = receiver.Receive()
if err != nil {
task.Error("RTSP receive error", "error", err)
}
return
case MethodPlay:
res := &util.Response{Request: req}
if err = task.WriteResponse(res); err != nil {
task.Error("Failed to write response", "error", err)
return
}
task.Info("Starting RTSP play session")
err = sender.Send()
if err != nil {
task.Error("RTSP send error", "error", err)
}
return
case MethodTeardown:
res := &util.Response{Request: req}
_ = task.WriteResponse(res)
task.Info("RTSP teardown received")
return
default:
task.Warn("unsupported method", "method", req.Method)
}
}
}
func reqTrackID(req *util.Request) int {
var s string
if req.URL.RawQuery != "" {
s = req.URL.RawQuery
} else {
s = req.URL.Path
}
if i := strings.LastIndexByte(s, '='); i > 0 {
if i, err := strconv.Atoi(s[i+1:]); err == nil {
return i
}
}
return -1
}