rewrite ServerConn read handler

This commit is contained in:
aler9
2020-12-13 12:33:09 +01:00
parent 4c942d33fe
commit 2a1af5a409
4 changed files with 383 additions and 325 deletions

View File

@@ -5,7 +5,6 @@ package main
import ( import (
"fmt" "fmt"
"log" "log"
"strings"
"sync" "sync"
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
@@ -29,170 +28,109 @@ func handleConn(conn *gortsplib.ServerConn) {
log.Printf("client connected") log.Printf("client connected")
// this is called when a request arrives // called after receiving a DESCRIBE request.
onRequest := func(req *base.Request) (*base.Response, error) { onDescribe := func(req *base.Request) (*base.Response, error) {
switch req.Method { mutex.Lock()
// the Options method must return all available methods defer mutex.Unlock()
case base.Options:
// no one is publishing yet
if publisher == nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusNotFound,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
string(base.Announce),
string(base.Setup),
string(base.Play),
string(base.Record),
string(base.Teardown),
}, ", ")},
},
}, nil }, nil
// the Describe method must return the SDP of the stream
case base.Describe:
mutex.Lock()
defer mutex.Unlock()
// no one is publishing yet
if publisher == nil {
return &base.Response{
StatusCode: base.StatusNotFound,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Base": base.HeaderValue{req.URL.String() + "/"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Content: sdp,
}, nil
// the Announce method is called by publishers
case base.Announce:
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("Content-Type header missing")
}
if ct[0] != "application/sdp" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unsupported Content-Type '%s'", ct)
}
tracks, err := gortsplib.ReadTracks(req.Content)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid SDP: %s", err)
}
if len(tracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("no tracks defined")
}
mutex.Lock()
defer mutex.Unlock()
if publisher != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
publisher = conn
sdp = tracks.Write()
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
// The Setup method is called
// * by publishers, after Announce
// * by readers
case base.Setup:
th, err := headers.ReadTransport(req.Header["Transport"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header: %s", err)
}
// support TCP only
if th.Protocol == gortsplib.StreamProtocolUDP {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": req.Header["Transport"],
"Session": base.HeaderValue{"12345678"},
},
}, nil
// The Play method is called by readers, after Setup
case base.Play:
mutex.Lock()
defer mutex.Unlock()
readers[conn] = struct{}{}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(false)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
// The Record method is called by publishers, after Announce and Setup
case base.Record:
mutex.Lock()
defer mutex.Unlock()
if conn != publisher {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
// The Teardown method is called to close a session
case base.Teardown:
return &base.Response{
StatusCode: base.StatusOK,
}, fmt.Errorf("terminated")
} }
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusOK,
}, fmt.Errorf("unhandled method: %v", req.Method) Header: base.Header{
"Content-Base": base.HeaderValue{req.URL.String() + "/"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Content: sdp,
}, nil
} }
// this is called when a frame arrives // called after receiving an ANNOUNCE request.
onAnnounce := func(req *base.Request, tracks gortsplib.Tracks) (*base.Response, error) {
mutex.Lock()
defer mutex.Unlock()
if publisher != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
publisher = conn
sdp = tracks.Write()
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
// called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) {
// support TCP only
if th.Protocol == gortsplib.StreamProtocolUDP {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": req.Header["Transport"],
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
// called after receiving a PLAY request.
onPlay := func(req *base.Request) (*base.Response, error) {
mutex.Lock()
defer mutex.Unlock()
readers[conn] = struct{}{}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(false)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
// called after receiving a RECORD request.
onRecord := func(req *base.Request) (*base.Response, error) {
mutex.Lock()
defer mutex.Unlock()
if conn != publisher {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
// called after receiving a Frame.
onFrame := func(trackID int, typ gortsplib.StreamType, buf []byte) { onFrame := func(trackID int, typ gortsplib.StreamType, buf []byte) {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
@@ -205,7 +143,14 @@ func handleConn(conn *gortsplib.ServerConn) {
} }
} }
err := <-conn.Read(onRequest, onFrame) err := <-conn.Read(gortsplib.ServerConnReadHandlers{
OnDescribe: onDescribe,
OnAnnounce: onAnnounce,
OnSetup: onSetup,
OnPlay: onPlay,
OnRecord: onRecord,
OnFrame: onFrame,
})
log.Printf("client disconnected (%s)", err) log.Printf("client disconnected (%s)", err)
mutex.Lock() mutex.Lock()

View File

@@ -26,9 +26,7 @@ const (
Options Method = "OPTIONS" Options Method = "OPTIONS"
Pause Method = "PAUSE" Pause Method = "PAUSE"
Play Method = "PLAY" Play Method = "PLAY"
PlayNotify Method = "PLAY_NOTIFY"
Record Method = "RECORD" Record Method = "RECORD"
Redirect Method = "REDIRECT"
Setup Method = "SETUP" Setup Method = "SETUP"
SetParameter Method = "SET_PARAMETER" SetParameter Method = "SET_PARAMETER"
Teardown Method = "TEARDOWN" Teardown Method = "TEARDOWN"

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -65,156 +64,93 @@ func (ts *testServ) handleConn(conn *ServerConn) {
defer ts.wg.Done() defer ts.wg.Done()
defer conn.Close() defer conn.Close()
// this is called when a request arrives onDescribe := func(req *base.Request) (*base.Response, error) {
onRequest := func(req *base.Request) (*base.Response, error) { ts.mutex.Lock()
switch req.Method { defer ts.mutex.Unlock()
case base.Options:
if ts.publisher == nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusNotFound,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
string(base.Announce),
string(base.Setup),
string(base.Play),
string(base.Record),
string(base.Teardown),
}, ", ")},
},
}, nil }, nil
case base.Describe:
ts.mutex.Lock()
defer ts.mutex.Unlock()
if ts.publisher == nil {
return &base.Response{
StatusCode: base.StatusNotFound,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Base": base.HeaderValue{req.URL.String() + "/"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Content: ts.sdp,
}, nil
case base.Announce:
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("Content-Type header missing")
}
if ct[0] != "application/sdp" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unsupported Content-Type '%s'", ct)
}
tracks, err := ReadTracks(req.Content)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid SDP: %s", err)
}
if len(tracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("no tracks defined")
}
ts.mutex.Lock()
defer ts.mutex.Unlock()
if ts.publisher != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
ts.publisher = conn
ts.sdp = tracks.Write()
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
case base.Setup:
th, err := headers.ReadTransport(req.Header["Transport"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header: %s", err)
}
if th.Protocol == StreamProtocolUDP {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": req.Header["Transport"],
"Session": base.HeaderValue{"12345678"},
},
}, nil
case base.Play:
ts.mutex.Lock()
defer ts.mutex.Unlock()
ts.readers[conn] = struct{}{}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(false)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
case base.Record:
ts.mutex.Lock()
defer ts.mutex.Unlock()
if conn != ts.publisher {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
case base.Teardown:
return &base.Response{
StatusCode: base.StatusOK,
}, fmt.Errorf("terminated")
} }
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusOK,
}, fmt.Errorf("unhandled method: %v", req.Method) Header: base.Header{
"Content-Base": base.HeaderValue{req.URL.String() + "/"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Content: ts.sdp,
}, nil
}
onAnnounce := func(req *base.Request, tracks Tracks) (*base.Response, error) {
ts.mutex.Lock()
defer ts.mutex.Unlock()
if ts.publisher != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
ts.publisher = conn
ts.sdp = tracks.Write()
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
onSetup := func(req *base.Request, th *headers.Transport) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": req.Header["Transport"],
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
onPlay := func(req *base.Request) (*base.Response, error) {
ts.mutex.Lock()
defer ts.mutex.Unlock()
ts.readers[conn] = struct{}{}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(false)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
}
onRecord := func(req *base.Request) (*base.Response, error) {
ts.mutex.Lock()
defer ts.mutex.Unlock()
if conn != ts.publisher {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
} }
onFrame := func(trackID int, typ StreamType, buf []byte) { onFrame := func(trackID int, typ StreamType, buf []byte) {
@@ -228,7 +164,14 @@ func (ts *testServ) handleConn(conn *ServerConn) {
} }
} }
<-conn.Read(onRequest, onFrame) <-conn.Read(ServerConnReadHandlers{
OnDescribe: onDescribe,
OnAnnounce: onAnnounce,
OnSetup: onSetup,
OnPlay: onPlay,
OnRecord: onRecord,
OnFrame: onFrame,
})
ts.mutex.Lock() ts.mutex.Lock()
defer ts.mutex.Unlock() defer ts.mutex.Unlock()

