From e7ab15750c056d682f825a4e94712a7a1fbd06ae Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 22 Oct 2021 17:40:18 +0200 Subject: [PATCH] server: replace SetuppedProtocol() with SetuppedTransport() --- examples/client-publish-options/main.go | 2 +- server_publish_test.go | 34 +-- server_read_test.go | 48 +++-- serverhandler.go | 3 +- serversession.go | 270 +++++++++++++----------- serverstream.go | 70 +++--- 6 files changed, 233 insertions(+), 194 deletions(-) diff --git a/examples/client-publish-options/main.go b/examples/client-publish-options/main.go index 0ac771db..c9d2ceb4 100644 --- a/examples/client-publish-options/main.go +++ b/examples/client-publish-options/main.go @@ -43,7 +43,7 @@ func main() { // Client allows to set additional client options c := &gortsplib.Client{ - // the stream transport (UDP, Multicast or TCP). If nil, it is chosen automatically + // the stream transport (UDP or TCP). If nil, it is chosen automatically Transport: nil, // timeout of read operations ReadTimeout: 10 * time.Second, diff --git a/server_publish_test.go b/server_publish_test.go index 0e5309d2..75242bf8 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -663,12 +663,12 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) { } func TestServerPublish(t *testing.T) { - for _, proto := range []string{ + for _, transport := range []string{ "udp", "tcp", "tls", } { - t.Run(proto, func(t *testing.T) { + t.Run(transport, func(t *testing.T) { connOpened := make(chan struct{}) connClosed := make(chan struct{}) sessionOpened := make(chan struct{}) @@ -720,7 +720,7 @@ func TestServerPublish(t *testing.T) { }, } - switch proto { + switch transport { case "udp": s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" @@ -740,7 +740,7 @@ func TestServerPublish(t *testing.T) { defer nconn.Close() conn := func() net.Conn { - if proto == "tls" { + if transport == "tls" { return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) } return nconn @@ -785,7 +785,7 @@ func TestServerPublish(t *testing.T) { }(), } - if proto == "udp" { + if transport == "udp" { inTH.Protocol = base.StreamProtocolUDP inTH.ClientPorts = &[2]int{35466, 35467} } else { @@ -811,7 +811,7 @@ func TestServerPublish(t *testing.T) { var l1 net.PacketConn var l2 net.PacketConn - if proto == "udp" { + if transport == "udp" { l1, err = net.ListenPacket("udp", "localhost:35466") require.NoError(t, err) defer l1.Close() @@ -833,7 +833,7 @@ func TestServerPublish(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) // client -> server - if proto == "udp" { + if transport == "udp" { time.Sleep(1 * time.Second) l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ @@ -863,7 +863,7 @@ func TestServerPublish(t *testing.T) { } // server -> client (RTCP) - if proto == "udp" { + if transport == "udp" { // skip firewall opening buf := make([]byte, 2048) _, _, err := l2.ReadFrom(buf) @@ -1148,11 +1148,11 @@ func TestServerPublishRTCPReport(t *testing.T) { } func TestServerPublishTimeout(t *testing.T) { - for _, proto := range []string{ + for _, transport := range []string{ "udp", "tcp", } { - t.Run(proto, func(t *testing.T) { + t.Run(transport, func(t *testing.T) { connClosed := make(chan struct{}) sessionClosed := make(chan struct{}) @@ -1183,7 +1183,7 @@ func TestServerPublishTimeout(t *testing.T) { ReadTimeout: 1 * time.Second, } - if proto == "udp" { + if transport == "udp" { s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" } @@ -1231,7 +1231,7 @@ func TestServerPublishTimeout(t *testing.T) { }(), } - if proto == "udp" { + if transport == "udp" { inTH.Protocol = base.StreamProtocolUDP inTH.ClientPorts = &[2]int{35466, 35467} } else { @@ -1268,7 +1268,7 @@ func TestServerPublishTimeout(t *testing.T) { <-sessionClosed - if proto == "tcp" { + if transport == "tcp" { <-connClosed } }) @@ -1276,11 +1276,11 @@ func TestServerPublishTimeout(t *testing.T) { } func TestServerPublishWithoutTeardown(t *testing.T) { - for _, proto := range []string{ + for _, transport := range []string{ "udp", "tcp", } { - t.Run(proto, func(t *testing.T) { + t.Run(transport, func(t *testing.T) { connClosed := make(chan struct{}) sessionClosed := make(chan struct{}) @@ -1311,7 +1311,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) { ReadTimeout: 1 * time.Second, } - if proto == "udp" { + if transport == "udp" { s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" } @@ -1358,7 +1358,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) { }(), } - if proto == "udp" { + if transport == "udp" { inTH.Protocol = base.StreamProtocolUDP inTH.ClientPorts = &[2]int{35466, 35467} } else { diff --git a/server_read_test.go b/server_read_test.go index 020314c4..15a819d0 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -287,13 +287,13 @@ func TestServerReadErrorSetupTrackTwice(t *testing.T) { } func TestServerRead(t *testing.T) { - for _, proto := range []string{ + for _, transport := range []string{ "udp", "tcp", "tls", "multicast", } { - t.Run(proto, func(t *testing.T) { + t.Run(transport, func(t *testing.T) { connOpened := make(chan struct{}) connClosed := make(chan struct{}) sessionOpened := make(chan struct{}) @@ -339,7 +339,7 @@ func TestServerRead(t *testing.T) { }, onFrame: func(ctx *ServerHandlerOnFrameCtx) { // skip multicast loopback - if proto == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { + if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { return } @@ -356,7 +356,7 @@ func TestServerRead(t *testing.T) { }, } - switch proto { + switch transport { case "udp": s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" @@ -381,7 +381,7 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) conn := func() net.Conn { - if proto == "tls" { + if transport == "tls" { return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) } return nconn @@ -397,7 +397,7 @@ func TestServerRead(t *testing.T) { }(), } - switch proto { + switch transport { case "udp": v := base.StreamDeliveryUnicast inTH.Delivery = &v @@ -431,11 +431,25 @@ func TestServerRead(t *testing.T) { err = th.Read(res.Header["Transport"]) require.NoError(t, err) + switch transport { + case "udp": + require.Equal(t, base.StreamProtocolUDP, th.Protocol) + require.Equal(t, base.StreamDeliveryUnicast, *th.Delivery) + + case "multicast": + require.Equal(t, base.StreamProtocolUDP, th.Protocol) + require.Equal(t, base.StreamDeliveryMulticast, *th.Delivery) + + default: + require.Equal(t, base.StreamProtocolTCP, th.Protocol) + require.Equal(t, base.StreamDeliveryUnicast, *th.Delivery) + } + <-sessionOpened var l1 net.PacketConn var l2 net.PacketConn - switch proto { + switch transport { case "udp": l1, err = net.ListenPacket("udp", listenIP+":35466") require.NoError(t, err) @@ -487,14 +501,14 @@ func TestServerRead(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) // server -> client - if proto == "udp" || proto == "multicast" { + if transport == "udp" || transport == "multicast" { buf := make([]byte, 2048) n, _, err := l1.ReadFrom(buf) require.NoError(t, err) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf[:n]) // skip firewall opening - if proto == "udp" { + if transport == "udp" { buf := make([]byte, 2048) _, _, err := l2.ReadFrom(buf) require.NoError(t, err) @@ -520,7 +534,7 @@ func TestServerRead(t *testing.T) { } // client -> server (RTCP) - switch proto { + switch transport { case "udp": l2.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), @@ -544,7 +558,7 @@ func TestServerRead(t *testing.T) { <-framesReceived } - if proto == "udp" || proto == "multicast" { + if transport == "udp" || transport == "multicast" { // ping with OPTIONS res, err = writeReqReadRes(bconn, base.Request{ Method: base.Options, @@ -1001,11 +1015,11 @@ func TestServerReadPlayPausePause(t *testing.T) { } func TestServerReadTimeout(t *testing.T) { - for _, proto := range []string{ + for _, transport := range []string{ "udp", // there's no timeout when reading with TCP } { - t.Run(proto, func(t *testing.T) { + t.Run(transport, func(t *testing.T) { sessionClosed := make(chan struct{}) track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}}) @@ -1092,11 +1106,11 @@ func TestServerReadTimeout(t *testing.T) { } func TestServerReadWithoutTeardown(t *testing.T) { - for _, proto := range []string{ + for _, transport := range []string{ "udp", "tcp", } { - t.Run(proto, func(t *testing.T) { + t.Run(transport, func(t *testing.T) { connClosed := make(chan struct{}) sessionClosed := make(chan struct{}) @@ -1133,7 +1147,7 @@ func TestServerReadWithoutTeardown(t *testing.T) { closeSessionAfterNoRequestsFor: 1 * time.Second, } - if proto == "udp" { + if transport == "udp" { s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" } @@ -1158,7 +1172,7 @@ func TestServerReadWithoutTeardown(t *testing.T) { }(), } - if proto == "udp" { + if transport == "udp" { inTH.Protocol = base.StreamProtocolUDP inTH.ClientPorts = &[2]int{35466, 35467} } else { diff --git a/serverhandler.go b/serverhandler.go index 50c9739a..d20aa8be 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -2,7 +2,6 @@ package gortsplib import ( "github.com/aler9/gortsplib/pkg/base" - "github.com/aler9/gortsplib/pkg/headers" ) // ServerHandler is the interface implemented by all the server handlers. @@ -99,7 +98,7 @@ type ServerHandlerOnSetupCtx struct { Path string Query string TrackID int - Transport *headers.Transport + Transport ClientTransport } // ServerHandlerOnSetup can be implemented by a ServerHandler. diff --git a/serversession.go b/serversession.go index c3bb8859..38e9c333 100644 --- a/serversession.go +++ b/serversession.go @@ -75,6 +75,29 @@ func setupGetTrackIDPathQuery( return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery) } +func setupGetTransport(th headers.Transport) (ClientTransport, bool) { + delivery := func() base.StreamDelivery { + if th.Delivery != nil { + return *th.Delivery + } + return base.StreamDeliveryUnicast + }() + + switch th.Protocol { + case base.StreamProtocolUDP: + if delivery == base.StreamDeliveryUnicast { + return ClientTransportUDP, true + } + return ClientTransportUDPMulticast, true + + default: // TCP + if delivery != base.StreamDeliveryUnicast { + return 0, false + } + return ClientTransportTCP, true + } +} + // ServerSessionState is a state of a ServerSession. type ServerSessionState int @@ -129,8 +152,7 @@ type ServerSession struct { state ServerSessionState setuppedTracks map[int]ServerSessionSetuppedTrack setuppedTracksByChannel map[int]int // tcp - setuppedProtocol *base.StreamProtocol - setuppedDelivery *base.StreamDelivery + setuppedTransport *ClientTransport setuppedBaseURL *base.URL // publish setuppedStream *ServerStream // read setuppedPath *string @@ -186,14 +208,9 @@ func (ss *ServerSession) SetuppedTracks() map[int]ServerSessionSetuppedTrack { return ss.setuppedTracks } -// SetuppedProtocol returns the stream protocol of the setupped tracks. -func (ss *ServerSession) SetuppedProtocol() *base.StreamProtocol { - return ss.setuppedProtocol -} - -// SetuppedDelivery returns the delivery method of the setupped tracks. -func (ss *ServerSession) SetuppedDelivery() *base.StreamDelivery { - return ss.setuppedDelivery +// SetuppedTransport returns the transport of the setupped tracks. +func (ss *ServerSession) SetuppedTransport() *ClientTransport { + return ss.setuppedTransport } // AnnouncedTracks returns the announced tracks. @@ -279,10 +296,10 @@ func (ss *ServerSession) run() { } } - // if session is not in state RECORD or PLAY, or protocol is TCP + // if session is not in state RECORD or PLAY, or transport is TCP if (ss.state != ServerSessionStatePublish && ss.state != ServerSessionStateRead) || - *ss.setuppedProtocol == base.StreamProtocolTCP { + *ss.setuppedTransport == ClientTransportTCP { // close if there are no active connections if len(ss.conns) == 0 { @@ -293,7 +310,8 @@ func (ss *ServerSession) run() { case <-checkTimeoutTicker.C: switch { // in case of RECORD and UDP, timeout happens when no frames are being received - case ss.state == ServerSessionStatePublish && *ss.setuppedProtocol == base.StreamProtocolUDP: + case ss.state == ServerSessionStatePublish && (*ss.setuppedTransport == ClientTransportUDP || + *ss.setuppedTransport == ClientTransportUDPMulticast): now := time.Now() lft := atomic.LoadInt64(ss.udpLastFrameTime) if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { @@ -301,7 +319,8 @@ func (ss *ServerSession) run() { } // in case of PLAY and UDP, timeout happens when no request arrives - case ss.state == ServerSessionStateRead && *ss.setuppedProtocol == base.StreamProtocolUDP: + case ss.state == ServerSessionStateRead && (*ss.setuppedTransport == ClientTransportUDP || + *ss.setuppedTransport == ClientTransportUDPMulticast): now := time.Now() if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor { return liberrors.ErrServerSessionTimedOut{} @@ -333,13 +352,12 @@ func (ss *ServerSession) run() { case ServerSessionStateRead: ss.setuppedStream.readerSetInactive(ss) - if *ss.setuppedProtocol == base.StreamProtocolUDP && - *ss.setuppedDelivery == base.StreamDeliveryUnicast { + if *ss.setuppedTransport == ClientTransportUDP { ss.s.udpRTCPListener.removeClient(ss) } case ServerSessionStatePublish: - if *ss.setuppedProtocol == base.StreamProtocolUDP { + if *ss.setuppedTransport == ClientTransportUDP { ss.s.udpRTPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss) } @@ -550,60 +568,35 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base }, liberrors.ErrServerTrackAlreadySetup{TrackID: trackID} } - delivery := func() base.StreamDelivery { - if inTH.Delivery != nil { - return *inTH.Delivery - } - return base.StreamDeliveryUnicast - }() - - switch ss.state { - case ServerSessionStateInitial, ServerSessionStatePreRead: // play - if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode} - } - - default: // record - if delivery == base.StreamDeliveryMulticast { - return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil - } - - if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode} - } + transport, ok := setupGetTransport(inTH) + if !ok { + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil } - if inTH.Protocol == base.StreamProtocolUDP { - if delivery == base.StreamDeliveryUnicast { - if ss.s.udpRTPListener == nil { - return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil - } - - if inTH.ClientPorts == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerTransportHeaderNoClientPorts{} - } - } else if ss.s.MulticastIPRange == "" { + switch transport { + case ClientTransportUDP: + if inTH.ClientPorts == nil { return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerTransportHeaderNoClientPorts{} } - } else { - if delivery == base.StreamDeliveryMulticast { + + if ss.s.udpRTPListener == nil { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, }, nil } + case ClientTransportUDPMulticast: + if ss.s.MulticastIPRange == "" { + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil + } + + default: // TCP if inTH.InterleavedIDs == nil { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -624,13 +617,34 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } } - if ss.setuppedProtocol != nil && - (*ss.setuppedProtocol != inTH.Protocol || *ss.setuppedDelivery != delivery) { + if ss.setuppedTransport != nil && *ss.setuppedTransport != transport { return &base.Response{ StatusCode: base.StatusBadRequest, }, liberrors.ErrServerTracksDifferentProtocols{} } + switch ss.state { + case ServerSessionStateInitial, ServerSessionStatePreRead: // play + if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode} + } + + default: // record + if transport == ClientTransportUDPMulticast { + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil + } + + if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode} + } + } + res, stream, err := ss.s.Handler.(ServerHandlerOnSetup).OnSetup(&ServerHandlerOnSetupCtx{ Server: ss.s, Session: ss, @@ -639,14 +653,13 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base Path: path, Query: query, TrackID: trackID, - Transport: &inTH, + Transport: transport, }) if res.StatusCode == base.StatusOK { if ss.state == ServerSessionStateInitial { err := stream.readerAdd(ss, - inTH.Protocol, - delivery, + transport, inTH.ClientPorts, ) if err != nil { @@ -670,8 +683,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } } - ss.setuppedProtocol = &inTH.Protocol - ss.setuppedDelivery = &delivery + ss.setuppedTransport = &transport if res.Header == nil { res.Header = make(base.Header) @@ -679,8 +691,18 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base sst := ServerSessionSetuppedTrack{} - switch { - case delivery == base.StreamDeliveryMulticast: + switch transport { + case ClientTransportUDP: + sst.udpRTPPort = inTH.ClientPorts[0] + sst.udpRTCPPort = inTH.ClientPorts[1] + + th.Protocol = base.StreamProtocolUDP + de := base.StreamDeliveryUnicast + th.Delivery = &de + th.ClientPorts = inTH.ClientPorts + th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()} + + case ClientTransportUDPMulticast: th.Protocol = base.StreamProtocolUDP de := base.StreamDeliveryMulticast th.Delivery = &de @@ -693,16 +715,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base stream.multicastListeners[trackID].rtcpListener.port(), } - case inTH.Protocol == base.StreamProtocolUDP: - sst.udpRTPPort = inTH.ClientPorts[0] - sst.udpRTCPPort = inTH.ClientPorts[1] - - th.Protocol = base.StreamProtocolUDP - de := base.StreamDeliveryUnicast - th.Delivery = &de - th.ClientPorts = inTH.ClientPorts - th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()} - default: // TCP sst.tcpChannel = inTH.InterleavedIDs[0] @@ -790,7 +802,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if res.StatusCode == base.StatusOK { ss.state = ServerSessionStateRead - if *ss.setuppedProtocol == base.StreamProtocolTCP { + if *ss.setuppedTransport == ClientTransportTCP { ss.tcpConn = sc } @@ -833,22 +845,26 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.setuppedStream.readerSetActive(ss) - if *ss.setuppedProtocol == base.StreamProtocolUDP { - if *ss.setuppedDelivery == base.StreamDeliveryUnicast { - for trackID, track := range ss.setuppedTracks { - // readers can send RTCP packets - sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false) + switch *ss.setuppedTransport { + case ClientTransportUDP: + for trackID, track := range ss.setuppedTracks { + // readers can send RTCP packets + sc.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, false) - // open the firewall by sending packets to the counterpart - ss.WriteFrame(trackID, StreamTypeRTCP, - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) - } + // open the firewall by sending packets to the counterpart + ss.WriteFrame(trackID, StreamTypeRTCP, + []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) } return res, err + + case ClientTransportUDPMulticast: + + default: // TCP + err = liberrors.ErrServerTCPFramesEnable{} } - return res, liberrors.ErrServerTCPFramesEnable{} + return res, err } } @@ -883,7 +899,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base path, query := base.PathSplitQuery(pathAndQuery) // allow to use WriteFrame() before response - if *ss.setuppedProtocol == base.StreamProtocolTCP { + if *ss.setuppedTransport == ClientTransportTCP { ss.tcpConn = sc } @@ -904,7 +920,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if res.StatusCode == base.StatusOK { ss.state = ServerSessionStatePublish - if *ss.setuppedProtocol == base.StreamProtocolUDP { + switch *ss.setuppedTransport { + case ClientTransportUDP: for trackID, track := range ss.setuppedTracks { ss.s.udpRTPListener.addClient(ss.ip(), track.udpRTPPort, ss, trackID, true) ss.s.udpRTCPListener.addClient(ss.ip(), track.udpRTCPPort, ss, trackID, true) @@ -916,10 +933,13 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) } - return res, err + case ClientTransportUDPMulticast: + + default: // TCP + err = liberrors.ErrServerTCPFramesEnable{} } - return res, liberrors.ErrServerTCPFramesEnable{} + return res, err } ss.tcpConn = nil @@ -967,23 +987,29 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStatePreRead ss.tcpConn = nil - if *ss.setuppedProtocol == base.StreamProtocolUDP { - if *ss.setuppedDelivery == base.StreamDeliveryUnicast { - ss.s.udpRTCPListener.removeClient(ss) - } - } else { - return res, liberrors.ErrServerTCPFramesDisable{} + switch *ss.setuppedTransport { + case ClientTransportUDP: + ss.s.udpRTCPListener.removeClient(ss) + + case ClientTransportUDPMulticast: + + default: // TCP + err = liberrors.ErrServerTCPFramesDisable{} } case ServerSessionStatePublish: ss.state = ServerSessionStatePrePublish ss.tcpConn = nil - if *ss.setuppedProtocol == base.StreamProtocolUDP { + switch *ss.setuppedTransport { + case ClientTransportUDP: ss.s.udpRTPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss) - } else { - return res, liberrors.ErrServerTCPFramesDisable{} + + case ClientTransportUDPMulticast: + + default: // TCP + err = liberrors.ErrServerTCPFramesDisable{} } } } @@ -1037,25 +1063,25 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload return } - if *ss.setuppedProtocol == base.StreamProtocolUDP { - if *ss.setuppedDelivery == base.StreamDeliveryUnicast { - track := ss.setuppedTracks[trackID] + switch *ss.setuppedTransport { + case ClientTransportUDP: + track := ss.setuppedTracks[trackID] - if streamType == StreamTypeRTP { - ss.s.udpRTPListener.write(payload, &net.UDPAddr{ - IP: ss.ip(), - Zone: ss.zone(), - Port: track.udpRTPPort, - }) - } else { - ss.s.udpRTCPListener.write(payload, &net.UDPAddr{ - IP: ss.ip(), - Zone: ss.zone(), - Port: track.udpRTCPPort, - }) - } + if streamType == StreamTypeRTP { + ss.s.udpRTPListener.write(payload, &net.UDPAddr{ + IP: ss.ip(), + Zone: ss.zone(), + Port: track.udpRTPPort, + }) + } else { + ss.s.udpRTCPListener.write(payload, &net.UDPAddr{ + IP: ss.ip(), + Zone: ss.zone(), + Port: track.udpRTCPPort, + }) } - } else { + + case ClientTransportTCP: channel := ss.setuppedTracks[trackID].tcpChannel if streamType == base.StreamTypeRTCP { channel++ diff --git a/serverstream.go b/serverstream.go index b5a5f699..42264ef9 100644 --- a/serverstream.go +++ b/serverstream.go @@ -7,7 +7,6 @@ import ( "sync/atomic" "time" - "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/liberrors" ) @@ -114,8 +113,7 @@ func (st *ServerStream) lastSequenceNumber(trackID int) uint16 { func (st *ServerStream) readerAdd( ss *ServerSession, - protocol base.StreamProtocol, - delivery base.StreamDelivery, + transport ClientTransport, clientPorts *[2]int, ) error { st.mutex.Lock() @@ -129,12 +127,11 @@ func (st *ServerStream) readerAdd( } } - // if new reader is a UDP-unicast reader, check that its port are not already - // in use by another reader. - if protocol == base.StreamProtocolUDP && delivery == base.StreamDeliveryUnicast { + switch transport { + case ClientTransportUDP: + // check whether client ports are already in use by another reader. for r := range st.readersUnicast { - if *r.setuppedProtocol == base.StreamProtocolUDP && - *r.setuppedDelivery == base.StreamDeliveryUnicast && + if *r.setuppedTransport == ClientTransportUDP && r.ip().Equal(ss.ip()) && r.zone() == ss.zone() { for _, rt := range r.setuppedTracks { @@ -144,30 +141,29 @@ func (st *ServerStream) readerAdd( } } } - } - // allocate multicast listeners - if protocol == base.StreamProtocolUDP && - delivery == base.StreamDeliveryMulticast && - st.multicastListeners == nil { - st.multicastListeners = make([]*listenerPair, len(st.tracks)) + case ClientTransportUDPMulticast: + // allocate multicast listeners + if st.multicastListeners == nil { + st.multicastListeners = make([]*listenerPair, len(st.tracks)) - for i := range st.tracks { - rtpListener, rtcpListener, err := newServerUDPListenerMulticastPair(st.s) - if err != nil { - for _, l := range st.multicastListeners { - if l != nil { - l.rtpListener.close() - l.rtcpListener.close() + for i := range st.tracks { + rtpListener, rtcpListener, err := newServerUDPListenerMulticastPair(st.s) + if err != nil { + for _, l := range st.multicastListeners { + if l != nil { + l.rtpListener.close() + l.rtcpListener.close() + } } + st.multicastListeners = nil + return err } - st.multicastListeners = nil - return err - } - st.multicastListeners[i] = &listenerPair{ - rtpListener: rtpListener, - rtcpListener: rtcpListener, + st.multicastListeners[i] = &listenerPair{ + rtpListener: rtpListener, + rtcpListener: rtcpListener, + } } } } @@ -196,9 +192,11 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) { st.mutex.Lock() defer st.mutex.Unlock() - if *ss.setuppedDelivery == base.StreamDeliveryUnicast { + switch *ss.setuppedTransport { + case ClientTransportUDP, ClientTransportTCP: st.readersUnicast[ss] = struct{}{} - } else { + + default: // UDPMulticast for trackID := range ss.setuppedTracks { st.multicastListeners[trackID].rtcpListener.addClient( ss.ip(), st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false) @@ -210,11 +208,15 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) { st.mutex.Lock() defer st.mutex.Unlock() - if *ss.setuppedDelivery == base.StreamDeliveryUnicast { + switch *ss.setuppedTransport { + case ClientTransportUDP, ClientTransportTCP: delete(st.readersUnicast, ss) - } else if st.multicastListeners != nil { - for trackID := range ss.setuppedTracks { - st.multicastListeners[trackID].rtcpListener.removeClient(ss) + + default: // UDPMulticast + if st.multicastListeners != nil { + for trackID := range ss.setuppedTracks { + st.multicastListeners[trackID].rtcpListener.removeClient(ss) + } } } } @@ -248,13 +250,11 @@ func (st *ServerStream) WriteFrame(trackID int, streamType StreamType, payload [ if streamType == StreamTypeRTP { st.multicastListeners[trackID].rtpListener.write(payload, &net.UDPAddr{ IP: st.multicastListeners[trackID].rtpListener.ip(), - Zone: "", Port: st.multicastListeners[trackID].rtpListener.port(), }) } else { st.multicastListeners[trackID].rtcpListener.write(payload, &net.UDPAddr{ IP: st.multicastListeners[trackID].rtpListener.ip(), - Zone: "", Port: st.multicastListeners[trackID].rtcpListener.port(), }) }