server: support UDP

This commit is contained in:
aler9
2021-01-04 22:32:49 +01:00
parent a53ba70dbc
commit 85e7127cfe
9 changed files with 820 additions and 247 deletions

View File

@@ -18,7 +18,8 @@ Features:
* Pause reading or publishing without disconnecting from the server
* Server
* Handle requests from clients
* Read and write streams with TCP
* Accept streams from clients with UDP or TCP
* Send streams to clients with UDP or TCP
* Encrypt streams with TLS (RTSPS)
## Table of contents
@@ -38,6 +39,7 @@ Features:
* [client-publish-options](examples/client-publish-options.go)
* [client-publish-pause](examples/client-publish-pause.go)
* [server](examples/server.go)
* [server-udp](examples/server-udp.go)
* [server-tls](examples/server-tls.go)
## API Documentation

View File

@@ -75,17 +75,9 @@ func handleConn(conn *gortsplib.ServerConn) {
// called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) {
// support TCP only
if th.Protocol == gortsplib.StreamProtocolUDP {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": req.Header["Transport"],
"Session": base.HeaderValue{"12345678"},
},
}, nil
@@ -98,8 +90,6 @@ func handleConn(conn *gortsplib.ServerConn) {
readers[conn] = struct{}{}
conn.EnableFrames(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
@@ -119,9 +109,6 @@ func handleConn(conn *gortsplib.ServerConn) {
}, fmt.Errorf("someone is already publishing")
}
conn.EnableFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{

184
examples/server-udp.go Normal file
View File

@@ -0,0 +1,184 @@
// +build ignore
package main
import (
"fmt"
"log"
"sync"
"github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
)
// This example shows how to
// 1. create a RTSP server which accepts plain connections
// 2. allow a single client to publish a stream with TCP or UDP
// 3. allow multiple clients to read that stream with TCP or UDP
var mutex sync.Mutex
var publisher *gortsplib.ServerConn
var sdp []byte
var readers = make(map[*gortsplib.ServerConn]struct{})
// this is called for each incoming connection
func handleConn(conn *gortsplib.ServerConn) {
defer conn.Close()
log.Printf("client connected")
// called after receiving a DESCRIBE request.
onDescribe := func(req *base.Request) (*base.Response, error) {
mutex.Lock()
defer mutex.Unlock()
// no one is publishing yet
if publisher == nil {
return &base.Response{
StatusCode: base.StatusNotFound,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Base": base.HeaderValue{req.URL.String() + "/"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Content: sdp,
}, nil
}
// called after receiving an ANNOUNCE request.
onAnnounce := func(req *base.Request, tracks gortsplib.Tracks) (*base.Response, error) {
mutex.Lock()
defer mutex.Unlock()
if publisher != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
publisher = conn
sdp = tracks.Write()
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
// called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
// called after receiving a PLAY request.
onPlay := func(req *base.Request) (*base.Response, error) {
mutex.Lock()
defer mutex.Unlock()
readers[conn] = struct{}{}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
// called after receiving a RECORD request.
onRecord := func(req *base.Request) (*base.Response, error) {
mutex.Lock()
defer mutex.Unlock()
if conn != publisher {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
// called after receiving a Frame.
onFrame := func(trackID int, typ gortsplib.StreamType, buf []byte) {
mutex.Lock()
defer mutex.Unlock()
// if we are the publisher, route frames to readers
if conn == publisher {
for r := range readers {
r.WriteFrame(trackID, typ, buf)
}
}
}
err := <-conn.Read(gortsplib.ServerConnReadHandlers{
OnDescribe: onDescribe,
OnAnnounce: onAnnounce,
OnSetup: onSetup,
OnPlay: onPlay,
OnRecord: onRecord,
OnFrame: onFrame,
})
log.Printf("client disconnected (%s)", err)
mutex.Lock()
defer mutex.Unlock()
if conn == publisher {
publisher = nil
sdp = nil
}
}
func main() {
// to publish or read UDP streams, two UDP listeners must be created
udpRTPListener, err := gortsplib.NewServerUDPListener(":8000")
if err != nil {
panic(err)
}
udpRTCPListener, err := gortsplib.NewServerUDPListener(":8001")
if err != nil {
panic(err)
}
// create configuration
conf := gortsplib.ServerConf{
UDPRTPListener: udpRTPListener,
UDPRTCPListener: udpRTCPListener,
}
// create server
s, err := conf.Serve(":8554")
if err != nil {
panic(err)
}
log.Printf("server is ready")
// accept connections
for {
conn, err := s.Accept()
if err != nil {
panic(err)
}
go handleConn(conn)
}
}

View File

@@ -13,7 +13,7 @@ import (
)
// This example shows how to
// 1. create a RTSP server
// 1. create a RTSP server which accepts plain connections
// 2. allow a single client to publish a stream with TCP
// 3. allow multiple clients to read that stream with TCP
@@ -74,17 +74,9 @@ func handleConn(conn *gortsplib.ServerConn) {
// called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) {
// support TCP only
if th.Protocol == gortsplib.StreamProtocolUDP {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": req.Header["Transport"],
"Session": base.HeaderValue{"12345678"},
},
}, nil
@@ -97,8 +89,6 @@ func handleConn(conn *gortsplib.ServerConn) {
readers[conn] = struct{}{}
conn.EnableFrames(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
@@ -118,9 +108,6 @@ func handleConn(conn *gortsplib.ServerConn) {
}, fmt.Errorf("someone is already publishing")
}
conn.EnableFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{

View File

@@ -1,8 +1,6 @@
package gortsplib
import (
"bufio"
"crypto/tls"
"net"
)
@@ -24,18 +22,5 @@ func (s *Server) Accept() (*ServerConn, error) {
return nil, err
}
conn := func() net.Conn {
if s.conf.TLSConfig != nil {
return tls.Server(nconn, s.conf.TLSConfig)
}
return nconn
}()
return &ServerConn{
s: s,
nconn: nconn,
br: bufio.NewReaderSize(conn, serverReadBufferSize),
bw: bufio.NewWriterSize(conn, serverWriteBufferSize),
terminate: make(chan struct{}),
}, nil
return newServerConn(s, nconn), nil
}

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"crypto/tls"
"fmt"
"net"
"time"
)
@@ -17,24 +18,32 @@ func Serve(address string) (*Server, error) {
// ServerConf allows to configure a Server.
// All fields are optional.
type ServerConf struct {
// a TLS configuration to accept TLS (RTSPS) connections.
// A TLS configuration to accept TLS (RTSPS) connections.
TLSConfig *tls.Config
// timeout of read operations.
// A ServerUDPListener to send and receive UDP/RTP packets.
// If UDPRTPListener and UDPRTCPListener are not null, the server can accept and send UDP streams.
UDPRTPListener *ServerUDPListener
// A ServerUDPListener to send and receive UDP/RTCP packets.
// If UDPRTPListener and UDPRTCPListener are not null, the server can accept and send UDP streams.
UDPRTCPListener *ServerUDPListener
// Timeout of read operations.
// It defaults to 10 seconds
ReadTimeout time.Duration
// timeout of write operations.
// Timeout of write operations.
// It defaults to 10 seconds
WriteTimeout time.Duration
// read buffer count.
// Read buffer count.
// If greater than 1, allows to pass buffers to routines different than the one
// that is reading frames.
// It defaults to 1
ReadBufferCount int
// function used to initialize the TCP listener.
// Function used to initialize the TCP listener.
// It defaults to net.Listen
Listen func(network string, address string) (net.Listener, error)
}
@@ -54,6 +63,15 @@ func (c ServerConf) Serve(address string) (*Server, error) {
c.Listen = net.Listen
}
if c.TLSConfig != nil && c.UDPRTPListener != nil {
return nil, fmt.Errorf("TLS can't be used together with UDP")
}
if (c.UDPRTPListener != nil && c.UDPRTCPListener == nil) ||
(c.UDPRTPListener == nil && c.UDPRTCPListener != nil) {
return nil, fmt.Errorf("UDPRTPListener and UDPRTPListener must be used together")
}
listener, err := c.Listen("tcp", address)
if err != nil {
return nil, err

View File

@@ -18,6 +18,8 @@ import (
type testServ struct {
s *Server
udpRTPListener *ServerUDPListener
udpRTCPListener *ServerUDPListener
wg sync.WaitGroup
mutex sync.Mutex
publisher *ServerConn
@@ -26,10 +28,32 @@ type testServ struct {
}
func newTestServ(tlsConf *tls.Config) (*testServ, error) {
conf := ServerConf{
var conf ServerConf
var udpRTPListener *ServerUDPListener
var udpRTCPListener *ServerUDPListener
if tlsConf != nil {
conf = ServerConf{
TLSConfig: tlsConf,
}
} else {
var err error
udpRTPListener, err = NewServerUDPListener(":8000")
if err != nil {
return nil, err
}
udpRTCPListener, err = NewServerUDPListener(":8001")
if err != nil {
return nil, err
}
conf = ServerConf{
UDPRTPListener: udpRTPListener,
UDPRTCPListener: udpRTCPListener,
}
}
s, err := conf.Serve(":8554")
if err != nil {
return nil, err
@@ -37,6 +61,8 @@ func newTestServ(tlsConf *tls.Config) (*testServ, error) {
ts := &testServ{
s: s,
udpRTPListener: udpRTPListener,
udpRTCPListener: udpRTCPListener,
readers: make(map[*ServerConn]struct{}),
}
@@ -49,6 +75,12 @@ func newTestServ(tlsConf *tls.Config) (*testServ, error) {
func (ts *testServ) close() {
ts.s.Close()
ts.wg.Wait()
if ts.udpRTPListener != nil {
ts.udpRTPListener.Close()
}
if ts.udpRTCPListener != nil {
ts.udpRTCPListener.Close()
}
}
func (ts *testServ) run() {
@@ -114,7 +146,6 @@ func (ts *testServ) handleConn(conn *ServerConn) {
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": req.Header["Transport"],
"Session": base.HeaderValue{"12345678"},
},
}, nil
@@ -126,8 +157,6 @@ func (ts *testServ) handleConn(conn *ServerConn) {
ts.readers[conn] = struct{}{}
conn.EnableFrames(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
@@ -146,9 +175,6 @@ func (ts *testServ) handleConn(conn *ServerConn) {
}, fmt.Errorf("someone is already publishing")
}
conn.EnableFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
@@ -238,20 +264,31 @@ y++U32uuSFiXDcSLarfIsE992MEJLSAynbF1Rsgsr3gXbGiuToJRyxbIeVy7gwzD
-----END RSA PRIVATE KEY-----
`)
func TestServerPublishReadTCP(t *testing.T) {
func TestServerPublishRead(t *testing.T) {
for _, ca := range []struct {
encrypted bool
publisher string
reader string
publisherSoft string
publisherProto string
readerSoft string
readerProto string
}{
{false, "ffmpeg", "ffmpeg"},
{false, "ffmpeg", "gstreamer"},
{false, "gstreamer", "ffmpeg"},
{false, "gstreamer", "gstreamer"},
{true, "ffmpeg", "ffmpeg"},
{true, "ffmpeg", "gstreamer"},
{true, "gstreamer", "ffmpeg"},
{true, "gstreamer", "gstreamer"},
{false, "ffmpeg", "udp", "ffmpeg", "udp"},
{false, "ffmpeg", "udp", "gstreamer", "udp"},
{false, "gstreamer", "udp", "ffmpeg", "udp"},
{false, "gstreamer", "udp", "gstreamer", "udp"},
{false, "ffmpeg", "tcp", "ffmpeg", "tcp"},
{false, "ffmpeg", "tcp", "gstreamer", "tcp"},
{false, "gstreamer", "tcp", "ffmpeg", "tcp"},
{false, "gstreamer", "tcp", "gstreamer", "tcp"},
{false, "ffmpeg", "tcp", "ffmpeg", "udp"},
{false, "ffmpeg", "udp", "ffmpeg", "tcp"},
{true, "ffmpeg", "tcp", "ffmpeg", "tcp"},
{true, "ffmpeg", "tcp", "gstreamer", "tcp"},
{true, "gstreamer", "tcp", "ffmpeg", "tcp"},
{true, "gstreamer", "tcp", "gstreamer", "tcp"},
} {
encryptedStr := func() string {
if ca.encrypted {
@@ -260,7 +297,8 @@ func TestServerPublishReadTCP(t *testing.T) {
return "plain"
}()
t.Run(encryptedStr+"_"+ca.publisher+"_"+ca.reader, func(t *testing.T) {
t.Run(encryptedStr+"_"+ca.publisherSoft+"_"+ca.publisherProto+"_"+
ca.readerSoft+"_"+ca.readerProto, func(t *testing.T) {
var proto string
var tlsConf *tls.Config
if !ca.encrypted {
@@ -278,7 +316,7 @@ func TestServerPublishReadTCP(t *testing.T) {
require.NoError(t, err)
defer ts.close()
switch ca.publisher {
switch ca.publisherSoft {
case "ffmpeg":
cnt1, err := newContainer("ffmpeg", "publish", []string{
"-re",
@@ -286,7 +324,7 @@ func TestServerPublishReadTCP(t *testing.T) {
"-i", "emptyvideo.ts",
"-c", "copy",
"-f", "rtsp",
"-rtsp_transport", "tcp",
"-rtsp_transport", ca.publisherProto,
proto + "://localhost:8554/teststream",
})
require.NoError(t, err)
@@ -295,7 +333,7 @@ func TestServerPublishReadTCP(t *testing.T) {
case "gstreamer":
cnt1, err := newContainer("gstreamer", "publish", []string{
"filesrc location=emptyvideo.ts ! tsdemux ! video/x-h264 ! rtspclientsink " +
"location=" + proto + "://127.0.0.1:8554/teststream protocols=tcp tls-validation-flags=0 latency=0 timeout=0 rtx-time=0",
"location=" + proto + "://127.0.0.1:8554/teststream protocols=" + ca.publisherProto + " tls-validation-flags=0 latency=0 timeout=0 rtx-time=0",
})
require.NoError(t, err)
defer cnt1.close()
@@ -305,10 +343,10 @@ func TestServerPublishReadTCP(t *testing.T) {
time.Sleep(1 * time.Second)
switch ca.reader {
switch ca.readerSoft {
case "ffmpeg":
cnt2, err := newContainer("ffmpeg", "read", []string{
"-rtsp_transport", "tcp",
"-rtsp_transport", ca.readerProto,
"-i", proto + "://localhost:8554/teststream",
"-vframes", "1",
"-f", "image2",
@@ -320,7 +358,7 @@ func TestServerPublishReadTCP(t *testing.T) {
case "gstreamer":
cnt2, err := newContainer("gstreamer", "read", []string{
"rtspsrc location=" + proto + "://127.0.0.1:8554/teststream protocols=tcp tls-validation-flags=0 latency=0 " +
"rtspsrc location=" + proto + "://127.0.0.1:8554/teststream protocols=" + ca.readerProto + " tls-validation-flags=0 latency=0 " +
"! application/x-rtp,media=video ! decodebin ! exitafterframe ! fakesink",
})
require.NoError(t, err)
@@ -399,6 +437,7 @@ func TestServerResponseBeforeFrames(t *testing.T) {
v := headers.TransportModePlay
return &v
}(),
InterleavedIds: &[2]int{0, 1},
}.Write(),
},
}.Write(bconn.Writer)

View File

@@ -2,9 +2,11 @@ package gortsplib
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
@@ -22,47 +24,38 @@ const (
// server errors.
var (
ErrServerTeardown = errors.New("teardown")
ErrServerContentTypeMissing = errors.New("Content-Type header is missing")
ErrServerNoTracksDefined = errors.New("no tracks defined")
ErrServerMissingCseq = errors.New("CSeq is missing")
ErrServerFramesDisabled = errors.New("frames are disabled")
)
// ServerConn is a server-side RTSP connection.
type ServerConn struct {
s *Server
nconn net.Conn
br *bufio.Reader
bw *bufio.Writer
writeMutex sync.Mutex
nextFramesEnabled bool
framesEnabled bool
readTimeoutEnabled bool
type serverConnState int
// in
terminate chan struct{}
const (
serverConnStateInitial serverConnState = iota
serverConnStatePlay
serverConnStateRecord
)
type serverConnTrack struct {
proto StreamProtocol
rtpPort int
rtcpPort int
}
// Close closes all the connection resources.
func (sc *ServerConn) Close() error {
err := sc.nconn.Close()
close(sc.terminate)
return err
func extractTrackID(controlPath string, mode *headers.TransportMode, trackLen int) (int, error) {
if mode == nil || *mode == headers.TransportModePlay {
if !strings.HasPrefix(controlPath, "trackID=") {
return 0, fmt.Errorf("invalid control attribute (%s)", controlPath)
}
// NetConn returns the underlying net.Conn.
func (sc *ServerConn) NetConn() net.Conn {
return sc.nconn
tmp, err := strconv.ParseInt(controlPath[len("trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, fmt.Errorf("invalid track id (%s)", controlPath)
}
trackID := int(tmp)
return trackID, nil
}
// EnableFrames allows reading and writing TCP frames.
func (sc *ServerConn) EnableFrames(v bool) {
sc.nextFramesEnabled = v
}
// EnableReadTimeout sets or removes the timeout on incoming packets.
func (sc *ServerConn) EnableReadTimeout(v bool) {
sc.readTimeoutEnabled = v
return trackLen, nil
}
// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
@@ -108,42 +101,137 @@ type ServerConnReadHandlers struct {
OnTeardown func(req *base.Request) (*base.Response, error)
// called after receiving a Frame.
OnFrame func(trackID int, streamType StreamType, content []byte)
OnFrame func(trackID int, streamType StreamType, payload []byte)
}
func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan error) {
handleRequest := func(req *base.Request) (*base.Response, error) {
if handlers.OnRequest != nil {
handlers.OnRequest(req)
// ServerConn is a server-side RTSP connection.
type ServerConn struct {
s *Server
nconn net.Conn
br *bufio.Reader
bw *bufio.Writer
state serverConnState
tracks map[int]serverConnTrack
tracksProto *StreamProtocol
writeMutex sync.Mutex
readHandlers ServerConnReadHandlers
nextFramesEnabled bool
framesEnabled bool
readTimeoutEnabled bool
// in
terminate chan struct{}
}
func newServerConn(s *Server, nconn net.Conn) *ServerConn {
conn := func() net.Conn {
if s.conf.TLSConfig != nil {
return tls.Server(nconn, s.conf.TLSConfig)
}
return nconn
}()
return &ServerConn{
s: s,
nconn: nconn,
br: bufio.NewReaderSize(conn, serverReadBufferSize),
bw: bufio.NewWriterSize(conn, serverWriteBufferSize),
tracks: make(map[int]serverConnTrack),
terminate: make(chan struct{}),
}
}
// Close closes all the connection resources.
func (sc *ServerConn) Close() error {
err := sc.nconn.Close()
close(sc.terminate)
return err
}
// NetConn returns the underlying net.Conn.
func (sc *ServerConn) NetConn() net.Conn {
return sc.nconn
}
func (sc *ServerConn) ip() net.IP {
return sc.nconn.RemoteAddr().(*net.TCPAddr).IP
}
func (sc *ServerConn) zone() string {
return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone
}
func (sc *ServerConn) frameModeEnable() {
switch sc.state {
case serverConnStatePlay:
if *sc.tracksProto == StreamProtocolTCP {
sc.nextFramesEnabled = true
}
case serverConnStateRecord:
if *sc.tracksProto == StreamProtocolTCP {
sc.nextFramesEnabled = true
sc.readTimeoutEnabled = true
} else {
for trackID, track := range sc.tracks {
sc.s.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
sc.s.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
}
}
}
}
func (sc *ServerConn) frameModeDisable() {
switch sc.state {
case serverConnStatePlay:
sc.nextFramesEnabled = false
case serverConnStateRecord:
sc.nextFramesEnabled = false
sc.readTimeoutEnabled = false
for _, track := range sc.tracks {
if track.proto == StreamProtocolUDP {
sc.s.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
sc.s.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
}
}
}
}
func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if sc.readHandlers.OnRequest != nil {
sc.readHandlers.OnRequest(req)
}
switch req.Method {
case base.Options:
if handlers.OnOptions != nil {
return handlers.OnOptions(req)
if sc.readHandlers.OnOptions != nil {
return sc.readHandlers.OnOptions(req)
}
var methods []string
if handlers.OnDescribe != nil {
if sc.readHandlers.OnDescribe != nil {
methods = append(methods, string(base.Describe))
}
if handlers.OnAnnounce != nil {
if sc.readHandlers.OnAnnounce != nil {
methods = append(methods, string(base.Announce))
}
if handlers.OnSetup != nil {
if sc.readHandlers.OnSetup != nil {
methods = append(methods, string(base.Setup))
}
if handlers.OnPlay != nil {
if sc.readHandlers.OnPlay != nil {
methods = append(methods, string(base.Play))
}
if handlers.OnRecord != nil {
if sc.readHandlers.OnRecord != nil {
methods = append(methods, string(base.Record))
}
if handlers.OnPause != nil {
if sc.readHandlers.OnPause != nil {
methods = append(methods, string(base.Pause))
}
methods = append(methods, string(base.GetParameter))
if handlers.OnSetParameter != nil {
if sc.readHandlers.OnSetParameter != nil {
methods = append(methods, string(base.SetParameter))
}
methods = append(methods, string(base.Teardown))
@@ -156,17 +244,17 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
}, nil
case base.Describe:
if handlers.OnDescribe != nil {
return handlers.OnDescribe(req)
if sc.readHandlers.OnDescribe != nil {
return sc.readHandlers.OnDescribe(req)
}
case base.Announce:
if handlers.OnAnnounce != nil {
if sc.readHandlers.OnAnnounce != nil {
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, ErrServerContentTypeMissing
}, errors.New("Content-Type header is missing")
}
if ct[0] != "application/sdp" {
@@ -185,14 +273,22 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
if len(tracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, ErrServerNoTracksDefined
}, errors.New("no tracks defined")
}
return handlers.OnAnnounce(req, tracks)
res, err := sc.readHandlers.OnAnnounce(req, tracks)
return res, err
}
case base.Setup:
if handlers.OnSetup != nil {
if sc.readHandlers.OnSetup != nil {
_, controlPath, ok := req.URL.BasePathControlAttr()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unable to find control attribute (%s)", req.URL)
}
th, err := headers.ReadTransport(req.Header["Transport"])
if err != nil {
return &base.Response{
@@ -200,6 +296,88 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
}, fmt.Errorf("transport header: %s", err)
}
trackID, err := extractTrackID(controlPath, th.Mode, len(sc.tracks))
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if _, ok := sc.tracks[trackID]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("track %d has already been setup", trackID)
}
if sc.tracksProto != nil && *sc.tracksProto != th.Protocol {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("can't receive tracks with different protocols")
}
if th.Protocol == StreamProtocolUDP {
if sc.s.conf.UDPRTPListener == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if th.ClientPorts == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header does not have valid client ports (%v)", req.Header["Transport"])
}
} else {
if th.InterleavedIds == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header does not contain the interleaved field")
}
if (*th.InterleavedIds)[0] != (trackID*2) ||
(*th.InterleavedIds)[1] != (1+trackID*2) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("wrong interleaved ids, expected [%v %v], got %v",
(trackID * 2), (1 + trackID*2), *th.InterleavedIds)
}
}
res, err := sc.readHandlers.OnSetup(req, th)
if res.StatusCode == 200 {
sc.tracksProto = &th.Protocol
if th.Protocol == StreamProtocolUDP {
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolUDP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
ClientPorts: th.ClientPorts,
ServerPorts: &[2]int{sc.s.conf.UDPRTPListener.port(), sc.s.conf.UDPRTCPListener.port()},
}.Write()
sc.tracks[trackID] = serverConnTrack{
proto: StreamProtocolUDP,
rtpPort: th.ClientPorts[0],
rtcpPort: th.ClientPorts[1],
}
} else {
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolTCP,
InterleavedIds: th.InterleavedIds,
}.Write()
sc.tracks[trackID] = serverConnTrack{
proto: StreamProtocolTCP,
}
}
}
// workaround to prevent a bug in rtspclientsink
// that makes impossible for the client to receive the response
// and send frames.
@@ -214,27 +392,48 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
}
}
return handlers.OnSetup(req, th)
return res, err
}
case base.Play:
if handlers.OnPlay != nil {
return handlers.OnPlay(req)
if sc.readHandlers.OnPlay != nil {
res, err := sc.readHandlers.OnPlay(req)
if res.StatusCode == 200 {
sc.state = serverConnStatePlay
sc.frameModeEnable()
}
return res, err
}
case base.Record:
if handlers.OnRecord != nil {
return handlers.OnRecord(req)
if sc.readHandlers.OnRecord != nil {
res, err := sc.readHandlers.OnRecord(req)
if res.StatusCode == 200 {
sc.state = serverConnStateRecord
sc.frameModeEnable()
}
return res, err
}
case base.Pause:
if handlers.OnPause != nil {
return handlers.OnPause(req)
if sc.readHandlers.OnPause != nil {
res, err := sc.readHandlers.OnPause(req)
if res.StatusCode == 200 {
sc.frameModeDisable()
sc.state = serverConnStateInitial
}
return res, err
}
case base.GetParameter:
if handlers.OnGetParameter != nil {
return handlers.OnGetParameter(req)
if sc.readHandlers.OnGetParameter != nil {
return sc.readHandlers.OnGetParameter(req)
}
// GET_PARAMETER is used like a ping
@@ -247,13 +446,13 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
}, nil
case base.SetParameter:
if handlers.OnSetParameter != nil {
return handlers.OnSetParameter(req)
if sc.readHandlers.OnSetParameter != nil {
return sc.readHandlers.OnSetParameter(req)
}
case base.Teardown:
if handlers.OnTeardown != nil {
return handlers.OnTeardown(req)
if sc.readHandlers.OnTeardown != nil {
return sc.readHandlers.OnTeardown(req)
}
return &base.Response{
@@ -266,6 +465,7 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
}, fmt.Errorf("unhandled method: %v", req.Method)
}
func (sc *ServerConn) backgroundRead() error {
handleRequestOuter := func(req *base.Request) error {
// check cseq
cseq, ok := req.Header["CSeq"]
@@ -277,10 +477,10 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
Header: base.Header{},
}.Write(sc.bw)
sc.writeMutex.Unlock()
return ErrServerMissingCseq
return errors.New("CSeq is missing")
}
res, err := handleRequest(req)
res, err := sc.handleRequest(req)
if res.Header == nil {
res.Header = base.Header{}
@@ -292,8 +492,8 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
// add server
res.Header["Server"] = base.HeaderValue{"gortsplib"}
if handlers.OnResponse != nil {
handlers.OnResponse(res)
if sc.readHandlers.OnResponse != nil {
sc.readHandlers.OnResponse(res)
}
sc.writeMutex.Lock()
@@ -302,7 +502,7 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
res.Write(sc.bw)
// set framesEnabled after sending the response
// in order to start sending frames after the response
// in order to start sending frames after the response, never before
if sc.framesEnabled != sc.nextFramesEnabled {
sc.framesEnabled = sc.nextFramesEnabled
}
@@ -335,7 +535,7 @@ outer:
switch what.(type) {
case *base.InterleavedFrame:
handlers.OnFrame(frame.TrackID, frame.StreamType, frame.Content)
sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Content)
case *base.Request:
err := handleRequestOuter(&req)
@@ -360,34 +560,60 @@ outer:
}
}
done <- errRet
sc.frameModeDisable()
return errRet
}
// Read starts reading requests and frames.
// it returns a channel that is written when the reading stops.
func (sc *ServerConn) Read(handlers ServerConnReadHandlers) chan error {
func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error {
// channel is buffered, since listening to it is not mandatory
done := make(chan error, 1)
go sc.backgroundRead(handlers, done)
sc.readHandlers = readHandlers
go func() {
done <- sc.backgroundRead()
}()
return done
}
// WriteFrame writes a frame.
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, content []byte) error {
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error {
sc.writeMutex.Lock()
defer sc.writeMutex.Unlock()
track := sc.tracks[trackID]
if track.proto == StreamProtocolUDP {
if streamType == StreamTypeRtp {
return sc.s.conf.UDPRTPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.rtpPort,
})
}
return sc.s.conf.UDPRTCPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.rtcpPort,
})
}
// StreamProtocolTCP
if !sc.framesEnabled {
return ErrServerFramesDisabled
return errors.New("frames are disabled")
}
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
frame := base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,
Content: content,
Content: payload,
}
return frame.Write(sc.bw)
}

145
serverudpl.go Normal file
View File

@@ -0,0 +1,145 @@
package gortsplib
import (
"net"
"sync"
"time"
"github.com/aler9/gortsplib/pkg/multibuffer"
)
const (
// use the same buffer size as gstreamer's rtspsrc
kernelReadBufferSize = 0x80000
readBufferSize = 2048
)
type publisherData struct {
publisher *ServerConn
trackID int
}
type publisherAddr struct {
ip [net.IPv6len]byte // use a fixed-size array to enable the equality operator
port int
}
func (p *publisherAddr) fill(ip net.IP, port int) {
p.port = port
if len(ip) == net.IPv4len {
copy(p.ip[0:], []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}) // v4InV6Prefix
copy(p.ip[12:], ip)
} else {
copy(p.ip[:], ip)
}
}
// ServerUDPListener is a UDP server that can be used to send and receive RTP and RTCP packets.
type ServerUDPListener struct {
streamType StreamType
pc *net.UDPConn
readBuf *multibuffer.MultiBuffer
publishersMutex sync.RWMutex
publishers map[publisherAddr]*publisherData
writeMutex sync.Mutex
// out
done chan struct{}
}
// NewServerUDPListener allocates a ServerUDPListener.
func NewServerUDPListener(address string) (*ServerUDPListener, error) {
tmp, err := net.ListenPacket("udp", address)
if err != nil {
return nil, err
}
pc := tmp.(*net.UDPConn)
err = pc.SetReadBuffer(kernelReadBufferSize)
if err != nil {
return nil, err
}
s := &ServerUDPListener{
pc: pc,
readBuf: multibuffer.New(1, readBufferSize),
publishers: make(map[publisherAddr]*publisherData),
done: make(chan struct{}),
}
go s.run()
return s, nil
}
// Close closes the listener.
func (s *ServerUDPListener) Close() {
s.pc.Close()
<-s.done
}
func (s *ServerUDPListener) run() {
defer close(s.done)
for {
buf := s.readBuf.Next()
n, addr, err := s.pc.ReadFromUDP(buf)
if err != nil {
break
}
func() {
s.publishersMutex.RLock()
defer s.publishersMutex.RUnlock()
// find publisher data
var pubAddr publisherAddr
pubAddr.fill(addr.IP, addr.Port)
pubData, ok := s.publishers[pubAddr]
if !ok {
return
}
pubData.publisher.readHandlers.OnFrame(pubData.trackID, s.streamType, buf[:n])
}()
}
}
func (s *ServerUDPListener) port() int {
return s.pc.LocalAddr().(*net.UDPAddr).Port
}
func (s *ServerUDPListener) write(writeTimeout time.Duration, buf []byte, addr *net.UDPAddr) error {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
s.pc.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err := s.pc.WriteTo(buf, addr)
return err
}
func (s *ServerUDPListener) addPublisher(ip net.IP, port int, trackID int, sc *ServerConn) {
s.publishersMutex.Lock()
defer s.publishersMutex.Unlock()
var addr publisherAddr
addr.fill(ip, port)
s.publishers[addr] = &publisherData{
publisher: sc,
trackID: trackID,
}
}
func (s *ServerUDPListener) removePublisher(ip net.IP, port int) {
s.publishersMutex.Lock()
defer s.publishersMutex.Unlock()
var addr publisherAddr
addr.fill(ip, port)
delete(s.publishers, addr)
}