View File

@@ -2,12 +2,15 @@ package gortsplib
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/multibuffer" "github.com/aler9/gortsplib/pkg/multibuffer"
) )
@@ -16,6 +19,14 @@ const (
serverWriteBufferSize = 4096 serverWriteBufferSize = 4096
) )
// server errors.
var (
ErrServerTeardown = errors.New("teardown")
ErrServerContentTypeMissing = errors.New("Content-Type header is missing")
ErrServerNoTracksDefined = errors.New("no tracks defined")
ErrServerMissingCseq = errors.New("CSeq is missing")
)
// ServerConn is a server-side RTSP connection. // ServerConn is a server-side RTSP connection.
type ServerConn struct { type ServerConn struct {
s *Server s *Server
@@ -47,12 +58,172 @@ func (sc *ServerConn) EnableReadTimeout(v bool) {
sc.readTimeout = v sc.readTimeout = v
} }
func (sc *ServerConn) backgroundRead( // ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
onRequest func(req *base.Request) (*base.Response, error), type ServerConnReadHandlers struct {
onFrame func(trackID int, streamType StreamType, content []byte), // called after receiving a OPTIONS request.
done chan error, // if nil, it is generated automatically.
) { OnOptions func(req *base.Request) (*base.Response, error)
handleRequest := func(req *base.Request) error {
// called after receiving a DESCRIBE request.
OnDescribe func(req *base.Request) (*base.Response, error)
// called after receiving an ANNOUNCE request.
OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error)
// called after receiving a SETUP request.
OnSetup func(req *base.Request, th *headers.Transport) (*base.Response, error)
// called after receiving a PLAY request.
OnPlay func(req *base.Request) (*base.Response, error)
// called after receiving a RECORD request.
OnRecord func(req *base.Request) (*base.Response, error)
// called after receiving a GET_PARAMETER request.
// if nil, it is generated automatically.
OnGetParameter func(req *base.Request) (*base.Response, error)
// called after receiving a SET_PARAMETER request.
OnSetParameter func(req *base.Request) (*base.Response, error)
// called after receiving a TEARDOWN request.
// if nil, it is generated automatically.
OnTeardown func(req *base.Request) (*base.Response, error)
// called after receiving a Frame.
OnFrame func(trackID int, streamType StreamType, content []byte)
}
func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan error) {
handleRequest := func(req *base.Request) (*base.Response, error) {
switch req.Method {
case base.Options:
if handlers.OnOptions != nil {
return handlers.OnOptions(req)
}
var methods []string
if handlers.OnDescribe != nil {
methods = append(methods, string(base.Describe))
}
if handlers.OnAnnounce != nil {
methods = append(methods, string(base.Announce))
}
if handlers.OnSetup != nil {
methods = append(methods, string(base.Setup))
}
if handlers.OnPlay != nil {
methods = append(methods, string(base.Play))
}
if handlers.OnRecord != nil {
methods = append(methods, string(base.Record))
}
methods = append(methods, string(base.GetParameter))
if handlers.OnSetParameter != nil {
methods = append(methods, string(base.SetParameter))
}
methods = append(methods, string(base.Teardown))
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join(methods, ", ")},
},
}, nil
case base.Describe:
if handlers.OnDescribe != nil {
return handlers.OnDescribe(req)
}
case base.Announce:
if handlers.OnAnnounce != nil {
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, ErrServerContentTypeMissing
}
if ct[0] != "application/sdp" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unsupported Content-Type '%s'", ct)
}
tracks, err := ReadTracks(req.Content)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid SDP: %s", err)
}
if len(tracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, ErrServerNoTracksDefined
}
return handlers.OnAnnounce(req, tracks)
}
case base.Setup:
if handlers.OnSetup != nil {
th, err := headers.ReadTransport(req.Header["Transport"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header: %s", err)
}
return handlers.OnSetup(req, th)
}
case base.Play:
if handlers.OnPlay != nil {
return handlers.OnPlay(req)
}
case base.Record:
if handlers.OnRecord != nil {
return handlers.OnRecord(req)
}
case base.GetParameter:
if handlers.OnGetParameter != nil {
return handlers.OnGetParameter(req)
}
// GET_PARAMETER is used like a ping
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"text/parameters"},
},
Content: []byte("\n"),
}, nil
case base.SetParameter:
if handlers.OnSetParameter != nil {
return handlers.OnSetParameter(req)
}
case base.Teardown:
if handlers.OnTeardown != nil {
return handlers.OnTeardown(req)
}
return &base.Response{
StatusCode: base.StatusOK,
}, ErrServerTeardown
}
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unhandled method: %v", req.Method)
}
handleRequestOuter := func(req *base.Request) error {
sc.mutex.Lock() sc.mutex.Lock()
defer sc.mutex.Unlock() defer sc.mutex.Unlock()
@@ -64,17 +235,21 @@ func (sc *ServerConn) backgroundRead(
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
Header: base.Header{}, Header: base.Header{},
}.Write(sc.bw) }.Write(sc.bw)
return fmt.Errorf("cseq is missing") return ErrServerMissingCseq
} }
res, err := onRequest(req) res, err := handleRequest(req)
// add cseq to response
if res.Header == nil { if res.Header == nil {
res.Header = base.Header{} res.Header = base.Header{}
} }
// add cseq
res.Header["CSeq"] = cseq res.Header["CSeq"] = cseq
// add server
res.Header["Server"] = base.HeaderValue{"gortsplib"}
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
res.Write(sc.bw) res.Write(sc.bw)
@@ -104,10 +279,10 @@ outer:
switch what.(type) { switch what.(type) {
case *base.InterleavedFrame: case *base.InterleavedFrame:
onFrame(frame.TrackID, frame.StreamType, frame.Content) handlers.OnFrame(frame.TrackID, frame.StreamType, frame.Content)
case *base.Request: case *base.Request:
err := handleRequest(&req) err := handleRequestOuter(&req)
if err != nil { if err != nil {
errRet = err errRet = err
break outer break outer
@@ -121,7 +296,7 @@ outer:
break outer break outer
} }
err = handleRequest(&req) err = handleRequestOuter(&req)
if err != nil { if err != nil {
errRet = err errRet = err
break outer break outer
@@ -134,14 +309,11 @@ outer:
// Read starts reading requests and frames. // Read starts reading requests and frames.
// it returns a channel that is written when the reading stops. // it returns a channel that is written when the reading stops.
func (sc *ServerConn) Read( func (sc *ServerConn) Read(handlers ServerConnReadHandlers) chan error {
onRequest func(req *base.Request) (*base.Response, error),
onFrame func(trackID int, streamType StreamType, content []byte),
) chan error {
// channel is buffered, since listening to it is not mandatory // channel is buffered, since listening to it is not mandatory
done := make(chan error, 1) done := make(chan error, 1)
go sc.backgroundRead(onRequest, onFrame, done) go sc.backgroundRead(handlers, done)
return done return done
} }