From 25772271dbdb3c2642dac3711ed7b8da90b1fc14 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 20 Mar 2021 09:55:04 +0100 Subject: [PATCH] headers: rewrite initializers as members of their structs --- clientconn.go | 6 ++++-- clientconnread.go | 3 +-- clientconnread_test.go | 16 ++++++++++------ pkg/auth/sender.go | 6 ++++-- pkg/auth/validator.go | 3 ++- pkg/headers/auth.go | 20 +++++++++----------- pkg/headers/auth_test.go | 25 +++++++++++++------------ pkg/headers/rtpinfo.go | 22 ++++++++++------------ pkg/headers/rtpinfo_test.go | 11 ++++++----- pkg/headers/session.go | 20 +++++++++----------- pkg/headers/session_test.go | 13 +++++++------ pkg/headers/transport.go | 28 +++++++++++++--------------- pkg/headers/transport_test.go | 19 ++++++++++--------- serverconn.go | 5 +++-- serverconnpublish_test.go | 15 ++++++++------- serverconnread_test.go | 15 ++++++++------- 16 files changed, 117 insertions(+), 110 deletions(-) diff --git a/clientconn.go b/clientconn.go index 2c79082d..dbe6b02e 100644 --- a/clientconn.go +++ b/clientconn.go @@ -266,7 +266,8 @@ func (c *ClientConn) Do(req *base.Request) (*base.Response, error) { // get session from response if v, ok := res.Header["Session"]; ok { - sx, err := headers.ReadSession(v) + var sx headers.Session + err := sx.Read(v) if err != nil { return nil, fmt.Errorf("unable to parse session header: %s", err) } @@ -568,7 +569,8 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track, return res, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } - thRes, err := headers.ReadTransport(res.Header["Transport"]) + var thRes headers.Transport + err = thRes.Read(res.Header["Transport"]) if err != nil { if proto == StreamProtocolUDP { rtpListener.close() diff --git a/clientconnread.go b/clientconnread.go index f5137384..35f04dc6 100644 --- a/clientconnread.go +++ b/clientconnread.go @@ -32,8 +32,7 @@ func (c *ClientConn) Play() (*base.Response, error) { } if v, ok := res.Header["RTP-Info"]; ok { - var err error - c.rtpInfo, err = headers.ReadRTPInfo(v) + err := c.rtpInfo.Read(v) if err != nil { return nil, fmt.Errorf("unable to parse RTP-Info: %v", err) } diff --git a/clientconnread_test.go b/clientconnread_test.go index 9348e27f..84f2d305 100644 --- a/clientconnread_test.go +++ b/clientconnread_test.go @@ -102,7 +102,8 @@ func TestClientConnRead(t *testing.T) { require.NoError(t, err) require.Equal(t, base.Setup, req.Method) - th, err := headers.ReadTransport(req.Header["Transport"]) + var th headers.Transport + err = th.Read(req.Header["Transport"]) require.NoError(t, err) if ca.proto == "udp" { @@ -259,7 +260,8 @@ func TestClientConnReadNoServerPorts(t *testing.T) { require.NoError(t, err) require.Equal(t, base.Setup, req.Method) - th, err := headers.ReadTransport(req.Header["Transport"]) + var th headers.Transport + err = th.Read(req.Header["Transport"]) require.NoError(t, err) err = base.Response{ @@ -517,7 +519,8 @@ func TestClientConnReadRedirect(t *testing.T) { require.NoError(t, err) require.Equal(t, base.Setup, req.Method) - th, err := headers.ReadTransport(req.Header["Transport"]) + var th headers.Transport + err = th.Read(req.Header["Transport"]) require.NoError(t, err) err = base.Response{ @@ -669,7 +672,8 @@ func TestClientConnReadPause(t *testing.T) { require.NoError(t, err) require.Equal(t, base.Setup, req.Method) - inTH, err := headers.ReadTransport(req.Header["Transport"]) + var inTH headers.Transport + err = inTH.Read(req.Header["Transport"]) require.NoError(t, err) th := headers.Transport{ @@ -706,7 +710,7 @@ func TestClientConnReadPause(t *testing.T) { }.Write(bconn.Writer) require.NoError(t, err) - writerTerminate, writerDone := writeFrames(inTH, bconn) + writerTerminate, writerDone := writeFrames(&inTH, bconn) err = req.Read(bconn.Reader) require.NoError(t, err) @@ -729,7 +733,7 @@ func TestClientConnReadPause(t *testing.T) { }.Write(bconn.Writer) require.NoError(t, err) - writerTerminate, writerDone = writeFrames(inTH, bconn) + writerTerminate, writerDone = writeFrames(&inTH, bconn) err = req.Read(bconn.Reader) require.NoError(t, err) diff --git a/pkg/auth/sender.go b/pkg/auth/sender.go index 859a7276..dd68e6b4 100644 --- a/pkg/auth/sender.go +++ b/pkg/auth/sender.go @@ -30,7 +30,8 @@ func NewSender(v base.HeaderValue, user string, pass string) (*Sender, error) { } return "" }(); headerAuthDigest != "" { - auth, err := headers.ReadAuth(base.HeaderValue{headerAuthDigest}) + var auth headers.Auth + err := auth.Read(base.HeaderValue{headerAuthDigest}) if err != nil { return nil, err } @@ -60,7 +61,8 @@ func NewSender(v base.HeaderValue, user string, pass string) (*Sender, error) { } return "" }(); headerAuthBasic != "" { - auth, err := headers.ReadAuth(base.HeaderValue{headerAuthBasic}) + var auth headers.Auth + err := auth.Read(base.HeaderValue{headerAuthBasic}) if err != nil { return nil, err } diff --git a/pkg/auth/validator.go b/pkg/auth/validator.go index 1e1242a9..024f648c 100644 --- a/pkg/auth/validator.go +++ b/pkg/auth/validator.go @@ -135,7 +135,8 @@ func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur * } } else if strings.HasPrefix(v0, "Digest ") { - auth, err := headers.ReadAuth(base.HeaderValue{v0}) + var auth headers.Auth + err := auth.Read(base.HeaderValue{v0}) if err != nil { return err } diff --git a/pkg/headers/auth.go b/pkg/headers/auth.go index ab1a6c38..1839d1ab 100644 --- a/pkg/headers/auth.go +++ b/pkg/headers/auth.go @@ -79,23 +79,21 @@ func findValue(v0 string) (string, string, error) { } } -// ReadAuth decodes an Authenticate or a WWW-Authenticate header. -func ReadAuth(v base.HeaderValue) (*Auth, error) { +// Read decodes an Authenticate or a WWW-Authenticate header. +func (h *Auth) Read(v base.HeaderValue) error { if len(v) == 0 { - return nil, fmt.Errorf("value not provided") + return fmt.Errorf("value not provided") } if len(v) > 1 { - return nil, fmt.Errorf("value provided multiple times (%v)", v) + return fmt.Errorf("value provided multiple times (%v)", v) } - h := &Auth{} - v0 := v[0] i := strings.IndexByte(v0, ' ') if i < 0 { - return nil, fmt.Errorf("unable to find method (%s)", v0) + return fmt.Errorf("unable to find method (%s)", v0) } switch v0[:i] { @@ -106,14 +104,14 @@ func ReadAuth(v base.HeaderValue) (*Auth, error) { h.Method = AuthDigest default: - return nil, fmt.Errorf("invalid method (%s)", v0[:i]) + return fmt.Errorf("invalid method (%s)", v0[:i]) } v0 = v0[i+1:] for len(v0) > 0 { i := strings.IndexByte(v0, '=') if i < 0 { - return nil, fmt.Errorf("unable to find key (%s)", v0) + return fmt.Errorf("unable to find key (%s)", v0) } var key string key, v0 = v0[:i], v0[i+1:] @@ -122,7 +120,7 @@ func ReadAuth(v base.HeaderValue) (*Auth, error) { var err error val, v0, err = findValue(v0) if err != nil { - return nil, err + return err } switch key { @@ -164,7 +162,7 @@ func ReadAuth(v base.HeaderValue) (*Auth, error) { } } - return h, nil + return nil } // Write encodes an Authenticate or a WWW-Authenticate header. diff --git a/pkg/headers/auth_test.go b/pkg/headers/auth_test.go index 3820692d..fcd09661 100644 --- a/pkg/headers/auth_test.go +++ b/pkg/headers/auth_test.go @@ -12,13 +12,13 @@ var casesAuth = []struct { name string vin base.HeaderValue vout base.HeaderValue - h *Auth + h Auth }{ { "basic", base.HeaderValue{`Basic realm="4419b63f5e51"`}, base.HeaderValue{`Basic realm="4419b63f5e51"`}, - &Auth{ + Auth{ Method: AuthBasic, Realm: func() *string { v := "4419b63f5e51" @@ -30,7 +30,7 @@ var casesAuth = []struct { "digest request 1", base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`}, base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`}, - &Auth{ + Auth{ Method: AuthDigest, Realm: func() *string { v := "4419b63f5e51" @@ -50,7 +50,7 @@ var casesAuth = []struct { "digest request 2", base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale=FALSE`}, base.HeaderValue{`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`}, - &Auth{ + Auth{ Method: AuthDigest, Realm: func() *string { v := "4419b63f5e51" @@ -70,7 +70,7 @@ var casesAuth = []struct { "digest request 3", base.HeaderValue{`Digest realm="4419b63f5e51",nonce="133767111917411116111311118211673010032", stale="FALSE"`}, base.HeaderValue{`Digest realm="4419b63f5e51", nonce="133767111917411116111311118211673010032", stale="FALSE"`}, - &Auth{ + Auth{ Method: AuthDigest, Realm: func() *string { v := "4419b63f5e51" @@ -90,7 +90,7 @@ var casesAuth = []struct { "digest response generic", base.HeaderValue{`Digest username="aa", realm="bb", nonce="cc", uri="dd", response="ee"`}, base.HeaderValue{`Digest username="aa", realm="bb", nonce="cc", uri="dd", response="ee"`}, - &Auth{ + Auth{ Method: AuthDigest, Username: func() *string { v := "aa" @@ -118,7 +118,7 @@ var casesAuth = []struct { "digest response with empty field", base.HeaderValue{`Digest username="", realm="IPCAM", nonce="5d17cd12b9fa8a85ac5ceef0926ea5a6", uri="rtsp://localhost:8554/mystream", response="c072ae90eb4a27f4cdcb90d62266b2a1"`}, base.HeaderValue{`Digest username="", realm="IPCAM", nonce="5d17cd12b9fa8a85ac5ceef0926ea5a6", uri="rtsp://localhost:8554/mystream", response="c072ae90eb4a27f4cdcb90d62266b2a1"`}, - &Auth{ + Auth{ Method: AuthDigest, Username: func() *string { v := "" @@ -146,7 +146,7 @@ var casesAuth = []struct { "digest response with no spaces and additional fields", base.HeaderValue{`Digest realm="Please log in with a valid username",nonce="752a62306daf32b401a41004555c7663",opaque="",stale=FALSE,algorithm=MD5`}, base.HeaderValue{`Digest realm="Please log in with a valid username", nonce="752a62306daf32b401a41004555c7663", opaque="", stale="FALSE", algorithm="MD5"`}, - &Auth{ + Auth{ Method: AuthDigest, Realm: func() *string { v := "Please log in with a valid username" @@ -175,9 +175,10 @@ var casesAuth = []struct { func TestAuthRead(t *testing.T) { for _, c := range casesAuth { t.Run(c.name, func(t *testing.T) { - req, err := ReadAuth(c.vin) + var h Auth + err := h.Read(c.vin) require.NoError(t, err) - require.Equal(t, c.h, req) + require.Equal(t, c.h, h) }) } } @@ -185,8 +186,8 @@ func TestAuthRead(t *testing.T) { func TestAuthWrite(t *testing.T) { for _, c := range casesAuth { t.Run(c.name, func(t *testing.T) { - req := c.h.Write() - require.Equal(t, c.vout, req) + vout := c.h.Write() + require.Equal(t, c.vout, vout) }) } } diff --git a/pkg/headers/rtpinfo.go b/pkg/headers/rtpinfo.go index 148eb9da..9522b296 100644 --- a/pkg/headers/rtpinfo.go +++ b/pkg/headers/rtpinfo.go @@ -18,25 +18,23 @@ type RTPInfoEntry struct { // RTPInfo is a RTP-Info header. type RTPInfo []*RTPInfoEntry -// ReadRTPInfo decodes a RTP-Info header. -func ReadRTPInfo(v base.HeaderValue) (*RTPInfo, error) { +// Read decodes a RTP-Info header. +func (h *RTPInfo) Read(v base.HeaderValue) error { if len(v) == 0 { - return nil, fmt.Errorf("value not provided") + return fmt.Errorf("value not provided") } if len(v) > 1 { - return nil, fmt.Errorf("value provided multiple times (%v)", v) + return fmt.Errorf("value provided multiple times (%v)", v) } - h := &RTPInfo{} - for _, tmp := range strings.Split(v[0], ",") { e := &RTPInfoEntry{} for _, kv := range strings.Split(tmp, ";") { tmp := strings.SplitN(kv, "=", 2) if len(tmp) != 2 { - return nil, fmt.Errorf("unable to parse key-value (%v)", kv) + return fmt.Errorf("unable to parse key-value (%v)", kv) } k, v := tmp[0], tmp[1] @@ -44,33 +42,33 @@ func ReadRTPInfo(v base.HeaderValue) (*RTPInfo, error) { case "url": vu, err := base.ParseURL(v) if err != nil { - return nil, err + return err } e.URL = vu case "seq": vi, err := strconv.ParseUint(v, 10, 16) if err != nil { - return nil, err + return err } e.SequenceNumber = uint16(vi) case "rtptime": vi, err := strconv.ParseUint(v, 10, 32) if err != nil { - return nil, err + return err } e.Timestamp = uint32(vi) default: - return nil, fmt.Errorf("invalid key: %v", k) + return fmt.Errorf("invalid key: %v", k) } } *h = append(*h, e) } - return h, nil + return nil } // Clone clones a RTPInfo. diff --git a/pkg/headers/rtpinfo_test.go b/pkg/headers/rtpinfo_test.go index 7eddc91b..7391bd77 100644 --- a/pkg/headers/rtpinfo_test.go +++ b/pkg/headers/rtpinfo_test.go @@ -12,13 +12,13 @@ var casesRTPInfo = []struct { name string vin base.HeaderValue vout base.HeaderValue - h *RTPInfo + h RTPInfo }{ { "single value", base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556`}, base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556`}, - &RTPInfo{ + RTPInfo{ { URL: base.MustParseURL("rtsp://127.0.0.1/test.mkv/track1"), SequenceNumber: 35243, @@ -30,7 +30,7 @@ var casesRTPInfo = []struct { "multiple value", base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556,url=rtsp://127.0.0.1/test.mkv/track2;seq=13655;rtptime=2848846950`}, base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556,url=rtsp://127.0.0.1/test.mkv/track2;seq=13655;rtptime=2848846950`}, - &RTPInfo{ + RTPInfo{ { URL: base.MustParseURL("rtsp://127.0.0.1/test.mkv/track1"), SequenceNumber: 35243, @@ -48,9 +48,10 @@ var casesRTPInfo = []struct { func TestRTPInfoRead(t *testing.T) { for _, c := range casesRTPInfo { t.Run(c.name, func(t *testing.T) { - req, err := ReadRTPInfo(c.vin) + var h RTPInfo + err := h.Read(c.vin) require.NoError(t, err) - require.Equal(t, c.h, req) + require.Equal(t, c.h, h) }) } } diff --git a/pkg/headers/session.go b/pkg/headers/session.go index 2b5be3fe..f3bcacec 100644 --- a/pkg/headers/session.go +++ b/pkg/headers/session.go @@ -17,23 +17,21 @@ type Session struct { Timeout *uint } -// ReadSession decodes a Session header. -func ReadSession(v base.HeaderValue) (*Session, error) { +// Read decodes a Session header. +func (h *Session) Read(v base.HeaderValue) error { if len(v) == 0 { - return nil, fmt.Errorf("value not provided") + return fmt.Errorf("value not provided") } if len(v) > 1 { - return nil, fmt.Errorf("value provided multiple times (%v)", v) + return fmt.Errorf("value provided multiple times (%v)", v) } parts := strings.Split(v[0], ";") if len(parts) == 0 { - return nil, fmt.Errorf("invalid value (%v)", v) + return fmt.Errorf("invalid value (%v)", v) } - h := &Session{} - h.Session = parts[0] for _, part := range parts[1:] { @@ -42,24 +40,24 @@ func ReadSession(v base.HeaderValue) (*Session, error) { kv := strings.Split(part, "=") if len(kv) != 2 { - return nil, fmt.Errorf("invalid value") + return fmt.Errorf("invalid value") } key, strValue := kv[0], kv[1] if key != "timeout" { - return nil, fmt.Errorf("invalid key '%s'", key) + return fmt.Errorf("invalid key '%s'", key) } iv, err := strconv.ParseUint(strValue, 10, 64) if err != nil { - return nil, err + return err } uiv := uint(iv) h.Timeout = &uiv } - return h, nil + return nil } // Write encodes a Session header. diff --git a/pkg/headers/session_test.go b/pkg/headers/session_test.go index dd4f72a3..39966c5c 100644 --- a/pkg/headers/session_test.go +++ b/pkg/headers/session_test.go @@ -12,13 +12,13 @@ var casesSession = []struct { name string vin base.HeaderValue vout base.HeaderValue - h *Session + h Session }{ { "base", base.HeaderValue{`A3eqwsafq3rFASqew`}, base.HeaderValue{`A3eqwsafq3rFASqew`}, - &Session{ + Session{ Session: "A3eqwsafq3rFASqew", }, }, @@ -26,7 +26,7 @@ var casesSession = []struct { "with timeout", base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`}, base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`}, - &Session{ + Session{ Session: "A3eqwsafq3rFASqew", Timeout: func() *uint { v := uint(47) @@ -38,7 +38,7 @@ var casesSession = []struct { "with timeout and space", base.HeaderValue{`A3eqwsafq3rFASqew; timeout=47`}, base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`}, - &Session{ + Session{ Session: "A3eqwsafq3rFASqew", Timeout: func() *uint { v := uint(47) @@ -51,9 +51,10 @@ var casesSession = []struct { func TestSessionRead(t *testing.T) { for _, c := range casesSession { t.Run(c.name, func(t *testing.T) { - req, err := ReadSession(c.vin) + var h Session + err := h.Read(c.vin) require.NoError(t, err) - require.Equal(t, c.h, req) + require.Equal(t, c.h, h) }) } } diff --git a/pkg/headers/transport.go b/pkg/headers/transport.go index 763e5e44..3fafcced 100644 --- a/pkg/headers/transport.go +++ b/pkg/headers/transport.go @@ -89,21 +89,19 @@ func parsePorts(val string) (*[2]int, error) { return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val) } -// ReadTransport decodes a Transport header. -func ReadTransport(v base.HeaderValue) (*Transport, error) { +// Read decodes a Transport header. +func (h *Transport) Read(v base.HeaderValue) error { if len(v) == 0 { - return nil, fmt.Errorf("value not provided") + return fmt.Errorf("value not provided") } if len(v) > 1 { - return nil, fmt.Errorf("value provided multiple times (%v)", v) + return fmt.Errorf("value provided multiple times (%v)", v) } - h := &Transport{} - parts := strings.Split(v[0], ";") if len(parts) == 0 { - return nil, fmt.Errorf("invalid value (%v)", v) + return fmt.Errorf("invalid value (%v)", v) } switch parts[0] { @@ -114,7 +112,7 @@ func ReadTransport(v base.HeaderValue) (*Transport, error) { h.Protocol = base.StreamProtocolTCP default: - return nil, fmt.Errorf("invalid protocol (%v)", v) + return fmt.Errorf("invalid protocol (%v)", v) } parts = parts[1:] @@ -140,7 +138,7 @@ func ReadTransport(v base.HeaderValue) (*Transport, error) { } else if strings.HasPrefix(t, "ttl=") { v, err := strconv.ParseUint(t[len("ttl="):], 10, 64) if err != nil { - return nil, err + return err } vu := uint(v) h.TTL = &vu @@ -148,28 +146,28 @@ func ReadTransport(v base.HeaderValue) (*Transport, error) { } else if strings.HasPrefix(t, "port=") { ports, err := parsePorts(t[len("port="):]) if err != nil { - return nil, err + return err } h.Ports = ports } else if strings.HasPrefix(t, "client_port=") { ports, err := parsePorts(t[len("client_port="):]) if err != nil { - return nil, err + return err } h.ClientPorts = ports } else if strings.HasPrefix(t, "server_port=") { ports, err := parsePorts(t[len("server_port="):]) if err != nil { - return nil, err + return err } h.ServerPorts = ports } else if strings.HasPrefix(t, "interleaved=") { ports, err := parsePorts(t[len("interleaved="):]) if err != nil { - return nil, err + return err } h.InterleavedIds = ports @@ -190,14 +188,14 @@ func ReadTransport(v base.HeaderValue) (*Transport, error) { h.Mode = &v default: - return nil, fmt.Errorf("invalid transport mode: '%s'", str) + return fmt.Errorf("invalid transport mode: '%s'", str) } } // ignore non-standard keys } - return h, nil + return nil } // Write encodes a Transport header diff --git a/pkg/headers/transport_test.go b/pkg/headers/transport_test.go index d60ceb3c..0fef2846 100644 --- a/pkg/headers/transport_test.go +++ b/pkg/headers/transport_test.go @@ -12,13 +12,13 @@ var casesTransport = []struct { name string vin base.HeaderValue vout base.HeaderValue - h *Transport + h Transport }{ { "udp unicast play request", base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY"`}, base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play`}, - &Transport{ + Transport{ Protocol: base.StreamProtocolUDP, Delivery: func() *base.StreamDelivery { v := base.StreamDeliveryUnicast @@ -35,7 +35,7 @@ var casesTransport = []struct { "udp unicast play response", base.HeaderValue{`RTP/AVP/UDP;unicast;client_port=3056-3057;server_port=5000-5001`}, base.HeaderValue{`RTP/AVP;unicast;client_port=3056-3057;server_port=5000-5001`}, - &Transport{ + Transport{ Protocol: base.StreamProtocolUDP, Delivery: func() *base.StreamDelivery { v := base.StreamDeliveryUnicast @@ -49,7 +49,7 @@ var casesTransport = []struct { "udp multicast play request / response", base.HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`}, base.HeaderValue{`RTP/AVP;multicast`}, - &Transport{ + Transport{ Protocol: base.StreamProtocolUDP, Delivery: func() *base.StreamDelivery { v := base.StreamDeliveryMulticast @@ -70,7 +70,7 @@ var casesTransport = []struct { "tcp play request / response", base.HeaderValue{`RTP/AVP/TCP;interleaved=0-1`}, base.HeaderValue{`RTP/AVP/TCP;interleaved=0-1`}, - &Transport{ + Transport{ Protocol: base.StreamProtocolTCP, InterleavedIds: &[2]int{0, 1}, }, @@ -79,7 +79,7 @@ var casesTransport = []struct { "udp unicast play response with a single port", base.HeaderValue{`RTP/AVP/UDP;unicast;server_port=8052;client_port=14186;ssrc=39140788;mode=PLAY`}, base.HeaderValue{`RTP/AVP;unicast;client_port=14186-14187;server_port=8052-8053;mode=play`}, - &Transport{ + Transport{ Protocol: base.StreamProtocolUDP, Delivery: func() *base.StreamDelivery { v := base.StreamDeliveryUnicast @@ -97,7 +97,7 @@ var casesTransport = []struct { "udp record response with receive", base.HeaderValue{`RTP/AVP/UDP;unicast;mode=receive;source=localhost;client_port=14186-14187;server_port=5000-5001`}, base.HeaderValue{`RTP/AVP;unicast;client_port=14186-14187;server_port=5000-5001;mode=record`}, - &Transport{ + Transport{ Protocol: base.StreamProtocolUDP, Delivery: func() *base.StreamDelivery { v := base.StreamDeliveryUnicast @@ -116,9 +116,10 @@ var casesTransport = []struct { func TestTransportRead(t *testing.T) { for _, c := range casesTransport { t.Run(c.name, func(t *testing.T) { - req, err := ReadTransport(c.vin) + var h Transport + err := h.Read(c.vin) require.NoError(t, err) - require.Equal(t, c.h, req) + require.Equal(t, c.h, h) }) } } diff --git a/serverconn.go b/serverconn.go index f64d2a49..eaf12e1e 100644 --- a/serverconn.go +++ b/serverconn.go @@ -672,7 +672,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, err } - th, err := headers.ReadTransport(req.Header["Transport"]) + var th headers.Transport + err = th.Read(req.Header["Transport"]) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -755,7 +756,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { Path: path, Query: query, TrackID: trackID, - Transport: th, + Transport: &th, }) if res.StatusCode == base.StatusOK { diff --git a/serverconnpublish_test.go b/serverconnpublish_test.go index 19ea49b7..8b5d3b2d 100644 --- a/serverconnpublish_test.go +++ b/serverconnpublish_test.go @@ -616,7 +616,7 @@ func TestServerConnPublishReceivePackets(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - th := &headers.Transport{ + inTH := &headers.Transport{ Delivery: func() *base.StreamDelivery { v := base.StreamDeliveryUnicast return &v @@ -628,11 +628,11 @@ func TestServerConnPublishReceivePackets(t *testing.T) { } if proto == "udp" { - th.Protocol = StreamProtocolUDP - th.ClientPorts = &[2]int{35466, 35467} + inTH.Protocol = StreamProtocolUDP + inTH.ClientPorts = &[2]int{35466, 35467} } else { - th.Protocol = StreamProtocolTCP - th.InterleavedIds = &[2]int{0, 1} + inTH.Protocol = StreamProtocolTCP + inTH.InterleavedIds = &[2]int{0, 1} } err = base.Request{ @@ -640,7 +640,7 @@ func TestServerConnPublishReceivePackets(t *testing.T) { URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, - "Transport": th.Write(), + "Transport": inTH.Write(), }, }.Write(bconn.Writer) require.NoError(t, err) @@ -649,7 +649,8 @@ func TestServerConnPublishReceivePackets(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - th, err = headers.ReadTransport(res.Header["Transport"]) + var th headers.Transport + err = th.Read(res.Header["Transport"]) require.NoError(t, err) err = base.Request{ diff --git a/serverconnread_test.go b/serverconnread_test.go index 073b7b9f..eb63e9da 100644 --- a/serverconnread_test.go +++ b/serverconnread_test.go @@ -348,7 +348,7 @@ func TestServerConnReadReceivePackets(t *testing.T) { defer conn.Close() bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - th := &headers.Transport{ + inTH := &headers.Transport{ Delivery: func() *base.StreamDelivery { v := base.StreamDeliveryUnicast return &v @@ -360,11 +360,11 @@ func TestServerConnReadReceivePackets(t *testing.T) { } if proto == "udp" { - th.Protocol = StreamProtocolUDP - th.ClientPorts = &[2]int{35466, 35467} + inTH.Protocol = StreamProtocolUDP + inTH.ClientPorts = &[2]int{35466, 35467} } else { - th.Protocol = StreamProtocolTCP - th.InterleavedIds = &[2]int{0, 1} + inTH.Protocol = StreamProtocolTCP + inTH.InterleavedIds = &[2]int{0, 1} } err = base.Request{ @@ -372,7 +372,7 @@ func TestServerConnReadReceivePackets(t *testing.T) { URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"1"}, - "Transport": th.Write(), + "Transport": inTH.Write(), }, }.Write(bconn.Writer) require.NoError(t, err) @@ -382,7 +382,8 @@ func TestServerConnReadReceivePackets(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - th, err = headers.ReadTransport(res.Header["Transport"]) + var th headers.Transport + err = th.Read(res.Header["Transport"]) require.NoError(t, err) err = base.Request{