diff --git a/README.md b/README.md index cc58a390..ae95c5f2 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,8 @@ Features: * Pause reading or publishing without disconnecting from the server * Server * Handle requests from clients - * Read and write streams with TCP + * Accept streams from clients with UDP or TCP + * Send streams to clients with UDP or TCP * Encrypt streams with TLS (RTSPS) ## Table of contents @@ -38,6 +39,7 @@ Features: * [client-publish-options](examples/client-publish-options.go) * [client-publish-pause](examples/client-publish-pause.go) * [server](examples/server.go) +* [server-udp](examples/server-udp.go) * [server-tls](examples/server-tls.go) ## API Documentation diff --git a/examples/server-tls.go b/examples/server-tls.go index edf78a47..62af3cd3 100644 --- a/examples/server-tls.go +++ b/examples/server-tls.go @@ -75,18 +75,10 @@ func handleConn(conn *gortsplib.ServerConn) { // called after receiving a SETUP request. onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) { - // support TCP only - if th.Protocol == gortsplib.StreamProtocolUDP { - return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil - } - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "Transport": req.Header["Transport"], - "Session": base.HeaderValue{"12345678"}, + "Session": base.HeaderValue{"12345678"}, }, }, nil } @@ -98,8 +90,6 @@ func handleConn(conn *gortsplib.ServerConn) { readers[conn] = struct{}{} - conn.EnableFrames(true) - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ @@ -119,9 +109,6 @@ func handleConn(conn *gortsplib.ServerConn) { }, fmt.Errorf("someone is already publishing") } - conn.EnableFrames(true) - conn.EnableReadTimeout(true) - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ diff --git a/examples/server-udp.go b/examples/server-udp.go new file mode 100644 index 00000000..5821051d --- /dev/null +++ b/examples/server-udp.go @@ -0,0 +1,184 @@ +// +build ignore + +package main + +import ( + "fmt" + "log" + "sync" + + "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/headers" +) + +// This example shows how to +// 1. create a RTSP server which accepts plain connections +// 2. allow a single client to publish a stream with TCP or UDP +// 3. allow multiple clients to read that stream with TCP or UDP + +var mutex sync.Mutex +var publisher *gortsplib.ServerConn +var sdp []byte +var readers = make(map[*gortsplib.ServerConn]struct{}) + +// this is called for each incoming connection +func handleConn(conn *gortsplib.ServerConn) { + defer conn.Close() + + log.Printf("client connected") + + // called after receiving a DESCRIBE request. + onDescribe := func(req *base.Request) (*base.Response, error) { + mutex.Lock() + defer mutex.Unlock() + + // no one is publishing yet + if publisher == nil { + return &base.Response{ + StatusCode: base.StatusNotFound, + }, nil + } + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Base": base.HeaderValue{req.URL.String() + "/"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Content: sdp, + }, nil + } + + // called after receiving an ANNOUNCE request. + onAnnounce := func(req *base.Request, tracks gortsplib.Tracks) (*base.Response, error) { + mutex.Lock() + defer mutex.Unlock() + + if publisher != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + publisher = conn + sdp = tracks.Write() + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + // called after receiving a SETUP request. + onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + // called after receiving a PLAY request. + onPlay := func(req *base.Request) (*base.Response, error) { + mutex.Lock() + defer mutex.Unlock() + + readers[conn] = struct{}{} + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + // called after receiving a RECORD request. + onRecord := func(req *base.Request) (*base.Response, error) { + mutex.Lock() + defer mutex.Unlock() + + if conn != publisher { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + // called after receiving a Frame. + onFrame := func(trackID int, typ gortsplib.StreamType, buf []byte) { + mutex.Lock() + defer mutex.Unlock() + + // if we are the publisher, route frames to readers + if conn == publisher { + for r := range readers { + r.WriteFrame(trackID, typ, buf) + } + } + } + + err := <-conn.Read(gortsplib.ServerConnReadHandlers{ + OnDescribe: onDescribe, + OnAnnounce: onAnnounce, + OnSetup: onSetup, + OnPlay: onPlay, + OnRecord: onRecord, + OnFrame: onFrame, + }) + log.Printf("client disconnected (%s)", err) + + mutex.Lock() + defer mutex.Unlock() + + if conn == publisher { + publisher = nil + sdp = nil + } +} + +func main() { + // to publish or read UDP streams, two UDP listeners must be created + udpRTPListener, err := gortsplib.NewServerUDPListener(":8000") + if err != nil { + panic(err) + } + udpRTCPListener, err := gortsplib.NewServerUDPListener(":8001") + if err != nil { + panic(err) + } + + // create configuration + conf := gortsplib.ServerConf{ + UDPRTPListener: udpRTPListener, + UDPRTCPListener: udpRTCPListener, + } + + // create server + s, err := conf.Serve(":8554") + if err != nil { + panic(err) + } + log.Printf("server is ready") + + // accept connections + for { + conn, err := s.Accept() + if err != nil { + panic(err) + } + + go handleConn(conn) + } +} diff --git a/examples/server.go b/examples/server.go index b038f6cb..759a6c18 100644 --- a/examples/server.go +++ b/examples/server.go @@ -13,7 +13,7 @@ import ( ) // This example shows how to -// 1. create a RTSP server +// 1. create a RTSP server which accepts plain connections // 2. allow a single client to publish a stream with TCP // 3. allow multiple clients to read that stream with TCP @@ -74,18 +74,10 @@ func handleConn(conn *gortsplib.ServerConn) { // called after receiving a SETUP request. onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) { - // support TCP only - if th.Protocol == gortsplib.StreamProtocolUDP { - return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil - } - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "Transport": req.Header["Transport"], - "Session": base.HeaderValue{"12345678"}, + "Session": base.HeaderValue{"12345678"}, }, }, nil } @@ -97,8 +89,6 @@ func handleConn(conn *gortsplib.ServerConn) { readers[conn] = struct{}{} - conn.EnableFrames(true) - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ @@ -118,9 +108,6 @@ func handleConn(conn *gortsplib.ServerConn) { }, fmt.Errorf("someone is already publishing") } - conn.EnableFrames(true) - conn.EnableReadTimeout(true) - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ diff --git a/server.go b/server.go index 162aeb96..cff87338 100644 --- a/server.go +++ b/server.go @@ -1,8 +1,6 @@ package gortsplib import ( - "bufio" - "crypto/tls" "net" ) @@ -24,18 +22,5 @@ func (s *Server) Accept() (*ServerConn, error) { return nil, err } - conn := func() net.Conn { - if s.conf.TLSConfig != nil { - return tls.Server(nconn, s.conf.TLSConfig) - } - return nconn - }() - - return &ServerConn{ - s: s, - nconn: nconn, - br: bufio.NewReaderSize(conn, serverReadBufferSize), - bw: bufio.NewWriterSize(conn, serverWriteBufferSize), - terminate: make(chan struct{}), - }, nil + return newServerConn(s, nconn), nil } diff --git a/serverconf.go b/serverconf.go index 87a6cfbe..1eb8e7b2 100644 --- a/serverconf.go +++ b/serverconf.go @@ -2,6 +2,7 @@ package gortsplib import ( "crypto/tls" + "fmt" "net" "time" ) @@ -17,24 +18,32 @@ func Serve(address string) (*Server, error) { // ServerConf allows to configure a Server. // All fields are optional. type ServerConf struct { - // a TLS configuration to accept TLS (RTSPS) connections. + // A TLS configuration to accept TLS (RTSPS) connections. TLSConfig *tls.Config - // timeout of read operations. + // A ServerUDPListener to send and receive UDP/RTP packets. + // If UDPRTPListener and UDPRTCPListener are not null, the server can accept and send UDP streams. + UDPRTPListener *ServerUDPListener + + // A ServerUDPListener to send and receive UDP/RTCP packets. + // If UDPRTPListener and UDPRTCPListener are not null, the server can accept and send UDP streams. + UDPRTCPListener *ServerUDPListener + + // Timeout of read operations. // It defaults to 10 seconds ReadTimeout time.Duration - // timeout of write operations. + // Timeout of write operations. // It defaults to 10 seconds WriteTimeout time.Duration - // read buffer count. + // Read buffer count. // If greater than 1, allows to pass buffers to routines different than the one // that is reading frames. // It defaults to 1 ReadBufferCount int - // function used to initialize the TCP listener. + // Function used to initialize the TCP listener. // It defaults to net.Listen Listen func(network string, address string) (net.Listener, error) } @@ -54,6 +63,15 @@ func (c ServerConf) Serve(address string) (*Server, error) { c.Listen = net.Listen } + if c.TLSConfig != nil && c.UDPRTPListener != nil { + return nil, fmt.Errorf("TLS can't be used together with UDP") + } + + if (c.UDPRTPListener != nil && c.UDPRTCPListener == nil) || + (c.UDPRTPListener == nil && c.UDPRTCPListener != nil) { + return nil, fmt.Errorf("UDPRTPListener and UDPRTPListener must be used together") + } + listener, err := c.Listen("tcp", address) if err != nil { return nil, err diff --git a/serverconf_test.go b/serverconf_test.go index 9df4ef96..db9237a7 100644 --- a/serverconf_test.go +++ b/serverconf_test.go @@ -17,17 +17,41 @@ import ( ) type testServ struct { - s *Server - wg sync.WaitGroup - mutex sync.Mutex - publisher *ServerConn - sdp []byte - readers map[*ServerConn]struct{} + s *Server + udpRTPListener *ServerUDPListener + udpRTCPListener *ServerUDPListener + wg sync.WaitGroup + mutex sync.Mutex + publisher *ServerConn + sdp []byte + readers map[*ServerConn]struct{} } func newTestServ(tlsConf *tls.Config) (*testServ, error) { - conf := ServerConf{ - TLSConfig: tlsConf, + var conf ServerConf + var udpRTPListener *ServerUDPListener + var udpRTCPListener *ServerUDPListener + if tlsConf != nil { + conf = ServerConf{ + TLSConfig: tlsConf, + } + + } else { + var err error + udpRTPListener, err = NewServerUDPListener(":8000") + if err != nil { + return nil, err + } + + udpRTCPListener, err = NewServerUDPListener(":8001") + if err != nil { + return nil, err + } + + conf = ServerConf{ + UDPRTPListener: udpRTPListener, + UDPRTCPListener: udpRTCPListener, + } } s, err := conf.Serve(":8554") @@ -36,8 +60,10 @@ func newTestServ(tlsConf *tls.Config) (*testServ, error) { } ts := &testServ{ - s: s, - readers: make(map[*ServerConn]struct{}), + s: s, + udpRTPListener: udpRTPListener, + udpRTCPListener: udpRTCPListener, + readers: make(map[*ServerConn]struct{}), } ts.wg.Add(1) @@ -49,6 +75,12 @@ func newTestServ(tlsConf *tls.Config) (*testServ, error) { func (ts *testServ) close() { ts.s.Close() ts.wg.Wait() + if ts.udpRTPListener != nil { + ts.udpRTPListener.Close() + } + if ts.udpRTCPListener != nil { + ts.udpRTCPListener.Close() + } } func (ts *testServ) run() { @@ -114,8 +146,7 @@ func (ts *testServ) handleConn(conn *ServerConn) { return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "Transport": req.Header["Transport"], - "Session": base.HeaderValue{"12345678"}, + "Session": base.HeaderValue{"12345678"}, }, }, nil } @@ -126,8 +157,6 @@ func (ts *testServ) handleConn(conn *ServerConn) { ts.readers[conn] = struct{}{} - conn.EnableFrames(true) - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ @@ -146,9 +175,6 @@ func (ts *testServ) handleConn(conn *ServerConn) { }, fmt.Errorf("someone is already publishing") } - conn.EnableFrames(true) - conn.EnableReadTimeout(true) - return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ @@ -238,20 +264,31 @@ y++U32uuSFiXDcSLarfIsE992MEJLSAynbF1Rsgsr3gXbGiuToJRyxbIeVy7gwzD -----END RSA PRIVATE KEY----- `) -func TestServerPublishReadTCP(t *testing.T) { +func TestServerPublishRead(t *testing.T) { for _, ca := range []struct { - encrypted bool - publisher string - reader string + encrypted bool + publisherSoft string + publisherProto string + readerSoft string + readerProto string }{ - {false, "ffmpeg", "ffmpeg"}, - {false, "ffmpeg", "gstreamer"}, - {false, "gstreamer", "ffmpeg"}, - {false, "gstreamer", "gstreamer"}, - {true, "ffmpeg", "ffmpeg"}, - {true, "ffmpeg", "gstreamer"}, - {true, "gstreamer", "ffmpeg"}, - {true, "gstreamer", "gstreamer"}, + {false, "ffmpeg", "udp", "ffmpeg", "udp"}, + {false, "ffmpeg", "udp", "gstreamer", "udp"}, + {false, "gstreamer", "udp", "ffmpeg", "udp"}, + {false, "gstreamer", "udp", "gstreamer", "udp"}, + + {false, "ffmpeg", "tcp", "ffmpeg", "tcp"}, + {false, "ffmpeg", "tcp", "gstreamer", "tcp"}, + {false, "gstreamer", "tcp", "ffmpeg", "tcp"}, + {false, "gstreamer", "tcp", "gstreamer", "tcp"}, + + {false, "ffmpeg", "tcp", "ffmpeg", "udp"}, + {false, "ffmpeg", "udp", "ffmpeg", "tcp"}, + + {true, "ffmpeg", "tcp", "ffmpeg", "tcp"}, + {true, "ffmpeg", "tcp", "gstreamer", "tcp"}, + {true, "gstreamer", "tcp", "ffmpeg", "tcp"}, + {true, "gstreamer", "tcp", "gstreamer", "tcp"}, } { encryptedStr := func() string { if ca.encrypted { @@ -260,7 +297,8 @@ func TestServerPublishReadTCP(t *testing.T) { return "plain" }() - t.Run(encryptedStr+"_"+ca.publisher+"_"+ca.reader, func(t *testing.T) { + t.Run(encryptedStr+"_"+ca.publisherSoft+"_"+ca.publisherProto+"_"+ + ca.readerSoft+"_"+ca.readerProto, func(t *testing.T) { var proto string var tlsConf *tls.Config if !ca.encrypted { @@ -278,7 +316,7 @@ func TestServerPublishReadTCP(t *testing.T) { require.NoError(t, err) defer ts.close() - switch ca.publisher { + switch ca.publisherSoft { case "ffmpeg": cnt1, err := newContainer("ffmpeg", "publish", []string{ "-re", @@ -286,7 +324,7 @@ func TestServerPublishReadTCP(t *testing.T) { "-i", "emptyvideo.ts", "-c", "copy", "-f", "rtsp", - "-rtsp_transport", "tcp", + "-rtsp_transport", ca.publisherProto, proto + "://localhost:8554/teststream", }) require.NoError(t, err) @@ -295,7 +333,7 @@ func TestServerPublishReadTCP(t *testing.T) { case "gstreamer": cnt1, err := newContainer("gstreamer", "publish", []string{ "filesrc location=emptyvideo.ts ! tsdemux ! video/x-h264 ! rtspclientsink " + - "location=" + proto + "://127.0.0.1:8554/teststream protocols=tcp tls-validation-flags=0 latency=0 timeout=0 rtx-time=0", + "location=" + proto + "://127.0.0.1:8554/teststream protocols=" + ca.publisherProto + " tls-validation-flags=0 latency=0 timeout=0 rtx-time=0", }) require.NoError(t, err) defer cnt1.close() @@ -305,10 +343,10 @@ func TestServerPublishReadTCP(t *testing.T) { time.Sleep(1 * time.Second) - switch ca.reader { + switch ca.readerSoft { case "ffmpeg": cnt2, err := newContainer("ffmpeg", "read", []string{ - "-rtsp_transport", "tcp", + "-rtsp_transport", ca.readerProto, "-i", proto + "://localhost:8554/teststream", "-vframes", "1", "-f", "image2", @@ -320,7 +358,7 @@ func TestServerPublishReadTCP(t *testing.T) { case "gstreamer": cnt2, err := newContainer("gstreamer", "read", []string{ - "rtspsrc location=" + proto + "://127.0.0.1:8554/teststream protocols=tcp tls-validation-flags=0 latency=0 " + + "rtspsrc location=" + proto + "://127.0.0.1:8554/teststream protocols=" + ca.readerProto + " tls-validation-flags=0 latency=0 " + "! application/x-rtp,media=video ! decodebin ! exitafterframe ! fakesink", }) require.NoError(t, err) @@ -399,6 +437,7 @@ func TestServerResponseBeforeFrames(t *testing.T) { v := headers.TransportModePlay return &v }(), + InterleavedIds: &[2]int{0, 1}, }.Write(), }, }.Write(bconn.Writer) diff --git a/serverconn.go b/serverconn.go index cd7d3d0c..abab3a49 100644 --- a/serverconn.go +++ b/serverconn.go @@ -2,9 +2,11 @@ package gortsplib import ( "bufio" + "crypto/tls" "errors" "fmt" "net" + "strconv" "strings" "sync" "time" @@ -21,48 +23,39 @@ const ( // server errors. var ( - ErrServerTeardown = errors.New("teardown") - ErrServerContentTypeMissing = errors.New("Content-Type header is missing") - ErrServerNoTracksDefined = errors.New("no tracks defined") - ErrServerMissingCseq = errors.New("CSeq is missing") - ErrServerFramesDisabled = errors.New("frames are disabled") + ErrServerTeardown = errors.New("teardown") ) -// ServerConn is a server-side RTSP connection. -type ServerConn struct { - s *Server - nconn net.Conn - br *bufio.Reader - bw *bufio.Writer - writeMutex sync.Mutex - nextFramesEnabled bool - framesEnabled bool - readTimeoutEnabled bool +type serverConnState int - // in - terminate chan struct{} +const ( + serverConnStateInitial serverConnState = iota + serverConnStatePlay + serverConnStateRecord +) + +type serverConnTrack struct { + proto StreamProtocol + rtpPort int + rtcpPort int } -// Close closes all the connection resources. -func (sc *ServerConn) Close() error { - err := sc.nconn.Close() - close(sc.terminate) - return err -} +func extractTrackID(controlPath string, mode *headers.TransportMode, trackLen int) (int, error) { + if mode == nil || *mode == headers.TransportModePlay { + if !strings.HasPrefix(controlPath, "trackID=") { + return 0, fmt.Errorf("invalid control attribute (%s)", controlPath) + } -// NetConn returns the underlying net.Conn. -func (sc *ServerConn) NetConn() net.Conn { - return sc.nconn -} + tmp, err := strconv.ParseInt(controlPath[len("trackID="):], 10, 64) + if err != nil || tmp < 0 { + return 0, fmt.Errorf("invalid track id (%s)", controlPath) + } + trackID := int(tmp) -// EnableFrames allows reading and writing TCP frames. -func (sc *ServerConn) EnableFrames(v bool) { - sc.nextFramesEnabled = v -} + return trackID, nil + } -// EnableReadTimeout sets or removes the timeout on incoming packets. -func (sc *ServerConn) EnableReadTimeout(v bool) { - sc.readTimeoutEnabled = v + return trackLen, nil } // ServerConnReadHandlers allows to set the handlers required by ServerConn.Read. @@ -108,164 +101,371 @@ type ServerConnReadHandlers struct { OnTeardown func(req *base.Request) (*base.Response, error) // called after receiving a Frame. - OnFrame func(trackID int, streamType StreamType, content []byte) + OnFrame func(trackID int, streamType StreamType, payload []byte) } -func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan error) { - handleRequest := func(req *base.Request) (*base.Response, error) { - if handlers.OnRequest != nil { - handlers.OnRequest(req) +// ServerConn is a server-side RTSP connection. +type ServerConn struct { + s *Server + nconn net.Conn + br *bufio.Reader + bw *bufio.Writer + state serverConnState + tracks map[int]serverConnTrack + tracksProto *StreamProtocol + writeMutex sync.Mutex + readHandlers ServerConnReadHandlers + nextFramesEnabled bool + framesEnabled bool + readTimeoutEnabled bool + + // in + terminate chan struct{} +} + +func newServerConn(s *Server, nconn net.Conn) *ServerConn { + conn := func() net.Conn { + if s.conf.TLSConfig != nil { + return tls.Server(nconn, s.conf.TLSConfig) + } + return nconn + }() + + return &ServerConn{ + s: s, + nconn: nconn, + br: bufio.NewReaderSize(conn, serverReadBufferSize), + bw: bufio.NewWriterSize(conn, serverWriteBufferSize), + tracks: make(map[int]serverConnTrack), + terminate: make(chan struct{}), + } +} + +// Close closes all the connection resources. +func (sc *ServerConn) Close() error { + err := sc.nconn.Close() + close(sc.terminate) + return err +} + +// NetConn returns the underlying net.Conn. +func (sc *ServerConn) NetConn() net.Conn { + return sc.nconn +} + +func (sc *ServerConn) ip() net.IP { + return sc.nconn.RemoteAddr().(*net.TCPAddr).IP +} + +func (sc *ServerConn) zone() string { + return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone +} + +func (sc *ServerConn) frameModeEnable() { + switch sc.state { + case serverConnStatePlay: + if *sc.tracksProto == StreamProtocolTCP { + sc.nextFramesEnabled = true } - switch req.Method { - case base.Options: - if handlers.OnOptions != nil { - return handlers.OnOptions(req) + case serverConnStateRecord: + if *sc.tracksProto == StreamProtocolTCP { + sc.nextFramesEnabled = true + sc.readTimeoutEnabled = true + + } else { + for trackID, track := range sc.tracks { + sc.s.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc) + sc.s.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc) + } + } + } +} + +func (sc *ServerConn) frameModeDisable() { + switch sc.state { + case serverConnStatePlay: + sc.nextFramesEnabled = false + + case serverConnStateRecord: + sc.nextFramesEnabled = false + sc.readTimeoutEnabled = false + + for _, track := range sc.tracks { + if track.proto == StreamProtocolUDP { + sc.s.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort) + sc.s.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort) + } + } + } +} + +func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { + if sc.readHandlers.OnRequest != nil { + sc.readHandlers.OnRequest(req) + } + + switch req.Method { + case base.Options: + if sc.readHandlers.OnOptions != nil { + return sc.readHandlers.OnOptions(req) + } + + var methods []string + if sc.readHandlers.OnDescribe != nil { + methods = append(methods, string(base.Describe)) + } + if sc.readHandlers.OnAnnounce != nil { + methods = append(methods, string(base.Announce)) + } + if sc.readHandlers.OnSetup != nil { + methods = append(methods, string(base.Setup)) + } + if sc.readHandlers.OnPlay != nil { + methods = append(methods, string(base.Play)) + } + if sc.readHandlers.OnRecord != nil { + methods = append(methods, string(base.Record)) + } + if sc.readHandlers.OnPause != nil { + methods = append(methods, string(base.Pause)) + } + methods = append(methods, string(base.GetParameter)) + if sc.readHandlers.OnSetParameter != nil { + methods = append(methods, string(base.SetParameter)) + } + methods = append(methods, string(base.Teardown)) + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join(methods, ", ")}, + }, + }, nil + + case base.Describe: + if sc.readHandlers.OnDescribe != nil { + return sc.readHandlers.OnDescribe(req) + } + + case base.Announce: + if sc.readHandlers.OnAnnounce != nil { + ct, ok := req.Header["Content-Type"] + if !ok || len(ct) != 1 { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, errors.New("Content-Type header is missing") } - var methods []string - if handlers.OnDescribe != nil { - methods = append(methods, string(base.Describe)) - } - if handlers.OnAnnounce != nil { - methods = append(methods, string(base.Announce)) - } - if handlers.OnSetup != nil { - methods = append(methods, string(base.Setup)) - } - if handlers.OnPlay != nil { - methods = append(methods, string(base.Play)) - } - if handlers.OnRecord != nil { - methods = append(methods, string(base.Record)) - } - if handlers.OnPause != nil { - methods = append(methods, string(base.Pause)) - } - methods = append(methods, string(base.GetParameter)) - if handlers.OnSetParameter != nil { - methods = append(methods, string(base.SetParameter)) - } - methods = append(methods, string(base.Teardown)) - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join(methods, ", ")}, - }, - }, nil - - case base.Describe: - if handlers.OnDescribe != nil { - return handlers.OnDescribe(req) + if ct[0] != "application/sdp" { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unsupported Content-Type '%s'", ct) } - case base.Announce: - if handlers.OnAnnounce != nil { - ct, ok := req.Header["Content-Type"] - if !ok || len(ct) != 1 { + tracks, err := ReadTracks(req.Content) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid SDP: %s", err) + } + + if len(tracks) == 0 { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, errors.New("no tracks defined") + } + + res, err := sc.readHandlers.OnAnnounce(req, tracks) + return res, err + } + + case base.Setup: + if sc.readHandlers.OnSetup != nil { + _, controlPath, ok := req.URL.BasePathControlAttr() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unable to find control attribute (%s)", req.URL) + } + + th, err := headers.ReadTransport(req.Header["Transport"]) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("transport header: %s", err) + } + + trackID, err := extractTrackID(controlPath, th.Mode, len(sc.tracks)) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } + + if _, ok := sc.tracks[trackID]; ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("track %d has already been setup", trackID) + } + + if sc.tracksProto != nil && *sc.tracksProto != th.Protocol { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("can't receive tracks with different protocols") + } + + if th.Protocol == StreamProtocolUDP { + if sc.s.conf.UDPRTPListener == nil { return &base.Response{ - StatusCode: base.StatusBadRequest, - }, ErrServerContentTypeMissing + StatusCode: base.StatusUnsupportedTransport, + }, nil } - if ct[0] != "application/sdp" { + if th.ClientPorts == nil { return &base.Response{ StatusCode: base.StatusBadRequest, - }, fmt.Errorf("unsupported Content-Type '%s'", ct) + }, fmt.Errorf("transport header does not have valid client ports (%v)", req.Header["Transport"]) } - tracks, err := ReadTracks(req.Content) - if err != nil { + } else { + if th.InterleavedIds == nil { return &base.Response{ StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid SDP: %s", err) + }, fmt.Errorf("transport header does not contain the interleaved field") } - if len(tracks) == 0 { + if (*th.InterleavedIds)[0] != (trackID*2) || + (*th.InterleavedIds)[1] != (1+trackID*2) { return &base.Response{ - StatusCode: base.StatusBadRequest, - }, ErrServerNoTracksDefined + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("wrong interleaved ids, expected [%v %v], got %v", + (trackID * 2), (1 + trackID*2), *th.InterleavedIds) } - - return handlers.OnAnnounce(req, tracks) } - case base.Setup: - if handlers.OnSetup != nil { - th, err := headers.ReadTransport(req.Header["Transport"]) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("transport header: %s", err) - } + res, err := sc.readHandlers.OnSetup(req, th) - // workaround to prevent a bug in rtspclientsink - // that makes impossible for the client to receive the response - // and send frames. - // this was causing problems during unit tests. - if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 && - strings.HasPrefix(ua[0], "GStreamer") { - t := time.NewTimer(1 * time.Second) - defer t.Stop() - select { - case <-t.C: - case <-sc.terminate: + if res.StatusCode == 200 { + sc.tracksProto = &th.Protocol + + if th.Protocol == StreamProtocolUDP { + res.Header["Transport"] = headers.Transport{ + Protocol: StreamProtocolUDP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + ClientPorts: th.ClientPorts, + ServerPorts: &[2]int{sc.s.conf.UDPRTPListener.port(), sc.s.conf.UDPRTCPListener.port()}, + }.Write() + + sc.tracks[trackID] = serverConnTrack{ + proto: StreamProtocolUDP, + rtpPort: th.ClientPorts[0], + rtcpPort: th.ClientPorts[1], + } + + } else { + res.Header["Transport"] = headers.Transport{ + Protocol: StreamProtocolTCP, + InterleavedIds: th.InterleavedIds, + }.Write() + + sc.tracks[trackID] = serverConnTrack{ + proto: StreamProtocolTCP, } } - - return handlers.OnSetup(req, th) } - case base.Play: - if handlers.OnPlay != nil { - return handlers.OnPlay(req) + // workaround to prevent a bug in rtspclientsink + // that makes impossible for the client to receive the response + // and send frames. + // this was causing problems during unit tests. + if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 && + strings.HasPrefix(ua[0], "GStreamer") { + t := time.NewTimer(1 * time.Second) + defer t.Stop() + select { + case <-t.C: + case <-sc.terminate: + } } - case base.Record: - if handlers.OnRecord != nil { - return handlers.OnRecord(req) + return res, err + } + + case base.Play: + if sc.readHandlers.OnPlay != nil { + res, err := sc.readHandlers.OnPlay(req) + + if res.StatusCode == 200 { + sc.state = serverConnStatePlay + sc.frameModeEnable() } - case base.Pause: - if handlers.OnPause != nil { - return handlers.OnPause(req) + return res, err + } + + case base.Record: + if sc.readHandlers.OnRecord != nil { + res, err := sc.readHandlers.OnRecord(req) + + if res.StatusCode == 200 { + sc.state = serverConnStateRecord + sc.frameModeEnable() } - case base.GetParameter: - if handlers.OnGetParameter != nil { - return handlers.OnGetParameter(req) + return res, err + } + + case base.Pause: + if sc.readHandlers.OnPause != nil { + res, err := sc.readHandlers.OnPause(req) + + if res.StatusCode == 200 { + sc.frameModeDisable() + sc.state = serverConnStateInitial } - // GET_PARAMETER is used like a ping - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Content-Type": base.HeaderValue{"text/parameters"}, - }, - Content: []byte("\n"), - }, nil + return res, err + } - case base.SetParameter: - if handlers.OnSetParameter != nil { - return handlers.OnSetParameter(req) - } + case base.GetParameter: + if sc.readHandlers.OnGetParameter != nil { + return sc.readHandlers.OnGetParameter(req) + } - case base.Teardown: - if handlers.OnTeardown != nil { - return handlers.OnTeardown(req) - } + // GET_PARAMETER is used like a ping + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"text/parameters"}, + }, + Content: []byte("\n"), + }, nil - return &base.Response{ - StatusCode: base.StatusOK, - }, ErrServerTeardown + case base.SetParameter: + if sc.readHandlers.OnSetParameter != nil { + return sc.readHandlers.OnSetParameter(req) + } + + case base.Teardown: + if sc.readHandlers.OnTeardown != nil { + return sc.readHandlers.OnTeardown(req) } return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("unhandled method: %v", req.Method) + StatusCode: base.StatusOK, + }, ErrServerTeardown } + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unhandled method: %v", req.Method) +} + +func (sc *ServerConn) backgroundRead() error { handleRequestOuter := func(req *base.Request) error { // check cseq cseq, ok := req.Header["CSeq"] @@ -277,10 +477,10 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan Header: base.Header{}, }.Write(sc.bw) sc.writeMutex.Unlock() - return ErrServerMissingCseq + return errors.New("CSeq is missing") } - res, err := handleRequest(req) + res, err := sc.handleRequest(req) if res.Header == nil { res.Header = base.Header{} @@ -292,8 +492,8 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan // add server res.Header["Server"] = base.HeaderValue{"gortsplib"} - if handlers.OnResponse != nil { - handlers.OnResponse(res) + if sc.readHandlers.OnResponse != nil { + sc.readHandlers.OnResponse(res) } sc.writeMutex.Lock() @@ -302,7 +502,7 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan res.Write(sc.bw) // set framesEnabled after sending the response - // in order to start sending frames after the response + // in order to start sending frames after the response, never before if sc.framesEnabled != sc.nextFramesEnabled { sc.framesEnabled = sc.nextFramesEnabled } @@ -335,7 +535,7 @@ outer: switch what.(type) { case *base.InterleavedFrame: - handlers.OnFrame(frame.TrackID, frame.StreamType, frame.Content) + sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Content) case *base.Request: err := handleRequestOuter(&req) @@ -360,34 +560,60 @@ outer: } } - done <- errRet + sc.frameModeDisable() + + return errRet } // Read starts reading requests and frames. // it returns a channel that is written when the reading stops. -func (sc *ServerConn) Read(handlers ServerConnReadHandlers) chan error { +func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error { // channel is buffered, since listening to it is not mandatory done := make(chan error, 1) - go sc.backgroundRead(handlers, done) + sc.readHandlers = readHandlers + + go func() { + done <- sc.backgroundRead() + }() return done } // WriteFrame writes a frame. -func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, content []byte) error { +func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error { sc.writeMutex.Lock() defer sc.writeMutex.Unlock() + track := sc.tracks[trackID] + + if track.proto == StreamProtocolUDP { + if streamType == StreamTypeRtp { + return sc.s.conf.UDPRTPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{ + IP: sc.ip(), + Zone: sc.zone(), + Port: track.rtpPort, + }) + } + + return sc.s.conf.UDPRTCPListener.write(sc.s.conf.WriteTimeout, payload, &net.UDPAddr{ + IP: sc.ip(), + Zone: sc.zone(), + Port: track.rtcpPort, + }) + } + + // StreamProtocolTCP + if !sc.framesEnabled { - return ErrServerFramesDisabled + return errors.New("frames are disabled") } sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) frame := base.InterleavedFrame{ TrackID: trackID, StreamType: streamType, - Content: content, + Content: payload, } return frame.Write(sc.bw) } diff --git a/serverudpl.go b/serverudpl.go new file mode 100644 index 00000000..69e7c445 --- /dev/null +++ b/serverudpl.go @@ -0,0 +1,145 @@ +package gortsplib + +import ( + "net" + "sync" + "time" + + "github.com/aler9/gortsplib/pkg/multibuffer" +) + +const ( + // use the same buffer size as gstreamer's rtspsrc + kernelReadBufferSize = 0x80000 + + readBufferSize = 2048 +) + +type publisherData struct { + publisher *ServerConn + trackID int +} + +type publisherAddr struct { + ip [net.IPv6len]byte // use a fixed-size array to enable the equality operator + port int +} + +func (p *publisherAddr) fill(ip net.IP, port int) { + p.port = port + + if len(ip) == net.IPv4len { + copy(p.ip[0:], []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}) // v4InV6Prefix + copy(p.ip[12:], ip) + } else { + copy(p.ip[:], ip) + } +} + +// ServerUDPListener is a UDP server that can be used to send and receive RTP and RTCP packets. +type ServerUDPListener struct { + streamType StreamType + + pc *net.UDPConn + readBuf *multibuffer.MultiBuffer + publishersMutex sync.RWMutex + publishers map[publisherAddr]*publisherData + writeMutex sync.Mutex + + // out + done chan struct{} +} + +// NewServerUDPListener allocates a ServerUDPListener. +func NewServerUDPListener(address string) (*ServerUDPListener, error) { + tmp, err := net.ListenPacket("udp", address) + if err != nil { + return nil, err + } + pc := tmp.(*net.UDPConn) + + err = pc.SetReadBuffer(kernelReadBufferSize) + if err != nil { + return nil, err + } + + s := &ServerUDPListener{ + pc: pc, + readBuf: multibuffer.New(1, readBufferSize), + publishers: make(map[publisherAddr]*publisherData), + done: make(chan struct{}), + } + + go s.run() + + return s, nil +} + +// Close closes the listener. +func (s *ServerUDPListener) Close() { + s.pc.Close() + <-s.done +} + +func (s *ServerUDPListener) run() { + defer close(s.done) + + for { + buf := s.readBuf.Next() + n, addr, err := s.pc.ReadFromUDP(buf) + if err != nil { + break + } + + func() { + s.publishersMutex.RLock() + defer s.publishersMutex.RUnlock() + + // find publisher data + var pubAddr publisherAddr + pubAddr.fill(addr.IP, addr.Port) + pubData, ok := s.publishers[pubAddr] + if !ok { + return + } + + pubData.publisher.readHandlers.OnFrame(pubData.trackID, s.streamType, buf[:n]) + }() + } +} + +func (s *ServerUDPListener) port() int { + return s.pc.LocalAddr().(*net.UDPAddr).Port +} + +func (s *ServerUDPListener) write(writeTimeout time.Duration, buf []byte, addr *net.UDPAddr) error { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + s.pc.SetWriteDeadline(time.Now().Add(writeTimeout)) + _, err := s.pc.WriteTo(buf, addr) + return err +} + +func (s *ServerUDPListener) addPublisher(ip net.IP, port int, trackID int, sc *ServerConn) { + s.publishersMutex.Lock() + defer s.publishersMutex.Unlock() + + var addr publisherAddr + addr.fill(ip, port) + + s.publishers[addr] = &publisherData{ + publisher: sc, + trackID: trackID, + } +} + +func (s *ServerUDPListener) removePublisher(ip net.IP, port int) { + s.publishersMutex.Lock() + defer s.publishersMutex.Unlock() + + var addr publisherAddr + addr.fill(ip, port) + + delete(s.publishers, addr) +}