mirror of
https://github.com/aler9/gortsplib
synced 2025-10-08 16:40:09 +08:00
rewrite transport header
This commit is contained in:
@@ -357,12 +357,12 @@ func (c *ConnClient) urlForTrack(baseUrl *url.URL, track *Track) *url.URL {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnClient) setup(u *url.URL, track *Track, transport []string) (*Response, error) {
|
func (c *ConnClient) setup(u *url.URL, track *Track, ht *HeaderTransport) (*Response, error) {
|
||||||
res, err := c.Do(&Request{
|
res, err := c.Do(&Request{
|
||||||
Method: SETUP,
|
Method: SETUP,
|
||||||
Url: c.urlForTrack(u, track),
|
Url: c.urlForTrack(u, track),
|
||||||
Header: Header{
|
Header: Header{
|
||||||
"Transport": HeaderValue{strings.Join(transport, ";")},
|
"Transport": ht.Write(),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -406,10 +406,13 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := c.setup(u, track, []string{
|
res, err := c.setup(u, track, &HeaderTransport{
|
||||||
"RTP/AVP/UDP",
|
Protocol: StreamProtocolUDP,
|
||||||
"unicast",
|
Cast: func() *StreamCast {
|
||||||
fmt.Sprintf("client_port=%d-%d", rtpPort, rtcpPort),
|
ret := StreamUnicast
|
||||||
|
return &ret
|
||||||
|
}(),
|
||||||
|
ClientPorts: &[2]int{rtpPort, rtcpPort},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rtpListener.close()
|
rtpListener.close()
|
||||||
@@ -424,8 +427,7 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
|
|||||||
return nil, nil, nil, fmt.Errorf("SETUP: transport header: %s", err)
|
return nil, nil, nil, fmt.Errorf("SETUP: transport header: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rtpServerPort, rtcpServerPort := th.Ports("server_port")
|
if th.ServerPorts == nil {
|
||||||
if rtpServerPort == 0 {
|
|
||||||
rtpListener.close()
|
rtpListener.close()
|
||||||
rtcpListener.close()
|
rtcpListener.close()
|
||||||
return nil, nil, nil, fmt.Errorf("SETUP: server ports not provided")
|
return nil, nil, nil, fmt.Errorf("SETUP: server ports not provided")
|
||||||
@@ -437,11 +439,11 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
|
|||||||
c.rtcpReceivers[track.Id] = NewRtcpReceiver()
|
c.rtcpReceivers[track.Id] = NewRtcpReceiver()
|
||||||
|
|
||||||
rtpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP
|
rtpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP
|
||||||
rtpListener.publisherPort = rtpServerPort
|
rtpListener.publisherPort = (*th.ServerPorts)[0]
|
||||||
c.rtpListeners[track.Id] = rtpListener
|
c.rtpListeners[track.Id] = rtpListener
|
||||||
|
|
||||||
rtcpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP
|
rtcpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP
|
||||||
rtcpListener.publisherPort = rtcpServerPort
|
rtcpListener.publisherPort = (*th.ServerPorts)[1]
|
||||||
c.rtcpListeners[track.Id] = rtcpListener
|
c.rtcpListeners[track.Id] = rtcpListener
|
||||||
|
|
||||||
return rtpListener.Read, rtcpListener.Read, res, nil
|
return rtpListener.Read, rtcpListener.Read, res, nil
|
||||||
@@ -462,11 +464,14 @@ func (c *ConnClient) SetupTCP(u *url.URL, track *Track) (*Response, error) {
|
|||||||
return nil, fmt.Errorf("track has already been setup")
|
return nil, fmt.Errorf("track has already been setup")
|
||||||
}
|
}
|
||||||
|
|
||||||
interleaved := fmt.Sprintf("interleaved=%d-%d", (track.Id * 2), (track.Id*2)+1)
|
interleavedIds := &[2]int{(track.Id * 2), (track.Id * 2) + 1}
|
||||||
res, err := c.setup(u, track, []string{
|
res, err := c.setup(u, track, &HeaderTransport{
|
||||||
"RTP/AVP/TCP",
|
Protocol: StreamProtocolTCP,
|
||||||
"unicast",
|
Cast: func() *StreamCast {
|
||||||
interleaved,
|
ret := StreamUnicast
|
||||||
|
return &ret
|
||||||
|
}(),
|
||||||
|
InterleavedIds: interleavedIds,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -477,10 +482,10 @@ func (c *ConnClient) SetupTCP(u *url.URL, track *Track) (*Response, error) {
|
|||||||
return nil, fmt.Errorf("SETUP: transport header: %s", err)
|
return nil, fmt.Errorf("SETUP: transport header: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := th[interleaved]
|
if th.InterleavedIds == nil || (*th.InterleavedIds)[0] != (*interleavedIds)[0] ||
|
||||||
if !ok {
|
(*th.InterleavedIds)[1] != (*interleavedIds)[1] {
|
||||||
return nil, fmt.Errorf("SETUP: transport header does not contain '%s' (%s)",
|
return nil, fmt.Errorf("SETUP: transport header does not have interleaved ids %v (%s)",
|
||||||
interleaved, res.Header["Transport"])
|
*interleavedIds, res.Header["Transport"])
|
||||||
}
|
}
|
||||||
|
|
||||||
c.streamUrl = u
|
c.streamUrl = u
|
||||||
|
@@ -8,9 +8,9 @@ import (
|
|||||||
|
|
||||||
var casesHeaderAuth = []struct {
|
var casesHeaderAuth = []struct {
|
||||||
name string
|
name string
|
||||||
dec HeaderValue
|
vin HeaderValue
|
||||||
enc HeaderValue
|
vout HeaderValue
|
||||||
ha *HeaderAuth
|
h *HeaderAuth
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"basic",
|
"basic",
|
||||||
@@ -112,9 +112,9 @@ var casesHeaderAuth = []struct {
|
|||||||
func TestHeaderAuthRead(t *testing.T) {
|
func TestHeaderAuthRead(t *testing.T) {
|
||||||
for _, c := range casesHeaderAuth {
|
for _, c := range casesHeaderAuth {
|
||||||
t.Run(c.name, func(t *testing.T) {
|
t.Run(c.name, func(t *testing.T) {
|
||||||
req, err := ReadHeaderAuth(c.dec)
|
req, err := ReadHeaderAuth(c.vin)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, c.ha, req)
|
require.Equal(t, c.h, req)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -122,8 +122,8 @@ func TestHeaderAuthRead(t *testing.T) {
|
|||||||
func TestHeaderAuthWrite(t *testing.T) {
|
func TestHeaderAuthWrite(t *testing.T) {
|
||||||
for _, c := range casesHeaderAuth {
|
for _, c := range casesHeaderAuth {
|
||||||
t.Run(c.name, func(t *testing.T) {
|
t.Run(c.name, func(t *testing.T) {
|
||||||
req := c.ha.Write()
|
req := c.h.Write()
|
||||||
require.Equal(t, c.enc, req)
|
require.Equal(t, c.vout, req)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -7,9 +7,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var casesHeaderSession = []struct {
|
var casesHeaderSession = []struct {
|
||||||
name string
|
name string
|
||||||
value HeaderValue
|
v HeaderValue
|
||||||
hs *HeaderSession
|
h *HeaderSession
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"base",
|
"base",
|
||||||
@@ -45,9 +45,9 @@ var casesHeaderSession = []struct {
|
|||||||
func TestHeaderSession(t *testing.T) {
|
func TestHeaderSession(t *testing.T) {
|
||||||
for _, c := range casesHeaderSession {
|
for _, c := range casesHeaderSession {
|
||||||
t.Run(c.name, func(t *testing.T) {
|
t.Run(c.name, func(t *testing.T) {
|
||||||
req, err := ReadHeaderSession(c.value)
|
req, err := ReadHeaderSession(c.v)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, c.hs, req)
|
require.Equal(t, c.h, req)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -7,10 +7,47 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// HeaderTransport is a Transport header.
|
// HeaderTransport is a Transport header.
|
||||||
type HeaderTransport map[string]struct{}
|
type HeaderTransport struct {
|
||||||
|
// protocol of the stream
|
||||||
|
Protocol StreamProtocol
|
||||||
|
|
||||||
|
// cast of the stream
|
||||||
|
Cast *StreamCast
|
||||||
|
|
||||||
|
// client ports
|
||||||
|
ClientPorts *[2]int
|
||||||
|
|
||||||
|
// server ports
|
||||||
|
ServerPorts *[2]int
|
||||||
|
|
||||||
|
// interleaved frame ids
|
||||||
|
InterleavedIds *[2]int
|
||||||
|
|
||||||
|
// mode
|
||||||
|
Mode *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePorts(val string) (*[2]int, error) {
|
||||||
|
ports := strings.Split(val, "-")
|
||||||
|
if len(ports) != 2 {
|
||||||
|
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
port1, err := strconv.ParseInt(ports[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
port2, err := strconv.ParseInt(ports[1], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &[2]int{int(port1), int(port2)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ReadHeaderTransport parses a Transport header.
|
// ReadHeaderTransport parses a Transport header.
|
||||||
func ReadHeaderTransport(v HeaderValue) (HeaderTransport, error) {
|
func ReadHeaderTransport(v HeaderValue) (*HeaderTransport, error) {
|
||||||
if len(v) == 0 {
|
if len(v) == 0 {
|
||||||
return nil, fmt.Errorf("value not provided")
|
return nil, fmt.Errorf("value not provided")
|
||||||
}
|
}
|
||||||
@@ -19,63 +56,99 @@ func ReadHeaderTransport(v HeaderValue) (HeaderTransport, error) {
|
|||||||
return nil, fmt.Errorf("value provided multiple times (%v)", v)
|
return nil, fmt.Errorf("value provided multiple times (%v)", v)
|
||||||
}
|
}
|
||||||
|
|
||||||
ht := make(map[string]struct{})
|
ht := &HeaderTransport{}
|
||||||
|
protoSet := false
|
||||||
|
|
||||||
for _, t := range strings.Split(v[0], ";") {
|
for _, t := range strings.Split(v[0], ";") {
|
||||||
ht[t] = struct{}{}
|
if t == "RTP/AVP" || t == "RTP/AVP/UDP" {
|
||||||
|
ht.Protocol = StreamProtocolUDP
|
||||||
|
protoSet = true
|
||||||
|
|
||||||
|
} else if t == "RTP/AVP/TCP" {
|
||||||
|
ht.Protocol = StreamProtocolTCP
|
||||||
|
protoSet = true
|
||||||
|
|
||||||
|
} else if t == "unicast" {
|
||||||
|
ret := StreamUnicast
|
||||||
|
ht.Cast = &ret
|
||||||
|
|
||||||
|
} else if t == "multicast" {
|
||||||
|
ret := StreamMulticast
|
||||||
|
ht.Cast = &ret
|
||||||
|
|
||||||
|
} else if strings.HasPrefix(t, "client_port=") {
|
||||||
|
ports, err := parsePorts(t[len("client_port="):])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ht.ClientPorts = ports
|
||||||
|
|
||||||
|
} else if strings.HasPrefix(t, "server_port=") {
|
||||||
|
ports, err := parsePorts(t[len("server_port="):])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ht.ServerPorts = ports
|
||||||
|
|
||||||
|
} else if strings.HasPrefix(t, "interleaved=") {
|
||||||
|
ports, err := parsePorts(t[len("interleaved="):])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ht.InterleavedIds = ports
|
||||||
|
|
||||||
|
} else if strings.HasPrefix(t, "mode=") {
|
||||||
|
ret := strings.ToLower(t[len("mode="):])
|
||||||
|
ret = strings.TrimPrefix(ret, "\"")
|
||||||
|
ret = strings.TrimSuffix(ret, "\"")
|
||||||
|
ht.Mode = &ret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// protocol is the only mandatory field
|
||||||
|
if !protoSet {
|
||||||
|
return nil, fmt.Errorf("protocol not set (%v)", v)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ht, nil
|
return ht, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUDP check whether the header contains the UDP protocol.
|
// Write encodes a Transport header
|
||||||
func (ht HeaderTransport) IsUDP() bool {
|
func (ht *HeaderTransport) Write() HeaderValue {
|
||||||
if _, ok := ht["RTP/AVP"]; ok {
|
var vals []string
|
||||||
return true
|
|
||||||
}
|
|
||||||
if _, ok := ht["RTP/AVP/UDP"]; ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsTCP check whether the header contains the TCP protocol.
|
if ht.Protocol == StreamProtocolUDP {
|
||||||
func (ht HeaderTransport) IsTCP() bool {
|
vals = append(vals, "RTP/AVP")
|
||||||
_, ok := ht["RTP/AVP/TCP"]
|
} else {
|
||||||
return ok
|
vals = append(vals, "RTP/AVP/TCP")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value gets a value from the header.
|
if ht.Cast != nil {
|
||||||
func (ht HeaderTransport) Value(key string) string {
|
if *ht.Cast == StreamUnicast {
|
||||||
prefix := key + "="
|
vals = append(vals, "unicast")
|
||||||
for t := range ht {
|
} else {
|
||||||
if strings.HasPrefix(t, prefix) {
|
vals = append(vals, "multicast")
|
||||||
return t[len(prefix):]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
|
||||||
}
|
if ht.ClientPorts != nil {
|
||||||
|
ports := *ht.ClientPorts
|
||||||
// Ports gets a value from the header and parses its ports.
|
vals = append(vals, "client_port="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10))
|
||||||
func (ht HeaderTransport) Ports(key string) (int, int) {
|
}
|
||||||
val := ht.Value(key)
|
|
||||||
if val == "" {
|
if ht.ServerPorts != nil {
|
||||||
return 0, 0
|
ports := *ht.ServerPorts
|
||||||
}
|
vals = append(vals, "server_port="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10))
|
||||||
|
}
|
||||||
ports := strings.Split(val, "-")
|
|
||||||
if len(ports) != 2 {
|
if ht.InterleavedIds != nil {
|
||||||
return 0, 0
|
ports := *ht.InterleavedIds
|
||||||
}
|
vals = append(vals, "interleaved="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10))
|
||||||
|
}
|
||||||
port1, err := strconv.ParseInt(ports[0], 10, 64)
|
|
||||||
if err != nil {
|
if ht.Mode != nil {
|
||||||
return 0, 0
|
vals = append(vals, "mode="+*ht.Mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
port2, err := strconv.ParseInt(ports[1], 10, 64)
|
return HeaderValue{strings.Join(vals, ";")}
|
||||||
if err != nil {
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(port1), int(port2)
|
|
||||||
}
|
}
|
||||||
|
86
header-transport_test.go
Normal file
86
header-transport_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package gortsplib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var casesHeaderTransport = []struct {
|
||||||
|
name string
|
||||||
|
vin HeaderValue
|
||||||
|
vout HeaderValue
|
||||||
|
h *HeaderTransport
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"udp unicast play request",
|
||||||
|
HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY"`},
|
||||||
|
HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play`},
|
||||||
|
&HeaderTransport{
|
||||||
|
Protocol: StreamProtocolUDP,
|
||||||
|
Cast: func() *StreamCast {
|
||||||
|
ret := StreamUnicast
|
||||||
|
return &ret
|
||||||
|
}(),
|
||||||
|
ClientPorts: &[2]int{3456, 3457},
|
||||||
|
Mode: func() *string {
|
||||||
|
ret := "play"
|
||||||
|
return &ret
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp unicast play response",
|
||||||
|
HeaderValue{`RTP/AVP/UDP;unicast;client_port=3056-3057;server_port=5000-5001`},
|
||||||
|
HeaderValue{`RTP/AVP;unicast;client_port=3056-3057;server_port=5000-5001`},
|
||||||
|
&HeaderTransport{
|
||||||
|
Protocol: StreamProtocolUDP,
|
||||||
|
Cast: func() *StreamCast {
|
||||||
|
ret := StreamUnicast
|
||||||
|
return &ret
|
||||||
|
}(),
|
||||||
|
ClientPorts: &[2]int{3056, 3057},
|
||||||
|
ServerPorts: &[2]int{5000, 5001},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp multicast play request / response",
|
||||||
|
HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`},
|
||||||
|
HeaderValue{`RTP/AVP;multicast`},
|
||||||
|
&HeaderTransport{
|
||||||
|
Protocol: StreamProtocolUDP,
|
||||||
|
Cast: func() *StreamCast {
|
||||||
|
ret := StreamMulticast
|
||||||
|
return &ret
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tcp play request / response",
|
||||||
|
HeaderValue{`RTP/AVP/TCP;interleaved=0-1`},
|
||||||
|
HeaderValue{`RTP/AVP/TCP;interleaved=0-1`},
|
||||||
|
&HeaderTransport{
|
||||||
|
Protocol: StreamProtocolTCP,
|
||||||
|
InterleavedIds: &[2]int{0, 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderTransportRead(t *testing.T) {
|
||||||
|
for _, c := range casesHeaderTransport {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
req, err := ReadHeaderTransport(c.vin)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, c.h, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderTransportWrite(t *testing.T) {
|
||||||
|
for _, c := range casesHeaderTransport {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
req := c.h.Write()
|
||||||
|
require.Equal(t, c.vout, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
21
utils.go
21
utils.go
@@ -11,7 +11,7 @@ const (
|
|||||||
rtspMaxContentLength = 4096
|
rtspMaxContentLength = 4096
|
||||||
)
|
)
|
||||||
|
|
||||||
// StreamProtocol is the protocol of a stream
|
// StreamProtocol is the protocol of a stream.
|
||||||
type StreamProtocol int
|
type StreamProtocol int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -30,6 +30,25 @@ func (sp StreamProtocol) String() string {
|
|||||||
return "tcp"
|
return "tcp"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StreamCast is the cast of a stream.
|
||||||
|
type StreamCast int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Unicast means that the stream will be unicasted
|
||||||
|
StreamUnicast StreamCast = iota
|
||||||
|
|
||||||
|
// Multicast means that the stream will be multicasted
|
||||||
|
StreamMulticast
|
||||||
|
)
|
||||||
|
|
||||||
|
// String implements fmt.Stringer
|
||||||
|
func (sc StreamCast) String() string {
|
||||||
|
if sc == StreamUnicast {
|
||||||
|
return "unicast"
|
||||||
|
}
|
||||||
|
return "multicast"
|
||||||
|
}
|
||||||
|
|
||||||
// StreamType is the type of a stream.
|
// StreamType is the type of a stream.
|
||||||
type StreamType int
|
type StreamType int
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user