mirror of
https://github.com/pion/ice.git
synced 2025-10-03 22:56:36 +08:00
@@ -54,6 +54,7 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu
|
||||
* [Lander Noterman](https://github.com/LanderN)
|
||||
* [BUPTCZQ](https://github.com/buptczq)
|
||||
* [Henry](https://github.com/cryptix)
|
||||
* [Jerko Steiner](https://github.com/jeremija)
|
||||
|
||||
### License
|
||||
MIT License - see [LICENSE](LICENSE) for full text
|
||||
|
35
agent.go
35
agent.go
@@ -121,6 +121,7 @@ type Agent struct {
|
||||
log logging.LeveledLogger
|
||||
|
||||
net *vnet.Net
|
||||
tcp *tcpIPMux
|
||||
|
||||
interfaceFilter func(string) bool
|
||||
|
||||
@@ -305,6 +306,12 @@ func NewAgent(config *AgentConfig) (*Agent, error) {
|
||||
insecureSkipVerify: config.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
a.tcp = newTCPIPMux(tcpIPMuxParams{
|
||||
ListenPort: config.TCPListenPort,
|
||||
Logger: log,
|
||||
ReadBufferSize: 8,
|
||||
})
|
||||
|
||||
if a.net == nil {
|
||||
a.net = vnet.NewNet(nil)
|
||||
} else if a.net.IsVirtual() {
|
||||
@@ -695,6 +702,15 @@ func (a *Agent) checkKeepalive() {
|
||||
|
||||
// AddRemoteCandidate adds a new remote candidate
|
||||
func (a *Agent) AddRemoteCandidate(c Candidate) error {
|
||||
// canot check for network yet because it might not be applied
|
||||
// when mDNS hostame is used.
|
||||
if c.TCPType() == TCPTypeActive {
|
||||
// TCP Candidates with tcptype active will probe server passive ones, so
|
||||
// no need to do anything with them.
|
||||
a.log.Infof("Ignoring remote candidate with tcpType active: %s", c)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we have a mDNS Candidate lets fully resolve it before adding it locally
|
||||
if c.Type() == CandidateTypeHost && strings.HasSuffix(c.Address(), ".local") {
|
||||
if a.mDNSMode == MulticastDNSModeDisabled {
|
||||
@@ -871,6 +887,7 @@ func (a *Agent) Close() error {
|
||||
|
||||
a.gatherCandidateCancel()
|
||||
a.err.Store(ErrClosed)
|
||||
a.tcp.RemoveUfrag(a.localUfrag)
|
||||
close(a.done)
|
||||
|
||||
<-done
|
||||
@@ -901,9 +918,9 @@ func (a *Agent) deleteAllCandidates() {
|
||||
}
|
||||
|
||||
func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Candidate {
|
||||
ip, port, err := addrIPAndPort(addr)
|
||||
if err != nil {
|
||||
a.log.Warn(err.Error())
|
||||
ip, port, _, ok := parseAddr(addr)
|
||||
if !ok {
|
||||
a.log.Warnf("Error parsing addr: %s", addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -932,10 +949,17 @@ func (a *Agent) sendBindingRequest(m *stun.Message, local, remote Candidate) {
|
||||
|
||||
func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) {
|
||||
base := remote
|
||||
|
||||
ip, port, _, ok := parseAddr(base.addr())
|
||||
if !ok {
|
||||
a.log.Warnf("Error parsing addr: %s", base.addr())
|
||||
return
|
||||
}
|
||||
|
||||
if out, err := stun.Build(m, stun.BindingSuccess,
|
||||
&stun.XORMappedAddress{
|
||||
IP: base.addr().IP,
|
||||
Port: base.addr().Port,
|
||||
IP: ip,
|
||||
Port: port,
|
||||
},
|
||||
stun.NewShortTermIntegrity(a.localPwd),
|
||||
stun.Fingerprint,
|
||||
@@ -1048,6 +1072,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
|
||||
Component: local.Component(),
|
||||
RelAddr: "",
|
||||
RelPort: 0,
|
||||
// TODO set TCPType
|
||||
}
|
||||
|
||||
prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig)
|
||||
|
@@ -138,6 +138,11 @@ type AgentConfig struct {
|
||||
// InsecureSkipVerify controls if self-signed certificates are accepted when connecting
|
||||
// to TURN servers via TLS or DTLS
|
||||
InsecureSkipVerify bool
|
||||
|
||||
// TCPListenPort will be used to start a TCP listener on all allowed interfaces for
|
||||
// ICE TCP. Currently only passive candidates are supported. This functionality is
|
||||
// experimental and this API will likely change in the future.
|
||||
TCPListenPort int
|
||||
}
|
||||
|
||||
// initWithDefaults populates an agent and falls back to defaults if fields are unset
|
||||
|
@@ -29,10 +29,11 @@ type Candidate interface {
|
||||
RelatedAddress() *CandidateRelatedAddress
|
||||
String() string
|
||||
Type() CandidateType
|
||||
TCPType() TCPType
|
||||
|
||||
Equal(other Candidate) bool
|
||||
|
||||
addr() *net.UDPAddr
|
||||
addr() net.Addr
|
||||
agent() *Agent
|
||||
context() context.Context
|
||||
|
||||
|
@@ -20,8 +20,9 @@ type candidateBase struct {
|
||||
address string
|
||||
port int
|
||||
relatedAddress *CandidateRelatedAddress
|
||||
tcpType TCPType
|
||||
|
||||
resolvedAddr *net.UDPAddr
|
||||
resolvedAddr net.Addr
|
||||
|
||||
lastSent atomic.Value
|
||||
lastReceived atomic.Value
|
||||
@@ -97,6 +98,10 @@ func (c *candidateBase) RelatedAddress() *CandidateRelatedAddress {
|
||||
return c.relatedAddress
|
||||
}
|
||||
|
||||
func (c *candidateBase) TCPType() TCPType {
|
||||
return c.tcpType
|
||||
}
|
||||
|
||||
// start runs the candidate using the provided connection
|
||||
func (c *candidateBase) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) {
|
||||
if c.conn != nil {
|
||||
@@ -227,7 +232,7 @@ func (c *candidateBase) Equal(other Candidate) bool {
|
||||
|
||||
// String makes the candidateBase printable
|
||||
func (c *candidateBase) String() string {
|
||||
return fmt.Sprintf("%s %s:%d%s", c.Type(), c.Address(), c.Port(), c.relatedAddress)
|
||||
return fmt.Sprintf("%s %s %s:%d%s", c.NetworkType(), c.Type(), c.Address(), c.Port(), c.relatedAddress)
|
||||
}
|
||||
|
||||
// LastReceived returns a time.Time indicating the last time
|
||||
@@ -266,7 +271,7 @@ func (c *candidateBase) seen(outbound bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *candidateBase) addr() *net.UDPAddr {
|
||||
func (c *candidateBase) addr() net.Addr {
|
||||
return c.resolvedAddr
|
||||
}
|
||||
|
||||
|
@@ -19,6 +19,7 @@ type CandidateHostConfig struct {
|
||||
Address string
|
||||
Port int
|
||||
Component uint16
|
||||
TCPType TCPType
|
||||
}
|
||||
|
||||
// NewCandidateHost creates a new host candidate
|
||||
@@ -36,6 +37,7 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
|
||||
candidateType: CandidateTypeHost,
|
||||
component: config.Component,
|
||||
port: config.Port,
|
||||
tcpType: config.TCPType,
|
||||
},
|
||||
network: config.Network,
|
||||
}
|
||||
@@ -50,6 +52,7 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -60,6 +63,7 @@ func (c *CandidateHost) setIP(ip net.IP) error {
|
||||
}
|
||||
|
||||
c.candidateBase.networkType = networkType
|
||||
c.candidateBase.resolvedAddr = &net.UDPAddr{IP: ip, Port: c.port}
|
||||
c.candidateBase.resolvedAddr = createAddr(networkType, ip, c.port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -44,7 +44,7 @@ func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*Candidate
|
||||
candidateType: CandidateTypePeerReflexive,
|
||||
address: config.Address,
|
||||
port: config.Port,
|
||||
resolvedAddr: &net.UDPAddr{IP: ip, Port: config.Port},
|
||||
resolvedAddr: createAddr(networkType, ip, config.Port),
|
||||
component: config.Component,
|
||||
relatedAddress: &CandidateRelatedAddress{
|
||||
Address: config.RelAddr,
|
||||
|
@@ -102,4 +102,7 @@ var (
|
||||
|
||||
// ErrRunCanceled indicates a run operation was canceled by its individual done
|
||||
ErrRunCanceled = errors.New("run was canceled by done")
|
||||
|
||||
// ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr.
|
||||
ErrTCPRemoteAddrAlreadyExists = errors.New("conn with same remote addr already exists")
|
||||
)
|
||||
|
55
gather.go
55
gather.go
@@ -123,6 +123,15 @@ func (a *Agent) gatherCandidates(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) {
|
||||
networks := map[string]struct{}{}
|
||||
for _, networkType := range networkTypes {
|
||||
if networkType.IsTCP() {
|
||||
networks[tcp] = struct{}{}
|
||||
} else {
|
||||
networks[udp] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, networkTypes)
|
||||
if err != nil {
|
||||
a.log.Warnf("failed to iterate local interfaces, host candidates will not be gathered %s", err)
|
||||
@@ -144,19 +153,51 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
address = a.mDNSName
|
||||
}
|
||||
|
||||
for _, network := range supportedNetworks {
|
||||
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
|
||||
for network := range networks {
|
||||
var port int
|
||||
var conn net.PacketConn
|
||||
var err error
|
||||
|
||||
var tcpType TCPType
|
||||
switch network {
|
||||
case tcp:
|
||||
if a.tcp == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// below is for passive mode
|
||||
// TODO active mode
|
||||
// TODO S-O mode
|
||||
|
||||
mux, muxErr := a.tcp.Listen(ip)
|
||||
if muxErr != nil {
|
||||
a.log.Warnf("could not listen %s %s\n", network, ip)
|
||||
continue
|
||||
}
|
||||
|
||||
a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag)
|
||||
conn, err = mux.GetConn(a.localUfrag)
|
||||
if err != nil {
|
||||
a.log.Warnf("error getting tcp conn by ufrag: %s %s\n", network, ip, a.localUfrag)
|
||||
continue
|
||||
}
|
||||
port = conn.LocalAddr().(*net.TCPAddr).Port
|
||||
tcpType = TCPTypePassive
|
||||
case udp:
|
||||
conn, err = listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
|
||||
if err != nil {
|
||||
a.log.Warnf("could not listen %s %s\n", network, ip)
|
||||
continue
|
||||
}
|
||||
|
||||
port := conn.LocalAddr().(*net.UDPAddr).Port
|
||||
port = conn.LocalAddr().(*net.UDPAddr).Port
|
||||
}
|
||||
hostConfig := CandidateHostConfig{
|
||||
Network: network,
|
||||
Address: address,
|
||||
Port: port,
|
||||
Component: ComponentRTP,
|
||||
TCPType: tcpType,
|
||||
}
|
||||
|
||||
c, err := NewCandidateHost(&hostConfig)
|
||||
@@ -187,6 +228,10 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
|
||||
defer wg.Wait()
|
||||
|
||||
for _, networkType := range networkTypes {
|
||||
if networkType.IsTCP() {
|
||||
continue
|
||||
}
|
||||
|
||||
network := networkType.String()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -237,6 +282,10 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*URL, networkT
|
||||
defer wg.Wait()
|
||||
|
||||
for _, networkType := range networkTypes {
|
||||
if networkType.IsTCP() {
|
||||
continue
|
||||
}
|
||||
|
||||
for i := range urls {
|
||||
wg.Add(1)
|
||||
go func(url URL, network string) {
|
||||
|
@@ -119,15 +119,18 @@ func TestSTUNConcurrency(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: supportedNetworkTypes,
|
||||
Urls: urls,
|
||||
CandidateTypes: []CandidateType{CandidateTypeServerReflexive},
|
||||
CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive},
|
||||
TCPListenPort: 9999,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||
assert.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
if c != nil {
|
||||
if c == nil {
|
||||
candidateGatheredFunc()
|
||||
return
|
||||
}
|
||||
t.Log(c.NetworkType(), c.Priority(), c)
|
||||
}))
|
||||
assert.NoError(t, a.GatherCandidates())
|
||||
|
||||
|
@@ -11,16 +11,11 @@ const (
|
||||
tcp = "tcp"
|
||||
)
|
||||
|
||||
var supportedNetworks = []string{
|
||||
udp,
|
||||
// tcp, // Not supported yet
|
||||
}
|
||||
|
||||
var supportedNetworkTypes = []NetworkType{
|
||||
NetworkTypeUDP4,
|
||||
NetworkTypeUDP6,
|
||||
// NetworkTypeTCP4, // Not supported yet
|
||||
// NetworkTypeTCP6, // Not supported yet
|
||||
NetworkTypeTCP4,
|
||||
NetworkTypeTCP6,
|
||||
}
|
||||
|
||||
// NetworkType represents the type of network
|
||||
@@ -55,6 +50,16 @@ func (t NetworkType) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// IsUDP returns true when network is UDP4 or UDP6.
|
||||
func (t NetworkType) IsUDP() bool {
|
||||
return t == NetworkTypeUDP4 || t == NetworkTypeUDP6
|
||||
}
|
||||
|
||||
// IsTCP returns true when network is TCP4 or TCP6.
|
||||
func (t NetworkType) IsTCP() bool {
|
||||
return t == NetworkTypeTCP4 || t == NetworkTypeTCP6
|
||||
}
|
||||
|
||||
// NetworkShort returns the short network description
|
||||
func (t NetworkType) NetworkShort() string {
|
||||
switch t {
|
||||
|
@@ -3,6 +3,8 @@ package ice
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNetworkTypeParsing_Success(t *testing.T) {
|
||||
@@ -72,3 +74,17 @@ func TestNetworkTypeParsing_Failure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTypeIsUDP(t *testing.T) {
|
||||
assert.True(t, NetworkTypeUDP4.IsUDP())
|
||||
assert.True(t, NetworkTypeUDP6.IsUDP())
|
||||
assert.False(t, NetworkTypeUDP4.IsTCP())
|
||||
assert.False(t, NetworkTypeUDP6.IsTCP())
|
||||
}
|
||||
|
||||
func TestNetworkTypeIsTCP(t *testing.T) {
|
||||
assert.True(t, NetworkTypeTCP4.IsTCP())
|
||||
assert.True(t, NetworkTypeTCP6.IsTCP())
|
||||
assert.False(t, NetworkTypeTCP4.IsUDP())
|
||||
assert.False(t, NetworkTypeTCP6.IsUDP())
|
||||
}
|
||||
|
102
tcp_ip_mux.go
Normal file
102
tcp_ip_mux.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/logging"
|
||||
)
|
||||
|
||||
// tcpMuxes is a map of local addr listeners to tcpMux
|
||||
var tcpMuxes map[string]*tcpMux
|
||||
var tcpMuxesMu sync.Mutex
|
||||
|
||||
type tcpIPMux struct {
|
||||
params *tcpIPMuxParams
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type tcpIPMuxParams struct {
|
||||
ListenPort int
|
||||
ReadBufferSize int
|
||||
Logger logging.LeveledLogger
|
||||
}
|
||||
|
||||
func newTCPIPMux(params tcpIPMuxParams) *tcpIPMux {
|
||||
m := &tcpIPMux{
|
||||
params: ¶ms,
|
||||
}
|
||||
|
||||
tcpMuxesMu.Lock()
|
||||
|
||||
if tcpMuxes == nil {
|
||||
tcpMuxes = map[string]*tcpMux{}
|
||||
}
|
||||
|
||||
tcpMuxesMu.Unlock()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *tcpIPMux) Remove(key string) {
|
||||
tcpMuxesMu.Lock()
|
||||
defer tcpMuxesMu.Unlock()
|
||||
|
||||
if tcpMux, ok := tcpMuxes[key]; ok {
|
||||
err := tcpMux.Close()
|
||||
if err != nil {
|
||||
m.params.Logger.Errorf("Error closing tcpMux for key: %s: %s", key, err)
|
||||
}
|
||||
delete(tcpMuxes, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *tcpIPMux) RemoveUfrag(ufrag string) {
|
||||
tcpMuxesMu.Lock()
|
||||
defer tcpMuxesMu.Unlock()
|
||||
|
||||
for _, tcpMux := range tcpMuxes {
|
||||
tcpMux.RemoveConn(ufrag)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *tcpIPMux) Listen(ip net.IP) (*tcpMux, error) {
|
||||
tcpMuxesMu.Lock()
|
||||
defer tcpMuxesMu.Unlock()
|
||||
|
||||
key := net.JoinHostPort(ip.String(), strconv.Itoa(m.params.ListenPort))
|
||||
|
||||
tcpMux, ok := tcpMuxes[key]
|
||||
if ok {
|
||||
return tcpMux, nil
|
||||
}
|
||||
|
||||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: ip,
|
||||
Port: m.params.ListenPort,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key = net.JoinHostPort(ip.String(), strconv.Itoa(listener.Addr().(*net.TCPAddr).Port))
|
||||
|
||||
tcpMux = newTCPMux(tcpMuxParams{
|
||||
Listener: listener,
|
||||
Logger: m.params.Logger,
|
||||
ReadBufferSize: m.params.ReadBufferSize,
|
||||
})
|
||||
|
||||
tcpMuxes[key] = tcpMux
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
<-tcpMux.CloseChannel()
|
||||
m.Remove(key)
|
||||
}()
|
||||
|
||||
return tcpMux, nil
|
||||
}
|
58
tcp_ip_mux_test.go
Normal file
58
tcp_ip_mux_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun"
|
||||
"github.com/pion/transport/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTCP_Recv(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
loggerFactory := logging.NewDefaultLoggerFactory()
|
||||
|
||||
tim := newTCPIPMux(tcpIPMuxParams{
|
||||
ListenPort: 8080,
|
||||
Logger: loggerFactory.NewLogger("ice"),
|
||||
ReadBufferSize: 20,
|
||||
})
|
||||
|
||||
tcpMux, err := tim.Listen(net.IP{127, 0, 0, 1})
|
||||
require.NoError(t, err, "error starting listener")
|
||||
defer func() {
|
||||
_ = tcpMux.Close()
|
||||
}()
|
||||
|
||||
require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
|
||||
require.NoError(t, err, "error dialing test tcp connection")
|
||||
|
||||
msg := stun.New()
|
||||
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
|
||||
msg.Add(stun.AttrUsername, []byte("myufrag:otherufrag"))
|
||||
msg.Add(stun.AttrICEControlling, nil)
|
||||
msg.Encode()
|
||||
|
||||
n, err := writeStreamingPacket(conn, msg.Raw)
|
||||
require.NoError(t, err, "error writing tcp stun packet")
|
||||
|
||||
pktConn, err := tcpMux.GetConn("myufrag")
|
||||
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||
defer func() {
|
||||
_ = pktConn.Close()
|
||||
}()
|
||||
|
||||
recv := make([]byte, n)
|
||||
n2, raddr, err := pktConn.ReadFrom(recv)
|
||||
require.NoError(t, err, "error receiving data")
|
||||
assert.Equal(t, conn.LocalAddr(), raddr, "remote tcp address mismatch")
|
||||
assert.Equal(t, n, n2, "received byte size mismatch")
|
||||
assert.Equal(t, msg.Raw, recv, "received bytes mismatch")
|
||||
}
|
273
tcp_mux.go
Normal file
273
tcp_mux.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun"
|
||||
)
|
||||
|
||||
type tcpMux struct {
|
||||
params *tcpMuxParams
|
||||
|
||||
// conns is a map of all tcpPacketConns indexed by ufrag
|
||||
conns map[string]*tcpPacketConn
|
||||
|
||||
mu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
closedChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type tcpMuxParams struct {
|
||||
Listener net.Listener
|
||||
Logger logging.LeveledLogger
|
||||
ReadBufferSize int
|
||||
}
|
||||
|
||||
func newTCPMux(params tcpMuxParams) *tcpMux {
|
||||
m := &tcpMux{
|
||||
params: ¶ms,
|
||||
|
||||
conns: map[string]*tcpPacketConn{},
|
||||
|
||||
closedChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
m.start()
|
||||
}()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *tcpMux) start() {
|
||||
m.params.Logger.Infof("Listening TCP on %s\n", m.params.Listener.Addr())
|
||||
for {
|
||||
conn, err := m.params.Listener.Accept()
|
||||
if err != nil {
|
||||
m.params.Logger.Infof("Error accepting connection: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.params.Logger.Debugf("Accepted connection from: %s to %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
m.handleConn(conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *tcpMux) LocalAddr() net.Addr {
|
||||
return m.params.Listener.Addr()
|
||||
}
|
||||
|
||||
func (m *tcpMux) GetConn(ufrag string) (net.PacketConn, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
conn, ok := m.conns[ufrag]
|
||||
|
||||
if ok {
|
||||
return conn, nil
|
||||
// return nil, fmt.Errorf("duplicate ufrag %v", ufrag)
|
||||
}
|
||||
|
||||
conn = m.createConn(ufrag, m.LocalAddr())
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (m *tcpMux) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn {
|
||||
conn := newTCPPacketConn(tcpPacketParams{
|
||||
ReadBuffer: m.params.ReadBufferSize,
|
||||
LocalAddr: localAddr,
|
||||
Logger: m.params.Logger,
|
||||
})
|
||||
m.conns[ufrag] = conn
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
<-conn.CloseChannel()
|
||||
m.RemoveConn(ufrag)
|
||||
}()
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (m *tcpMux) closeAndLogError(closer io.Closer) {
|
||||
err := closer.Close()
|
||||
if err != nil {
|
||||
m.params.Logger.Warnf("Error closing connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *tcpMux) handleConn(conn net.Conn) {
|
||||
buf := make([]byte, receiveMTU)
|
||||
|
||||
n, err := readStreamingPacket(conn, buf)
|
||||
|
||||
if err != nil {
|
||||
m.params.Logger.Warnf("Error reading first packet: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
buf = buf[:n]
|
||||
|
||||
msg := &stun.Message{
|
||||
Raw: make([]byte, len(buf)),
|
||||
}
|
||||
// Explicitly copy raw buffer so Message can own the memory.
|
||||
copy(msg.Raw, buf)
|
||||
if err = msg.Decode(); err != nil {
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v\n", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
if m == nil || msg.Type.Method != stun.MethodBinding { // not a stun
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("Not a STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
|
||||
return
|
||||
}
|
||||
|
||||
for _, attr := range msg.Attributes {
|
||||
m.params.Logger.Debugf("msg attr: %s\n", attr.String())
|
||||
}
|
||||
|
||||
// Firefox will send ICEControlling for its Active canddiate. We
|
||||
// currently support passive local TCP candidates only.
|
||||
//
|
||||
// TODO: not sure what will be sent for caniddate with tcptype S-O.
|
||||
_, err = msg.Get(stun.AttrICEControlling)
|
||||
if err != nil {
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("No ICEControlling attribute in STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
|
||||
return
|
||||
}
|
||||
|
||||
attr, err := msg.Get(stun.AttrUsername)
|
||||
if err != nil {
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("No Username attribute in STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
|
||||
return
|
||||
}
|
||||
|
||||
ufrag := strings.Split(string(attr), ":")[0]
|
||||
m.params.Logger.Debugf("Ufrag: %s\n", ufrag)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
packetConn, ok := m.conns[ufrag]
|
||||
if !ok {
|
||||
packetConn = m.createConn(ufrag, conn.LocalAddr())
|
||||
}
|
||||
|
||||
if err := packetConn.AddConn(conn, buf); err != nil {
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("Error adding conn to tcpPacketConn from %s to %s, %w\n", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (m *tcpMux) Close() error {
|
||||
m.mu.Lock()
|
||||
|
||||
m.closeOnce.Do(func() {
|
||||
close(m.closedChan)
|
||||
})
|
||||
|
||||
m.conns = map[string]*tcpPacketConn{}
|
||||
m.mu.Unlock()
|
||||
|
||||
err := m.params.Listener.Close()
|
||||
|
||||
m.wg.Wait()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *tcpMux) CloseChannel() <-chan struct{} {
|
||||
return m.closedChan
|
||||
}
|
||||
|
||||
func (m *tcpMux) RemoveConn(ufrag string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if conn, ok := m.conns[ufrag]; ok {
|
||||
m.closeAndLogError(conn)
|
||||
delete(m.conns, ufrag)
|
||||
}
|
||||
|
||||
if len(m.conns) == 0 {
|
||||
m.closeOnce.Do(func() {
|
||||
close(m.closedChan)
|
||||
})
|
||||
|
||||
m.closeAndLogError(m.params.Listener)
|
||||
}
|
||||
}
|
||||
|
||||
const streamingPacketHeaderLen = 2
|
||||
|
||||
// readStreamingPacket reads 1 packet from stream
|
||||
// read packet bytes https://tools.ietf.org/html/rfc4571#section-2
|
||||
// 2-byte length header prepends each packet:
|
||||
// 0 1 2 3
|
||||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
// -----------------------------------------------------------------
|
||||
// | LENGTH | RTP or RTCP packet ... |
|
||||
// -----------------------------------------------------------------
|
||||
func readStreamingPacket(conn net.Conn, buf []byte) (int, error) {
|
||||
var header = make([]byte, streamingPacketHeaderLen)
|
||||
var bytesRead, n int
|
||||
var err error
|
||||
|
||||
for bytesRead < streamingPacketHeaderLen {
|
||||
if n, err = conn.Read(header[bytesRead:streamingPacketHeaderLen]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
bytesRead += n
|
||||
}
|
||||
|
||||
length := int(binary.BigEndian.Uint16(header))
|
||||
|
||||
if length > cap(buf) {
|
||||
return length, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
bytesRead = 0
|
||||
for bytesRead < length {
|
||||
if n, err = conn.Read(buf[bytesRead:length]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
bytesRead += n
|
||||
}
|
||||
|
||||
return bytesRead, nil
|
||||
}
|
||||
|
||||
func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) {
|
||||
bufferCopy := make([]byte, streamingPacketHeaderLen+len(buf))
|
||||
binary.BigEndian.PutUint16(bufferCopy, uint16(len(buf)))
|
||||
copy(bufferCopy[2:], buf)
|
||||
|
||||
n, err := conn.Write(bufferCopy)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return n - streamingPacketHeaderLen, nil
|
||||
}
|
237
tcp_packet_conn.go
Normal file
237
tcp_packet_conn.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/logging"
|
||||
)
|
||||
|
||||
type tcpPacketConn struct {
|
||||
params *tcpPacketParams
|
||||
|
||||
// conns is a map of net.Conns indexed by remote net.Addr.String()
|
||||
conns map[string]net.Conn
|
||||
|
||||
recvChan chan streamingPacket
|
||||
|
||||
mu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
closedChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type streamingPacket struct {
|
||||
Data []byte
|
||||
RAddr net.Addr
|
||||
Err error
|
||||
}
|
||||
|
||||
type tcpPacketParams struct {
|
||||
ReadBuffer int
|
||||
LocalAddr net.Addr
|
||||
Logger logging.LeveledLogger
|
||||
}
|
||||
|
||||
func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
|
||||
p := &tcpPacketConn{
|
||||
params: ¶ms,
|
||||
|
||||
conns: map[string]net.Conn{},
|
||||
|
||||
recvChan: make(chan streamingPacket, params.ReadBuffer),
|
||||
closedChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
|
||||
t.params.Logger.Infof("AddConn: %s %s", conn.RemoteAddr().Network(), conn.RemoteAddr())
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-t.closedChan:
|
||||
return io.ErrClosedPipe
|
||||
default:
|
||||
}
|
||||
|
||||
if _, ok := t.conns[conn.RemoteAddr().String()]; ok {
|
||||
return ErrTCPRemoteAddrAlreadyExists
|
||||
}
|
||||
|
||||
t.conns[conn.RemoteAddr().String()] = conn
|
||||
|
||||
if firstPacketData != nil {
|
||||
t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil}
|
||||
}
|
||||
|
||||
t.wg.Add(1)
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
t.startReading(conn)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) startReading(conn net.Conn) {
|
||||
buf := make([]byte, receiveMTU)
|
||||
|
||||
for {
|
||||
n, err := readStreamingPacket(conn, buf)
|
||||
// t.params.Logger.Infof("readStreamingPacket read %d bytes", n)
|
||||
|
||||
if err != nil {
|
||||
t.params.Logger.Infof("Error reading streaming packet: %s\n", err)
|
||||
t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err})
|
||||
t.removeConn(conn)
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]byte, n)
|
||||
copy(data, buf[:n])
|
||||
|
||||
// t.params.Logger.Infof("Writing read streaming packet to recvChan: %d bytes", len(data))
|
||||
t.handleRecv(streamingPacket{data, conn.RemoteAddr(), nil})
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) handleRecv(pkt streamingPacket) {
|
||||
t.mu.Lock()
|
||||
|
||||
recvChan := t.recvChan
|
||||
if t.isClosed() {
|
||||
recvChan = nil
|
||||
}
|
||||
|
||||
t.mu.Unlock()
|
||||
|
||||
select {
|
||||
case recvChan <- pkt:
|
||||
case <-t.closedChan:
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) isClosed() bool {
|
||||
select {
|
||||
case <-t.closedChan:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo is for passive and s-o candidates.
|
||||
func (t *tcpPacketConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) {
|
||||
pkt, ok := <-t.recvChan
|
||||
|
||||
if !ok {
|
||||
return 0, nil, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
if pkt.Err != nil {
|
||||
return 0, pkt.RAddr, pkt.Err
|
||||
}
|
||||
|
||||
if cap(b) < len(pkt.Data) {
|
||||
return 0, pkt.RAddr, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
n = len(pkt.Data)
|
||||
copy(b, pkt.Data[:n])
|
||||
return n, pkt.RAddr, err
|
||||
}
|
||||
|
||||
// WriteTo is for active and s-o candidates.
|
||||
func (t *tcpPacketConn) WriteTo(buf []byte, raddr net.Addr) (n int, err error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
conn, ok := t.conns[raddr.String()]
|
||||
if !ok {
|
||||
return 0, io.ErrClosedPipe
|
||||
// conn, err := net.DialTCP(tcp, nil, raddr.(*net.TCPAddr))
|
||||
|
||||
// if err != nil {
|
||||
// t.params.Logger.Tracef("DialTCP error: %s", err)
|
||||
// return 0, err
|
||||
// }
|
||||
|
||||
// go t.startReading(conn)
|
||||
// t.conns[raddr.String()] = conn
|
||||
}
|
||||
|
||||
n, err = writeStreamingPacket(conn, buf)
|
||||
if err != nil {
|
||||
t.params.Logger.Tracef("Error writing to %s\n", raddr)
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) closeAndLogError(closer io.Closer) {
|
||||
err := closer.Close()
|
||||
if err != nil {
|
||||
t.params.Logger.Warnf("Error closing connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) removeConn(conn net.Conn) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.closeAndLogError(conn)
|
||||
|
||||
delete(t.conns, conn.RemoteAddr().String())
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) Close() error {
|
||||
t.mu.Lock()
|
||||
|
||||
t.closeOnce.Do(func() {
|
||||
close(t.closedChan)
|
||||
close(t.recvChan)
|
||||
})
|
||||
|
||||
for _, conn := range t.conns {
|
||||
t.closeAndLogError(conn)
|
||||
delete(t.conns, conn.RemoteAddr().String())
|
||||
}
|
||||
|
||||
t.mu.Unlock()
|
||||
|
||||
t.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) LocalAddr() net.Addr {
|
||||
return t.params.LocalAddr
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) SetDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) SetReadDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) SetWriteDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) CloseChannel() <-chan struct{} {
|
||||
return t.closedChan
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) String() string {
|
||||
return fmt.Sprintf("tcpPacketConn{LocalAddr: %s}", t.params.LocalAddr)
|
||||
}
|
48
tcptype.go
Normal file
48
tcptype.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package ice
|
||||
|
||||
import "strings"
|
||||
|
||||
// TCPType is the type of ICE TCP candidate as described in
|
||||
// ttps://tools.ietf.org/html/rfc6544#section-4.5
|
||||
type TCPType int
|
||||
|
||||
const (
|
||||
// TCPTypeUnspecified is the default value. For example UDP candidates do not
|
||||
// need this field.
|
||||
TCPTypeUnspecified TCPType = iota
|
||||
// TCPTypeActive is active TCP candidate, which initiates TCP connections.
|
||||
TCPTypeActive
|
||||
// TCPTypePassive is passive TCP candidate, only accepts TCP connections.
|
||||
TCPTypePassive
|
||||
// TCPTypeSimultaneousOpen is like active and passive at the same time.
|
||||
TCPTypeSimultaneousOpen
|
||||
)
|
||||
|
||||
// NewTCPType creates a new TCPType from string.
|
||||
func NewTCPType(value string) TCPType {
|
||||
switch strings.ToLower(value) {
|
||||
case "active":
|
||||
return TCPTypeActive
|
||||
case "passive":
|
||||
return TCPTypePassive
|
||||
case "so":
|
||||
return TCPTypeSimultaneousOpen
|
||||
default:
|
||||
return TCPTypeUnspecified
|
||||
}
|
||||
}
|
||||
|
||||
func (t TCPType) String() string {
|
||||
switch t {
|
||||
case TCPTypeUnspecified:
|
||||
return ""
|
||||
case TCPTypeActive:
|
||||
return "active"
|
||||
case TCPTypePassive:
|
||||
return "passive"
|
||||
case TCPTypeSimultaneousOpen:
|
||||
return "so"
|
||||
default:
|
||||
return ErrUnknownType.Error()
|
||||
}
|
||||
}
|
23
tcptype_test.go
Normal file
23
tcptype_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTCPType(t *testing.T) {
|
||||
var tcpType TCPType
|
||||
|
||||
assert.Equal(t, TCPTypeUnspecified, tcpType)
|
||||
assert.Equal(t, TCPTypeActive, NewTCPType("active"))
|
||||
assert.Equal(t, TCPTypePassive, NewTCPType("passive"))
|
||||
assert.Equal(t, TCPTypeSimultaneousOpen, NewTCPType("so"))
|
||||
assert.Equal(t, TCPTypeUnspecified, NewTCPType("something else"))
|
||||
|
||||
assert.Equal(t, "", TCPTypeUnspecified.String())
|
||||
assert.Equal(t, "active", TCPTypeActive.String())
|
||||
assert.Equal(t, "passive", TCPTypePassive.String())
|
||||
assert.Equal(t, "so", TCPTypeSimultaneousOpen.String())
|
||||
assert.Equal(t, "Unknown", TCPType(-1).String())
|
||||
}
|
20
util.go
20
util.go
@@ -53,6 +53,15 @@ func parseAddr(in net.Addr) (net.IP, int, NetworkType, bool) {
|
||||
return nil, 0, 0, false
|
||||
}
|
||||
|
||||
func createAddr(network NetworkType, ip net.IP, port int) net.Addr {
|
||||
switch {
|
||||
case network.IsTCP():
|
||||
return &net.TCPAddr{IP: ip, Port: port}
|
||||
default:
|
||||
return &net.UDPAddr{IP: ip, Port: port}
|
||||
}
|
||||
}
|
||||
|
||||
func addrEqual(a, b net.Addr) bool {
|
||||
aIP, aPort, aType, aOk := parseAddr(a)
|
||||
if !aOk {
|
||||
@@ -221,14 +230,3 @@ func listenUDPInPortRange(vnet *vnet.Net, log logging.LeveledLogger, portMax, po
|
||||
}
|
||||
return nil, ErrPort
|
||||
}
|
||||
|
||||
func addrIPAndPort(addr net.Addr) (net.IP, int, error) {
|
||||
switch casted := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
return casted.IP, casted.Port, nil
|
||||
case *net.TCPAddr:
|
||||
return casted.IP, casted.Port, nil
|
||||
default:
|
||||
return nil, 0, fmt.Errorf("unsupported address type %T", addr)
|
||||
}
|
||||
}
|
||||
|
13
util_test.go
13
util_test.go
@@ -3,6 +3,8 @@ package ice
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsSupportedIPv6(t *testing.T) {
|
||||
@@ -26,3 +28,14 @@ func TestIsSupportedIPv6(t *testing.T) {
|
||||
t.Errorf("isSupportedIPv6 return false with IPv6 global unicast address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAddr(t *testing.T) {
|
||||
ipv4 := net.IP{127, 0, 0, 1}
|
||||
ipv6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
port := 9000
|
||||
|
||||
assert.Equal(t, &net.UDPAddr{IP: ipv4, Port: port}, createAddr(NetworkTypeUDP4, ipv4, port))
|
||||
assert.Equal(t, &net.UDPAddr{IP: ipv6, Port: port}, createAddr(NetworkTypeUDP6, ipv6, port))
|
||||
assert.Equal(t, &net.TCPAddr{IP: ipv4, Port: port}, createAddr(NetworkTypeTCP4, ipv4, port))
|
||||
assert.Equal(t, &net.TCPAddr{IP: ipv6, Port: port}, createAddr(NetworkTypeTCP6, ipv6, port))
|
||||
}
|
||||
|
Reference in New Issue
Block a user