Files
gortsplib/serverconn.go
2021-05-04 16:51:20 +02:00

1134 lines
28 KiB
Go

package gortsplib
import (
"bufio"
"crypto/tls"
"fmt"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/multibuffer"
"github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtcpreceiver"
)
const (
serverConnReadBufferSize = 4096
serverConnWriteBufferSize = 4096
serverConnCheckStreamPeriod = 5 * time.Second
)
func stringsReverseIndex(s, substr string) int {
for i := len(s) - 1 - len(substr); i >= 0; i-- {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}
func setupGetTrackIDPathQuery(url *base.URL,
thMode *headers.TransportMode,
announcedTracks []ServerConnAnnouncedTrack,
setupPath *string, setupQuery *string) (int, string, string, error) {
pathAndQuery, ok := url.RTSPPathAndQuery()
if !ok {
return 0, "", "", liberrors.ErrServerNoPath{}
}
if thMode == nil || *thMode == headers.TransportModePlay {
i := stringsReverseIndex(pathAndQuery, "/trackID=")
// URL doesn't contain trackID - it's track zero
if i < 0 {
if !strings.HasSuffix(pathAndQuery, "/") {
return 0, "", "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery)
}
pathAndQuery = pathAndQuery[:len(pathAndQuery)-1]
path, query := base.PathSplitQuery(pathAndQuery)
// we assume it's track 0
return 0, path, query, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, "", "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery)
}
trackID := int(tmp)
pathAndQuery = pathAndQuery[:i]
path, query := base.PathSplitQuery(pathAndQuery)
if setupPath != nil && (path != *setupPath || query != *setupQuery) {
return 0, "", "", fmt.Errorf("can't setup tracks with different paths")
}
return trackID, path, query, nil
}
for trackID, track := range announcedTracks {
u, _ := track.track.URL()
if u.String() == url.String() {
return trackID, *setupPath, *setupQuery, nil
}
}
return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
}
// ServerConnSetuppedTrack is a setupped track of a ServerConn.
type ServerConnSetuppedTrack struct {
udpRTPPort int
udpRTCPPort int
}
// ServerConnAnnouncedTrack is an announced track of a ServerConn.
type ServerConnAnnouncedTrack struct {
track *Track
rtcpReceiver *rtcpreceiver.RTCPReceiver
udpLastFrameTime *int64
}
// ServerConnState is a state of a ServerConn.
type ServerConnState int
// standard states.
const (
ServerConnStateInitial ServerConnState = iota
ServerConnStatePrePlay
ServerConnStatePlay
ServerConnStatePreRecord
ServerConnStateRecord
)
// String implements fmt.Stringer.
func (s ServerConnState) String() string {
switch s {
case ServerConnStateInitial:
return "initial"
case ServerConnStatePrePlay:
return "prePlay"
case ServerConnStatePlay:
return "play"
case ServerConnStatePreRecord:
return "preRecord"
case ServerConnStateRecord:
return "record"
}
return "unknown"
}
// ServerConn is a server-side RTSP connection.
type ServerConn struct {
s *Server
wg *sync.WaitGroup
nconn net.Conn
br *bufio.Reader
bw *bufio.Writer
state ServerConnState
setuppedTracks map[int]ServerConnSetuppedTrack
setupProtocol *StreamProtocol
setupPath *string
setupQuery *string
// TCP stream protocol
doEnableTCPFrame bool
tcpFrameEnabled bool
tcpFrameTimeout bool
tcpFrameBuffer *multibuffer.MultiBuffer
tcpFrameWriteBuffer *ringbuffer.RingBuffer
tcpBackgroundWriteDone chan struct{}
// publish
announcedTracks []ServerConnAnnouncedTrack
backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{}
udpTimeout int32
// in
terminate chan struct{}
}
func newServerConn(
s *Server,
wg *sync.WaitGroup,
nconn net.Conn) *ServerConn {
sc := &ServerConn{
s: s,
wg: wg,
nconn: nconn,
terminate: make(chan struct{}),
}
wg.Add(1)
go sc.run()
return sc
}
// State returns the state.
func (sc *ServerConn) State() ServerConnState {
return sc.state
}
// StreamProtocol returns the stream protocol of the setupped tracks.
func (sc *ServerConn) StreamProtocol() *StreamProtocol {
return sc.setupProtocol
}
// SetuppedTracks returns the setupped tracks.
func (sc *ServerConn) SetuppedTracks() map[int]ServerConnSetuppedTrack {
return sc.setuppedTracks
}
// AnnouncedTracks returns the announced tracks.
func (sc *ServerConn) AnnouncedTracks() []ServerConnAnnouncedTrack {
return sc.announcedTracks
}
// NetConn returns the underlying net.Conn.
func (sc *ServerConn) NetConn() net.Conn {
return sc.nconn
}
func (sc *ServerConn) ip() net.IP {
return sc.nconn.RemoteAddr().(*net.TCPAddr).IP
}
func (sc *ServerConn) zone() string {
return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone
}
func (sc *ServerConn) checkState(allowed map[ServerConnState]struct{}) error {
if _, ok := allowed[sc.state]; ok {
return nil
}
allowedList := make([]fmt.Stringer, len(allowed))
i := 0
for a := range allowed {
allowedList[i] = a
i++
}
return liberrors.ErrServerWrongState{AllowedList: allowedList, State: sc.state}
}
func (sc *ServerConn) run() {
defer sc.wg.Done()
if h, ok := sc.s.Handler.(ServerHandlerOnConnOpen); ok {
h.OnConnOpen(sc)
}
conn := func() net.Conn {
if sc.s.TLSConfig != nil {
return tls.Server(sc.nconn, sc.s.TLSConfig)
}
return sc.nconn
}()
sc.br = bufio.NewReaderSize(conn, serverConnReadBufferSize)
sc.bw = bufio.NewWriterSize(conn, serverConnWriteBufferSize)
// instantiate always to allow writing to this conn before Play()
sc.tcpFrameWriteBuffer = ringbuffer.New(uint64(sc.s.ReadBufferCount))
sc.tcpBackgroundWriteDone = make(chan struct{})
readDone := make(chan error)
go func() {
readDone <- sc.backgroundRead()
}()
var err error
select {
case err = <-readDone:
sc.nconn.Close()
sc.s.connClose <- sc
<-sc.terminate
case <-sc.terminate:
sc.nconn.Close()
err = <-readDone
}
if h, ok := sc.s.Handler.(ServerHandlerOnConnClose); ok {
h.OnConnClose(sc, err)
}
}
func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if cseq, ok := req.Header["CSeq"]; !ok || len(cseq) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
Header: base.Header{},
}, liberrors.ErrServerCSeqMissing{}
}
switch req.Method {
case base.Options:
if h, ok := sc.s.Handler.(ServerHandlerOnOptions); ok {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
return h.OnOptions(&ServerHandlerOnOptionsCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
})
}
var methods []string
if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
methods = append(methods, string(base.Describe))
}
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
methods = append(methods, string(base.Announce))
}
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
methods = append(methods, string(base.Setup))
}
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
methods = append(methods, string(base.Play))
}
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
methods = append(methods, string(base.Record))
}
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
methods = append(methods, string(base.Pause))
}
methods = append(methods, string(base.GetParameter))
if _, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
methods = append(methods, string(base.SetParameter))
}
methods = append(methods, string(base.Teardown))
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join(methods, ", ")},
},
}, nil
case base.Describe:
if h, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStateInitial: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
res, sdp, err := h.OnDescribe(&ServerHandlerOnDescribeCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK && sdp != nil {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Content-Base"] = base.HeaderValue{req.URL.String() + "/"}
res.Header["Content-Type"] = base.HeaderValue{"application/sdp"}
res.Body = sdp
}
return res, err
}
case base.Announce:
if h, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStateInitial: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerContentTypeMissing{}
}
if ct[0] != "application/sdp" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerContentTypeUnsupported{CT: ct}
}
tracks, err := ReadTracks(req.Body, req.URL)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerSDPInvalid{Err: err}
}
if len(tracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerSDPNoTracksDefined{}
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
for _, track := range tracks {
trackURL, err := track.URL()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unable to generate track URL")
}
trackPath, ok := trackURL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track URL (%v)", trackURL)
}
if !strings.HasPrefix(trackPath, path) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track path: must begin with '%s', but is '%s'",
path, trackPath)
}
}
res, err := h.OnAnnounce(&ServerHandlerOnAnnounceCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
Tracks: tracks,
})
if res.StatusCode == base.StatusOK {
sc.state = ServerConnStatePreRecord
sc.setupPath = &path
sc.setupQuery = &query
sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
for trackID, track := range tracks {
clockRate, _ := track.ClockRate()
v := time.Now().Unix()
sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{
track: track,
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
udpLastFrameTime: &v,
}
}
}
return res, err
}
case base.Setup:
if h, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStateInitial: {},
ServerConnStatePrePlay: {},
ServerConnStatePreRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
var th headers.Transport
err = th.Read(req.Header["Transport"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalid{Err: err}
}
if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
trackID, path, query, err := setupGetTrackIDPathQuery(req.URL, th.Mode,
sc.announcedTracks, sc.setupPath, sc.setupQuery)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if _, ok := sc.setuppedTracks[trackID]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTrackAlreadySetup{TrackID: trackID}
}
switch sc.state {
case ServerConnStateInitial, ServerConnStatePrePlay: // play
if th.Mode != nil && *th.Mode != headers.TransportModePlay {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderWrongMode{Mode: th.Mode}
}
default: // record
if th.Mode == nil || *th.Mode != headers.TransportModeRecord {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderWrongMode{Mode: th.Mode}
}
}
if th.Protocol == StreamProtocolUDP {
if sc.s.udpRTPListener == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if th.ClientPorts == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoClientPorts{}
}
} else {
if th.InterleavedIDs == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoInterleavedIDs{}
}
if th.InterleavedIDs[0] != (trackID*2) ||
th.InterleavedIDs[1] != (1+trackID*2) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderWrongInterleavedIDs{
Expected: [2]int{(trackID * 2), (1 + trackID*2)}, Value: *th.InterleavedIDs}
}
}
if sc.setupProtocol != nil && *sc.setupProtocol != th.Protocol {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTracksDifferentProtocols{}
}
res, err := h.OnSetup(&ServerHandlerOnSetupCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
TrackID: trackID,
Transport: &th,
})
if res.StatusCode == base.StatusOK {
sc.setupProtocol = &th.Protocol
if sc.setuppedTracks == nil {
sc.setuppedTracks = make(map[int]ServerConnSetuppedTrack)
}
if th.Protocol == StreamProtocolUDP {
sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{
udpRTPPort: th.ClientPorts[0],
udpRTCPPort: th.ClientPorts[1],
}
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolUDP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
ClientPorts: th.ClientPorts,
ServerPorts: &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()},
}.Write()
} else {
sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{}
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolTCP,
InterleavedIDs: th.InterleavedIDs,
}.Write()
}
}
if sc.state == ServerConnStateInitial {
sc.state = ServerConnStatePrePlay
sc.setupPath = &path
sc.setupQuery = &query
}
// workaround to prevent a bug in rtspclientsink
// that makes impossible for the client to receive the response
// and send frames.
// this was causing problems during unit tests.
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
strings.HasPrefix(ua[0], "GStreamer") {
select {
case <-time.After(1 * time.Second):
case <-sc.terminate:
}
}
return res, err
}
case base.Play:
if h, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
// play can be sent twice, allow calling it even if we're already playing
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStatePrePlay: {},
ServerConnStatePlay: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if len(sc.setuppedTracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoTracksSetup{}
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := h.OnPlay(&ServerHandlerOnPlayCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK && sc.state != ServerConnStatePlay {
sc.state = ServerConnStatePlay
sc.frameModeEnable()
}
return res, err
}
case base.Record:
if h, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStatePreRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if len(sc.setuppedTracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoTracksSetup{}
}
if len(sc.setuppedTracks) != len(sc.announcedTracks) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNotAllAnnouncedTracksSetup{}
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := h.OnRecord(&ServerHandlerOnRecordCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
sc.state = ServerConnStateRecord
sc.frameModeEnable()
}
return res, err
}
case base.Pause:
if h, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
err := sc.checkState(map[ServerConnState]struct{}{
ServerConnStatePrePlay: {},
ServerConnStatePlay: {},
ServerConnStatePreRecord: {},
ServerConnStateRecord: {},
})
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := h.OnPause(&ServerHandlerOnPauseCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
switch sc.state {
case ServerConnStatePlay:
sc.frameModeDisable()
sc.state = ServerConnStatePrePlay
case ServerConnStateRecord:
sc.frameModeDisable()
sc.state = ServerConnStatePreRecord
}
}
return res, err
}
case base.GetParameter:
if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
})
}
// GET_PARAMETER is used like a ping
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"text/parameters"},
},
Body: []byte("\n"),
}, nil
case base.SetParameter:
if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
return h.OnSetParameter(&ServerHandlerOnSetParameterCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
})
}
case base.Teardown:
if h, ok := sc.s.Handler.(ServerHandlerOnTeardown); ok {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerNoPath{}
}
path, query := base.PathSplitQuery(pathAndQuery)
return h.OnTeardown(&ServerHandlerOnTeardownCtx{
Conn: sc,
Req: req,
Path: path,
Query: query,
})
}
return &base.Response{
StatusCode: base.StatusOK,
}, liberrors.ErrServerTeardown{}
}
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unhandled method: %v", req.Method)
}
func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
if h, ok := sc.s.Handler.(ServerHandlerOnRequest); ok {
h.OnRequest(req)
}
res, err := sc.handleRequest(req)
if res.Header == nil {
res.Header = base.Header{}
}
// add cseq
if _, ok := err.(liberrors.ErrServerCSeqMissing); !ok {
res.Header["CSeq"] = req.Header["CSeq"]
}
// add server
res.Header["Server"] = base.HeaderValue{"gortsplib"}
if h, ok := sc.s.Handler.(ServerHandlerOnResponse); ok {
h.OnResponse(res)
}
switch {
case sc.doEnableTCPFrame: // start background write
sc.doEnableTCPFrame = false
sc.tcpFrameEnabled = true
if sc.state == ServerConnStateRecord {
sc.tcpFrameBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize))
} else {
// when playing, tcpFrameBuffer is only used to receive RTCP receiver reports,
// that are much smaller than RTP frames and are sent at a fixed interval
// (about 2 frames every 10 secs).
// decrease RAM consumption by allocating less buffers.
sc.tcpFrameBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize))
}
// write response before frames
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
res.Write(sc.bw)
// start background write
sc.tcpBackgroundWriteDone = make(chan struct{})
go sc.tcpBackgroundWrite()
case sc.tcpFrameEnabled: // write to background write
sc.tcpFrameWriteBuffer.Push(res)
default: // write directly
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
res.Write(sc.bw)
}
return err
}
func (sc *ServerConn) backgroundRead() error {
defer sc.frameModeDisable()
var req base.Request
var frame base.InterleavedFrame
for {
if sc.tcpFrameEnabled {
if sc.tcpFrameTimeout {
sc.nconn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
}
frame.Payload = sc.tcpFrameBuffer.Next()
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
if err != nil {
return err
}
switch what.(type) {
case *base.InterleavedFrame:
// forward frame only if it has been set up
if _, ok := sc.setuppedTracks[frame.TrackID]; ok {
if sc.state == ServerConnStateRecord {
sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
frame.StreamType, frame.Payload)
}
if h, ok := sc.s.Handler.(ServerHandlerOnFrame); ok {
h.OnFrame(&ServerHandlerOnFrameCtx{
Conn: sc,
TrackID: frame.TrackID,
StreamType: frame.StreamType,
Payload: frame.Payload,
})
}
}
case *base.Request:
err := sc.handleRequestOuter(&req)
if err != nil {
return err
}
}
} else {
err := req.Read(sc.br)
if err != nil {
if atomic.LoadInt32(&sc.udpTimeout) == 1 {
return liberrors.ErrServerNoUDPPacketsRecently{}
}
return err
}
err = sc.handleRequestOuter(&req)
if err != nil {
return err
}
}
}
}
func (sc *ServerConn) backgroundRecord() {
defer close(sc.backgroundRecordDone)
checkStreamTicker := time.NewTicker(serverConnCheckStreamPeriod)
defer checkStreamTicker.Stop()
receiverReportTicker := time.NewTicker(sc.s.receiverReportPeriod)
defer receiverReportTicker.Stop()
for {
select {
case <-checkStreamTicker.C:
if *sc.setupProtocol != StreamProtocolUDP {
continue
}
inTimeout := func() bool {
now := time.Now()
for _, track := range sc.announcedTracks {
lft := atomic.LoadInt64(track.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) < sc.s.ReadTimeout {
return false
}
}
return true
}()
if inTimeout {
atomic.StoreInt32(&sc.udpTimeout, 1)
sc.nconn.Close()
return
}
case <-receiverReportTicker.C:
now := time.Now()
for trackID, track := range sc.announcedTracks {
r := track.rtcpReceiver.Report(now)
sc.WriteFrame(trackID, StreamTypeRTCP, r)
}
case <-sc.backgroundRecordTerminate:
return
}
}
}
func (sc *ServerConn) tcpBackgroundWrite() {
defer close(sc.tcpBackgroundWriteDone)
for {
what, ok := sc.tcpFrameWriteBuffer.Pull()
if !ok {
return
}
switch w := what.(type) {
case *base.InterleavedFrame:
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
w.Write(sc.bw)
case *base.Response:
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
w.Write(sc.bw)
}
}
}
func (sc *ServerConn) frameModeEnable() {
switch sc.state {
case ServerConnStatePlay:
if *sc.setupProtocol == StreamProtocolTCP {
sc.doEnableTCPFrame = true
} else {
// readers can send RTCP frames, they cannot sent RTP frames
for trackID, track := range sc.setuppedTracks {
sc.s.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, false)
}
}
case ServerConnStateRecord:
if *sc.setupProtocol == StreamProtocolTCP {
sc.doEnableTCPFrame = true
sc.tcpFrameTimeout = true
} else {
for trackID, track := range sc.setuppedTracks {
sc.s.udpRTPListener.addClient(sc.ip(), track.udpRTPPort, sc, trackID, true)
sc.s.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, true)
// open the firewall by sending packets to the counterpart
sc.WriteFrame(trackID, StreamTypeRTP,
[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
sc.WriteFrame(trackID, StreamTypeRTCP,
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
}
}
sc.backgroundRecordTerminate = make(chan struct{})
sc.backgroundRecordDone = make(chan struct{})
go sc.backgroundRecord()
}
}
func (sc *ServerConn) frameModeDisable() {
switch sc.state {
case ServerConnStatePlay:
if *sc.setupProtocol == StreamProtocolTCP {
sc.tcpFrameEnabled = false
sc.tcpFrameWriteBuffer.Close()
<-sc.tcpBackgroundWriteDone
sc.tcpFrameWriteBuffer.Reset()
} else {
for _, track := range sc.setuppedTracks {
sc.s.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort)
}
}
case ServerConnStateRecord:
close(sc.backgroundRecordTerminate)
<-sc.backgroundRecordDone
if *sc.setupProtocol == StreamProtocolTCP {
sc.tcpFrameTimeout = false
sc.nconn.SetReadDeadline(time.Time{})
sc.tcpFrameEnabled = false
sc.tcpFrameWriteBuffer.Close()
<-sc.tcpBackgroundWriteDone
sc.tcpFrameWriteBuffer.Reset()
} else {
for _, track := range sc.setuppedTracks {
sc.s.udpRTPListener.removeClient(sc.ip(), track.udpRTPPort)
sc.s.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort)
}
}
}
}
// WriteFrame writes a frame.
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) {
if *sc.setupProtocol == StreamProtocolUDP {
track := sc.setuppedTracks[trackID]
if streamType == StreamTypeRTP {
sc.s.udpRTPListener.write(payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.udpRTPPort,
})
return
}
sc.s.udpRTCPListener.write(payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.udpRTCPPort,
})
return
}
sc.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,
Payload: payload,
})
}