From 2a1af5a409acaf6f2f787f14a8d622614c66a0a5 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 13 Dec 2020 12:33:09 +0100 Subject: [PATCH] rewrite ServerConn read handler --- examples/server.go | 263 ++++++++++++++++++-------------------------- pkg/base/request.go | 2 - serverconf_test.go | 237 +++++++++++++++------------------------ serverconn.go | 206 +++++++++++++++++++++++++++++++--- 4 files changed, 383 insertions(+), 325 deletions(-) diff --git a/examples/server.go b/examples/server.go index 8adcb268..22afc197 100644 --- a/examples/server.go +++ b/examples/server.go @@ -5,7 +5,6 @@ package main import ( "fmt" "log" - "strings" "sync" "github.com/aler9/gortsplib" @@ -29,170 +28,109 @@ func handleConn(conn *gortsplib.ServerConn) { log.Printf("client connected") - // this is called when a request arrives - onRequest := func(req *base.Request) (*base.Response, error) { - switch req.Method { - // the Options method must return all available methods - case base.Options: + // called after receiving a DESCRIBE request. + onDescribe := func(req *base.Request) (*base.Response, error) { + mutex.Lock() + defer mutex.Unlock() + + // no one is publishing yet + if publisher == nil { return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Describe), - string(base.Announce), - string(base.Setup), - string(base.Play), - string(base.Record), - string(base.Teardown), - }, ", ")}, - }, + StatusCode: base.StatusNotFound, }, nil - - // the Describe method must return the SDP of the stream - case base.Describe: - mutex.Lock() - defer mutex.Unlock() - - // no one is publishing yet - if publisher == nil { - return &base.Response{ - StatusCode: base.StatusNotFound, - }, nil - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Content-Base": base.HeaderValue{req.URL.String() + "/"}, - "Content-Type": base.HeaderValue{"application/sdp"}, - }, - Content: sdp, - }, nil - - // the Announce method is called by publishers - case base.Announce: - ct, ok := req.Header["Content-Type"] - if !ok || len(ct) != 1 { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("Content-Type header missing") - } - - if ct[0] != "application/sdp" { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("unsupported Content-Type '%s'", ct) - } - - tracks, err := gortsplib.ReadTracks(req.Content) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid SDP: %s", err) - } - - if len(tracks) == 0 { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("no tracks defined") - } - - mutex.Lock() - defer mutex.Unlock() - - if publisher != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - publisher = conn - sdp = tracks.Write() - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - - // The Setup method is called - // * by publishers, after Announce - // * by readers - case base.Setup: - th, err := headers.ReadTransport(req.Header["Transport"]) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("transport header: %s", err) - } - - // support TCP only - if th.Protocol == gortsplib.StreamProtocolUDP { - return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": req.Header["Transport"], - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - - // The Play method is called by readers, after Setup - case base.Play: - mutex.Lock() - defer mutex.Unlock() - - readers[conn] = struct{}{} - - conn.EnableReadFrames(true) - conn.EnableReadTimeout(false) - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - - // The Record method is called by publishers, after Announce and Setup - case base.Record: - mutex.Lock() - defer mutex.Unlock() - - if conn != publisher { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - conn.EnableReadFrames(true) - conn.EnableReadTimeout(true) - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - - // The Teardown method is called to close a session - case base.Teardown: - return &base.Response{ - StatusCode: base.StatusOK, - }, fmt.Errorf("terminated") } return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("unhandled method: %v", req.Method) + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Base": base.HeaderValue{req.URL.String() + "/"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Content: sdp, + }, nil } - // this is called when a frame arrives + // called after receiving an ANNOUNCE request. + onAnnounce := func(req *base.Request, tracks gortsplib.Tracks) (*base.Response, error) { + mutex.Lock() + defer mutex.Unlock() + + if publisher != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + publisher = conn + sdp = tracks.Write() + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + // called after receiving a SETUP request. + onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) { + // support TCP only + if th.Protocol == gortsplib.StreamProtocolUDP { + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil + } + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": req.Header["Transport"], + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + // called after receiving a PLAY request. + onPlay := func(req *base.Request) (*base.Response, error) { + mutex.Lock() + defer mutex.Unlock() + + readers[conn] = struct{}{} + + conn.EnableReadFrames(true) + conn.EnableReadTimeout(false) + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + // called after receiving a RECORD request. + onRecord := func(req *base.Request) (*base.Response, error) { + mutex.Lock() + defer mutex.Unlock() + + if conn != publisher { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + conn.EnableReadFrames(true) + conn.EnableReadTimeout(true) + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + // called after receiving a Frame. onFrame := func(trackID int, typ gortsplib.StreamType, buf []byte) { mutex.Lock() defer mutex.Unlock() @@ -205,7 +143,14 @@ func handleConn(conn *gortsplib.ServerConn) { } } - err := <-conn.Read(onRequest, onFrame) + err := <-conn.Read(gortsplib.ServerConnReadHandlers{ + OnDescribe: onDescribe, + OnAnnounce: onAnnounce, + OnSetup: onSetup, + OnPlay: onPlay, + OnRecord: onRecord, + OnFrame: onFrame, + }) log.Printf("client disconnected (%s)", err) mutex.Lock() diff --git a/pkg/base/request.go b/pkg/base/request.go index 73847bf9..32085638 100644 --- a/pkg/base/request.go +++ b/pkg/base/request.go @@ -26,9 +26,7 @@ const ( Options Method = "OPTIONS" Pause Method = "PAUSE" Play Method = "PLAY" - PlayNotify Method = "PLAY_NOTIFY" Record Method = "RECORD" - Redirect Method = "REDIRECT" Setup Method = "SETUP" SetParameter Method = "SET_PARAMETER" Teardown Method = "TEARDOWN" diff --git a/serverconf_test.go b/serverconf_test.go index 5adbf2b8..2ef2097a 100644 --- a/serverconf_test.go +++ b/serverconf_test.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net" - "strings" "sync" "testing" "time" @@ -65,156 +64,93 @@ func (ts *testServ) handleConn(conn *ServerConn) { defer ts.wg.Done() defer conn.Close() - // this is called when a request arrives - onRequest := func(req *base.Request) (*base.Response, error) { - switch req.Method { - case base.Options: + onDescribe := func(req *base.Request) (*base.Response, error) { + ts.mutex.Lock() + defer ts.mutex.Unlock() + + if ts.publisher == nil { return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Describe), - string(base.Announce), - string(base.Setup), - string(base.Play), - string(base.Record), - string(base.Teardown), - }, ", ")}, - }, + StatusCode: base.StatusNotFound, }, nil - - case base.Describe: - ts.mutex.Lock() - defer ts.mutex.Unlock() - - if ts.publisher == nil { - return &base.Response{ - StatusCode: base.StatusNotFound, - }, nil - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Content-Base": base.HeaderValue{req.URL.String() + "/"}, - "Content-Type": base.HeaderValue{"application/sdp"}, - }, - Content: ts.sdp, - }, nil - - case base.Announce: - ct, ok := req.Header["Content-Type"] - if !ok || len(ct) != 1 { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("Content-Type header missing") - } - - if ct[0] != "application/sdp" { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("unsupported Content-Type '%s'", ct) - } - - tracks, err := ReadTracks(req.Content) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid SDP: %s", err) - } - - if len(tracks) == 0 { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("no tracks defined") - } - - ts.mutex.Lock() - defer ts.mutex.Unlock() - - if ts.publisher != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - ts.publisher = conn - ts.sdp = tracks.Write() - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - - case base.Setup: - th, err := headers.ReadTransport(req.Header["Transport"]) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("transport header: %s", err) - } - - if th.Protocol == StreamProtocolUDP { - return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": req.Header["Transport"], - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - - case base.Play: - ts.mutex.Lock() - defer ts.mutex.Unlock() - - ts.readers[conn] = struct{}{} - - conn.EnableReadFrames(true) - conn.EnableReadTimeout(false) - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - - case base.Record: - ts.mutex.Lock() - defer ts.mutex.Unlock() - - if conn != ts.publisher { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - conn.EnableReadFrames(true) - conn.EnableReadTimeout(true) - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - - case base.Teardown: - return &base.Response{ - StatusCode: base.StatusOK, - }, fmt.Errorf("terminated") } return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("unhandled method: %v", req.Method) + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Base": base.HeaderValue{req.URL.String() + "/"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Content: ts.sdp, + }, nil + } + + onAnnounce := func(req *base.Request, tracks Tracks) (*base.Response, error) { + ts.mutex.Lock() + defer ts.mutex.Unlock() + + if ts.publisher != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + ts.publisher = conn + ts.sdp = tracks.Write() + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": req.Header["Transport"], + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + onPlay := func(req *base.Request) (*base.Response, error) { + ts.mutex.Lock() + defer ts.mutex.Unlock() + + ts.readers[conn] = struct{}{} + + conn.EnableReadFrames(true) + conn.EnableReadTimeout(false) + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + } + + onRecord := func(req *base.Request) (*base.Response, error) { + ts.mutex.Lock() + defer ts.mutex.Unlock() + + if conn != ts.publisher { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + conn.EnableReadFrames(true) + conn.EnableReadTimeout(true) + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil } onFrame := func(trackID int, typ StreamType, buf []byte) { @@ -228,7 +164,14 @@ func (ts *testServ) handleConn(conn *ServerConn) { } } - <-conn.Read(onRequest, onFrame) + <-conn.Read(ServerConnReadHandlers{ + OnDescribe: onDescribe, + OnAnnounce: onAnnounce, + OnSetup: onSetup, + OnPlay: onPlay, + OnRecord: onRecord, + OnFrame: onFrame, + }) ts.mutex.Lock() defer ts.mutex.Unlock() diff --git a/serverconn.go b/serverconn.go index b903b58f..331f3996 100644 --- a/serverconn.go +++ b/serverconn.go @@ -2,12 +2,15 @@ package gortsplib import ( "bufio" + "errors" "fmt" "net" + "strings" "sync" "time" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/multibuffer" ) @@ -16,6 +19,14 @@ const ( serverWriteBufferSize = 4096 ) +// server errors. +var ( + ErrServerTeardown = errors.New("teardown") + ErrServerContentTypeMissing = errors.New("Content-Type header is missing") + ErrServerNoTracksDefined = errors.New("no tracks defined") + ErrServerMissingCseq = errors.New("CSeq is missing") +) + // ServerConn is a server-side RTSP connection. type ServerConn struct { s *Server @@ -47,12 +58,172 @@ func (sc *ServerConn) EnableReadTimeout(v bool) { sc.readTimeout = v } -func (sc *ServerConn) backgroundRead( - onRequest func(req *base.Request) (*base.Response, error), - onFrame func(trackID int, streamType StreamType, content []byte), - done chan error, -) { - handleRequest := func(req *base.Request) error { +// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read. +type ServerConnReadHandlers struct { + // called after receiving a OPTIONS request. + // if nil, it is generated automatically. + OnOptions func(req *base.Request) (*base.Response, error) + + // called after receiving a DESCRIBE request. + OnDescribe func(req *base.Request) (*base.Response, error) + + // called after receiving an ANNOUNCE request. + OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error) + + // called after receiving a SETUP request. + OnSetup func(req *base.Request, th *headers.Transport) (*base.Response, error) + + // called after receiving a PLAY request. + OnPlay func(req *base.Request) (*base.Response, error) + + // called after receiving a RECORD request. + OnRecord func(req *base.Request) (*base.Response, error) + + // called after receiving a GET_PARAMETER request. + // if nil, it is generated automatically. + OnGetParameter func(req *base.Request) (*base.Response, error) + + // called after receiving a SET_PARAMETER request. + OnSetParameter func(req *base.Request) (*base.Response, error) + + // called after receiving a TEARDOWN request. + // if nil, it is generated automatically. + OnTeardown func(req *base.Request) (*base.Response, error) + + // called after receiving a Frame. + OnFrame func(trackID int, streamType StreamType, content []byte) +} + +func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan error) { + handleRequest := func(req *base.Request) (*base.Response, error) { + switch req.Method { + case base.Options: + if handlers.OnOptions != nil { + return handlers.OnOptions(req) + } + + var methods []string + if handlers.OnDescribe != nil { + methods = append(methods, string(base.Describe)) + } + if handlers.OnAnnounce != nil { + methods = append(methods, string(base.Announce)) + } + if handlers.OnSetup != nil { + methods = append(methods, string(base.Setup)) + } + if handlers.OnPlay != nil { + methods = append(methods, string(base.Play)) + } + if handlers.OnRecord != nil { + methods = append(methods, string(base.Record)) + } + methods = append(methods, string(base.GetParameter)) + if handlers.OnSetParameter != nil { + methods = append(methods, string(base.SetParameter)) + } + methods = append(methods, string(base.Teardown)) + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join(methods, ", ")}, + }, + }, nil + + case base.Describe: + if handlers.OnDescribe != nil { + return handlers.OnDescribe(req) + } + + case base.Announce: + if handlers.OnAnnounce != nil { + ct, ok := req.Header["Content-Type"] + if !ok || len(ct) != 1 { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, ErrServerContentTypeMissing + } + + if ct[0] != "application/sdp" { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unsupported Content-Type '%s'", ct) + } + + tracks, err := ReadTracks(req.Content) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid SDP: %s", err) + } + + if len(tracks) == 0 { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, ErrServerNoTracksDefined + } + + return handlers.OnAnnounce(req, tracks) + } + + case base.Setup: + if handlers.OnSetup != nil { + th, err := headers.ReadTransport(req.Header["Transport"]) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("transport header: %s", err) + } + + return handlers.OnSetup(req, th) + } + + case base.Play: + if handlers.OnPlay != nil { + return handlers.OnPlay(req) + } + + case base.Record: + if handlers.OnRecord != nil { + return handlers.OnRecord(req) + } + + case base.GetParameter: + if handlers.OnGetParameter != nil { + return handlers.OnGetParameter(req) + } + + // GET_PARAMETER is used like a ping + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"text/parameters"}, + }, + Content: []byte("\n"), + }, nil + + case base.SetParameter: + if handlers.OnSetParameter != nil { + return handlers.OnSetParameter(req) + } + + case base.Teardown: + if handlers.OnTeardown != nil { + return handlers.OnTeardown(req) + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, ErrServerTeardown + } + + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unhandled method: %v", req.Method) + } + + handleRequestOuter := func(req *base.Request) error { sc.mutex.Lock() defer sc.mutex.Unlock() @@ -64,17 +235,21 @@ func (sc *ServerConn) backgroundRead( StatusCode: base.StatusBadRequest, Header: base.Header{}, }.Write(sc.bw) - return fmt.Errorf("cseq is missing") + return ErrServerMissingCseq } - res, err := onRequest(req) + res, err := handleRequest(req) - // add cseq to response if res.Header == nil { res.Header = base.Header{} } + + // add cseq res.Header["CSeq"] = cseq + // add server + res.Header["Server"] = base.HeaderValue{"gortsplib"} + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) res.Write(sc.bw) @@ -104,10 +279,10 @@ outer: switch what.(type) { case *base.InterleavedFrame: - onFrame(frame.TrackID, frame.StreamType, frame.Content) + handlers.OnFrame(frame.TrackID, frame.StreamType, frame.Content) case *base.Request: - err := handleRequest(&req) + err := handleRequestOuter(&req) if err != nil { errRet = err break outer @@ -121,7 +296,7 @@ outer: break outer } - err = handleRequest(&req) + err = handleRequestOuter(&req) if err != nil { errRet = err break outer @@ -134,14 +309,11 @@ outer: // Read starts reading requests and frames. // it returns a channel that is written when the reading stops. -func (sc *ServerConn) Read( - onRequest func(req *base.Request) (*base.Response, error), - onFrame func(trackID int, streamType StreamType, content []byte), -) chan error { +func (sc *ServerConn) Read(handlers ServerConnReadHandlers) chan error { // channel is buffered, since listening to it is not mandatory done := make(chan error, 1) - go sc.backgroundRead(onRequest, onFrame, done) + go sc.backgroundRead(handlers, done) return done }