diff --git a/cmd/main.go b/cmd/main.go index 86fb278..8589866 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "log" @@ -27,7 +28,7 @@ func main() { log.Fatal(err) } - e, err := engine.Run(cfg) + e, err := engine.Run(context.Background(), cfg) if err != nil { log.Fatal(err) } diff --git a/core/config/config.go b/core/config/config.go index 770d446..6d7258a 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -23,6 +23,8 @@ type Config struct { PeerID string Bootstraps []string PeersRouteTable map[string]netip.Prefix + Relays []string + EnableAutoRelay bool EnableMDNS bool // log diff --git a/core/engine/engine.go b/core/engine/engine.go index 881c2ef..bdf42a0 100644 --- a/core/engine/engine.go +++ b/core/engine/engine.go @@ -56,27 +56,65 @@ type Engine struct { } } -func Run(cfg *config.Config) (*Engine, error) { +func Run(ctx context.Context, cfg *config.Config) (*Engine, error) { var ( - e = new(Engine) - err error + e = new(Engine) + err error + options []libp2p.Option ) e.cfg = cfg mlog.SetOutputTypes(cfg.LogConfigs...) e.log = mlog.New("engine") - e.ctx, e.cancel = context.WithCancel(context.Background()) + e.ctx, e.cancel = context.WithCancel(ctx) e.devWriter = make(PacketChan, ChanSize) e.devReader = make(PacketChan, ChanSize) - e.relayChan = make(chan peer.AddrInfo, ChanSize) pk, err := cfg.PrivateKey.PrivKey() if err != nil { return nil, err } - node, err := libp2p.New( - libp2p.Identity(pk), - ) + options = append(options, libp2p.Identity(pk)) + + if len(cfg.Relays) > 0 { + var relays []peer.AddrInfo + for _, relay := range cfg.Relays { + addrInfo, err := peer.AddrInfoFromString(relay) + if err != nil { + e.log.Warnf("fail to parse '%s': %v", relay, err) + continue + } + relays = append(relays, *addrInfo) + } + options = append(options, libp2p.EnableAutoRelayWithStaticRelays(relays)) + } else if cfg.EnableAutoRelay { + e.relayChan = make(chan peer.AddrInfo, ChanSize) + options = append(options, libp2p.EnableAutoRelayWithPeerSource(func(ctx context.Context, num int) <-chan peer.AddrInfo { + c := make(chan peer.AddrInfo, num) + go func() { + defer close(c) + for ; num >= 0; num-- { + select { + case v, ok := <-e.relayChan: + if !ok { + return + } + e.log.Debugf("auto relay find node: %v", v) + select { + case c <- v: + case <-ctx.Done(): + return + } + case <-ctx.Done(): + return + } + } + }() + return c + })) + } + + node, err := libp2p.New(options...) if err != nil { return nil, err } @@ -138,6 +176,11 @@ func (e *Engine) Run() error { } } + if len(e.cfg.Relays) == 0 && e.cfg.EnableAutoRelay { + // start auto relay detect + go e.autoRelayFinder(e.ctx) + } + e.host.SetStreamHandler(VPNStreamProtocol, e.VPNHandler) util.Advertise(e.ctx, e.discovery, e.host.ID().String()) diff --git a/core/engine/relay.go b/core/engine/relay.go new file mode 100644 index 0000000..0f1b7e8 --- /dev/null +++ b/core/engine/relay.go @@ -0,0 +1,53 @@ +package engine + +import ( + "context" + "time" + + "github.com/libp2p/go-libp2p/core/peer" +) + +func (e *Engine) autoRelayFinder(ctx context.Context) { + e.log.Debugf("successfully start auto relay finder!") + peers := e.host.Network().Peers() + for _, p := range peers { + addrs := e.host.Peerstore().Addrs(p) + if len(addrs) == 0 { + continue + } + node := peer.AddrInfo{ID: p, Addrs: addrs} + select { + case e.relayChan <- node: + e.log.Debugf("find relay candidate node %s", node) + default: + } + } + + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + closestPeers, err := e.dht.GetClosestPeers(ctx, e.host.ID().String()) + if err != nil { + e.log.Warnf("autoRelay get cloest peers error: %v", err) + continue + } + + for _, p := range closestPeers { + addrs := e.host.Peerstore().Addrs(p) + if len(addrs) == 0 { + continue + } + node := peer.AddrInfo{ID: p, Addrs: addrs} + select { + case e.relayChan <- node: + e.log.Debugf("find relay candidate node %s", node) + default: + } + } + } + } +}