server: replace SetuppedProtocol() with SetuppedTransport()

This commit is contained in:
aler9
2021-10-22 17:40:18 +02:00
parent 7a000bed0e
commit e7ab15750c
6 changed files with 233 additions and 194 deletions

View File

@@ -43,7 +43,7 @@ func main() {
// Client allows to set additional client options // Client allows to set additional client options
c := &gortsplib.Client{ c := &gortsplib.Client{
// the stream transport (UDP, Multicast or TCP). If nil, it is chosen automatically // the stream transport (UDP or TCP). If nil, it is chosen automatically
Transport: nil, Transport: nil,
// timeout of read operations // timeout of read operations
ReadTimeout: 10 * time.Second, ReadTimeout: 10 * time.Second,

View File

@@ -663,12 +663,12 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) {
} }
func TestServerPublish(t *testing.T) { func TestServerPublish(t *testing.T) {
for _, proto := range []string{ for _, transport := range []string{
"udp", "udp",
"tcp", "tcp",
"tls", "tls",
} { } {
t.Run(proto, func(t *testing.T) { t.Run(transport, func(t *testing.T) {
connOpened := make(chan struct{}) connOpened := make(chan struct{})
connClosed := make(chan struct{}) connClosed := make(chan struct{})
sessionOpened := make(chan struct{}) sessionOpened := make(chan struct{})
@@ -720,7 +720,7 @@ func TestServerPublish(t *testing.T) {
}, },
} }
switch proto { switch transport {
case "udp": case "udp":
s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001" s.UDPRTCPAddress = "127.0.0.1:8001"
@@ -740,7 +740,7 @@ func TestServerPublish(t *testing.T) {
defer nconn.Close() defer nconn.Close()
conn := func() net.Conn { conn := func() net.Conn {
if proto == "tls" { if transport == "tls" {
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
} }
return nconn return nconn
@@ -785,7 +785,7 @@ func TestServerPublish(t *testing.T) {
}(), }(),
} }
if proto == "udp" { if transport == "udp" {
inTH.Protocol = base.StreamProtocolUDP inTH.Protocol = base.StreamProtocolUDP
inTH.ClientPorts = &[2]int{35466, 35467} inTH.ClientPorts = &[2]int{35466, 35467}
} else { } else {
@@ -811,7 +811,7 @@ func TestServerPublish(t *testing.T) {
var l1 net.PacketConn var l1 net.PacketConn
var l2 net.PacketConn var l2 net.PacketConn
if proto == "udp" { if transport == "udp" {
l1, err = net.ListenPacket("udp", "localhost:35466") l1, err = net.ListenPacket("udp", "localhost:35466")
require.NoError(t, err) require.NoError(t, err)
defer l1.Close() defer l1.Close()
@@ -833,7 +833,7 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
// client -> server // client -> server
if proto == "udp" { if transport == "udp" {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{
@@ -863,7 +863,7 @@ func TestServerPublish(t *testing.T) {
} }
// server -> client (RTCP) // server -> client (RTCP)
if proto == "udp" { if transport == "udp" {
// skip firewall opening // skip firewall opening
buf := make([]byte, 2048) buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf) _, _, err := l2.ReadFrom(buf)
@@ -1148,11 +1148,11 @@ func TestServerPublishRTCPReport(t *testing.T) {
} }
func TestServerPublishTimeout(t *testing.T) { func TestServerPublishTimeout(t *testing.T) {
for _, proto := range []string{ for _, transport := range []string{
"udp", "udp",
"tcp", "tcp",
} { } {
t.Run(proto, func(t *testing.T) { t.Run(transport, func(t *testing.T) {
connClosed := make(chan struct{}) connClosed := make(chan struct{})
sessionClosed := make(chan struct{}) sessionClosed := make(chan struct{})
@@ -1183,7 +1183,7 @@ func TestServerPublishTimeout(t *testing.T) {
ReadTimeout: 1 * time.Second, ReadTimeout: 1 * time.Second,
} }
if proto == "udp" { if transport == "udp" {
s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001" s.UDPRTCPAddress = "127.0.0.1:8001"
} }
@@ -1231,7 +1231,7 @@ func TestServerPublishTimeout(t *testing.T) {
}(), }(),
} }
if proto == "udp" { if transport == "udp" {
inTH.Protocol = base.StreamProtocolUDP inTH.Protocol = base.StreamProtocolUDP
inTH.ClientPorts = &[2]int{35466, 35467} inTH.ClientPorts = &[2]int{35466, 35467}
} else { } else {
@@ -1268,7 +1268,7 @@ func TestServerPublishTimeout(t *testing.T) {
<-sessionClosed <-sessionClosed
if proto == "tcp" { if transport == "tcp" {
<-connClosed <-connClosed
} }
}) })
@@ -1276,11 +1276,11 @@ func TestServerPublishTimeout(t *testing.T) {
} }
func TestServerPublishWithoutTeardown(t *testing.T) { func TestServerPublishWithoutTeardown(t *testing.T) {
for _, proto := range []string{ for _, transport := range []string{
"udp", "udp",
"tcp", "tcp",
} { } {
t.Run(proto, func(t *testing.T) { t.Run(transport, func(t *testing.T) {
connClosed := make(chan struct{}) connClosed := make(chan struct{})
sessionClosed := make(chan struct{}) sessionClosed := make(chan struct{})
@@ -1311,7 +1311,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
ReadTimeout: 1 * time.Second, ReadTimeout: 1 * time.Second,
} }
if proto == "udp" { if transport == "udp" {
s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001" s.UDPRTCPAddress = "127.0.0.1:8001"
} }
@@ -1358,7 +1358,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
}(), }(),
} }
if proto == "udp" { if transport == "udp" {
inTH.Protocol = base.StreamProtocolUDP inTH.Protocol = base.StreamProtocolUDP
inTH.ClientPorts = &[2]int{35466, 35467} inTH.ClientPorts = &[2]int{35466, 35467}
} else { } else {

View File

@@ -287,13 +287,13 @@ func TestServerReadErrorSetupTrackTwice(t *testing.T) {
} }
func TestServerRead(t *testing.T) { func TestServerRead(t *testing.T) {
for _, proto := range []string{ for _, transport := range []string{
"udp", "udp",
"tcp", "tcp",
"tls", "tls",
"multicast", "multicast",
} { } {
t.Run(proto, func(t *testing.T) { t.Run(transport, func(t *testing.T) {
connOpened := make(chan struct{}) connOpened := make(chan struct{})
connClosed := make(chan struct{}) connClosed := make(chan struct{})
sessionOpened := make(chan struct{}) sessionOpened := make(chan struct{})
@@ -339,7 +339,7 @@ func TestServerRead(t *testing.T) {
}, },
onFrame: func(ctx *ServerHandlerOnFrameCtx) { onFrame: func(ctx *ServerHandlerOnFrameCtx) {
// skip multicast loopback // skip multicast loopback
if proto == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 {
return return
} }
@@ -356,7 +356,7 @@ func TestServerRead(t *testing.T) {
}, },
} }
switch proto { switch transport {
case "udp": case "udp":
s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001" s.UDPRTCPAddress = "127.0.0.1:8001"
@@ -381,7 +381,7 @@ func TestServerRead(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
conn := func() net.Conn { conn := func() net.Conn {
if proto == "tls" { if transport == "tls" {
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
} }
return nconn return nconn
@@ -397,7 +397,7 @@ func TestServerRead(t *testing.T) {
}(), }(),
} }
switch proto { switch transport {
case "udp": case "udp":
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
inTH.Delivery = &v inTH.Delivery = &v
@@ -431,11 +431,25 @@ func TestServerRead(t *testing.T) {
err = th.Read(res.Header["Transport"]) err = th.Read(res.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
switch transport {
case "udp":
require.Equal(t, base.StreamProtocolUDP, th.Protocol)
require.Equal(t, base.StreamDeliveryUnicast, *th.Delivery)
case "multicast":
require.Equal(t, base.StreamProtocolUDP, th.Protocol)
require.Equal(t, base.StreamDeliveryMulticast, *th.Delivery)
default:
require.Equal(t, base.StreamProtocolTCP, th.Protocol)
require.Equal(t, base.StreamDeliveryUnicast, *th.Delivery)
}
<-sessionOpened <-sessionOpened
var l1 net.PacketConn var l1 net.PacketConn
var l2 net.PacketConn var l2 net.PacketConn
switch proto { switch transport {
case "udp": case "udp":
l1, err = net.ListenPacket("udp", listenIP+":35466") l1, err = net.ListenPacket("udp", listenIP+":35466")
require.NoError(t, err) require.NoError(t, err)
@@ -487,14 +501,14 @@ func TestServerRead(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
// server -> client // server -> client
if proto == "udp" || proto == "multicast" { if transport == "udp" || transport == "multicast" {
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, _, err := l1.ReadFrom(buf) n, _, err := l1.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf[:n]) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf[:n])
// skip firewall opening // skip firewall opening
if proto == "udp" { if transport == "udp" {
buf := make([]byte, 2048) buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf) _, _, err := l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
@@ -520,7 +534,7 @@ func TestServerRead(t *testing.T) {
} }
// client -> server (RTCP) // client -> server (RTCP)
switch proto { switch transport {
case "udp": case "udp":
l2.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ l2.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
@@ -544,7 +558,7 @@ func TestServerRead(t *testing.T) {
<-framesReceived <-framesReceived
} }
if proto == "udp" || proto == "multicast" { if transport == "udp" || transport == "multicast" {
// ping with OPTIONS // ping with OPTIONS
res, err = writeReqReadRes(bconn, base.Request{ res, err = writeReqReadRes(bconn, base.Request{
Method: base.Options, Method: base.Options,
@@ -1001,11 +1015,11 @@ func TestServerReadPlayPausePause(t *testing.T) {
} }
func TestServerReadTimeout(t *testing.T) { func TestServerReadTimeout(t *testing.T) {
for _, proto := range []string{ for _, transport := range []string{
"udp", "udp",
// there's no timeout when reading with TCP // there's no timeout when reading with TCP
} { } {
t.Run(proto, func(t *testing.T) { t.Run(transport, func(t *testing.T) {
sessionClosed := make(chan struct{}) sessionClosed := make(chan struct{})
track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}}) track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}})
@@ -1092,11 +1106,11 @@ func TestServerReadTimeout(t *testing.T) {
} }
func TestServerReadWithoutTeardown(t *testing.T) { func TestServerReadWithoutTeardown(t *testing.T) {
for _, proto := range []string{ for _, transport := range []string{
"udp", "udp",
"tcp", "tcp",
} { } {
t.Run(proto, func(t *testing.T) { t.Run(transport, func(t *testing.T) {
connClosed := make(chan struct{}) connClosed := make(chan struct{})
sessionClosed := make(chan struct{}) sessionClosed := make(chan struct{})
@@ -1133,7 +1147,7 @@ func TestServerReadWithoutTeardown(t *testing.T) {
closeSessionAfterNoRequestsFor: 1 * time.Second, closeSessionAfterNoRequestsFor: 1 * time.Second,
} }
if proto == "udp" { if transport == "udp" {
s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001" s.UDPRTCPAddress = "127.0.0.1:8001"
} }
@@ -1158,7 +1172,7 @@ func TestServerReadWithoutTeardown(t *testing.T) {
}(), }(),
} }
if proto == "udp" { if transport == "udp" {
inTH.Protocol = base.StreamProtocolUDP inTH.Protocol = base.StreamProtocolUDP
inTH.ClientPorts = &[2]int{35466, 35467} inTH.ClientPorts = &[2]int{35466, 35467}
} else { } else {

View File

@@ -2,7 +2,6 @@ package gortsplib
import ( import (
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
) )
// ServerHandler is the interface implemented by all the server handlers. // ServerHandler is the interface implemented by all the server handlers.
@@ -99,7 +98,7 @@ type ServerHandlerOnSetupCtx struct {
Path string Path string
Query string Query string
TrackID int TrackID int
Transport *headers.Transport Transport ClientTransport
} }
// ServerHandlerOnSetup can be implemented by a ServerHandler. // ServerHandlerOnSetup can be implemented by a ServerHandler.

View File

@@ -75,6 +75,29 @@ func setupGetTrackIDPathQuery(
return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery) return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
} }
func setupGetTransport(th headers.Transport) (ClientTransport, bool) {
delivery := func() base.StreamDelivery {
if th.Delivery != nil {
return *th.Delivery
}
return base.StreamDeliveryUnicast
}()
switch th.Protocol {
case base.StreamProtocolUDP:
if delivery == base.StreamDeliveryUnicast {
return ClientTransportUDP, true
}
return ClientTransportUDPMulticast, true
default: // TCP
if delivery != base.StreamDeliveryUnicast {
return 0, false
}
return ClientTransportTCP, true
}
}
// ServerSessionState is a state of a ServerSession. // ServerSessionState is a state of a ServerSession.
type ServerSessionState int type ServerSessionState int
@@ -129,8 +152,7 @@ type ServerSession struct {
state ServerSessionState state ServerSessionState
setuppedTracks map[int]ServerSessionSetuppedTrack setuppedTracks map[int]ServerSessionSetuppedTrack
setuppedTracksByChannel map[int]int // tcp setuppedTracksByChannel map[int]int // tcp
setuppedProtocol *base.StreamProtocol setuppedTransport *ClientTransport
setuppedDelivery *base.StreamDelivery
setuppedBaseURL *base.URL // publish setuppedBaseURL *base.URL // publish
setuppedStream *ServerStream // read setuppedStream *ServerStream // read
setuppedPath *string setuppedPath *string
@@ -186,14 +208,9 @@ func (ss *ServerSession) SetuppedTracks() map[int]ServerSessionSetuppedTrack {
return ss.setuppedTracks return ss.setuppedTracks
} }
// SetuppedProtocol returns the stream protocol of the setupped tracks. // SetuppedTransport returns the transport of the setupped tracks.
func (ss *ServerSession) SetuppedProtocol() *base.StreamProtocol { func (ss *ServerSession) SetuppedTransport() *ClientTransport {
return ss.setuppedProtocol return ss.setuppedTransport
}
// SetuppedDelivery returns the delivery method of the setupped tracks.
func (ss *ServerSession) SetuppedDelivery() *base.StreamDelivery {
return ss.setuppedDelivery
} }
// AnnouncedTracks returns the announced tracks. // AnnouncedTracks returns the announced tracks.
@@ -279,10 +296,10 @@ func (ss *ServerSession) run() {
} }
} }
// if session is not in state RECORD or PLAY, or protocol is TCP // if session is not in state RECORD or PLAY, or transport is TCP
if (ss.state != ServerSessionStatePublish && if (ss.state != ServerSessionStatePublish &&
ss.state != ServerSessionStateRead) || ss.state != ServerSessionStateRead) ||
*ss.setuppedProtocol == base.StreamProtocolTCP { *ss.setuppedTransport == ClientTransportTCP {
// close if there are no active connections // close if there are no active connections
if len(ss.conns) == 0 { if len(ss.conns) == 0 {
@@ -293,7 +310,8 @@ func (ss *ServerSession) run() {
case <-checkTimeoutTicker.C: case <-checkTimeoutTicker.C:
switch { switch {
// in case of RECORD and UDP, timeout happens when no frames are being received // in case of RECORD and UDP, timeout happens when no frames are being received
case ss.state == ServerSessionStatePublish && *ss.setuppedProtocol == base.StreamProtocolUDP: case ss.state == ServerSessionStatePublish && (*ss.setuppedTransport == ClientTransportUDP ||
*ss.setuppedTransport == ClientTransportUDPMulticast):
now := time.Now() now := time.Now()
lft := atomic.LoadInt64(ss.udpLastFrameTime) lft := atomic.LoadInt64(ss.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
@@ -301,7 +319,8 @@ func (ss *ServerSession) run() {
} }
// in case of PLAY and UDP, timeout happens when no request arrives // in case of PLAY and UDP, timeout happens when no request arrives
case ss.state == ServerSessionStateRead && *ss.setuppedProtocol == base.StreamProtocolUDP: case ss.state == ServerSessionStateRead && (*ss.setuppedTransport == ClientTransportUDP ||
*ss.setuppedTransport == ClientTransportUDPMulticast):
now := time.Now() now := time.Now()
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor { if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
return liberrors.ErrServerSessionTimedOut{} return liberrors.ErrServerSessionTimedOut{}
@@ -333,13 +352,12 @@ func (ss *ServerSession) run() {
case ServerSessionStateRead: case ServerSessionStateRead:
ss.setuppedStream.readerSetInactive(ss) ss.setuppedStream.readerSetInactive(ss)
if *ss.setuppedProtocol == base.StreamProtocolUDP && if *ss.setuppedTransport == ClientTransportUDP {
*ss.setuppedDelivery == base.StreamDeliveryUnicast {
ss.s.udpRTCPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss)
} }
case ServerSessionStatePublish: case ServerSessionStatePublish:
if *ss.setuppedProtocol == base.StreamProtocolUDP { if *ss.setuppedTransport == ClientTransportUDP {
ss.s.udpRTPListener.removeClient(ss) ss.s.udpRTPListener.removeClient(ss)
ss.s.udpRTCPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss)
} }
@@ -550,60 +568,35 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, liberrors.ErrServerTrackAlreadySetup{TrackID: trackID} }, liberrors.ErrServerTrackAlreadySetup{TrackID: trackID}
} }
delivery := func() base.StreamDelivery { transport, ok := setupGetTransport(inTH)
if inTH.Delivery != nil { if !ok {
return *inTH.Delivery
}
return base.StreamDeliveryUnicast
}()
switch ss.state {
case ServerSessionStateInitial, ServerSessionStatePreRead: // play
if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
}
default: // record
if delivery == base.StreamDeliveryMulticast {
return &base.Response{ return &base.Response{
StatusCode: base.StatusUnsupportedTransport, StatusCode: base.StatusUnsupportedTransport,
}, nil }, nil
} }
if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord { switch transport {
case ClientTransportUDP:
if inTH.ClientPorts == nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode} }, liberrors.ErrServerTransportHeaderNoClientPorts{}
}
} }
if inTH.Protocol == base.StreamProtocolUDP {
if delivery == base.StreamDeliveryUnicast {
if ss.s.udpRTPListener == nil { if ss.s.udpRTPListener == nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusUnsupportedTransport, StatusCode: base.StatusUnsupportedTransport,
}, nil }, nil
} }
if inTH.ClientPorts == nil { case ClientTransportUDPMulticast:
return &base.Response{ if ss.s.MulticastIPRange == "" {
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderNoClientPorts{}
}
} else if ss.s.MulticastIPRange == "" {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
} else {
if delivery == base.StreamDeliveryMulticast {
return &base.Response{ return &base.Response{
StatusCode: base.StatusUnsupportedTransport, StatusCode: base.StatusUnsupportedTransport,
}, nil }, nil
} }
default: // TCP
if inTH.InterleavedIDs == nil { if inTH.InterleavedIDs == nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
@@ -624,13 +617,34 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
} }
} }
if ss.setuppedProtocol != nil && if ss.setuppedTransport != nil && *ss.setuppedTransport != transport {
(*ss.setuppedProtocol != inTH.Protocol || *ss.setuppedDelivery != delivery) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTracksDifferentProtocols{} }, liberrors.ErrServerTracksDifferentProtocols{}
} }
switch ss.state {
case ServerSessionStateInitial, ServerSessionStatePreRead: // play
if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
}
default: // record
if transport == ClientTransportUDPMulticast {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
}
}
res, stream, err := ss.s.Handler.(ServerHandlerOnSetup).OnSetup(&ServerHandlerOnSetupCtx{ res, stream, err := ss.s.Handler.(ServerHandlerOnSetup).OnSetup(&ServerHandlerOnSetupCtx{
Server: ss.s, Server: ss.s,
Session: ss, Session: ss,
@@ -639,14 +653,13 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
Path: path, Path: path,
Query: query, Query: query,
TrackID: trackID, TrackID: trackID,
Transport: &inTH, Transport: transport,
}) })
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {
if ss.state == ServerSessionStateInitial { if ss.state == ServerSessionStateInitial {
err := stream.readerAdd(ss, err := stream.readerAdd(ss,
inTH.Protocol, transport,
delivery,
inTH.ClientPorts, inTH.ClientPorts,
) )
if err != nil { if err != nil {
@@ -670,8 +683,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
} }
} }
ss.setuppedProtocol = &inTH.Protocol ss.setuppedTransport = &transport
ss.setuppedDelivery = &delivery
if res.Header == nil { if res.Header == nil {
res.Header = make(base.Header) res.Header = make(base.Header)
@@ -679,8 +691,18 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
sst := ServerSessionSetuppedTrack{} sst := ServerSessionSetuppedTrack{}
switch { switch transport {
case delivery == base.StreamDeliveryMulticast: case ClientTransportUDP:
sst.udpRTPPort = inTH.ClientPorts[0]
sst.udpRTCPPort = inTH.ClientPorts[1]
th.Protocol = base.StreamProtocolUDP
de := base.StreamDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
case ClientTransportUDPMulticast:
th.Protocol = base.StreamProtocolUDP th.Protocol = base.StreamProtocolUDP
de := base.StreamDeliveryMulticast de := base.StreamDeliveryMulticast
th.Delivery = &de th.Delivery = &de
@@ -693,16 +715,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
stream.multicastListeners[trackID].rtcpListener.port(), stream.multicastListeners[trackID].rtcpListener.port(),
} }
case inTH.Protocol == base.StreamProtocolUDP:
sst.udpRTPPort = inTH.ClientPorts[0]
sst.udpRTCPPort = inTH.ClientPorts[1]
th.Protocol = base.StreamProtocolUDP
de := base.StreamDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
default: // TCP default: // TCP
sst.tcpChannel = inTH.InterleavedIDs[0] sst.tcpChannel = inTH.InterleavedIDs[0]
@@ -790,7 +802,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStateRead ss.state = ServerSessionStateRead
if *ss.setuppedProtocol == base.StreamProtocolTCP { if *ss.setuppedTransport == ClientTransportTCP {
ss.tcpConn = sc ss.tcpConn = sc
} }
@@ -833,8 +845,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.setuppedStream.readerSetActive(ss) ss.setuppedStream.readerSetActive(ss)
if *ss.setuppedProtocol == base.StreamProtocolUDP { switch *ss.setuppedTransport {
if *ss.setuppedDelivery == base.StreamDeliveryUnicast { case ClientTransportUDP:
for trackID, track := range ss.setuppedTracks { for trackID, track := range ss.setuppedTracks {
// readers can send RTCP packets // readers can send RTCP packets
sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false) sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false)
@@ -843,13 +855,17 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.WriteFrame(trackID, StreamTypeRTCP, ss.WriteFrame(trackID, StreamTypeRTCP,
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
} }
return res, err
case ClientTransportUDPMulticast:
default: // TCP
err = liberrors.ErrServerTCPFramesEnable{}
} }
return res, err return res, err
} }
return res, liberrors.ErrServerTCPFramesEnable{}
}
} }
return res, err return res, err
@@ -883,7 +899,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
path, query := base.PathSplitQuery(pathAndQuery) path, query := base.PathSplitQuery(pathAndQuery)
// allow to use WriteFrame() before response // allow to use WriteFrame() before response
if *ss.setuppedProtocol == base.StreamProtocolTCP { if *ss.setuppedTransport == ClientTransportTCP {
ss.tcpConn = sc ss.tcpConn = sc
} }
@@ -904,7 +920,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStatePublish ss.state = ServerSessionStatePublish
if *ss.setuppedProtocol == base.StreamProtocolUDP { switch *ss.setuppedTransport {
case ClientTransportUDP:
for trackID, track := range ss.setuppedTracks { for trackID, track := range ss.setuppedTracks {
ss.s.udpRTPListener.addClient(ss.ip(), track.udpRTPPort, ss, trackID, true) ss.s.udpRTPListener.addClient(ss.ip(), track.udpRTPPort, ss, trackID, true)
ss.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, true) ss.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, true)
@@ -916,10 +933,13 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
} }
return res, err case ClientTransportUDPMulticast:
default: // TCP
err = liberrors.ErrServerTCPFramesEnable{}
} }
return res, liberrors.ErrServerTCPFramesEnable{} return res, err
} }
ss.tcpConn = nil ss.tcpConn = nil
@@ -967,23 +987,29 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.state = ServerSessionStatePreRead ss.state = ServerSessionStatePreRead
ss.tcpConn = nil ss.tcpConn = nil
if *ss.setuppedProtocol == base.StreamProtocolUDP { switch *ss.setuppedTransport {
if *ss.setuppedDelivery == base.StreamDeliveryUnicast { case ClientTransportUDP:
ss.s.udpRTCPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss)
}
} else { case ClientTransportUDPMulticast:
return res, liberrors.ErrServerTCPFramesDisable{}
default: // TCP
err = liberrors.ErrServerTCPFramesDisable{}
} }
case ServerSessionStatePublish: case ServerSessionStatePublish:
ss.state = ServerSessionStatePrePublish ss.state = ServerSessionStatePrePublish
ss.tcpConn = nil ss.tcpConn = nil
if *ss.setuppedProtocol == base.StreamProtocolUDP { switch *ss.setuppedTransport {
case ClientTransportUDP:
ss.s.udpRTPListener.removeClient(ss) ss.s.udpRTPListener.removeClient(ss)
ss.s.udpRTCPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss)
} else {
return res, liberrors.ErrServerTCPFramesDisable{} case ClientTransportUDPMulticast:
default: // TCP
err = liberrors.ErrServerTCPFramesDisable{}
} }
} }
} }
@@ -1037,8 +1063,8 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
return return
} }
if *ss.setuppedProtocol == base.StreamProtocolUDP { switch *ss.setuppedTransport {
if *ss.setuppedDelivery == base.StreamDeliveryUnicast { case ClientTransportUDP:
track := ss.setuppedTracks[trackID] track := ss.setuppedTracks[trackID]
if streamType == StreamTypeRTP { if streamType == StreamTypeRTP {
@@ -1054,8 +1080,8 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
Port: track.udpRTCPPort, Port: track.udpRTCPPort,
}) })
} }
}
} else { case ClientTransportTCP:
channel := ss.setuppedTracks[trackID].tcpChannel channel := ss.setuppedTracks[trackID].tcpChannel
if streamType == base.StreamTypeRTCP { if streamType == base.StreamTypeRTCP {
channel++ channel++

View File

@@ -7,7 +7,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/liberrors"
) )
@@ -114,8 +113,7 @@ func (st *ServerStream) lastSequenceNumber(trackID int) uint16 {
func (st *ServerStream) readerAdd( func (st *ServerStream) readerAdd(
ss *ServerSession, ss *ServerSession,
protocol base.StreamProtocol, transport ClientTransport,
delivery base.StreamDelivery,
clientPorts *[2]int, clientPorts *[2]int,
) error { ) error {
st.mutex.Lock() st.mutex.Lock()
@@ -129,12 +127,11 @@ func (st *ServerStream) readerAdd(
} }
} }
// if new reader is a UDP-unicast reader, check that its port are not already switch transport {
// in use by another reader. case ClientTransportUDP:
if protocol == base.StreamProtocolUDP && delivery == base.StreamDeliveryUnicast { // check whether client ports are already in use by another reader.
for r := range st.readersUnicast { for r := range st.readersUnicast {
if *r.setuppedProtocol == base.StreamProtocolUDP && if *r.setuppedTransport == ClientTransportUDP &&
*r.setuppedDelivery == base.StreamDeliveryUnicast &&
r.ip().Equal(ss.ip()) && r.ip().Equal(ss.ip()) &&
r.zone() == ss.zone() { r.zone() == ss.zone() {
for _, rt := range r.setuppedTracks { for _, rt := range r.setuppedTracks {
@@ -144,12 +141,10 @@ func (st *ServerStream) readerAdd(
} }
} }
} }
}
case ClientTransportUDPMulticast:
// allocate multicast listeners // allocate multicast listeners
if protocol == base.StreamProtocolUDP && if st.multicastListeners == nil {
delivery == base.StreamDeliveryMulticast &&
st.multicastListeners == nil {
st.multicastListeners = make([]*listenerPair, len(st.tracks)) st.multicastListeners = make([]*listenerPair, len(st.tracks))
for i := range st.tracks { for i := range st.tracks {
@@ -171,6 +166,7 @@ func (st *ServerStream) readerAdd(
} }
} }
} }
}
st.readers[ss] = struct{}{} st.readers[ss] = struct{}{}
@@ -196,9 +192,11 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) {
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() defer st.mutex.Unlock()
if *ss.setuppedDelivery == base.StreamDeliveryUnicast { switch *ss.setuppedTransport {
case ClientTransportUDP, ClientTransportTCP:
st.readersUnicast[ss] = struct{}{} st.readersUnicast[ss] = struct{}{}
} else {
default: // UDPMulticast
for trackID := range ss.setuppedTracks { for trackID := range ss.setuppedTracks {
st.multicastListeners[trackID].rtcpListener.addClient( st.multicastListeners[trackID].rtcpListener.addClient(
ss.ip(), st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false) ss.ip(), st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false)
@@ -210,13 +208,17 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) {
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() defer st.mutex.Unlock()
if *ss.setuppedDelivery == base.StreamDeliveryUnicast { switch *ss.setuppedTransport {
case ClientTransportUDP, ClientTransportTCP:
delete(st.readersUnicast, ss) delete(st.readersUnicast, ss)
} else if st.multicastListeners != nil {
default: // UDPMulticast
if st.multicastListeners != nil {
for trackID := range ss.setuppedTracks { for trackID := range ss.setuppedTracks {
st.multicastListeners[trackID].rtcpListener.removeClient(ss) st.multicastListeners[trackID].rtcpListener.removeClient(ss)
} }
} }
}
} }
// WriteFrame writes a frame to all the readers of the stream. // WriteFrame writes a frame to all the readers of the stream.
@@ -248,13 +250,11 @@ func (st *ServerStream) WriteFrame(trackID int, streamType StreamType, payload [
if streamType == StreamTypeRTP { if streamType == StreamTypeRTP {
st.multicastListeners[trackID].rtpListener.write(payload, &net.UDPAddr{ st.multicastListeners[trackID].rtpListener.write(payload, &net.UDPAddr{
IP: st.multicastListeners[trackID].rtpListener.ip(), IP: st.multicastListeners[trackID].rtpListener.ip(),
Zone: "",
Port: st.multicastListeners[trackID].rtpListener.port(), Port: st.multicastListeners[trackID].rtpListener.port(),
}) })
} else { } else {
st.multicastListeners[trackID].rtcpListener.write(payload, &net.UDPAddr{ st.multicastListeners[trackID].rtcpListener.write(payload, &net.UDPAddr{
IP: st.multicastListeners[trackID].rtpListener.ip(), IP: st.multicastListeners[trackID].rtpListener.ip(),
Zone: "",
Port: st.multicastListeners[trackID].rtcpListener.port(), Port: st.multicastListeners[trackID].rtcpListener.port(),
}) })
} }