diff --git a/agent.go b/agent.go index 2ab4700..94f92e1 100644 --- a/agent.go +++ b/agent.go @@ -3,16 +3,20 @@ package ice import ( + "context" "fmt" "math/rand" "net" + "strings" "sync" "sync/atomic" "time" "github.com/pion/logging" + "github.com/pion/mdns" "github.com/pion/stun" "github.com/pion/transport/packetio" + "golang.org/x/net/ipv4" ) const ( @@ -83,6 +87,10 @@ type Agent struct { connectionState ConnectionState gatheringState GatheringState + mDNSMode MulticastDNSMode + mDNSName string + mDNSConn *mdns.Conn + haveStarted atomic.Value isControlling bool @@ -166,6 +174,9 @@ type AgentConfig struct { // work perform synchronous gathering. Trickle bool + // MulticastDNSMode controls mDNS behavior for the ICE agent + MulticastDNSMode MulticastDNSMode + // ConnectionTimeout defaults to 30 seconds when this property is nil. // If the duration is 0, we will never timeout this connection. ConnectionTimeout *time.Duration @@ -216,6 +227,41 @@ func NewAgent(config *AgentConfig) (*Agent, error) { return nil, ErrPort } + mDNSName, err := generateMulticastDNSName() + if err != nil { + return nil, err + } + + mDNSMode := config.MulticastDNSMode + if mDNSMode == 0 { + mDNSMode = MulticastDNSModeQueryOnly + } + + var mDNSConn *mdns.Conn + if mDNSMode != MulticastDNSModeDisabled { + addr, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddress) + if err != nil { + return nil, err + } + + l, err := net.ListenUDP("udp4", addr) + if err != nil { + return nil, err + } + + switch mDNSMode { + case MulticastDNSModeQueryOnly: + mDNSConn, err = mdns.Server(ipv4.NewPacketConn(l), &mdns.Config{}) + case MulticastDNSModeQueryAndGather: + mDNSConn, err = mdns.Server(ipv4.NewPacketConn(l), &mdns.Config{ + LocalNames: []string{mDNSName}, + }) + } + if err != nil { + return nil, err + } + } + loggerFactory := config.LoggerFactory if loggerFactory == nil { loggerFactory = logging.NewDefaultLoggerFactory() @@ -243,6 +289,10 @@ func NewAgent(config *AgentConfig) (*Agent, error) { trickle: config.Trickle, log: loggerFactory.NewLogger("ice"), + mDNSMode: mDNSMode, + mDNSName: mDNSName, + mDNSConn: mDNSConn, + forceCandidateContact: make(chan bool, 1), } a.haveStarted.Store(false) @@ -557,11 +607,53 @@ func (a *Agent) checkKeepalive() { // AddRemoteCandidate adds a new remote candidate func (a *Agent) AddRemoteCandidate(c Candidate) error { + // If we have a mDNS Candidate lets fully resolve it before adding it locally + if c.Type() == CandidateTypeHost && strings.HasSuffix(c.Address(), ".local") { + if a.mDNSMode == MulticastDNSModeDisabled { + return nil + } + + hostCandidate, ok := c.(*CandidateHost) + if !ok { + return ErrAddressParseFailed + } + + go a.resolveAndAddMulticastCandidate(hostCandidate) + return nil + } + return a.run(func(agent *Agent) { agent.addRemoteCandidate(c) }) } +func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) { + _, src, err := a.mDNSConn.Query(context.TODO(), c.Address()) + if err != nil { + a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err) + return + } + + ip, _, _, _ := parseAddr(src) + if ip == nil { + a.log.Warnf("Failed to discover mDNS candidate %s: failed to parse IP", c.Address()) + return + } + + if err = c.setIP(ip); err != nil { + a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err) + return + } + + if err = a.run(func(agent *Agent) { + agent.addRemoteCandidate(c) + }); err != nil { + a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err) + return + + } +} + // addRemoteCandidate assumes you are holding the lock (must be execute using a.run) func (a *Agent) addRemoteCandidate(c Candidate) { set := a.remoteCandidates[c.NetworkType()] @@ -661,11 +753,16 @@ func (a *Agent) Close() error { } delete(agent.remoteCandidates, net) } - - err := a.buffer.Close() - if err != nil { + if err := a.buffer.Close(); err != nil { a.log.Warnf("failed to close buffer: %v", err) } + + if a.mDNSConn != nil { + if err := a.mDNSConn.Close(); err != nil { + a.log.Warnf("failed to close mDNS Conn: %v", err) + } + } + }) if err != nil { return err diff --git a/candidate_host.go b/candidate_host.go index 1300c7f..3a590bf 100644 --- a/candidate_host.go +++ b/candidate_host.go @@ -8,31 +8,42 @@ import ( // CandidateHost is a candidate of type host type CandidateHost struct { candidateBase + + network string } // NewCandidateHost creates a new host candidate func NewCandidateHost(network string, address string, port int, component uint16) (*CandidateHost, error) { c := &CandidateHost{ candidateBase: candidateBase{ + address: address, candidateType: CandidateTypeHost, - port: port, component: component, + port: port, }, + network: network, } + if !strings.HasSuffix(address, ".local") { ip := net.ParseIP(address) if ip == nil { return nil, ErrAddressParseFailed } - networkType, err := determineNetworkType(network, ip) - if err != nil { + if err := c.setIP(ip); err != nil { return nil, err } - - c.candidateBase.networkType = networkType - c.candidateBase.resolvedAddr = &net.UDPAddr{IP: ip, Port: port} } - return c, nil } + +func (c *CandidateHost) setIP(ip net.IP) error { + networkType, err := determineNetworkType(c.network, ip) + if err != nil { + return err + } + + c.candidateBase.networkType = networkType + c.candidateBase.resolvedAddr = &net.UDPAddr{IP: ip, Port: c.port} + return nil +} diff --git a/gather.go b/gather.go index 188c098..5f8d16e 100644 --- a/gather.go +++ b/gather.go @@ -176,13 +176,25 @@ func (a *Agent) gatherCandidatesLocal(networkTypes []NetworkType) { return } + address := ip.String() + if a.mDNSMode == MulticastDNSModeQueryAndGather { + address = a.mDNSName + } + port := conn.LocalAddr().(*net.UDPAddr).Port - c, err := NewCandidateHost(network, ip.String(), port, ComponentRTP) + c, err := NewCandidateHost(network, address, port, ComponentRTP) if err != nil { a.log.Warnf("Failed to create host candidate: %s %s %d: %v\n", network, ip, port, err) return } + if a.mDNSMode == MulticastDNSModeQueryAndGather { + if err = c.setIP(ip); err != nil { + a.log.Warnf("Failed to create host candidate: %s %s %d: %v\n", network, ip, port, err) + return + } + } + if err := a.run(func(agent *Agent) { a.addCandidate(c) }); err != nil { diff --git a/go.mod b/go.mod index 632d0f8..e4d9721 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,11 @@ go 1.12 require ( github.com/pion/logging v0.2.1 + github.com/pion/mdns v0.0.2 github.com/pion/stun v0.3.1 github.com/pion/transport v0.7.0 github.com/pion/turn v1.1.4 github.com/pion/turnc v0.0.6 github.com/stretchr/testify v1.3.0 + golang.org/x/net v0.0.0-20190619014844-b5b0513f8c1b ) diff --git a/go.sum b/go.sum index bab2304..e654dfc 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/gortc/turn v0.7.3 h1:CE72C79erbcsfa6L/QDhKztcl2kDq1UK20ImrJWDt/w= github.com/gortc/turn v0.7.3/go.mod h1:gvguwaGAFyv5/9KrcW9MkCgHALYD+e99mSM7pSCYYho= github.com/pion/logging v0.2.1 h1:LwASkBKZ+2ysGJ+jLv1E/9H1ge0k1nTfi1X+5zirkDk= github.com/pion/logging v0.2.1/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/mdns v0.0.2 h1:T22Gg4dSuYVYsZ21oRFh9z7twzAm27+5PEKiABbjCvM= +github.com/pion/mdns v0.0.2/go.mod h1:VrN3wefVgtfL8QgpEblPUC46ag1reLIfpqekCnKunLE= github.com/pion/stun v0.3.0/go.mod h1:xrCld6XM+6GWDZdvjPlLMsTU21rNxnO6UO8XsAvHr/M= github.com/pion/stun v0.3.1 h1:d09JJzOmOS8ZzIp8NppCMgrxGZpJ4Ix8qirfNYyI3BA= github.com/pion/stun v0.3.1/go.mod h1:xrCld6XM+6GWDZdvjPlLMsTU21rNxnO6UO8XsAvHr/M= @@ -24,5 +26,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190403144856-b630fd6fe46b h1:/zjbcJPEGAyu6Is/VBOALsgdi4z9+kz/Vtdm6S+beD0= golang.org/x/net v0.0.0-20190403144856-b630fd6fe46b/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190619014844-b5b0513f8c1b h1:lkjdUzSyJ5P1+eal9fxXX9Xg2BTfswsonKUse48C0uE= +golang.org/x/net v0.0.0-20190619014844-b5b0513f8c1b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/mdns.go b/mdns.go new file mode 100644 index 0000000..908a4fe --- /dev/null +++ b/mdns.go @@ -0,0 +1,32 @@ +package ice + +import ( + "crypto/rand" + "fmt" +) + +// MulticastDNSMode represents the different Multicast modes ICE can run in +type MulticastDNSMode byte + +// MulticastDNSMode enum +const ( + // MulticastDNSModeDisabled means remote mDNS candidates will be discarded, and local host candidates will use IPs + MulticastDNSModeDisabled MulticastDNSMode = iota + 1 + + // MulticastDNSModeQueryOnly means remote mDNS candidates will be accepted, and local host candidates will use IPs + MulticastDNSModeQueryOnly + + // MulticastDNSModeQueryAndGather means remote mDNS candidates will be accepted, and local host candidates will use mDNS + MulticastDNSModeQueryAndGather +) + +func generateMulticastDNSName() (string, error) { + b := make([]byte, 16) + _, err := rand.Read(b) //nolint + + if err != nil { + return "", err + } + + return fmt.Sprintf("%X-%X-%X-%X-%X.local", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), nil +} diff --git a/mdns_test.go b/mdns_test.go new file mode 100644 index 0000000..9c6b972 --- /dev/null +++ b/mdns_test.go @@ -0,0 +1,102 @@ +package ice + +import ( + "testing" + "time" + + "github.com/pion/transport/test" +) + +func TestMulticastDNSOnlyConnection(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + cfg := &AgentConfig{ + NetworkTypes: []NetworkType{NetworkTypeUDP4}, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + } + + aAgent, err := NewAgent(cfg) + if err != nil { + t.Fatal(err) + } + + aNotifier, aConnected := onConnected() + if err = aAgent.OnConnectionStateChange(aNotifier); err != nil { + t.Fatal(err) + } + + bAgent, err := NewAgent(cfg) + if err != nil { + t.Fatal(err) + } + + bNotifier, bConnected := onConnected() + if err = bAgent.OnConnectionStateChange(bNotifier); err != nil { + t.Fatal(err) + } + + connect(aAgent, bAgent) + <-aConnected + <-bConnected + + if err = aAgent.Close(); err != nil { + t.Fatal(err) + } + if err = bAgent.Close(); err != nil { + t.Fatal(err) + } +} + +func TestMulticastDNSMixedConnection(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: []NetworkType{NetworkTypeUDP4}, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + }) + if err != nil { + t.Fatal(err) + } + + aNotifier, aConnected := onConnected() + if err = aAgent.OnConnectionStateChange(aNotifier); err != nil { + t.Fatal(err) + } + + bAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: []NetworkType{NetworkTypeUDP4}, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryOnly, + }) + if err != nil { + t.Fatal(err) + } + + bNotifier, bConnected := onConnected() + if err = bAgent.OnConnectionStateChange(bNotifier); err != nil { + t.Fatal(err) + } + + connect(aAgent, bAgent) + <-aConnected + <-bConnected + + if err = aAgent.Close(); err != nil { + t.Fatal(err) + } + if err = bAgent.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/transport_test.go b/transport_test.go index 260fab3..d365888 100644 --- a/transport_test.go +++ b/transport_test.go @@ -330,45 +330,23 @@ func pipeWithTimeout(iceTimeout time.Duration, iceKeepalive time.Duration) (*Con return aConn, bConn } -func copyCandidate(o Candidate) Candidate { +func copyCandidate(o Candidate) (c Candidate) { + var err error switch orig := o.(type) { case *CandidateHost: - return &CandidateHost{ - candidateBase{ - candidateType: orig.candidateType, - networkType: orig.networkType, - address: orig.address, - port: orig.port, - component: orig.component, - }, - } + c, err = NewCandidateHost(udp, orig.address, orig.port, orig.component) case *CandidateServerReflexive: - return &CandidateServerReflexive{ - candidateBase{ - candidateType: orig.candidateType, - networkType: orig.networkType, - address: orig.address, - port: orig.port, - component: orig.component, - relatedAddress: orig.relatedAddress, - }, - } - + c, err = NewCandidateServerReflexive(udp, orig.address, orig.port, orig.component, orig.relatedAddress.Address, orig.relatedAddress.Port) case *CandidateRelay: - return &CandidateRelay{ - candidateBase{ - candidateType: orig.candidateType, - networkType: orig.networkType, - address: orig.address, - port: orig.port, - component: orig.component, - relatedAddress: orig.relatedAddress, - }, - nil, nil, nil, - } + c, err = NewCandidateRelay(udp, orig.address, orig.port, orig.component, orig.relatedAddress.Address, orig.relatedAddress.Port) default: - return nil + panic("Tried to copy unsupported candidate type") } + + if err != nil { + panic(err) + } + return c } func onConnected() (func(ConnectionState), chan struct{}) {