mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-10-06 00:57:23 +08:00
global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/go118/netip"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
@@ -38,7 +39,7 @@ type netTun struct {
|
||||
events chan tun.Event
|
||||
incomingPacket chan buffer.VectorisedView
|
||||
mtu int
|
||||
dnsServers []net.IP
|
||||
dnsServers []netip.Addr
|
||||
hasV4, hasV6 bool
|
||||
}
|
||||
type endpoint netTun
|
||||
@@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
|
||||
}
|
||||
|
||||
func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
|
||||
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
|
||||
opts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
|
||||
@@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
|
||||
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
||||
}
|
||||
for _, ip := range localAddresses {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
|
||||
}
|
||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
|
||||
}
|
||||
var protoNumber tcpip.NetworkProtocolNumber
|
||||
if ip.Is4() {
|
||||
protoNumber = ipv4.ProtocolNumber
|
||||
} else if ip.Is6() {
|
||||
protoNumber = ipv6.ProtocolNumber
|
||||
}
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: protoNumber,
|
||||
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
|
||||
}
|
||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
|
||||
}
|
||||
if ip.Is4() {
|
||||
dev.hasV4 = true
|
||||
} else {
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv6.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
|
||||
}
|
||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
|
||||
}
|
||||
} else if ip.Is6() {
|
||||
dev.hasV6 = true
|
||||
}
|
||||
}
|
||||
@@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
|
||||
return tun.mtu, nil
|
||||
}
|
||||
|
||||
func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.Address(ip4),
|
||||
Port: uint16(port),
|
||||
}, ipv4.ProtocolNumber
|
||||
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||
var protoNumber tcpip.NetworkProtocolNumber
|
||||
if endpoint.Addr().Is4() {
|
||||
protoNumber = ipv4.ProtocolNumber
|
||||
} else {
|
||||
return tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.Address(ip),
|
||||
Port: uint16(port),
|
||||
}, ipv6.ProtocolNumber
|
||||
protoNumber = ipv6.ProtocolNumber
|
||||
}
|
||||
return tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
|
||||
Port: endpoint.Port(),
|
||||
}, protoNumber
|
||||
}
|
||||
|
||||
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
|
||||
fa, pn := convertToFullAddr(addr)
|
||||
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
|
||||
}
|
||||
|
||||
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
||||
if addr == nil {
|
||||
panic("todo: deal with auto addr semantics for nil addr")
|
||||
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
|
||||
}
|
||||
fa, pn := convertToFullAddr(addr.IP, addr.Port)
|
||||
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
|
||||
return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
|
||||
}
|
||||
|
||||
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
|
||||
fa, pn := convertToFullAddr(addr)
|
||||
return gonet.DialTCP(net.stack, fa, pn)
|
||||
}
|
||||
|
||||
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
||||
if addr == nil {
|
||||
panic("todo: deal with auto addr semantics for nil addr")
|
||||
return net.DialTCPAddrPort(netip.AddrPort{})
|
||||
}
|
||||
fa, pn := convertToFullAddr(addr.IP, addr.Port)
|
||||
return gonet.DialTCP(net.stack, fa, pn)
|
||||
return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
|
||||
}
|
||||
|
||||
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
|
||||
fa, pn := convertToFullAddr(addr)
|
||||
return gonet.ListenTCP(net.stack, fa, pn)
|
||||
}
|
||||
|
||||
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
|
||||
if addr == nil {
|
||||
panic("todo: deal with auto addr semantics for nil addr")
|
||||
return net.ListenTCPAddrPort(netip.AddrPort{})
|
||||
}
|
||||
fa, pn := convertToFullAddr(addr.IP, addr.Port)
|
||||
return gonet.ListenTCP(net.stack, fa, pn)
|
||||
return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
|
||||
}
|
||||
|
||||
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
|
||||
var lfa, rfa *tcpip.FullAddress
|
||||
var pn tcpip.NetworkProtocolNumber
|
||||
if laddr != nil {
|
||||
if laddr.IsValid() || laddr.Port() > 0 {
|
||||
var addr tcpip.FullAddress
|
||||
addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
|
||||
addr, pn = convertToFullAddr(laddr)
|
||||
lfa = &addr
|
||||
}
|
||||
if raddr != nil {
|
||||
if raddr.IsValid() || raddr.Port() > 0 {
|
||||
var addr tcpip.FullAddress
|
||||
addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
|
||||
addr, pn = convertToFullAddr(raddr)
|
||||
rfa = &addr
|
||||
}
|
||||
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
||||
}
|
||||
|
||||
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||
var la, ra netip.AddrPort
|
||||
if laddr != nil {
|
||||
la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
|
||||
}
|
||||
if raddr != nil {
|
||||
ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
|
||||
}
|
||||
return net.DialUDPAddrPort(la, ra)
|
||||
}
|
||||
|
||||
var (
|
||||
errNoSuchHost = errors.New("no such host")
|
||||
errLameReferral = errors.New("lame referral")
|
||||
@@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
|
||||
return p, h, nil
|
||||
}
|
||||
|
||||
func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
|
||||
func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
|
||||
q.Class = dnsmessage.ClassINET
|
||||
id, udpReq, tcpReq, err := newRequest(q)
|
||||
if err != nil {
|
||||
@@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
|
||||
var c net.Conn
|
||||
var err error
|
||||
if useUDP {
|
||||
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
|
||||
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
|
||||
} else {
|
||||
c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
|
||||
c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
||||
zlen = zidx
|
||||
}
|
||||
}
|
||||
if ip := net.ParseIP(host[:zlen]); ip != nil {
|
||||
return []string{host[:zlen]}, nil
|
||||
if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
|
||||
return []string{ip.String()}, nil
|
||||
}
|
||||
|
||||
if !isDomainName(host) {
|
||||
@@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
||||
server string
|
||||
error
|
||||
}
|
||||
var addrsV4, addrsV6 []net.IP
|
||||
var addrsV4, addrsV6 []netip.Addr
|
||||
lanes := 0
|
||||
if tnet.hasV4 {
|
||||
lanes++
|
||||
@@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
||||
}
|
||||
break loop
|
||||
}
|
||||
addrsV4 = append(addrsV4, net.IP(a.A[:]))
|
||||
addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
|
||||
|
||||
case dnsmessage.TypeAAAA:
|
||||
aaaa, err := result.p.AAAAResource()
|
||||
@@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
||||
}
|
||||
break loop
|
||||
}
|
||||
addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
|
||||
addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
|
||||
|
||||
default:
|
||||
if err := result.p.SkipAnswer(); err != nil {
|
||||
@@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
||||
}
|
||||
}
|
||||
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
|
||||
var addrs []net.IP
|
||||
var addrs []netip.Addr
|
||||
if tnet.hasV6 {
|
||||
addrs = append(addrsV6, addrsV4...)
|
||||
} else {
|
||||
@@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
|
||||
if err != nil {
|
||||
return nil, &net.OpError{Op: "dial", Err: err}
|
||||
}
|
||||
var addrs []net.IP
|
||||
var addrs []netip.AddrPort
|
||||
for _, addr := range allAddr {
|
||||
if strings.IndexByte(addr, ':') != -1 && acceptV6 {
|
||||
addrs = append(addrs, net.ParseIP(addr))
|
||||
} else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
|
||||
addrs = append(addrs, net.ParseIP(addr))
|
||||
ip, err := netip.ParseAddr(addr)
|
||||
if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
|
||||
addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
|
||||
}
|
||||
}
|
||||
if len(addrs) == 0 && len(allAddr) != 0 {
|
||||
@@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
|
||||
|
||||
var c net.Conn
|
||||
if useUDP {
|
||||
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
|
||||
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
|
||||
} else {
|
||||
c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
|
||||
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
|
||||
}
|
||||
if err == nil {
|
||||
return c, nil
|
||||
|
Reference in New Issue
Block a user