mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-10-16 21:51:17 +08:00
tun/netstack: implement ICMP ping
Provide a PacketConn interface for netstack's ICMP endpoint; netstack currently only provides EchoRequest/EchoResponse ICMP support, so this code exposes only an interface for doing ping. Currently is missing: - Write deadlines - Context support Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org> [Jason: rework structure, match std go interfaces, add example code] Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:

committed by
Jason A. Donenfeld

parent
e0b8f11489
commit
a702597e22
@@ -14,6 +14,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -29,8 +30,10 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
type netTun struct {
|
||||
@@ -101,7 +104,7 @@ func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.Network
|
||||
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
|
||||
opts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
|
||||
HandleLocal: true,
|
||||
}
|
||||
dev := &netTun{
|
||||
@@ -281,6 +284,178 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||
return net.DialUDPAddrPort(la, ra)
|
||||
}
|
||||
|
||||
type PingConn struct {
|
||||
laddr PingAddr
|
||||
raddr PingAddr
|
||||
wq waiter.Queue
|
||||
ep tcpip.Endpoint
|
||||
deadline time.Time
|
||||
}
|
||||
|
||||
type PingAddr struct{ addr netip.Addr }
|
||||
|
||||
func (ia PingAddr) String() string {
|
||||
return ia.addr.String()
|
||||
}
|
||||
|
||||
func (ia PingAddr) Network() string {
|
||||
if ia.addr.Is4() {
|
||||
return "ping4"
|
||||
} else if ia.addr.Is6() {
|
||||
return "ping6"
|
||||
}
|
||||
return "ping"
|
||||
}
|
||||
|
||||
func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
|
||||
v6 := laddr.Is6() || raddr.Is6()
|
||||
bind := laddr.IsValid()
|
||||
if !bind {
|
||||
if v6 {
|
||||
laddr = netip.IPv6Unspecified()
|
||||
} else {
|
||||
laddr = netip.IPv4Unspecified()
|
||||
}
|
||||
}
|
||||
|
||||
tn := icmp.ProtocolNumber4
|
||||
pn := ipv4.ProtocolNumber
|
||||
if v6 {
|
||||
tn = icmp.ProtocolNumber6
|
||||
pn = ipv6.ProtocolNumber
|
||||
}
|
||||
|
||||
pc := &PingConn{laddr: PingAddr{laddr}}
|
||||
|
||||
ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
|
||||
if tcpipErr != nil {
|
||||
return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
|
||||
}
|
||||
pc.ep = ep
|
||||
|
||||
if bind {
|
||||
fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
|
||||
if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
|
||||
return nil, fmt.Errorf("ping bind: %s", tcpipErr)
|
||||
}
|
||||
}
|
||||
|
||||
if raddr.IsValid() {
|
||||
pc.raddr = PingAddr{raddr}
|
||||
fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
|
||||
if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
|
||||
return nil, fmt.Errorf("ping connect: %s", tcpipErr)
|
||||
}
|
||||
}
|
||||
|
||||
return pc, nil
|
||||
}
|
||||
|
||||
func (pc *PingConn) LocalAddr() net.Addr {
|
||||
return pc.laddr
|
||||
}
|
||||
|
||||
func (pc *PingConn) RemoteAddr() net.Addr {
|
||||
return pc.raddr
|
||||
}
|
||||
|
||||
func (pc *PingConn) Close() error {
|
||||
pc.ep.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pc *PingConn) SetWriteDeadline(t time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
ia, ok := addr.(PingAddr)
|
||||
if !ok || !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) {
|
||||
return 0, fmt.Errorf("ping write: mismatched protocols")
|
||||
}
|
||||
|
||||
var buf buffer.View
|
||||
if ia.addr.Is4() {
|
||||
buf = buffer.NewView(header.ICMPv4MinimumSize + len(p))
|
||||
copy(buf[header.ICMPv4MinimumSize:], p)
|
||||
icmp := header.ICMPv4(buf)
|
||||
icmp.SetType(header.ICMPv4Echo)
|
||||
} else if ia.addr.Is6() {
|
||||
buf = buffer.NewView(header.ICMPv6MinimumSize + len(p))
|
||||
copy(buf[header.ICMPv6MinimumSize:], p)
|
||||
icmp := header.ICMPv6(buf)
|
||||
icmp.SetType(header.ICMPv6EchoRequest)
|
||||
}
|
||||
|
||||
rdr := buf.Reader()
|
||||
rfa, _ := convertToFullAddr(netip.AddrPortFrom(ia.addr, 0))
|
||||
// won't block, no deadlines
|
||||
n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{
|
||||
To: &rfa,
|
||||
})
|
||||
if tcpipErr != nil {
|
||||
return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
|
||||
}
|
||||
|
||||
return int(n64), nil
|
||||
}
|
||||
|
||||
func (pc *PingConn) Write(p []byte) (n int, err error) {
|
||||
return pc.WriteTo(p, pc.raddr)
|
||||
}
|
||||
|
||||
func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
e, notifyCh := waiter.NewChannelEntry(nil)
|
||||
pc.wq.EventRegister(&e, waiter.EventIn)
|
||||
defer pc.wq.EventUnregister(&e)
|
||||
|
||||
deadline := pc.deadline
|
||||
|
||||
if deadline.IsZero() {
|
||||
<-notifyCh
|
||||
} else {
|
||||
select {
|
||||
case <-time.NewTimer(deadline.Sub(time.Now())).C:
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
case <-notifyCh:
|
||||
}
|
||||
}
|
||||
|
||||
min := header.ICMPv6MinimumSize
|
||||
if pc.laddr.addr.Is4() {
|
||||
min = header.ICMPv4MinimumSize
|
||||
}
|
||||
reply := make([]byte, min+len(p))
|
||||
w := tcpip.SliceWriter(reply)
|
||||
|
||||
res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
|
||||
NeedRemoteAddr: true,
|
||||
})
|
||||
if tcpipErr != nil {
|
||||
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
|
||||
}
|
||||
|
||||
addr = PingAddr{netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))}
|
||||
copy(p, reply[min:res.Count])
|
||||
return res.Count - min, addr, nil
|
||||
}
|
||||
|
||||
func (pc *PingConn) Read(p []byte) (n int, err error) {
|
||||
n, _, err = pc.ReadFrom(p)
|
||||
return
|
||||
}
|
||||
|
||||
func (pc *PingConn) SetDeadline(t time.Time) error {
|
||||
// pc.SetWriteDeadline is unimplemented
|
||||
|
||||
return pc.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (pc *PingConn) SetReadDeadline(t time.Time) error {
|
||||
pc.deadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
errNoSuchHost = errors.New("no such host")
|
||||
errLameReferral = errors.New("lame referral")
|
||||
@@ -755,33 +930,38 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
|
||||
return now.Add(timeout), nil
|
||||
}
|
||||
|
||||
var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
|
||||
|
||||
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if ctx == nil {
|
||||
panic("nil context")
|
||||
}
|
||||
var acceptV4, acceptV6, useUDP bool
|
||||
if len(network) == 3 {
|
||||
var acceptV4, acceptV6 bool
|
||||
matches := protoSplitter.FindStringSubmatch(network)
|
||||
if matches == nil {
|
||||
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
|
||||
} else if len(matches[2]) == 0 {
|
||||
acceptV4 = true
|
||||
acceptV6 = true
|
||||
} else if len(network) == 4 {
|
||||
acceptV4 = network[3] == '4'
|
||||
acceptV6 = network[3] == '6'
|
||||
} else {
|
||||
acceptV4 = matches[2][0] == '4'
|
||||
acceptV6 = !acceptV4
|
||||
}
|
||||
if !acceptV4 && !acceptV6 {
|
||||
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
|
||||
}
|
||||
if network[:3] == "udp" {
|
||||
useUDP = true
|
||||
} else if network[:3] != "tcp" {
|
||||
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
|
||||
}
|
||||
host, sport, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, &net.OpError{Op: "dial", Err: err}
|
||||
}
|
||||
port, err := strconv.Atoi(sport)
|
||||
if err != nil || port < 0 || port > 65535 {
|
||||
return nil, &net.OpError{Op: "dial", Err: errNumericPort}
|
||||
var host string
|
||||
var port int
|
||||
if matches[1] == "ping" {
|
||||
host = address
|
||||
} else {
|
||||
var sport string
|
||||
var err error
|
||||
host, sport, err = net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, &net.OpError{Op: "dial", Err: err}
|
||||
}
|
||||
port, err = strconv.Atoi(sport)
|
||||
if err != nil || port < 0 || port > 65535 {
|
||||
return nil, &net.OpError{Op: "dial", Err: errNumericPort}
|
||||
}
|
||||
}
|
||||
allAddr, err := tnet.LookupContextHost(ctx, host)
|
||||
if err != nil {
|
||||
@@ -829,10 +1009,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
|
||||
}
|
||||
|
||||
var c net.Conn
|
||||
if useUDP {
|
||||
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
|
||||
} else {
|
||||
switch matches[1] {
|
||||
case "tcp":
|
||||
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
|
||||
case "udp":
|
||||
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
|
||||
case "ping":
|
||||
c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
|
||||
}
|
||||
if err == nil {
|
||||
return c, nil
|
||||
|
Reference in New Issue
Block a user