From 47dad556f13a115afbc3120d565f6e4e09e3deec Mon Sep 17 00:00:00 2001 From: Joe Turki Date: Thu, 30 Jan 2025 00:17:06 -0600 Subject: [PATCH] Add methods to add and remove extensions Added `AddExtension` and `RemoveExtension` methods to `ICECandidate`, allowing extensions to be managed dynamically. Ensure that `TCPType` is stored in one place (candidate.TCPType) --- candidate.go | 8 ++- candidate_base.go | 76 +++++++++++++++++++--- candidate_test.go | 156 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 229 insertions(+), 11 deletions(-) diff --git a/candidate.go b/candidate.go index 4eb5206..89082f9 100644 --- a/candidate.go +++ b/candidate.go @@ -58,12 +58,18 @@ type Candidate interface { // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 //. Extensions() []CandidateExtension - // GetExtension returns the value of the extension attribute associated with the ICECandidate. // Extension attributes are defined in RFC 5245, Section 15.1: // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 //. GetExtension(key string) (value CandidateExtension, ok bool) + // AddExtension adds an extension attribute to the ICECandidate. + // If an extension with the same key already exists, it will be overwritten. + // Extension attributes are defined in RFC 5245, Section 15.1: + AddExtension(extension CandidateExtension) error + // RemoveExtension removes an extension attribute from the ICECandidate. + // Extension attributes are defined in RFC 5245, Section 15.1: + RemoveExtension(key string) (ok bool) String() string Type() CandidateType diff --git a/candidate_base.go b/candidate_base.go index 55a6ce8..f2ef422 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -548,17 +548,22 @@ type CandidateExtension struct { } func (c *candidateBase) Extensions() []CandidateExtension { - // IF Extensions were not parsed using UnmarshalCandidate - // For backwards compatibility when the TCPType is set manually - if len(c.extensions) == 0 && c.TCPType() != TCPTypeUnspecified { - return []CandidateExtension{{ - Key: "tcptype", - Value: c.TCPType().String(), - }} + tcpType := c.TCPType() + hasTCPType := 0 + if tcpType != TCPTypeUnspecified { + hasTCPType = 1 } - extensions := make([]CandidateExtension, len(c.extensions)) - copy(extensions, c.extensions) + extensions := make([]CandidateExtension, len(c.extensions)+hasTCPType) + // We store the TCPType in c.tcpType, but we need to return it as an extension. + if hasTCPType == 1 { + extensions[0] = CandidateExtension{ + Key: "tcptype", + Value: tcpType.String(), + } + } + + copy(extensions[hasTCPType:], c.extensions) return extensions } @@ -576,7 +581,7 @@ func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) { } // TCPType was manually set. - if key == "tcptype" && c.TCPType() != TCPTypeUnspecified { + if key == "tcptype" && c.TCPType() != TCPTypeUnspecified { //nolint:goconst extension.Value = c.TCPType().String() return extension, true @@ -585,6 +590,55 @@ func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) { return extension, false } +func (c *candidateBase) AddExtension(ext CandidateExtension) error { + if ext.Key == "tcptype" { + tcpType := NewTCPType(ext.Value) + if tcpType == TCPTypeUnspecified { + return fmt.Errorf("%w: invalid or unsupported TCPtype %s", errParseTCPType, ext.Value) + } + + c.tcpType = tcpType + + return nil + } + + if ext.Key == "" { + return fmt.Errorf("%w: key is empty", errParseExtension) + } + + // per spec, Extensions aren't explicitly unique, we only set the first one. + // If the exteion is set multiple times. + for i := range c.extensions { + if c.extensions[i].Key == ext.Key { + c.extensions[i] = ext + + return nil + } + } + + c.extensions = append(c.extensions, ext) + + return nil +} + +func (c *candidateBase) RemoveExtension(key string) (ok bool) { + if key == "tcptype" { + c.tcpType = TCPTypeUnspecified + ok = true + } + + for i := range c.extensions { + if c.extensions[i].Key == key { + c.extensions = append(c.extensions[:i], c.extensions[i+1:]...) + ok = true + + break + } + } + + return ok +} + // marshalExtensions returns the string representation of the candidate extensions. func (c *candidateBase) marshalExtensions() string { value := "" @@ -994,6 +1048,8 @@ func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension, if key == "tcptype" { rawTCPTypeRaw = value + + continue } extensions = append(extensions, CandidateExtension{key, value}) diff --git a/candidate_test.go b/candidate_test.go index caafefe..514a9de 100644 --- a/candidate_test.go +++ b/candidate_test.go @@ -1271,3 +1271,159 @@ func TestBaseCandidateExtensionsEqual(t *testing.T) { }) } } + +func TestCandidateAddExtension(t *testing.T) { + t.Run("Add extension", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }) + if err != nil { + t.Error(err) + } + + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) + require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"})) + + extensions := candidate.Extensions() + require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions) + }) + + t.Run("Add extension with existing key", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }) + if err != nil { + t.Error(err) + } + + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "d"})) + + extensions := candidate.Extensions() + require.Equal(t, []CandidateExtension{{"a", "d"}}, extensions) + }) + + t.Run("Keep tcptype extension", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeTCP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + TCPType: TCPTypeActive, + }) + if err != nil { + t.Error(err) + } + + ext, ok := candidate.GetExtension("tcptype") + require.True(t, ok) + require.Equal(t, ext, CandidateExtension{"tcptype", "active"}) + require.Equal(t, candidate.Extensions(), []CandidateExtension{{"tcptype", "active"}}) + + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) + + ext, ok = candidate.GetExtension("tcptype") + require.True(t, ok) + require.Equal(t, ext, CandidateExtension{"tcptype", "active"}) + require.Equal(t, candidate.Extensions(), []CandidateExtension{{"tcptype", "active"}, {"a", "b"}}) + }) + + t.Run("TcpType change extension", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeTCP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }) + if err != nil { + t.Error(err) + } + + require.NoError(t, candidate.AddExtension(CandidateExtension{"tcptype", "active"})) + + extensions := candidate.Extensions() + require.Equal(t, []CandidateExtension{{"tcptype", "active"}}, extensions) + require.Equal(t, TCPTypeActive, candidate.TCPType()) + + require.Error(t, candidate.AddExtension(CandidateExtension{"tcptype", "INVALID"})) + }) +} + +func TestCandidateRemoveExtension(t *testing.T) { + t.Run("Remove extension", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }) + if err != nil { + t.Error(err) + } + + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) + require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"})) + + require.True(t, candidate.RemoveExtension("a")) + + extensions := candidate.Extensions() + require.Equal(t, []CandidateExtension{{"c", "d"}}, extensions) + }) + + t.Run("Remove extension that does not exist", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }) + if err != nil { + t.Error(err) + } + + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) + require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"})) + + require.False(t, candidate.RemoveExtension("b")) + + extensions := candidate.Extensions() + require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions) + }) + + t.Run("Remove tcptype extension", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeTCP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + TCPType: TCPTypeActive, + }) + if err != nil { + t.Error(err) + } + + // tcptype extension should be removed, even if it's not in the extensions list (Not Parsed) + require.True(t, candidate.RemoveExtension("tcptype")) + require.Equal(t, TCPTypeUnspecified, candidate.TCPType()) + require.Empty(t, candidate.Extensions()) + + require.NoError(t, candidate.AddExtension(CandidateExtension{"tcptype", "passive"})) + + require.True(t, candidate.RemoveExtension("tcptype")) + require.Equal(t, TCPTypeUnspecified, candidate.TCPType()) + require.Empty(t, candidate.Extensions()) + }) +}