feat: implement mDNS discovery

This commit is contained in:
lynx
2024-04-26 16:13:40 +08:00
parent f9ad3540ce
commit 2b7ed4cb40
4 changed files with 79 additions and 36 deletions

View File

@@ -11,6 +11,7 @@ import (
"unsafe"
"github.com/wlynxg/NetHive/pkgs/command"
"github.com/wlynxg/NetHive/pkgs/system"
"golang.org/x/sys/unix"
)
@@ -18,15 +19,6 @@ import (
// https://man7.org/linux/man-pages/man7/netdevice.7.html
type ifReq [40]byte
// https://man7.org/linux/man-pages/man2/ioctl.2.html
func ioctl(fd uintptr, request uintptr, argp uintptr) error {
_, _, err := unix.Syscall(unix.SYS_IOCTL, fd, request, argp)
if err != 0 {
return os.NewSyscallError("ioctl", err)
}
return nil
}
// compilation time interface check
var _ Device = new(tun)
@@ -107,7 +99,7 @@ func (t *tun) changeState(state bool) error {
} else {
*(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) &^= syscall.IFF_UP
}
err = ioctl(uintptr(fd), unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifr[0])))
err = system.Ioctl(uintptr(fd), unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifr[0])))
if err != nil {
return err
}
@@ -124,7 +116,7 @@ func (t *tun) getNameFromSys() (string, error) {
var ifr ifReq
var errno syscall.Errno
err = conn.Control(func(fd uintptr) {
ioctl(fd, unix.TUNGETIFF, uintptr(unsafe.Pointer(&ifr[0])))
system.Ioctl(fd, unix.TUNGETIFF, uintptr(unsafe.Pointer(&ifr[0])))
})
if err != nil || errno != 0 {
return "", fmt.Errorf("failed to get name of TUN device: %w", err)
@@ -147,7 +139,7 @@ func (t *tun) getMTUFromSys() (int, error) {
var ifr ifReq
copy(ifr[:], t.name)
err = ioctl(uintptr(fd), unix.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifr[0])))
err = system.Ioctl(uintptr(fd), unix.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifr[0])))
if err != nil {
return -1, err
}
@@ -164,7 +156,7 @@ func (t *tun) setMTU(n int) error {
var ifr ifReq
copy(ifr[:], t.name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
err = ioctl(uintptr(fd), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifr[0])))
err = system.Ioctl(uintptr(fd), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifr[0])))
if err != nil {
return err
}
@@ -180,7 +172,7 @@ func (t *tun) getIFIndex() (int32, error) {
var ifr ifReq
copy(ifr[:], t.name)
err = ioctl(uintptr(fd), unix.SIOCGIFINDEX, uintptr(unsafe.Pointer(&ifr[0])))
err = system.Ioctl(uintptr(fd), unix.SIOCGIFINDEX, uintptr(unsafe.Pointer(&ifr[0])))
if err != nil {
return 0, err
}

View File

@@ -21,6 +21,7 @@ func (e *Engine) addConnByDst(dst netip.Addr) (PacketChan, error) {
e.routeTable.m.Range(func(key string, value netip.Prefix) bool {
if value.Addr().Compare(dst) == 0 {
conn = make(PacketChan, ChanSize)
e.routeTable.addr.Store(dst, conn)
go func() {
defer e.routeTable.id.Delete(key)
defer e.routeTable.addr.Delete(dst)
@@ -77,7 +78,7 @@ func (e *Engine) addConn(peerChan PacketChan, id string) {
if err != nil {
peerc, err := e.discovery.FindPeers(e.ctx, id)
if err != nil {
e.log.Warnf("Finding node by dht %s failed because %s", string(id), err)
e.log.Warnf("Finding node by dht %s failed because %s", id, err)
return
}
@@ -89,7 +90,7 @@ func (e *Engine) addConn(peerChan PacketChan, id string) {
}
}
}
e.log.Warnf("Connection establishment with node %s failed", string(id))
e.log.Warnf("Connection establishment with node %s failed", id)
return
}
}
@@ -98,19 +99,19 @@ func (e *Engine) addConn(peerChan PacketChan, id string) {
return
}
e.log.Infof("Peer [%s] connect success", string(id))
e.log.Infof("Peer [%s] connect success", id)
defer stream.Close()
go func() {
defer stream.Close()
_, err := io.Copy(stream, dev)
if err != nil && err != io.EOF {
e.log.Errorf("Peer [%s] stream write error: %s", string(id), err)
e.log.Errorf("Peer [%s] stream write error: %s", id, err)
}
}()
_, err = io.Copy(dev, stream)
if err != nil && err != io.EOF {
e.log.Errorf("Peer [%s] stream read error: %s", string(id), err)
e.log.Errorf("Peer [%s] stream read error: %s", id, err)
}
}

View File

@@ -3,11 +3,12 @@ package engine
import (
"context"
"fmt"
"github.com/wlynxg/NetHive/core/route"
"io"
"net/netip"
"sync"
"github.com/wlynxg/NetHive/core/route"
"github.com/wlynxg/NetHive/core/config"
"github.com/wlynxg/NetHive/core/device"
"github.com/wlynxg/NetHive/core/protocol"
@@ -64,8 +65,9 @@ func Run(cfg *config.Config) (*Engine, error) {
err error
)
e.log = mlog.New("engine")
e.cfg = cfg
mlog.SetOutputTypes(cfg.LogConfigs...)
e.log = mlog.New("engine")
e.ctx, e.cancel = context.WithCancel(context.Background())
e.devWriter = make(PacketChan, ChanSize)
e.devReader = make(PacketChan, ChanSize)
@@ -97,7 +99,7 @@ func (e *Engine) Run() error {
var err error
defer e.cancel()
// create tun
// TUN init
e.device, err = device.CreateTUN(e.cfg.TUNName, e.cfg.MTU)
if err != nil {
return err
@@ -108,17 +110,6 @@ func (e *Engine) Run() error {
return err
}
for id, prefix := range e.cfg.PeersRouteTable {
e.routeTable.m.Store(id, prefix)
err := route.Add(name, prefix)
if err != nil {
e.log.Warnf("fail to add %s's route: %s", id, prefix)
continue
}
e.log.Debugf("successfully add %s's route: %s", id, prefix)
}
if err := e.device.AddAddress(e.cfg.LocalAddr); err != nil {
return err
}
@@ -127,6 +118,18 @@ func (e *Engine) Run() error {
return err
}
for id, prefix := range e.cfg.PeersRouteTable {
e.routeTable.m.Store(id, prefix)
err := route.Add(name, prefix)
if err != nil {
e.log.Warnf("fail to add %s's route %s: %v", id, prefix, err)
continue
}
e.log.Debugf("successfully add %s's route: %s", id, prefix)
}
// DHT init
wg := sync.WaitGroup{}
for _, info := range e.cfg.Bootstraps {
addrInfo, err := peer.AddrInfoFromString(info)
@@ -145,6 +148,13 @@ func (e *Engine) Run() error {
}
wg.Wait()
if e.cfg.EnableMDNS {
err := e.EnableMdns()
if err != nil {
return err
}
}
e.host.SetStreamHandler(VPNStreamProtocol, e.VPNHandler)
util.Advertise(e.ctx, e.discovery, e.host.ID().String())

View File

@@ -4,9 +4,49 @@ import (
"time"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/p2p/discovery/mdns"
)
func (e *Engine) HandlePeerFound(pi peer.AddrInfo) {
e.log.Infof("find %s by mDNS", pi)
e.host.Peerstore().AddAddrs(pi.ID, pi.Addrs, 5*time.Minute)
const (
MDNSRetryInterval = 5 * time.Minute
)
func (e *Engine) HandlePeerFound(info peer.AddrInfo) {
e.log.Debugf("mDNS get node addr info: %s", info)
e.host.Peerstore().AddAddrs(info.ID, info.Addrs, peerstore.AddressTTL)
}
func (e *Engine) EnableMdns() error {
if e.mdns == nil {
// init mdns serve
e.mdns = mdns.NewMdnsService(e.host, "_p2proxy._udp", e)
}
go e.mdnsLoop()
return nil
}
func (e *Engine) mdnsLoop() {
if err := e.mdns.Start(); err != nil {
e.log.Warnf("fail to run mDNS service for the %dth time: %v", 1, err)
} else {
e.log.Infof("successfully run mDNS service!")
return
}
ticker := time.NewTicker(MDNSRetryInterval)
defer ticker.Stop()
for i := 2; ; i++ {
select {
case <-e.ctx.Done():
case <-ticker.C:
if err := e.mdns.Start(); err != nil {
e.log.Warnf("fail to run mDNS service for the %dth time: %v", 1, err)
} else {
e.log.Infof("successfully run mDNS service!")
return
}
}
}
}