add server tests

This commit is contained in:
aler9
2020-12-12 23:18:56 +01:00
parent 4b4d121088
commit 48c96be2b5
2 changed files with 277 additions and 8 deletions

View File

@@ -59,7 +59,7 @@ func (c *container) wait() int {
return int(code) return int(code)
} }
func TestDialRead(t *testing.T) { func TestClientDialRead(t *testing.T) {
for _, proto := range []string{ for _, proto := range []string{
"udp", "udp",
"tcp", "tcp",
@@ -119,7 +119,7 @@ func TestDialRead(t *testing.T) {
} }
} }
func TestDialReadAutomaticProtocol(t *testing.T) { func TestClientDialReadAutomaticProtocol(t *testing.T) {
cnt1, err := newContainer("rtsp-simple-server", "server", []string{ cnt1, err := newContainer("rtsp-simple-server", "server", []string{
"protocols: [tcp]\n", "protocols: [tcp]\n",
}) })
@@ -160,7 +160,7 @@ func TestDialReadAutomaticProtocol(t *testing.T) {
<-done <-done
} }
func TestDialReadRedirect(t *testing.T) { func TestClientDialReadRedirect(t *testing.T) {
cnt1, err := newContainer("rtsp-simple-server", "server", []string{ cnt1, err := newContainer("rtsp-simple-server", "server", []string{
"paths:\n" + "paths:\n" +
" path1:\n" + " path1:\n" +
@@ -203,7 +203,7 @@ func TestDialReadRedirect(t *testing.T) {
<-done <-done
} }
func TestDialReadPause(t *testing.T) { func TestClientDialReadPause(t *testing.T) {
for _, proto := range []string{ for _, proto := range []string{
"udp", "udp",
"tcp", "tcp",
@@ -274,7 +274,7 @@ func TestDialReadPause(t *testing.T) {
} }
} }
func TestDialPublishSerial(t *testing.T) { func TestClientDialPublishSerial(t *testing.T) {
for _, proto := range []string{ for _, proto := range []string{
"udp", "udp",
"tcp", "tcp",
@@ -335,7 +335,7 @@ func TestDialPublishSerial(t *testing.T) {
} }
} }
func TestDialPublishParallel(t *testing.T) { func TestClientDialPublishParallel(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
proto string proto string
server string server string
@@ -448,7 +448,7 @@ func TestDialPublishParallel(t *testing.T) {
} }
} }
func TestDialPublishPauseSerial(t *testing.T) { func TestClientDialPublishPauseSerial(t *testing.T) {
for _, proto := range []string{ for _, proto := range []string{
"udp", "udp",
"tcp", "tcp",
@@ -520,7 +520,7 @@ func TestDialPublishPauseSerial(t *testing.T) {
} }
} }
func TestDialPublishPauseParallel(t *testing.T) { func TestClientDialPublishPauseParallel(t *testing.T) {
for _, proto := range []string{ for _, proto := range []string{
"udp", "udp",
"tcp", "tcp",

269
serverconf_test.go Normal file
View File

@@ -0,0 +1,269 @@
package gortsplib
import (
"fmt"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
)
type testServ struct {
s *Server
wg sync.WaitGroup
mutex sync.Mutex
publisher *ServerConn
sdp []byte
readers map[*ServerConn]struct{}
}
func newTestServ() (*testServ, error) {
s, err := Serve(":8554")
if err != nil {
return nil, err
}
ts := &testServ{
s: s,
readers: make(map[*ServerConn]struct{}),
}
ts.wg.Add(1)
go ts.run()
return ts, nil
}
func (ts *testServ) close() {
ts.s.Close()
ts.wg.Wait()
}
func (ts *testServ) run() {
defer ts.wg.Done()
for {
conn, err := ts.s.Accept()
if err != nil {
return
}
ts.wg.Add(1)
go ts.handleConn(conn)
}
}
func (ts *testServ) handleConn(conn *ServerConn) {
defer ts.wg.Done()
defer conn.Close()
// this is called when a request arrives
onRequest := func(req *base.Request) (*base.Response, error) {
switch req.Method {
case base.Options:
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
string(base.Announce),
string(base.Setup),
string(base.Play),
string(base.Record),
string(base.Teardown),
}, ", ")},
},
}, nil
case base.Describe:
ts.mutex.Lock()
defer ts.mutex.Unlock()
if ts.publisher == nil {
return &base.Response{
StatusCode: base.StatusNotFound,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Base": base.HeaderValue{req.URL.String() + "/"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Content: ts.sdp,
}, nil
case base.Announce:
ct, ok := req.Header["Content-Type"]
if !ok || len(ct) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("Content-Type header missing")
}
if ct[0] != "application/sdp" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unsupported Content-Type '%s'", ct)
}
tracks, err := ReadTracks(req.Content)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid SDP: %s", err)
}
if len(tracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("no tracks defined")
}
ts.mutex.Lock()
defer ts.mutex.Unlock()
if ts.publisher != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
ts.publisher = conn
ts.sdp = tracks.Write()
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
case base.Setup:
th, err := headers.ReadTransport(req.Header["Transport"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header: %s", err)
}
if th.Protocol == StreamProtocolUDP {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": req.Header["Transport"],
"Session": base.HeaderValue{"12345678"},
},
}, nil
case base.Play:
ts.mutex.Lock()
defer ts.mutex.Unlock()
ts.readers[conn] = struct{}{}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(false)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
case base.Record:
ts.mutex.Lock()
defer ts.mutex.Unlock()
if conn != ts.publisher {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("someone is already publishing")
}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Session": base.HeaderValue{"12345678"},
},
}, nil
case base.Teardown:
return &base.Response{
StatusCode: base.StatusOK,
}, fmt.Errorf("terminated")
}
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unhandled method: %v", req.Method)
}
onFrame := func(trackID int, typ StreamType, buf []byte) {
ts.mutex.Lock()
defer ts.mutex.Unlock()
if conn == ts.publisher {
for r := range ts.readers {
r.WriteFrame(trackID, typ, buf)
}
}
}
<-conn.Read(onRequest, onFrame)
ts.mutex.Lock()
defer ts.mutex.Unlock()
if conn == ts.publisher {
ts.publisher = nil
ts.sdp = nil
}
}
func TestServerPublishReadTCP(t *testing.T) {
ts, err := newTestServ()
require.NoError(t, err)
defer ts.close()
cnt1, err := newContainer("ffmpeg", "publish", []string{
"-re",
"-stream_loop", "-1",
"-i", "/emptyvideo.ts",
"-c", "copy",
"-f", "rtsp",
"-rtsp_transport", "tcp",
"rtsp://localhost:8554/teststream",
})
require.NoError(t, err)
defer cnt1.close()
time.Sleep(1 * time.Second)
cnt2, err := newContainer("ffmpeg", "read", []string{
"-rtsp_transport", "tcp",
"-i", "rtsp://localhost:8554/teststream",
"-vframes", "1",
"-f", "image2",
"-y", "/dev/null",
})
require.NoError(t, err)
defer cnt2.close()
require.Equal(t, 0, cnt2.wait())
}