Files
gortsplib/serverconn.go
2021-01-06 12:26:07 +01:00

620 lines
15 KiB
Go

package gortsplib
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/multibuffer"
)
const (
serverReadBufferSize = 4096
serverWriteBufferSize = 4096
)
// server errors.
var (
ErrServerTeardown = errors.New("teardown")
)
type serverConnState int
const (
serverConnStateInitial serverConnState = iota
serverConnStatePlay
serverConnStateRecord
)
type serverConnTrack struct {
proto StreamProtocol
rtpPort int
rtcpPort int
}
func extractTrackID(controlPath string, mode *headers.TransportMode, trackLen int) (int, error) {
if mode == nil || *mode == headers.TransportModePlay {
if !strings.HasPrefix(controlPath, "trackID=") {
return 0, fmt.Errorf("invalid control attribute (%s)", controlPath)
}
tmp, err := strconv.ParseInt(controlPath[len("trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, fmt.Errorf("invalid track id (%s)", controlPath)
}
trackID := int(tmp)
return trackID, nil
}
return trackLen, nil
}
// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
// all fields are optional.
type ServerConnReadHandlers struct {
// called after receiving any request.
OnRequest func(req *base.Request)
// called before sending any response.
OnResponse func(res *base.Response)
// called after receiving a OPTIONS request.
// if nil, it is generated automatically.
OnOptions func(req *base.Request) (*base.Response, error)
// called after receiving a DESCRIBE request.
OnDescribe func(req *base.Request) (*base.Response, error)
// called after receiving an ANNOUNCE request.
OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error)
// called after receiving a SETUP request.
OnSetup func(req *base.Request, th *headers.Transport) (*base.Response, error)
// called after receiving a PLAY request.
OnPlay func(req *base.Request) (*base.Response, error)
// called after receiving a RECORD request.
OnRecord func(req *base.Request) (*base.Response, error)
// called after receiving a PAUSE request.
OnPause func(req *base.Request) (*base.Response, error)
// called after receiving a GET_PARAMETER request.
// if nil, it is generated automatically.
OnGetParameter func(req *base.Request) (*base.Response, error)
// called after receiving a SET_PARAMETER request.
OnSetParameter func(req *base.Request) (*base.Response, error)
// called after receiving a TEARDOWN request.
// if nil, it is generated automatically.
OnTeardown func(req *base.Request) (*base.Response, error)
// called after receiving a Frame.
OnFrame func(trackID int, streamType StreamType, payload []byte)
}
// ServerConn is a server-side RTSP connection.
type ServerConn struct {
s *Server
nconn net.Conn
br *bufio.Reader
bw *bufio.Writer
state serverConnState
tracks map[int]serverConnTrack
tracksProto *StreamProtocol
writeMutex sync.Mutex
readHandlers ServerConnReadHandlers
nextFramesEnabled bool
framesEnabled bool
readTimeoutEnabled bool
// in
terminate chan struct{}
}
func newServerConn(s *Server, nconn net.Conn) *ServerConn {
conn := func() net.Conn {
if s.conf.TLSConfig != nil {
return tls.Server(nconn, s.conf.TLSConfig)
}
return nconn
}()
return &ServerConn{
s: s,
nconn: nconn,
br: bufio.NewReaderSize(conn, serverReadBufferSize),
bw: bufio.NewWriterSize(conn, serverWriteBufferSize),
tracks: make(map[int]serverConnTrack),
terminate: make(chan struct{}),
}
}
// Close closes all the connection resources.
func (sc *ServerConn) Close() error {
err := sc.nconn.Close()
close(sc.terminate)
return err
}
// NetConn returns the underlying net.Conn.
func (sc *ServerConn) NetConn() net.Conn {
return sc.nconn
}
func (sc *ServerConn) ip() net.IP {
return sc.nconn.RemoteAddr().(*net.TCPAddr).IP
}
func (sc *ServerConn) zone() string {
return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone
}
func (sc *ServerConn) frameModeEnable() {
switch sc.state {
case serverConnStatePlay:
if *sc.tracksProto == StreamProtocolTCP {
sc.nextFramesEnabled = true
}
case serverConnStateRecord:
if *sc.tracksProto == StreamProtocolTCP {
sc.nextFramesEnabled = true
sc.readTimeoutEnabled = true
} else {
for trackID, track := range sc.tracks {
sc.s.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
sc.s.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
}
}
}
}
func (sc *ServerConn) frameModeDisable() {
switch sc.state {
case serverConnStatePlay:
sc.nextFramesEnabled = false
case serverConnStateRecord:
sc.nextFramesEnabled = false
sc.readTimeoutEnabled = false
for _, track := range sc.tracks {
if track.proto == StreamProtocolUDP {
sc.s.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
sc.s.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
}
}
}
}
func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if sc.readHandlers.OnRequest != nil {
sc.readHandlers.OnRequest(req)
}
switch req.Method {
case base.Options:
if sc.readHandlers.OnOptions != nil {
return sc.readHandlers.OnOptions(req)
}
var methods []string
if sc.readHandlers.OnDescribe != nil {
methods = append(methods, string(base.Describe))
}
if sc.readHandlers.OnAnnounce != nil {
methods = append(methods, string(base.Announce))
}
if sc.readHandlers.OnSetup != nil {
methods = append(methods, string(base.Setup))
}
if sc.readHandlers.OnPlay != nil {
methods = append(methods, string(base.Play))
}
if sc.readHandlers.OnRecord != nil {
methods = append(methods, string(base.Record))
}
if sc.readHandlers.OnPause != nil {
methods = append(methods, string(base.Pause))
}
methods = append(methods, string(base.GetParameter))
if sc.readHandlers.OnSetParameter != nil {
methods = append(methods, string(base.SetParameter))
}
methods = append(methods, string(base.Teardown))
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join(methods, ", ")},
},
}, nil
case base.Describe:
if sc.readHandlers.OnDescribe != nil {
return sc.readHandlers.OnDescribe(req)
}
case base.Announce:
if sc.readHandlers.OnAnnounce != nil {
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, errors.New("Content-Type header is missing")
}
if ct[0] != "application/sdp" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unsupported Content-Type '%s'", ct)
}
tracks, err := ReadTracks(req.Content)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid SDP: %s", err)
}
if len(tracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, errors.New("no tracks defined")
}
res, err := sc.readHandlers.OnAnnounce(req, tracks)
return res, err
}
case base.Setup:
if sc.readHandlers.OnSetup != nil {
_, controlPath, ok := req.URL.BasePathControlAttr()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unable to find control attribute (%s)", req.URL)
}
th, err := headers.ReadTransport(req.Header["Transport"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header: %s", err)
}
trackID, err := extractTrackID(controlPath, th.Mode, len(sc.tracks))
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if _, ok := sc.tracks[trackID]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("track %d has already been setup", trackID)
}
if sc.tracksProto != nil && *sc.tracksProto != th.Protocol {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("can't receive tracks with different protocols")
}
if th.Protocol == StreamProtocolUDP {
if sc.s.conf.UDPRTPListener == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if th.ClientPorts == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header does not have valid client ports (%v)", req.Header["Transport"])
}
} else {
if th.InterleavedIds == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header does not contain the interleaved field")
}
if (*th.InterleavedIds)[0] != (trackID*2) ||
(*th.InterleavedIds)[1] != (1+trackID*2) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("wrong interleaved ids, expected [%v %v], got %v",
(trackID * 2), (1 + trackID*2), *th.InterleavedIds)
}
}
res, err := sc.readHandlers.OnSetup(req, th)
if res.StatusCode == 200 {
sc.tracksProto = &th.Protocol
if th.Protocol == StreamProtocolUDP {
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolUDP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
ClientPorts: th.ClientPorts,
ServerPorts: &[2]int{sc.s.conf.UDPRTPListener.port(), sc.s.conf.UDPRTCPListener.port()},
}.Write()
sc.tracks[trackID] = serverConnTrack{
proto: StreamProtocolUDP,
rtpPort: th.ClientPorts[0],
rtcpPort: th.ClientPorts[1],
}
} else {
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolTCP,
InterleavedIds: th.InterleavedIds,
}.Write()
sc.tracks[trackID] = serverConnTrack{
proto: StreamProtocolTCP,
}
}
}
// workaround to prevent a bug in rtspclientsink
// that makes impossible for the client to receive the response
// and send frames.
// this was causing problems during unit tests.
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
strings.HasPrefix(ua[0], "GStreamer") {
t := time.NewTimer(1 * time.Second)
defer t.Stop()
select {
case <-t.C:
case <-sc.terminate:
}
}
return res, err
}
case base.Play:
if sc.readHandlers.OnPlay != nil {
res, err := sc.readHandlers.OnPlay(req)
if res.StatusCode == 200 {
sc.state = serverConnStatePlay
sc.frameModeEnable()
}
return res, err
}
case base.Record:
if sc.readHandlers.OnRecord != nil {
res, err := sc.readHandlers.OnRecord(req)
if res.StatusCode == 200 {
sc.state = serverConnStateRecord
sc.frameModeEnable()
}
return res, err
}
case base.Pause:
if sc.readHandlers.OnPause != nil {
res, err := sc.readHandlers.OnPause(req)
if res.StatusCode == 200 {
sc.frameModeDisable()
sc.state = serverConnStateInitial
}
return res, err
}
case base.GetParameter:
if sc.readHandlers.OnGetParameter != nil {
return sc.readHandlers.OnGetParameter(req)
}
// GET_PARAMETER is used like a ping
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"text/parameters"},
},
Content: []byte("\n"),
}, nil
case base.SetParameter:
if sc.readHandlers.OnSetParameter != nil {
return sc.readHandlers.OnSetParameter(req)
}
case base.Teardown:
if sc.readHandlers.OnTeardown != nil {
return sc.readHandlers.OnTeardown(req)
}
return &base.Response{
StatusCode: base.StatusOK,
}, ErrServerTeardown
}
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unhandled method: %v", req.Method)
}
func (sc *ServerConn) backgroundRead() error {
handleRequestOuter := func(req *base.Request) error {
// check cseq
cseq, ok := req.Header["CSeq"]
if !ok || len(cseq) != 1 {
sc.writeMutex.Lock()
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
base.Response{
StatusCode: base.StatusBadRequest,
Header: base.Header{},
}.Write(sc.bw)
sc.writeMutex.Unlock()
return errors.New("CSeq is missing")
}
res, err := sc.handleRequest(req)
if res.Header == nil {
res.Header = base.Header{}
}
// add cseq
res.Header["CSeq"] = cseq
// add server
res.Header["Server"] = base.HeaderValue{"gortsplib"}
if sc.readHandlers.OnResponse != nil {
sc.readHandlers.OnResponse(res)
}
sc.writeMutex.Lock()
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
res.Write(sc.bw)
// set framesEnabled after sending the response
// in order to start sending frames after the response, never before
if sc.framesEnabled != sc.nextFramesEnabled {
sc.framesEnabled = sc.nextFramesEnabled
}
sc.writeMutex.Unlock()
return err
}
var req base.Request
var frame base.InterleavedFrame
tcpFrameBuffer := multibuffer.New(sc.s.conf.ReadBufferCount, clientTCPFrameReadBufferSize)
var errRet error
outer:
for {
if sc.readTimeoutEnabled {
sc.nconn.SetReadDeadline(time.Now().Add(sc.s.conf.ReadTimeout))
} else {
sc.nconn.SetReadDeadline(time.Time{})
}
if sc.framesEnabled {
frame.Content = tcpFrameBuffer.Next()
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
if err != nil {
errRet = err
break outer
}
switch what.(type) {
case *base.InterleavedFrame:
sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Content)
case *base.Request:
err := handleRequestOuter(&req)
if err != nil {
errRet = err
break outer
}
}
} else {
err := req.Read(sc.br)
if err != nil {
errRet = err
break outer
}
err = handleRequestOuter(&req)
if err != nil {
errRet = err
break outer
}
}
}
sc.frameModeDisable()
return errRet
}
// Read starts reading requests and frames.
// it returns a channel that is written when the reading stops.
func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error {
// channel is buffered, since listening to it is not mandatory
done := make(chan error, 1)
sc.readHandlers = readHandlers
go func() {
done <- sc.backgroundRead()
}()
return done
}
// WriteFrame writes a frame.
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error {
sc.writeMutex.Lock()
defer sc.writeMutex.Unlock()
track := sc.tracks[trackID]
if track.proto == StreamProtocolUDP {
if streamType == StreamTypeRTP {
return sc.s.conf.UDPRTPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.rtpPort,
})
}
return sc.s.conf.UDPRTCPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.rtcpPort,
})
}
// StreamProtocolTCP
if !sc.framesEnabled {
return errors.New("frames are disabled")
}
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
frame := base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,
Content: payload,
}
return frame.Write(sc.bw)
}