headers: rewrite initializers as members of their structs

This commit is contained in:
aler9
2021-03-20 09:55:04 +01:00
parent 8936db52e4
commit 25772271db
16 changed files with 117 additions and 110 deletions

View File

@@ -266,7 +266,8 @@ func (c *ClientConn) Do(req *base.Request) (*base.Response, error) {
// get session from response // get session from response
if v, ok := res.Header["Session"]; ok { if v, ok := res.Header["Session"]; ok {
sx, err := headers.ReadSession(v) var sx headers.Session
err := sx.Read(v)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse session header: %s", err) return nil, fmt.Errorf("unable to parse session header: %s", err)
} }
@@ -568,7 +569,8 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track,
return res, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) return res, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage)
} }
thRes, err := headers.ReadTransport(res.Header["Transport"]) var thRes headers.Transport
err = thRes.Read(res.Header["Transport"])
if err != nil { if err != nil {
if proto == StreamProtocolUDP { if proto == StreamProtocolUDP {
rtpListener.close() rtpListener.close()

View File

@@ -32,8 +32,7 @@ func (c *ClientConn) Play() (*base.Response, error) {
} }
if v, ok := res.Header["RTP-Info"]; ok { if v, ok := res.Header["RTP-Info"]; ok {
var err error err := c.rtpInfo.Read(v)
c.rtpInfo, err = headers.ReadRTPInfo(v)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse RTP-Info: %v", err) return nil, fmt.Errorf("unable to parse RTP-Info: %v", err)
} }

View File

@@ -102,7 +102,8 @@ func TestClientConnRead(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
th, err := headers.ReadTransport(req.Header["Transport"]) var th headers.Transport
err = th.Read(req.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
if ca.proto == "udp" { if ca.proto == "udp" {
@@ -259,7 +260,8 @@ func TestClientConnReadNoServerPorts(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
th, err := headers.ReadTransport(req.Header["Transport"]) var th headers.Transport
err = th.Read(req.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
err = base.Response{ err = base.Response{
@@ -517,7 +519,8 @@ func TestClientConnReadRedirect(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
th, err := headers.ReadTransport(req.Header["Transport"]) var th headers.Transport
err = th.Read(req.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
err = base.Response{ err = base.Response{
@@ -669,7 +672,8 @@ func TestClientConnReadPause(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
inTH, err := headers.ReadTransport(req.Header["Transport"]) var inTH headers.Transport
err = inTH.Read(req.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
th := headers.Transport{ th := headers.Transport{
@@ -706,7 +710,7 @@ func TestClientConnReadPause(t *testing.T) {
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
writerTerminate, writerDone := writeFrames(inTH, bconn) writerTerminate, writerDone := writeFrames(&inTH, bconn)
err = req.Read(bconn.Reader) err = req.Read(bconn.Reader)
require.NoError(t, err) require.NoError(t, err)
@@ -729,7 +733,7 @@ func TestClientConnReadPause(t *testing.T) {
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
writerTerminate, writerDone = writeFrames(inTH, bconn) writerTerminate, writerDone = writeFrames(&inTH, bconn)
err = req.Read(bconn.Reader) err = req.Read(bconn.Reader)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -30,7 +30,8 @@ func NewSender(v base.HeaderValue, user string, pass string) (*Sender, error) {
} }
return "" return ""
}(); headerAuthDigest != "" { }(); headerAuthDigest != "" {
auth, err := headers.ReadAuth(base.HeaderValue{headerAuthDigest}) var auth headers.Auth
err := auth.Read(base.HeaderValue{headerAuthDigest})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -60,7 +61,8 @@ func NewSender(v base.HeaderValue, user string, pass string) (*Sender, error) {
} }
return "" return ""
}(); headerAuthBasic != "" { }(); headerAuthBasic != "" {
auth, err := headers.ReadAuth(base.HeaderValue{headerAuthBasic}) var auth headers.Auth
err := auth.Read(base.HeaderValue{headerAuthBasic})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -135,7 +135,8 @@ func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *
} }
} else if strings.HasPrefix(v0, "Digest ") { } else if strings.HasPrefix(v0, "Digest ") {
auth, err := headers.ReadAuth(base.HeaderValue{v0}) var auth headers.Auth
err := auth.Read(base.HeaderValue{v0})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -79,23 +79,21 @@ func findValue(v0 string) (string, string, error) {
} }
} }
// ReadAuth decodes an Authenticate or a WWW-Authenticate header. // Read decodes an Authenticate or a WWW-Authenticate header.
func ReadAuth(v base.HeaderValue) (*Auth, error) { func (h *Auth) Read(v base.HeaderValue) error {
if len(v) == 0 { if len(v) == 0 {
return nil, fmt.Errorf("value not provided") return fmt.Errorf("value not provided")
} }
if len(v) > 1 { if len(v) > 1 {
return nil, fmt.Errorf("value provided multiple times (%v)", v) return fmt.Errorf("value provided multiple times (%v)", v)
} }
h := &Auth{}
v0 := v[0] v0 := v[0]
i := strings.IndexByte(v0, ' ') i := strings.IndexByte(v0, ' ')
if i < 0 { if i < 0 {
return nil, fmt.Errorf("unable to find method (%s)", v0) return fmt.Errorf("unable to find method (%s)", v0)
} }
switch v0[:i] { switch v0[:i] {
@@ -106,14 +104,14 @@ func ReadAuth(v base.HeaderValue) (*Auth, error) {
h.Method = AuthDigest h.Method = AuthDigest
default: default:
return nil, fmt.Errorf("invalid method (%s)", v0[:i]) return fmt.Errorf("invalid method (%s)", v0[:i])
} }
v0 = v0[i+1:] v0 = v0[i+1:]
for len(v0) > 0 { for len(v0) > 0 {
i := strings.IndexByte(v0, '=') i := strings.IndexByte(v0, '=')
if i < 0 { if i < 0 {
return nil, fmt.Errorf("unable to find key (%s)", v0) return fmt.Errorf("unable to find key (%s)", v0)
} }
var key string var key string
key, v0 = v0[:i], v0[i+1:] key, v0 = v0[:i], v0[i+1:]
@@ -122,7 +120,7 @@ func ReadAuth(v base.HeaderValue) (*Auth, error) {
var err error var err error
val, v0, err = findValue(v0) val, v0, err = findValue(v0)
if err != nil { if err != nil {
return nil, err return err
} }
switch key { switch key {
@@ -164,7 +162,7 @@ func ReadAuth(v base.HeaderValue) (*Auth, error) {
} }
} }
return h, nil return nil
} }
// Write encodes an Authenticate or a WWW-Authenticate header. // Write encodes an Authenticate or a WWW-Authenticate header.

View File

@@ -12,13 +12,13 @@ var casesAuth = []struct {
name string name string
vin base.HeaderValue vin base.HeaderValue
vout base.HeaderValue vout base.HeaderValue
h *Auth h Auth
}{ }{
{ {
"basic", "basic",
base.HeaderValue{`Basic realm="4419b63f5e51"`}, base.HeaderValue{`Basic realm="4419b63f5e51"`},
base.HeaderValue{`Basic realm="4419b63f5e51"`}, base.HeaderValue{`Basic realm="4419b63f5e51"`},
&Auth{ Auth{
Method: AuthBasic, Method: AuthBasic,
Realm: func() *string { Realm: func() *string {
v := "4419b63f5e51" v := "4419b63f5e51"
@@ -30,7 +30,7 @@ var casesAuth = []struct {
"digest request 1", "digest request 1",
base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`}, base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`},
base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`}, base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`},
&Auth{ Auth{
Method: AuthDigest, Method: AuthDigest,
Realm: func() *string { Realm: func() *string {
v := "4419b63f5e51" v := "4419b63f5e51"
@@ -50,7 +50,7 @@ var casesAuth = []struct {
"digest request 2", "digest request 2",
base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale=FALSE`}, base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale=FALSE`},
base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`}, base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`},
&Auth{ Auth{
Method: AuthDigest, Method: AuthDigest,
Realm: func() *string { Realm: func() *string {
v := "4419b63f5e51" v := "4419b63f5e51"
@@ -70,7 +70,7 @@ var casesAuth = []struct {
"digest request 3", "digest request 3",
base.HeaderValue{`Digest realm="4419b63f5e51",nonce="133767111917411116111311118211673010032", stale="FALSE"`}, base.HeaderValue{`Digest realm="4419b63f5e51",nonce="133767111917411116111311118211673010032", stale="FALSE"`},
base.HeaderValue{`Digest realm="4419b63f5e51", nonce="133767111917411116111311118211673010032", stale="FALSE"`}, base.HeaderValue{`Digest realm="4419b63f5e51", nonce="133767111917411116111311118211673010032", stale="FALSE"`},
&Auth{ Auth{
Method: AuthDigest, Method: AuthDigest,
Realm: func() *string { Realm: func() *string {
v := "4419b63f5e51" v := "4419b63f5e51"
@@ -90,7 +90,7 @@ var casesAuth = []struct {
"digest response generic", "digest response generic",
base.HeaderValue{`Digest username="aa", realm="bb", nonce="cc", uri="dd", response="ee"`}, base.HeaderValue{`Digest username="aa", realm="bb", nonce="cc", uri="dd", response="ee"`},
base.HeaderValue{`Digest username="aa", realm="bb", nonce="cc", uri="dd", response="ee"`}, base.HeaderValue{`Digest username="aa", realm="bb", nonce="cc", uri="dd", response="ee"`},
&Auth{ Auth{
Method: AuthDigest, Method: AuthDigest,
Username: func() *string { Username: func() *string {
v := "aa" v := "aa"
@@ -118,7 +118,7 @@ var casesAuth = []struct {
"digest response with empty field", "digest response with empty field",
base.HeaderValue{`Digest username="", realm="IPCAM", nonce="5d17cd12b9fa8a85ac5ceef0926ea5a6", uri="rtsp://localhost:8554/mystream", response="c072ae90eb4a27f4cdcb90d62266b2a1"`}, base.HeaderValue{`Digest username="", realm="IPCAM", nonce="5d17cd12b9fa8a85ac5ceef0926ea5a6", uri="rtsp://localhost:8554/mystream", response="c072ae90eb4a27f4cdcb90d62266b2a1"`},
base.HeaderValue{`Digest username="", realm="IPCAM", nonce="5d17cd12b9fa8a85ac5ceef0926ea5a6", uri="rtsp://localhost:8554/mystream", response="c072ae90eb4a27f4cdcb90d62266b2a1"`}, base.HeaderValue{`Digest username="", realm="IPCAM", nonce="5d17cd12b9fa8a85ac5ceef0926ea5a6", uri="rtsp://localhost:8554/mystream", response="c072ae90eb4a27f4cdcb90d62266b2a1"`},
&Auth{ Auth{
Method: AuthDigest, Method: AuthDigest,
Username: func() *string { Username: func() *string {
v := "" v := ""
@@ -146,7 +146,7 @@ var casesAuth = []struct {
"digest response with no spaces and additional fields", "digest response with no spaces and additional fields",
base.HeaderValue{`Digest realm="Please log in with a valid username",nonce="752a62306daf32b401a41004555c7663",opaque="",stale=FALSE,algorithm=MD5`}, base.HeaderValue{`Digest realm="Please log in with a valid username",nonce="752a62306daf32b401a41004555c7663",opaque="",stale=FALSE,algorithm=MD5`},
base.HeaderValue{`Digest realm="Please log in with a valid username", nonce="752a62306daf32b401a41004555c7663", opaque="", stale="FALSE", algorithm="MD5"`}, base.HeaderValue{`Digest realm="Please log in with a valid username", nonce="752a62306daf32b401a41004555c7663", opaque="", stale="FALSE", algorithm="MD5"`},
&Auth{ Auth{
Method: AuthDigest, Method: AuthDigest,
Realm: func() *string { Realm: func() *string {
v := "Please log in with a valid username" v := "Please log in with a valid username"
@@ -175,9 +175,10 @@ var casesAuth = []struct {
func TestAuthRead(t *testing.T) { func TestAuthRead(t *testing.T) {
for _, c := range casesAuth { for _, c := range casesAuth {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req, err := ReadAuth(c.vin) var h Auth
err := h.Read(c.vin)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.h, req) require.Equal(t, c.h, h)
}) })
} }
} }
@@ -185,8 +186,8 @@ func TestAuthRead(t *testing.T) {
func TestAuthWrite(t *testing.T) { func TestAuthWrite(t *testing.T) {
for _, c := range casesAuth { for _, c := range casesAuth {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req := c.h.Write() vout := c.h.Write()
require.Equal(t, c.vout, req) require.Equal(t, c.vout, vout)
}) })
} }
} }

View File

@@ -18,25 +18,23 @@ type RTPInfoEntry struct {
// RTPInfo is a RTP-Info header. // RTPInfo is a RTP-Info header.
type RTPInfo []*RTPInfoEntry type RTPInfo []*RTPInfoEntry
// ReadRTPInfo decodes a RTP-Info header. // Read decodes a RTP-Info header.
func ReadRTPInfo(v base.HeaderValue) (*RTPInfo, error) { func (h *RTPInfo) Read(v base.HeaderValue) error {
if len(v) == 0 { if len(v) == 0 {
return nil, fmt.Errorf("value not provided") return fmt.Errorf("value not provided")
} }
if len(v) > 1 { if len(v) > 1 {
return nil, fmt.Errorf("value provided multiple times (%v)", v) return fmt.Errorf("value provided multiple times (%v)", v)
} }
h := &RTPInfo{}
for _, tmp := range strings.Split(v[0], ",") { for _, tmp := range strings.Split(v[0], ",") {
e := &RTPInfoEntry{} e := &RTPInfoEntry{}
for _, kv := range strings.Split(tmp, ";") { for _, kv := range strings.Split(tmp, ";") {
tmp := strings.SplitN(kv, "=", 2) tmp := strings.SplitN(kv, "=", 2)
if len(tmp) != 2 { if len(tmp) != 2 {
return nil, fmt.Errorf("unable to parse key-value (%v)", kv) return fmt.Errorf("unable to parse key-value (%v)", kv)
} }
k, v := tmp[0], tmp[1] k, v := tmp[0], tmp[1]
@@ -44,33 +42,33 @@ func ReadRTPInfo(v base.HeaderValue) (*RTPInfo, error) {
case "url": case "url":
vu, err := base.ParseURL(v) vu, err := base.ParseURL(v)
if err != nil { if err != nil {
return nil, err return err
} }
e.URL = vu e.URL = vu
case "seq": case "seq":
vi, err := strconv.ParseUint(v, 10, 16) vi, err := strconv.ParseUint(v, 10, 16)
if err != nil { if err != nil {
return nil, err return err
} }
e.SequenceNumber = uint16(vi) e.SequenceNumber = uint16(vi)
case "rtptime": case "rtptime":
vi, err := strconv.ParseUint(v, 10, 32) vi, err := strconv.ParseUint(v, 10, 32)
if err != nil { if err != nil {
return nil, err return err
} }
e.Timestamp = uint32(vi) e.Timestamp = uint32(vi)
default: default:
return nil, fmt.Errorf("invalid key: %v", k) return fmt.Errorf("invalid key: %v", k)
} }
} }
*h = append(*h, e) *h = append(*h, e)
} }
return h, nil return nil
} }
// Clone clones a RTPInfo. // Clone clones a RTPInfo.

View File

@@ -12,13 +12,13 @@ var casesRTPInfo = []struct {
name string name string
vin base.HeaderValue vin base.HeaderValue
vout base.HeaderValue vout base.HeaderValue
h *RTPInfo h RTPInfo
}{ }{
{ {
"single value", "single value",
base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556`}, base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556`},
base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556`}, base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556`},
&RTPInfo{ RTPInfo{
{ {
URL: base.MustParseURL("rtsp://127.0.0.1/test.mkv/track1"), URL: base.MustParseURL("rtsp://127.0.0.1/test.mkv/track1"),
SequenceNumber: 35243, SequenceNumber: 35243,
@@ -30,7 +30,7 @@ var casesRTPInfo = []struct {
"multiple value", "multiple value",
base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556,url=rtsp://127.0.0.1/test.mkv/track2;seq=13655;rtptime=2848846950`}, base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556,url=rtsp://127.0.0.1/test.mkv/track2;seq=13655;rtptime=2848846950`},
base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556,url=rtsp://127.0.0.1/test.mkv/track2;seq=13655;rtptime=2848846950`}, base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556,url=rtsp://127.0.0.1/test.mkv/track2;seq=13655;rtptime=2848846950`},
&RTPInfo{ RTPInfo{
{ {
URL: base.MustParseURL("rtsp://127.0.0.1/test.mkv/track1"), URL: base.MustParseURL("rtsp://127.0.0.1/test.mkv/track1"),
SequenceNumber: 35243, SequenceNumber: 35243,
@@ -48,9 +48,10 @@ var casesRTPInfo = []struct {
func TestRTPInfoRead(t *testing.T) { func TestRTPInfoRead(t *testing.T) {
for _, c := range casesRTPInfo { for _, c := range casesRTPInfo {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req, err := ReadRTPInfo(c.vin) var h RTPInfo
err := h.Read(c.vin)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.h, req) require.Equal(t, c.h, h)
}) })
} }
} }

View File

@@ -17,23 +17,21 @@ type Session struct {
Timeout *uint Timeout *uint
} }
// ReadSession decodes a Session header. // Read decodes a Session header.
func ReadSession(v base.HeaderValue) (*Session, error) { func (h *Session) Read(v base.HeaderValue) error {
if len(v) == 0 { if len(v) == 0 {
return nil, fmt.Errorf("value not provided") return fmt.Errorf("value not provided")
} }
if len(v) > 1 { if len(v) > 1 {
return nil, fmt.Errorf("value provided multiple times (%v)", v) return fmt.Errorf("value provided multiple times (%v)", v)
} }
parts := strings.Split(v[0], ";") parts := strings.Split(v[0], ";")
if len(parts) == 0 { if len(parts) == 0 {
return nil, fmt.Errorf("invalid value (%v)", v) return fmt.Errorf("invalid value (%v)", v)
} }
h := &Session{}
h.Session = parts[0] h.Session = parts[0]
for _, part := range parts[1:] { for _, part := range parts[1:] {
@@ -42,24 +40,24 @@ func ReadSession(v base.HeaderValue) (*Session, error) {
kv := strings.Split(part, "=") kv := strings.Split(part, "=")
if len(kv) != 2 { if len(kv) != 2 {
return nil, fmt.Errorf("invalid value") return fmt.Errorf("invalid value")
} }
key, strValue := kv[0], kv[1] key, strValue := kv[0], kv[1]
if key != "timeout" { if key != "timeout" {
return nil, fmt.Errorf("invalid key '%s'", key) return fmt.Errorf("invalid key '%s'", key)
} }
iv, err := strconv.ParseUint(strValue, 10, 64) iv, err := strconv.ParseUint(strValue, 10, 64)
if err != nil { if err != nil {
return nil, err return err
} }
uiv := uint(iv) uiv := uint(iv)
h.Timeout = &uiv h.Timeout = &uiv
} }
return h, nil return nil
} }
// Write encodes a Session header. // Write encodes a Session header.

View File

@@ -12,13 +12,13 @@ var casesSession = []struct {
name string name string
vin base.HeaderValue vin base.HeaderValue
vout base.HeaderValue vout base.HeaderValue
h *Session h Session
}{ }{
{ {
"base", "base",
base.HeaderValue{`A3eqwsafq3rFASqew`}, base.HeaderValue{`A3eqwsafq3rFASqew`},
base.HeaderValue{`A3eqwsafq3rFASqew`}, base.HeaderValue{`A3eqwsafq3rFASqew`},
&Session{ Session{
Session: "A3eqwsafq3rFASqew", Session: "A3eqwsafq3rFASqew",
}, },
}, },
@@ -26,7 +26,7 @@ var casesSession = []struct {
"with timeout", "with timeout",
base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`}, base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`},
base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`}, base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`},
&Session{ Session{
Session: "A3eqwsafq3rFASqew", Session: "A3eqwsafq3rFASqew",
Timeout: func() *uint { Timeout: func() *uint {
v := uint(47) v := uint(47)
@@ -38,7 +38,7 @@ var casesSession = []struct {
"with timeout and space", "with timeout and space",
base.HeaderValue{`A3eqwsafq3rFASqew; timeout=47`}, base.HeaderValue{`A3eqwsafq3rFASqew; timeout=47`},
base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`}, base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`},
&Session{ Session{
Session: "A3eqwsafq3rFASqew", Session: "A3eqwsafq3rFASqew",
Timeout: func() *uint { Timeout: func() *uint {
v := uint(47) v := uint(47)
@@ -51,9 +51,10 @@ var casesSession = []struct {
func TestSessionRead(t *testing.T) { func TestSessionRead(t *testing.T) {
for _, c := range casesSession { for _, c := range casesSession {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req, err := ReadSession(c.vin) var h Session
err := h.Read(c.vin)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.h, req) require.Equal(t, c.h, h)
}) })
} }
} }

View File

@@ -89,21 +89,19 @@ func parsePorts(val string) (*[2]int, error) {
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val) return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
} }
// ReadTransport decodes a Transport header. // Read decodes a Transport header.
func ReadTransport(v base.HeaderValue) (*Transport, error) { func (h *Transport) Read(v base.HeaderValue) error {
if len(v) == 0 { if len(v) == 0 {
return nil, fmt.Errorf("value not provided") return fmt.Errorf("value not provided")
} }
if len(v) > 1 { if len(v) > 1 {
return nil, fmt.Errorf("value provided multiple times (%v)", v) return fmt.Errorf("value provided multiple times (%v)", v)
} }
h := &Transport{}
parts := strings.Split(v[0], ";") parts := strings.Split(v[0], ";")
if len(parts) == 0 { if len(parts) == 0 {
return nil, fmt.Errorf("invalid value (%v)", v) return fmt.Errorf("invalid value (%v)", v)
} }
switch parts[0] { switch parts[0] {
@@ -114,7 +112,7 @@ func ReadTransport(v base.HeaderValue) (*Transport, error) {
h.Protocol = base.StreamProtocolTCP h.Protocol = base.StreamProtocolTCP
default: default:
return nil, fmt.Errorf("invalid protocol (%v)", v) return fmt.Errorf("invalid protocol (%v)", v)
} }
parts = parts[1:] parts = parts[1:]
@@ -140,7 +138,7 @@ func ReadTransport(v base.HeaderValue) (*Transport, error) {
} else if strings.HasPrefix(t, "ttl=") { } else if strings.HasPrefix(t, "ttl=") {
v, err := strconv.ParseUint(t[len("ttl="):], 10, 64) v, err := strconv.ParseUint(t[len("ttl="):], 10, 64)
if err != nil { if err != nil {
return nil, err return err
} }
vu := uint(v) vu := uint(v)
h.TTL = &vu h.TTL = &vu
@@ -148,28 +146,28 @@ func ReadTransport(v base.HeaderValue) (*Transport, error) {
} else if strings.HasPrefix(t, "port=") { } else if strings.HasPrefix(t, "port=") {
ports, err := parsePorts(t[len("port="):]) ports, err := parsePorts(t[len("port="):])
if err != nil { if err != nil {
return nil, err return err
} }
h.Ports = ports h.Ports = ports
} else if strings.HasPrefix(t, "client_port=") { } else if strings.HasPrefix(t, "client_port=") {
ports, err := parsePorts(t[len("client_port="):]) ports, err := parsePorts(t[len("client_port="):])
if err != nil { if err != nil {
return nil, err return err
} }
h.ClientPorts = ports h.ClientPorts = ports
} else if strings.HasPrefix(t, "server_port=") { } else if strings.HasPrefix(t, "server_port=") {
ports, err := parsePorts(t[len("server_port="):]) ports, err := parsePorts(t[len("server_port="):])
if err != nil { if err != nil {
return nil, err return err
} }
h.ServerPorts = ports h.ServerPorts = ports
} else if strings.HasPrefix(t, "interleaved=") { } else if strings.HasPrefix(t, "interleaved=") {
ports, err := parsePorts(t[len("interleaved="):]) ports, err := parsePorts(t[len("interleaved="):])
if err != nil { if err != nil {
return nil, err return err
} }
h.InterleavedIds = ports h.InterleavedIds = ports
@@ -190,14 +188,14 @@ func ReadTransport(v base.HeaderValue) (*Transport, error) {
h.Mode = &v h.Mode = &v
default: default:
return nil, fmt.Errorf("invalid transport mode: '%s'", str) return fmt.Errorf("invalid transport mode: '%s'", str)
} }
} }
// ignore non-standard keys // ignore non-standard keys
} }
return h, nil return nil
} }
// Write encodes a Transport header // Write encodes a Transport header

View File

@@ -12,13 +12,13 @@ var casesTransport = []struct {
name string name string
vin base.HeaderValue vin base.HeaderValue
vout base.HeaderValue vout base.HeaderValue
h *Transport h Transport
}{ }{
{ {
"udp unicast play request", "udp unicast play request",
base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY"`}, base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY"`},
base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play`}, base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play`},
&Transport{ Transport{
Protocol: base.StreamProtocolUDP, Protocol: base.StreamProtocolUDP,
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
@@ -35,7 +35,7 @@ var casesTransport = []struct {
"udp unicast play response", "udp unicast play response",
base.HeaderValue{`RTP/AVP/UDP;unicast;client_port=3056-3057;server_port=5000-5001`}, base.HeaderValue{`RTP/AVP/UDP;unicast;client_port=3056-3057;server_port=5000-5001`},
base.HeaderValue{`RTP/AVP;unicast;client_port=3056-3057;server_port=5000-5001`}, base.HeaderValue{`RTP/AVP;unicast;client_port=3056-3057;server_port=5000-5001`},
&Transport{ Transport{
Protocol: base.StreamProtocolUDP, Protocol: base.StreamProtocolUDP,
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
@@ -49,7 +49,7 @@ var casesTransport = []struct {
"udp multicast play request / response", "udp multicast play request / response",
base.HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`}, base.HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`},
base.HeaderValue{`RTP/AVP;multicast`}, base.HeaderValue{`RTP/AVP;multicast`},
&Transport{ Transport{
Protocol: base.StreamProtocolUDP, Protocol: base.StreamProtocolUDP,
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryMulticast v := base.StreamDeliveryMulticast
@@ -70,7 +70,7 @@ var casesTransport = []struct {
"tcp play request / response", "tcp play request / response",
base.HeaderValue{`RTP/AVP/TCP;interleaved=0-1`}, base.HeaderValue{`RTP/AVP/TCP;interleaved=0-1`},
base.HeaderValue{`RTP/AVP/TCP;interleaved=0-1`}, base.HeaderValue{`RTP/AVP/TCP;interleaved=0-1`},
&Transport{ Transport{
Protocol: base.StreamProtocolTCP, Protocol: base.StreamProtocolTCP,
InterleavedIds: &[2]int{0, 1}, InterleavedIds: &[2]int{0, 1},
}, },
@@ -79,7 +79,7 @@ var casesTransport = []struct {
"udp unicast play response with a single port", "udp unicast play response with a single port",
base.HeaderValue{`RTP/AVP/UDP;unicast;server_port=8052;client_port=14186;ssrc=39140788;mode=PLAY`}, base.HeaderValue{`RTP/AVP/UDP;unicast;server_port=8052;client_port=14186;ssrc=39140788;mode=PLAY`},
base.HeaderValue{`RTP/AVP;unicast;client_port=14186-14187;server_port=8052-8053;mode=play`}, base.HeaderValue{`RTP/AVP;unicast;client_port=14186-14187;server_port=8052-8053;mode=play`},
&Transport{ Transport{
Protocol: base.StreamProtocolUDP, Protocol: base.StreamProtocolUDP,
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
@@ -97,7 +97,7 @@ var casesTransport = []struct {
"udp record response with receive", "udp record response with receive",
base.HeaderValue{`RTP/AVP/UDP;unicast;mode=receive;source=localhost;client_port=14186-14187;server_port=5000-5001`}, base.HeaderValue{`RTP/AVP/UDP;unicast;mode=receive;source=localhost;client_port=14186-14187;server_port=5000-5001`},
base.HeaderValue{`RTP/AVP;unicast;client_port=14186-14187;server_port=5000-5001;mode=record`}, base.HeaderValue{`RTP/AVP;unicast;client_port=14186-14187;server_port=5000-5001;mode=record`},
&Transport{ Transport{
Protocol: base.StreamProtocolUDP, Protocol: base.StreamProtocolUDP,
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
@@ -116,9 +116,10 @@ var casesTransport = []struct {
func TestTransportRead(t *testing.T) { func TestTransportRead(t *testing.T) {
for _, c := range casesTransport { for _, c := range casesTransport {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req, err := ReadTransport(c.vin) var h Transport
err := h.Read(c.vin)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.h, req) require.Equal(t, c.h, h)
}) })
} }
} }

View File

@@ -672,7 +672,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err }, err
} }
th, err := headers.ReadTransport(req.Header["Transport"]) var th headers.Transport
err = th.Read(req.Header["Transport"])
if err != nil { if err != nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
@@ -755,7 +756,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
Path: path, Path: path,
Query: query, Query: query,
TrackID: trackID, TrackID: trackID,
Transport: th, Transport: &th,
}) })
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {

View File

@@ -616,7 +616,7 @@ func TestServerConnPublishReceivePackets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
th := &headers.Transport{ inTH := &headers.Transport{
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
return &v return &v
@@ -628,11 +628,11 @@ func TestServerConnPublishReceivePackets(t *testing.T) {
} }
if proto == "udp" { if proto == "udp" {
th.Protocol = StreamProtocolUDP inTH.Protocol = StreamProtocolUDP
th.ClientPorts = &[2]int{35466, 35467} inTH.ClientPorts = &[2]int{35466, 35467}
} else { } else {
th.Protocol = StreamProtocolTCP inTH.Protocol = StreamProtocolTCP
th.InterleavedIds = &[2]int{0, 1} inTH.InterleavedIds = &[2]int{0, 1}
} }
err = base.Request{ err = base.Request{
@@ -640,7 +640,7 @@ func TestServerConnPublishReceivePackets(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": th.Write(), "Transport": inTH.Write(),
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -649,7 +649,8 @@ func TestServerConnPublishReceivePackets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
th, err = headers.ReadTransport(res.Header["Transport"]) var th headers.Transport
err = th.Read(res.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
err = base.Request{ err = base.Request{

View File

@@ -348,7 +348,7 @@ func TestServerConnReadReceivePackets(t *testing.T) {
defer conn.Close() defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
th := &headers.Transport{ inTH := &headers.Transport{
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
return &v return &v
@@ -360,11 +360,11 @@ func TestServerConnReadReceivePackets(t *testing.T) {
} }
if proto == "udp" { if proto == "udp" {
th.Protocol = StreamProtocolUDP inTH.Protocol = StreamProtocolUDP
th.ClientPorts = &[2]int{35466, 35467} inTH.ClientPorts = &[2]int{35466, 35467}
} else { } else {
th.Protocol = StreamProtocolTCP inTH.Protocol = StreamProtocolTCP
th.InterleavedIds = &[2]int{0, 1} inTH.InterleavedIds = &[2]int{0, 1}
} }
err = base.Request{ err = base.Request{
@@ -372,7 +372,7 @@ func TestServerConnReadReceivePackets(t *testing.T) {
URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"1"}, "CSeq": base.HeaderValue{"1"},
"Transport": th.Write(), "Transport": inTH.Write(),
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -382,7 +382,8 @@ func TestServerConnReadReceivePackets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
th, err = headers.ReadTransport(res.Header["Transport"]) var th headers.Transport
err = th.Read(res.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
err = base.Request{ err = base.Request{