much simpler resolve

This commit is contained in:
Juan Batiz-Benet
2015-01-18 20:39:30 -08:00
parent d2ae70d361
commit eb450c7b6d
4 changed files with 44 additions and 8 deletions

20
addr.go Normal file
View File

@@ -0,0 +1,20 @@
package reuseport
import (
"net"
)
func ResolveAddr(network, address string) (net.Addr, error) {
switch network {
default:
return nil, net.UnknownNetworkError(network)
case "ip", "ip4", "ip6":
return net.ResolveIPAddr(network, address)
case "tcp", "tcp4", "tcp6":
return net.ResolveTCPAddr(network, address)
case "udp", "udp4", "udp6":
return net.ResolveUDPAddr(network, address)
case "unix", "unixgram", "unixpacket":
return net.ResolveUnixAddr(network, address)
}
}

View File

@@ -8,7 +8,6 @@ import (
"strconv" "strconv"
"syscall" "syscall"
resolve "github.com/jbenet/go-net-resolve-addr"
sockaddrnet "github.com/jbenet/go-sockaddr/net" sockaddrnet "github.com/jbenet/go-sockaddr/net"
) )
@@ -26,7 +25,7 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) {
localSockaddr syscall.Sockaddr localSockaddr syscall.Sockaddr
) )
netAddr, err := resolve.ResolveAddr("dial", netw, addr) netAddr, err := ResolveAddr(netw, addr)
if err != nil { if err != nil {
// fmt.Println("resolve addr failed") // fmt.Println("resolve addr failed")
return nil, err return nil, err
@@ -96,7 +95,7 @@ func listen(netw, addr string) (l net.Listener, err error) {
sockaddr syscall.Sockaddr sockaddr syscall.Sockaddr
) )
netAddr, err := resolve.ResolveAddr("listen", netw, addr) netAddr, err := ResolveAddr(netw, addr)
if err != nil { if err != nil {
// fmt.Println("resolve addr failed") // fmt.Println("resolve addr failed")
return nil, err return nil, err

View File

@@ -20,8 +20,6 @@ package reuseport
import ( import (
"errors" "errors"
"net" "net"
resolve "github.com/jbenet/go-net-resolve-addr"
) )
// ErrUnsuportedProtocol signals that the protocol is not currently // ErrUnsuportedProtocol signals that the protocol is not currently
@@ -45,7 +43,7 @@ func Dial(network, laddr, raddr string) (net.Conn, error) {
var d Dialer var d Dialer
if laddr != "" { if laddr != "" {
netladdr, err := resolve.ResolveAddr("dial", network, laddr) netladdr, err := ResolveAddr(network, laddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -5,6 +5,7 @@ import (
"io" "io"
"net" "net"
"os" "os"
"strings"
"testing" "testing"
) )
@@ -31,6 +32,13 @@ func TestListenSamePort(t *testing.T) {
// any ports // any ports
any := [][]string{ any := [][]string{
[]string{"tcp", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"tcp4", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"tcp6", "[::]:0", "[::]:0"},
[]string{"udp", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp4", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp6", "[::]:0", "[::]:0"},
[]string{"tcp", "127.0.0.1:0"}, []string{"tcp", "127.0.0.1:0"},
[]string{"tcp", "[::1]:0"}, []string{"tcp", "[::1]:0"},
[]string{"tcp4", "127.0.0.1:0"}, []string{"tcp4", "127.0.0.1:0"},
@@ -100,6 +108,13 @@ func TestListenSamePort(t *testing.T) {
func TestListenDialSamePort(t *testing.T) { func TestListenDialSamePort(t *testing.T) {
any := [][]string{ any := [][]string{
[]string{"tcp", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"tcp4", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"tcp6", "[::]:0", "[::]:0"},
[]string{"udp", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp4", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp6", "[::]:0", "[::]:0"},
[]string{"tcp", "127.0.0.1:0", "127.0.0.1:0"}, []string{"tcp", "127.0.0.1:0", "127.0.0.1:0"},
[]string{"tcp4", "127.0.0.1:0", "127.0.0.1:0"}, []string{"tcp4", "127.0.0.1:0", "127.0.0.1:0"},
[]string{"tcp6", "[::1]:0", "[::1]:0"}, []string{"tcp6", "[::1]:0", "[::1]:0"},
@@ -155,11 +170,11 @@ func TestListenDialSamePort(t *testing.T) {
defer c1.Close() defer c1.Close()
t.Log("dialed", c1.LocalAddr(), c1.RemoteAddr()) t.Log("dialed", c1.LocalAddr(), c1.RemoteAddr())
if l1.Addr().String() != c1.LocalAddr().String() { if getPort(l1.Addr()) != getPort(c1.LocalAddr()) {
t.Fatal("addrs should match", l1.Addr(), c1.LocalAddr()) t.Fatal("addrs should match", l1.Addr(), c1.LocalAddr())
} }
if l2.Addr().String() != c1.RemoteAddr().String() { if getPort(l2.Addr()) != getPort(c1.RemoteAddr()) {
t.Fatal("addrs should match", l2.Addr(), c1.RemoteAddr()) t.Fatal("addrs should match", l2.Addr(), c1.RemoteAddr())
} }
@@ -200,3 +215,7 @@ func TestUnixNotSupported(t *testing.T) {
} }
} }
} }
func getPort(a net.Addr) string {
return strings.Split(a.String(), ":")[1]
}