mirror of
https://github.com/aler9/gortsplib
synced 2025-10-04 23:02:45 +08:00
573 lines
12 KiB
Go
573 lines
12 KiB
Go
package gortsplib
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pion/rtcp"
|
|
|
|
"github.com/aler9/gortsplib/pkg/base"
|
|
"github.com/aler9/gortsplib/pkg/liberrors"
|
|
"github.com/aler9/gortsplib/pkg/multibuffer"
|
|
)
|
|
|
|
func getSessionID(header base.Header) string {
|
|
if h, ok := header["Session"]; ok && len(h) == 1 {
|
|
return h[0]
|
|
}
|
|
return ""
|
|
}
|
|
|
|
type readReq struct {
|
|
req *base.Request
|
|
res chan error
|
|
}
|
|
|
|
// ServerConn is a server-side RTSP connection.
|
|
type ServerConn struct {
|
|
s *Server
|
|
conn net.Conn
|
|
|
|
ctx context.Context
|
|
ctxCancel func()
|
|
remoteAddr *net.TCPAddr
|
|
br *bufio.Reader
|
|
sessions map[string]*ServerSession
|
|
tcpFrameEnabled bool
|
|
tcpSession *ServerSession
|
|
tcpFrameTimeout bool
|
|
tcpReadBuffer *multibuffer.MultiBuffer
|
|
tcpRTPPacketBuffer *rtpPacketMultiBuffer
|
|
tcpProcessFunc func(int, bool, []byte)
|
|
tcpWriterRunning bool
|
|
|
|
// in
|
|
sessionRemove chan *ServerSession
|
|
|
|
// out
|
|
done chan struct{}
|
|
}
|
|
|
|
func newServerConn(
|
|
s *Server,
|
|
nconn net.Conn) *ServerConn {
|
|
ctx, ctxCancel := context.WithCancel(s.ctx)
|
|
|
|
conn := func() net.Conn {
|
|
if s.TLSConfig != nil {
|
|
return tls.Server(nconn, s.TLSConfig)
|
|
}
|
|
return nconn
|
|
}()
|
|
|
|
sc := &ServerConn{
|
|
s: s,
|
|
conn: conn,
|
|
ctx: ctx,
|
|
ctxCancel: ctxCancel,
|
|
remoteAddr: conn.RemoteAddr().(*net.TCPAddr),
|
|
sessionRemove: make(chan *ServerSession),
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
s.wg.Add(1)
|
|
go sc.run()
|
|
|
|
return sc
|
|
}
|
|
|
|
// Close closes the ServerConn.
|
|
func (sc *ServerConn) Close() error {
|
|
sc.ctxCancel()
|
|
return nil
|
|
}
|
|
|
|
// NetConn returns the underlying net.Conn.
|
|
func (sc *ServerConn) NetConn() net.Conn {
|
|
return sc.conn
|
|
}
|
|
|
|
func (sc *ServerConn) ip() net.IP {
|
|
return sc.remoteAddr.IP
|
|
}
|
|
|
|
func (sc *ServerConn) zone() string {
|
|
return sc.remoteAddr.Zone
|
|
}
|
|
|
|
func (sc *ServerConn) run() {
|
|
defer sc.s.wg.Done()
|
|
defer close(sc.done)
|
|
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnConnOpen); ok {
|
|
h.OnConnOpen(&ServerHandlerOnConnOpenCtx{
|
|
Conn: sc,
|
|
})
|
|
}
|
|
|
|
sc.br = bufio.NewReaderSize(sc.conn, serverReadBufferSize)
|
|
sc.sessions = make(map[string]*ServerSession)
|
|
|
|
readRequest := make(chan readReq)
|
|
readErr := make(chan error)
|
|
readDone := make(chan struct{})
|
|
go func() {
|
|
defer close(readDone)
|
|
err := func() error {
|
|
var req base.Request
|
|
var frame base.InterleavedFrame
|
|
|
|
for {
|
|
if sc.tcpFrameEnabled {
|
|
if sc.tcpFrameTimeout {
|
|
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
|
|
}
|
|
|
|
frame.Payload = sc.tcpReadBuffer.Next()
|
|
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch what.(type) {
|
|
case *base.InterleavedFrame:
|
|
channel := frame.Channel
|
|
isRTP := true
|
|
if (channel % 2) != 0 {
|
|
channel--
|
|
isRTP = false
|
|
}
|
|
|
|
// forward frame only if it has been set up
|
|
if trackID, ok := sc.tcpSession.tcpTracksByChannel[channel]; ok {
|
|
sc.tcpProcessFunc(trackID, isRTP, frame.Payload)
|
|
}
|
|
|
|
case *base.Request:
|
|
cres := make(chan error)
|
|
select {
|
|
case readRequest <- readReq{req: &req, res: cres}:
|
|
err := <-cres
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case <-sc.ctx.Done():
|
|
return liberrors.ErrServerTerminated{}
|
|
}
|
|
}
|
|
} else {
|
|
err := req.Read(sc.br)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
cres := make(chan error)
|
|
select {
|
|
case readRequest <- readReq{req: &req, res: cres}:
|
|
err = <-cres
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case <-sc.ctx.Done():
|
|
return liberrors.ErrServerTerminated{}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case readErr <- err:
|
|
case <-sc.ctx.Done():
|
|
}
|
|
}()
|
|
|
|
err := func() error {
|
|
for {
|
|
select {
|
|
case req := <-readRequest:
|
|
req.res <- sc.handleRequestOuter(req.req)
|
|
|
|
case err := <-readErr:
|
|
return err
|
|
|
|
case ss := <-sc.sessionRemove:
|
|
if _, ok := sc.sessions[ss.secretID]; ok {
|
|
delete(sc.sessions, ss.secretID)
|
|
|
|
select {
|
|
case ss.connRemove <- sc:
|
|
case <-ss.ctx.Done():
|
|
}
|
|
}
|
|
|
|
case <-sc.ctx.Done():
|
|
return liberrors.ErrServerTerminated{}
|
|
}
|
|
}
|
|
}()
|
|
|
|
sc.ctxCancel()
|
|
|
|
sc.conn.Close()
|
|
<-readDone
|
|
|
|
for _, ss := range sc.sessions {
|
|
select {
|
|
case ss.connRemove <- sc:
|
|
case <-ss.ctx.Done():
|
|
}
|
|
}
|
|
|
|
select {
|
|
case sc.s.connClose <- sc:
|
|
case <-sc.s.ctx.Done():
|
|
}
|
|
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnConnClose); ok {
|
|
h.OnConnClose(&ServerHandlerOnConnCloseCtx{
|
|
Conn: sc,
|
|
Error: err,
|
|
})
|
|
}
|
|
}
|
|
|
|
func (sc *ServerConn) tcpProcessPlay(trackID int, isRTP bool, payload []byte) {
|
|
if !isRTP {
|
|
packets, err := rtcp.Unmarshal(payload)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
|
|
for _, pkt := range packets {
|
|
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
|
|
Session: sc.tcpSession,
|
|
TrackID: trackID,
|
|
Packet: pkt,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sc *ServerConn) tcpProcessRecord(trackID int, isRTP bool, payload []byte) {
|
|
if isRTP {
|
|
pkt := sc.tcpRTPPacketBuffer.next()
|
|
err := pkt.Unmarshal(payload)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
|
|
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
|
|
Session: sc.tcpSession,
|
|
TrackID: trackID,
|
|
Packet: pkt,
|
|
})
|
|
}
|
|
} else {
|
|
packets, err := rtcp.Unmarshal(payload)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
|
|
for _, pkt := range packets {
|
|
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
|
|
Session: sc.tcpSession,
|
|
TrackID: trackID,
|
|
Packet: pkt,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
|
if cseq, ok := req.Header["CSeq"]; !ok || len(cseq) != 1 {
|
|
return &base.Response{
|
|
StatusCode: base.StatusBadRequest,
|
|
Header: base.Header{},
|
|
}, liberrors.ErrServerCSeqMissing{}
|
|
}
|
|
|
|
sxID := getSessionID(req.Header)
|
|
|
|
// the connection can't communicate with another session
|
|
// if it's receiving or sending TCP frames.
|
|
if sc.tcpSession != nil &&
|
|
sxID != sc.tcpSession.secretID {
|
|
return &base.Response{
|
|
StatusCode: base.StatusBadRequest,
|
|
}, liberrors.ErrServerLinkedToOtherSession{}
|
|
}
|
|
|
|
switch req.Method {
|
|
case base.Options:
|
|
// handle request in session
|
|
if sxID != "" {
|
|
return sc.handleRequestInSession(sxID, req, false)
|
|
}
|
|
|
|
// handle request here
|
|
var methods []string
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
|
|
methods = append(methods, string(base.Describe))
|
|
}
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
|
|
methods = append(methods, string(base.Announce))
|
|
}
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
|
|
methods = append(methods, string(base.Setup))
|
|
}
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
|
|
methods = append(methods, string(base.Play))
|
|
}
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
|
|
methods = append(methods, string(base.Record))
|
|
}
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
|
|
methods = append(methods, string(base.Pause))
|
|
}
|
|
methods = append(methods, string(base.GetParameter))
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
|
|
methods = append(methods, string(base.SetParameter))
|
|
}
|
|
methods = append(methods, string(base.Teardown))
|
|
|
|
return &base.Response{
|
|
StatusCode: base.StatusOK,
|
|
Header: base.Header{
|
|
"Public": base.HeaderValue{strings.Join(methods, ", ")},
|
|
},
|
|
}, nil
|
|
|
|
case base.Describe:
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
|
|
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
|
|
if !ok {
|
|
return &base.Response{
|
|
StatusCode: base.StatusBadRequest,
|
|
}, liberrors.ErrServerInvalidPath{}
|
|
}
|
|
|
|
path, query := base.PathSplitQuery(pathAndQuery)
|
|
|
|
res, stream, err := h.OnDescribe(&ServerHandlerOnDescribeCtx{
|
|
Conn: sc,
|
|
Req: req,
|
|
Path: path,
|
|
Query: query,
|
|
})
|
|
|
|
if res.StatusCode == base.StatusOK {
|
|
if res.Header == nil {
|
|
res.Header = make(base.Header)
|
|
}
|
|
|
|
res.Header["Content-Base"] = base.HeaderValue{req.URL.String() + "/"}
|
|
res.Header["Content-Type"] = base.HeaderValue{"application/sdp"}
|
|
|
|
// VLC uses multicast if the SDP contains a multicast address.
|
|
// therefore, we introduce a special query (vlcmulticast) that allows
|
|
// to return a SDP that contains a multicast address.
|
|
multicast := false
|
|
if sc.s.MulticastIPRange != "" {
|
|
if q, err := url.ParseQuery(query); err == nil {
|
|
if _, ok := q["vlcmulticast"]; ok {
|
|
multicast = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if stream != nil {
|
|
res.Body = stream.Tracks().Write(multicast)
|
|
}
|
|
}
|
|
|
|
return res, err
|
|
}
|
|
|
|
case base.Announce:
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
|
|
return sc.handleRequestInSession(sxID, req, true)
|
|
}
|
|
|
|
case base.Setup:
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
|
|
return sc.handleRequestInSession(sxID, req, true)
|
|
}
|
|
|
|
case base.Play:
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
|
|
return sc.handleRequestInSession(sxID, req, false)
|
|
}
|
|
|
|
case base.Record:
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
|
|
return sc.handleRequestInSession(sxID, req, false)
|
|
}
|
|
|
|
case base.Pause:
|
|
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
|
|
return sc.handleRequestInSession(sxID, req, false)
|
|
}
|
|
|
|
case base.Teardown:
|
|
return sc.handleRequestInSession(sxID, req, false)
|
|
|
|
case base.GetParameter:
|
|
// handle request in session
|
|
if sxID != "" {
|
|
return sc.handleRequestInSession(sxID, req, false)
|
|
}
|
|
|
|
// handle request here
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {
|
|
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
|
|
if !ok {
|
|
return &base.Response{
|
|
StatusCode: base.StatusBadRequest,
|
|
}, liberrors.ErrServerInvalidPath{}
|
|
}
|
|
|
|
path, query := base.PathSplitQuery(pathAndQuery)
|
|
|
|
return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{
|
|
Conn: sc,
|
|
Req: req,
|
|
Path: path,
|
|
Query: query,
|
|
})
|
|
}
|
|
|
|
case base.SetParameter:
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
|
|
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
|
|
if !ok {
|
|
return &base.Response{
|
|
StatusCode: base.StatusBadRequest,
|
|
}, liberrors.ErrServerInvalidPath{}
|
|
}
|
|
|
|
path, query := base.PathSplitQuery(pathAndQuery)
|
|
|
|
return h.OnSetParameter(&ServerHandlerOnSetParameterCtx{
|
|
Conn: sc,
|
|
Req: req,
|
|
Path: path,
|
|
Query: query,
|
|
})
|
|
}
|
|
}
|
|
|
|
return &base.Response{
|
|
StatusCode: base.StatusBadRequest,
|
|
}, liberrors.ErrServerUnhandledRequest{Req: req}
|
|
}
|
|
|
|
func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnRequest); ok {
|
|
h.OnRequest(sc, req)
|
|
}
|
|
|
|
res, err := sc.handleRequest(req)
|
|
|
|
if res.Header == nil {
|
|
res.Header = make(base.Header)
|
|
}
|
|
|
|
// add cseq
|
|
if _, ok := err.(liberrors.ErrServerCSeqMissing); !ok {
|
|
res.Header["CSeq"] = req.Header["CSeq"]
|
|
}
|
|
|
|
// add server
|
|
res.Header["Server"] = base.HeaderValue{"gortsplib"}
|
|
|
|
if h, ok := sc.s.Handler.(ServerHandlerOnResponse); ok {
|
|
h.OnResponse(sc, res)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
res.Write(&buf)
|
|
|
|
sc.conn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
|
|
sc.conn.Write(buf.Bytes())
|
|
|
|
// start writer after sending the response
|
|
if sc.tcpFrameEnabled && !sc.tcpWriterRunning {
|
|
sc.tcpWriterRunning = true
|
|
select {
|
|
case sc.tcpSession.startWriter <- struct{}{}:
|
|
case <-sc.tcpSession.ctx.Done():
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (sc *ServerConn) handleRequestInSession(
|
|
sxID string,
|
|
req *base.Request,
|
|
create bool,
|
|
) (*base.Response, error) {
|
|
// if the session is already linked to this conn, communicate directly with it
|
|
if sxID != "" {
|
|
if ss, ok := sc.sessions[sxID]; ok {
|
|
cres := make(chan sessionRequestRes)
|
|
sreq := sessionRequestReq{
|
|
sc: sc,
|
|
req: req,
|
|
id: sxID,
|
|
create: create,
|
|
res: cres,
|
|
}
|
|
|
|
select {
|
|
case ss.request <- sreq:
|
|
res := <-cres
|
|
return res.res, res.err
|
|
|
|
case <-ss.ctx.Done():
|
|
return &base.Response{
|
|
StatusCode: base.StatusBadRequest,
|
|
}, liberrors.ErrServerTerminated{}
|
|
}
|
|
}
|
|
}
|
|
|
|
// otherwise, pass through Server
|
|
cres := make(chan sessionRequestRes)
|
|
sreq := sessionRequestReq{
|
|
sc: sc,
|
|
req: req,
|
|
id: sxID,
|
|
create: create,
|
|
res: cres,
|
|
}
|
|
|
|
select {
|
|
case sc.s.sessionRequest <- sreq:
|
|
res := <-cres
|
|
if res.ss != nil {
|
|
sc.sessions[res.ss.secretID] = res.ss
|
|
}
|
|
|
|
return res.res, res.err
|
|
|
|
case <-sc.s.ctx.Done():
|
|
return &base.Response{
|
|
StatusCode: base.StatusBadRequest,
|
|
}, liberrors.ErrServerTerminated{}
|
|
}
|
|
}
|