mirror of
https://github.com/pion/ice.git
synced 2025-09-27 03:45:54 +08:00
Add methods to add and remove extensions
Added `AddExtension` and `RemoveExtension` methods to `ICECandidate`, allowing extensions to be managed dynamically. Ensure that `TCPType` is stored in one place (candidate.TCPType)
This commit is contained in:
@@ -58,12 +58,18 @@ type Candidate interface {
|
|||||||
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
|
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
|
||||||
//.
|
//.
|
||||||
Extensions() []CandidateExtension
|
Extensions() []CandidateExtension
|
||||||
|
|
||||||
// GetExtension returns the value of the extension attribute associated with the ICECandidate.
|
// GetExtension returns the value of the extension attribute associated with the ICECandidate.
|
||||||
// Extension attributes are defined in RFC 5245, Section 15.1:
|
// Extension attributes are defined in RFC 5245, Section 15.1:
|
||||||
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
|
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
|
||||||
//.
|
//.
|
||||||
GetExtension(key string) (value CandidateExtension, ok bool)
|
GetExtension(key string) (value CandidateExtension, ok bool)
|
||||||
|
// AddExtension adds an extension attribute to the ICECandidate.
|
||||||
|
// If an extension with the same key already exists, it will be overwritten.
|
||||||
|
// Extension attributes are defined in RFC 5245, Section 15.1:
|
||||||
|
AddExtension(extension CandidateExtension) error
|
||||||
|
// RemoveExtension removes an extension attribute from the ICECandidate.
|
||||||
|
// Extension attributes are defined in RFC 5245, Section 15.1:
|
||||||
|
RemoveExtension(key string) (ok bool)
|
||||||
|
|
||||||
String() string
|
String() string
|
||||||
Type() CandidateType
|
Type() CandidateType
|
||||||
|
@@ -548,17 +548,22 @@ type CandidateExtension struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *candidateBase) Extensions() []CandidateExtension {
|
func (c *candidateBase) Extensions() []CandidateExtension {
|
||||||
// IF Extensions were not parsed using UnmarshalCandidate
|
tcpType := c.TCPType()
|
||||||
// For backwards compatibility when the TCPType is set manually
|
hasTCPType := 0
|
||||||
if len(c.extensions) == 0 && c.TCPType() != TCPTypeUnspecified {
|
if tcpType != TCPTypeUnspecified {
|
||||||
return []CandidateExtension{{
|
hasTCPType = 1
|
||||||
Key: "tcptype",
|
|
||||||
Value: c.TCPType().String(),
|
|
||||||
}}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
extensions := make([]CandidateExtension, len(c.extensions))
|
extensions := make([]CandidateExtension, len(c.extensions)+hasTCPType)
|
||||||
copy(extensions, c.extensions)
|
// We store the TCPType in c.tcpType, but we need to return it as an extension.
|
||||||
|
if hasTCPType == 1 {
|
||||||
|
extensions[0] = CandidateExtension{
|
||||||
|
Key: "tcptype",
|
||||||
|
Value: tcpType.String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(extensions[hasTCPType:], c.extensions)
|
||||||
|
|
||||||
return extensions
|
return extensions
|
||||||
}
|
}
|
||||||
@@ -576,7 +581,7 @@ func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TCPType was manually set.
|
// TCPType was manually set.
|
||||||
if key == "tcptype" && c.TCPType() != TCPTypeUnspecified {
|
if key == "tcptype" && c.TCPType() != TCPTypeUnspecified { //nolint:goconst
|
||||||
extension.Value = c.TCPType().String()
|
extension.Value = c.TCPType().String()
|
||||||
|
|
||||||
return extension, true
|
return extension, true
|
||||||
@@ -585,6 +590,55 @@ func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) {
|
|||||||
return extension, false
|
return extension, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *candidateBase) AddExtension(ext CandidateExtension) error {
|
||||||
|
if ext.Key == "tcptype" {
|
||||||
|
tcpType := NewTCPType(ext.Value)
|
||||||
|
if tcpType == TCPTypeUnspecified {
|
||||||
|
return fmt.Errorf("%w: invalid or unsupported TCPtype %s", errParseTCPType, ext.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.tcpType = tcpType
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ext.Key == "" {
|
||||||
|
return fmt.Errorf("%w: key is empty", errParseExtension)
|
||||||
|
}
|
||||||
|
|
||||||
|
// per spec, Extensions aren't explicitly unique, we only set the first one.
|
||||||
|
// If the exteion is set multiple times.
|
||||||
|
for i := range c.extensions {
|
||||||
|
if c.extensions[i].Key == ext.Key {
|
||||||
|
c.extensions[i] = ext
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.extensions = append(c.extensions, ext)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *candidateBase) RemoveExtension(key string) (ok bool) {
|
||||||
|
if key == "tcptype" {
|
||||||
|
c.tcpType = TCPTypeUnspecified
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range c.extensions {
|
||||||
|
if c.extensions[i].Key == key {
|
||||||
|
c.extensions = append(c.extensions[:i], c.extensions[i+1:]...)
|
||||||
|
ok = true
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
// marshalExtensions returns the string representation of the candidate extensions.
|
// marshalExtensions returns the string representation of the candidate extensions.
|
||||||
func (c *candidateBase) marshalExtensions() string {
|
func (c *candidateBase) marshalExtensions() string {
|
||||||
value := ""
|
value := ""
|
||||||
@@ -994,6 +1048,8 @@ func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension,
|
|||||||
|
|
||||||
if key == "tcptype" {
|
if key == "tcptype" {
|
||||||
rawTCPTypeRaw = value
|
rawTCPTypeRaw = value
|
||||||
|
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
extensions = append(extensions, CandidateExtension{key, value})
|
extensions = append(extensions, CandidateExtension{key, value})
|
||||||
|
@@ -1271,3 +1271,159 @@ func TestBaseCandidateExtensionsEqual(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCandidateAddExtension(t *testing.T) {
|
||||||
|
t.Run("Add extension", func(t *testing.T) {
|
||||||
|
candidate, err := NewCandidateHost(&CandidateHostConfig{
|
||||||
|
Network: NetworkTypeUDP4.String(),
|
||||||
|
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||||
|
Port: 53987,
|
||||||
|
Priority: 500,
|
||||||
|
Foundation: "750",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"}))
|
||||||
|
|
||||||
|
extensions := candidate.Extensions()
|
||||||
|
require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Add extension with existing key", func(t *testing.T) {
|
||||||
|
candidate, err := NewCandidateHost(&CandidateHostConfig{
|
||||||
|
Network: NetworkTypeUDP4.String(),
|
||||||
|
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||||
|
Port: 53987,
|
||||||
|
Priority: 500,
|
||||||
|
Foundation: "750",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "d"}))
|
||||||
|
|
||||||
|
extensions := candidate.Extensions()
|
||||||
|
require.Equal(t, []CandidateExtension{{"a", "d"}}, extensions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Keep tcptype extension", func(t *testing.T) {
|
||||||
|
candidate, err := NewCandidateHost(&CandidateHostConfig{
|
||||||
|
Network: NetworkTypeTCP4.String(),
|
||||||
|
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||||
|
Port: 53987,
|
||||||
|
Priority: 500,
|
||||||
|
Foundation: "750",
|
||||||
|
TCPType: TCPTypeActive,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ext, ok := candidate.GetExtension("tcptype")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, ext, CandidateExtension{"tcptype", "active"})
|
||||||
|
require.Equal(t, candidate.Extensions(), []CandidateExtension{{"tcptype", "active"}})
|
||||||
|
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
|
||||||
|
|
||||||
|
ext, ok = candidate.GetExtension("tcptype")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, ext, CandidateExtension{"tcptype", "active"})
|
||||||
|
require.Equal(t, candidate.Extensions(), []CandidateExtension{{"tcptype", "active"}, {"a", "b"}})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TcpType change extension", func(t *testing.T) {
|
||||||
|
candidate, err := NewCandidateHost(&CandidateHostConfig{
|
||||||
|
Network: NetworkTypeTCP4.String(),
|
||||||
|
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||||
|
Port: 53987,
|
||||||
|
Priority: 500,
|
||||||
|
Foundation: "750",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"tcptype", "active"}))
|
||||||
|
|
||||||
|
extensions := candidate.Extensions()
|
||||||
|
require.Equal(t, []CandidateExtension{{"tcptype", "active"}}, extensions)
|
||||||
|
require.Equal(t, TCPTypeActive, candidate.TCPType())
|
||||||
|
|
||||||
|
require.Error(t, candidate.AddExtension(CandidateExtension{"tcptype", "INVALID"}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCandidateRemoveExtension(t *testing.T) {
|
||||||
|
t.Run("Remove extension", func(t *testing.T) {
|
||||||
|
candidate, err := NewCandidateHost(&CandidateHostConfig{
|
||||||
|
Network: NetworkTypeUDP4.String(),
|
||||||
|
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||||
|
Port: 53987,
|
||||||
|
Priority: 500,
|
||||||
|
Foundation: "750",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"}))
|
||||||
|
|
||||||
|
require.True(t, candidate.RemoveExtension("a"))
|
||||||
|
|
||||||
|
extensions := candidate.Extensions()
|
||||||
|
require.Equal(t, []CandidateExtension{{"c", "d"}}, extensions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Remove extension that does not exist", func(t *testing.T) {
|
||||||
|
candidate, err := NewCandidateHost(&CandidateHostConfig{
|
||||||
|
Network: NetworkTypeUDP4.String(),
|
||||||
|
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||||
|
Port: 53987,
|
||||||
|
Priority: 500,
|
||||||
|
Foundation: "750",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"}))
|
||||||
|
|
||||||
|
require.False(t, candidate.RemoveExtension("b"))
|
||||||
|
|
||||||
|
extensions := candidate.Extensions()
|
||||||
|
require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Remove tcptype extension", func(t *testing.T) {
|
||||||
|
candidate, err := NewCandidateHost(&CandidateHostConfig{
|
||||||
|
Network: NetworkTypeTCP4.String(),
|
||||||
|
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||||
|
Port: 53987,
|
||||||
|
Priority: 500,
|
||||||
|
Foundation: "750",
|
||||||
|
TCPType: TCPTypeActive,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcptype extension should be removed, even if it's not in the extensions list (Not Parsed)
|
||||||
|
require.True(t, candidate.RemoveExtension("tcptype"))
|
||||||
|
require.Equal(t, TCPTypeUnspecified, candidate.TCPType())
|
||||||
|
require.Empty(t, candidate.Extensions())
|
||||||
|
|
||||||
|
require.NoError(t, candidate.AddExtension(CandidateExtension{"tcptype", "passive"}))
|
||||||
|
|
||||||
|
require.True(t, candidate.RemoveExtension("tcptype"))
|
||||||
|
require.Equal(t, TCPTypeUnspecified, candidate.TCPType())
|
||||||
|
require.Empty(t, candidate.Extensions())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user