diff --git a/go.mod b/go.mod index aae9099..db61da2 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/urfave/cli/v2 v2.27.1 github.com/vishvananda/netlink v1.1.0 + golang.org/x/sync v0.7.0 golang.org/x/sys v0.19.0 golang.org/x/term v0.19.0 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 6a8ab93..d981473 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pkg/client/client.go b/pkg/client/client.go index cb3e984..d5e2505 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -2,6 +2,7 @@ package client import ( "bufio" + "errors" "fmt" "net" "os" @@ -87,13 +88,12 @@ func GetLoginInfo() (string, string, error) { return strings.TrimSpace(user), strings.TrimSpace(passwd), nil } -func checkLoginTimeout(c chan string) { +func checkLoginTimeout(c chan string) error { select { case <-c: - return + return nil case <-time.After(10 * time.Second): - log.Error("login timeout") - os.Exit(1) + return errors.New("login timeout") } } func (c *Client) Close() error { @@ -112,15 +112,16 @@ func (c *Client) Close() error { return nil } -func (c *Client) HandleSignal(sigChan chan os.Signal) { +func (c *Client) HandleSignal(sigChan chan os.Signal) error { sig := <-sigChan log.Infof("received signal: %v, send fin pkt to close conn\n", sig) if err := c.Close(); err != nil { - log.Errorf("send fin pkt %s", err.Error()) + return fmt.Errorf("send fin pkt %s", err.Error()) } os.Exit(0) + return nil } func (c *Client) SetLogLevel() { diff --git a/pkg/client/conn.go b/pkg/client/conn.go index 5190883..97fac95 100644 --- a/pkg/client/conn.go +++ b/pkg/client/conn.go @@ -1,16 +1,19 @@ package client import ( + "fmt" "net" - "os" "time" "github.com/lucheng0127/virtuallan/pkg/packet" log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" ) -func (c *Client) HandleConn(netToIface chan *packet.VLPkt) { - go func() { +func (c *Client) HandleConn(netToIface chan *packet.VLPkt) error { + g := new(errgroup.Group) + + g.Go(func() error { for { pkt := <-netToIface if pkt.Type != packet.P_RAW { @@ -25,34 +28,40 @@ func (c *Client) HandleConn(netToIface chan *packet.VLPkt) { _, err = c.Iface.Write(stream) if err != nil { - log.Errorf("write to tap %s %s\n", c.Iface.Name(), err.Error()) - continue + return fmt.Errorf("write to tap %s %s", c.Iface.Name(), err.Error()) } } - }() + }) - for { - var buf [65535]byte + g.Go(func() error { + for { + var buf [65535]byte - n, err := c.Iface.Read(buf[:]) - if err != nil { - log.Errorf("read from tap %s %s\n", c.Iface.Name(), err.Error()) - continue + n, err := c.Iface.Read(buf[:]) + if err != nil { + log.Errorf("read from tap %s %s\n", c.Iface.Name(), err.Error()) + continue + } + + pkt := packet.NewRawPkt(buf[:n]) + stream, err := pkt.Encode() + if err != nil { + log.Warn("encode raw vlpkt failed: ", err) + continue + } + + _, err = c.Conn.Write(stream) + if err != nil { + return fmt.Errorf("send udp stream to %s %s", c.Conn.RemoteAddr().String(), err.Error()) + } } + }) - pkt := packet.NewRawPkt(buf[:n]) - stream, err := pkt.Encode() - if err != nil { - log.Warn("encode raw vlpkt failed: ", err) - continue - } - - _, err = c.Conn.Write(stream) - if err != nil { - log.Errorf("send udp stream to %s %s\n", c.Conn.RemoteAddr().String(), err.Error()) - os.Exit(1) - } + if err := g.Wait(); err != nil { + return err } + + return nil } func SendKeepalive(conn *net.UDPConn, addr string) error { diff --git a/pkg/client/run_linux.go b/pkg/client/run_linux.go index 6c9878e..d54de77 100644 --- a/pkg/client/run_linux.go +++ b/pkg/client/run_linux.go @@ -1,12 +1,12 @@ package client import ( + "errors" "fmt" "net" "os" "os/signal" "strings" - "sync" "github.com/erikdubbelboer/gspt" "github.com/lucheng0127/virtuallan/pkg/cipher" @@ -14,6 +14,7 @@ import ( "github.com/lucheng0127/virtuallan/pkg/utils" log "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" + "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" ) @@ -74,26 +75,28 @@ func (c *Client) Launch() error { c.Conn = conn + // Use errgroup check goruntine error + g := new(errgroup.Group) + // Handle signal sigChan := make(chan os.Signal, 8) signal.Notify(sigChan, unix.SIGTERM, unix.SIGINT) - go c.HandleSignal(sigChan) + g.Go(func() error { + return c.HandleSignal(sigChan) + }) // Do auth ipChan := make(chan string) netToIface := make(chan *packet.VLPkt, 1024) - var wg sync.WaitGroup - wg.Add(3) // Handle udp packet - go func() { + g.Go(func() error { for { var buf [65535]byte n, _, err := conn.ReadFromUDP(buf[:]) if err != nil { - log.Error("read from conn ", err) - os.Exit(1) + return fmt.Errorf("read from conn %s", err.Error()) } if n < 2 { @@ -110,14 +113,11 @@ func (c *Client) Launch() error { case packet.P_RESPONSE: switch pkt.VLBody.(*packet.RspBody).Code { case packet.RSP_AUTH_REQUIRED: - log.Error("auth failed") - os.Exit(1) + return errors.New("auth failed") case packet.RSP_IP_NOT_MATCH: - log.Error("ip not match") - os.Exit(1) + return errors.New("ip not match") case packet.RSP_USER_LOGGED: - log.Error("user already logged by other endpoint") - os.Exit(1) + return errors.New("user already logged by other endpoint") default: continue } @@ -131,24 +131,24 @@ func (c *Client) Launch() error { continue } } - }() + }) // Auth authPkt := packet.NewAuthPkt(c.user, c.password) authStream, err := authPkt.Encode() if err != nil { - log.Error("encode auth packet ", err) - os.Exit(1) + return fmt.Errorf("encode auth packet %s", err.Error()) } _, err = conn.Write(authStream) if err != nil { - log.Error("send auth packet ", err) - os.Exit(1) + return fmt.Errorf("send auth packet %s", err.Error()) } authChan := make(chan string, 1) - go checkLoginTimeout(authChan) + g.Go(func() error { + return checkLoginTimeout(authChan) + }) // Waiting for dhcp ip ipAddr := <-ipChan @@ -186,14 +186,21 @@ func (c *Client) Launch() error { // XXX: Sometime when client restart too fast will not reveice the first multicast pkt // Monitor multicast for route bordcast - go packet.MonitorRouteMulticast(tapIface, strings.Split(c.IPAddr, "/")[0]) + g.Go(func() error { + return packet.MonitorRouteMulticast(tapIface, strings.Split(c.IPAddr, "/")[0]) + }) // Send keepalive go c.DoKeepalive(10) // Switch io between udp net and tap interface - go c.HandleConn(netToIface) + g.Go(func() error { + return c.HandleConn(netToIface) + }) + + if err := g.Wait(); err != nil { + return err + } - wg.Wait() return nil } diff --git a/pkg/client/run_windows.go b/pkg/client/run_windows.go index 5d5cb74..3e9970e 100644 --- a/pkg/client/run_windows.go +++ b/pkg/client/run_windows.go @@ -1,16 +1,18 @@ package client import ( + "errors" + "fmt" "net" "os" "os/signal" - "sync" "github.com/lucheng0127/virtuallan/pkg/cipher" "github.com/lucheng0127/virtuallan/pkg/packet" "github.com/lucheng0127/virtuallan/pkg/utils" log "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" + "golang.org/x/sync/errgroup" ) func Run(cCtx *cli.Context) error { @@ -67,26 +69,28 @@ func (c *Client) Launch() error { c.Conn = conn + // Use errgroup check goruntine error + g := new(errgroup.Group) + // Handle signal sigChan := make(chan os.Signal, 8) signal.Notify(sigChan, os.Interrupt) - go c.HandleSignal(sigChan) + g.Go(func() error { + return c.HandleSignal(sigChan) + }) // Do auth ipChan := make(chan string) netToIface := make(chan *packet.VLPkt, 1024) - var wg sync.WaitGroup - wg.Add(3) // Handle udp packet - go func() { + g.Go(func() error { for { var buf [65535]byte n, _, err := conn.ReadFromUDP(buf[:]) if err != nil { - log.Error("read from conn ", err) - os.Exit(1) + return fmt.Errorf("read from conn %s", err) } if n < 2 { @@ -103,14 +107,11 @@ func (c *Client) Launch() error { case packet.P_RESPONSE: switch pkt.VLBody.(*packet.RspBody).Code { case packet.RSP_AUTH_REQUIRED: - log.Error("auth failed") - os.Exit(1) + return errors.New("auth failed") case packet.RSP_IP_NOT_MATCH: - log.Error("ip not match") - os.Exit(1) + return errors.New("ip not match") case packet.RSP_USER_LOGGED: - log.Error("user already logged by other endpoint") - os.Exit(1) + return errors.New("user already logged by other endpoint") default: continue } @@ -124,24 +125,24 @@ func (c *Client) Launch() error { continue } } - }() + }) // Auth authPkt := packet.NewAuthPkt(c.user, c.password) authStream, err := authPkt.Encode() if err != nil { - log.Error("encode auth packet ", err) - os.Exit(1) + return fmt.Errorf("encode auth packet %s", err.Error()) } _, err = conn.Write(authStream) if err != nil { - log.Error("send auth packet ", err) - os.Exit(1) + return fmt.Errorf("send auth packet %s", err.Error()) } authChan := make(chan string, 1) - go checkLoginTimeout(authChan) + g.Go(func() error { + return checkLoginTimeout(authChan) + }) // Waiting for dhcp ip ipAddr := <-ipChan @@ -159,8 +160,13 @@ func (c *Client) Launch() error { go c.DoKeepalive(10) // Switch io between udp net and tap interface - go c.HandleConn(netToIface) + g.Go(func() error { + return c.HandleConn(netToIface) + }) + + if err := g.Wait(); err != nil { + return err + } - wg.Wait() return nil } diff --git a/pkg/packet/multicast_linux.go b/pkg/packet/multicast_linux.go index a64afca..d007874 100644 --- a/pkg/packet/multicast_linux.go +++ b/pkg/packet/multicast_linux.go @@ -35,19 +35,17 @@ func MulticastStream(data []byte) error { return nil } -func MonitorRouteMulticast(iface *net.Interface, tapIP string) { +func MonitorRouteMulticast(iface *net.Interface, tapIP string) error { // Monitor route multicast will run as a goruntine in client so log error but don't exit udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", MULTICAST_ADDR, MULTICAST_PORT)) if err != nil { - log.Errorf("parse multicast addr %s", err.Error()) - return + return fmt.Errorf("parse multicast addr %s", err.Error()) } // Listen multicast on tap interface ln, err := net.ListenMulticastUDP("udp", iface, udpAddr) if err != nil { - log.Errorf("listen multicast address %s", err.Error()) - return + return fmt.Errorf("listen multicast address %s", err.Error()) } // Read data from udp @@ -82,7 +80,7 @@ func MonitorRouteMulticast(iface *net.Interface, tapIP string) { // TODO: Implement in windows // Sync routes, use flag replace, for unknow ip need delete if err := utils.SyncRoutesForIface(iface.Name, tapIP, routes); err != nil { - log.Errorf("sync route for %s %s", iface.Name, err.Error()) + return fmt.Errorf("sync route for %s %s", iface.Name, err.Error()) } } }