Use errgroup to return error of goruntine

This commit is contained in:
lucheng
2024-07-12 09:51:43 +08:00
parent 9d2038ff6f
commit c4db99d32a
7 changed files with 103 additions and 79 deletions

1
go.mod
View File

@@ -12,6 +12,7 @@ require (
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/urfave/cli/v2 v2.27.1
github.com/vishvananda/netlink v1.1.0
golang.org/x/sync v0.7.0
golang.org/x/sys v0.19.0
golang.org/x/term v0.19.0
gopkg.in/yaml.v3 v3.0.1

2
go.sum
View File

@@ -88,6 +88,8 @@ golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -2,6 +2,7 @@ package client
import (
"bufio"
"errors"
"fmt"
"net"
"os"
@@ -87,13 +88,12 @@ func GetLoginInfo() (string, string, error) {
return strings.TrimSpace(user), strings.TrimSpace(passwd), nil
}
func checkLoginTimeout(c chan string) {
func checkLoginTimeout(c chan string) error {
select {
case <-c:
return
return nil
case <-time.After(10 * time.Second):
log.Error("login timeout")
os.Exit(1)
return errors.New("login timeout")
}
}
func (c *Client) Close() error {
@@ -112,15 +112,16 @@ func (c *Client) Close() error {
return nil
}
func (c *Client) HandleSignal(sigChan chan os.Signal) {
func (c *Client) HandleSignal(sigChan chan os.Signal) error {
sig := <-sigChan
log.Infof("received signal: %v, send fin pkt to close conn\n", sig)
if err := c.Close(); err != nil {
log.Errorf("send fin pkt %s", err.Error())
return fmt.Errorf("send fin pkt %s", err.Error())
}
os.Exit(0)
return nil
}
func (c *Client) SetLogLevel() {

View File

@@ -1,16 +1,19 @@
package client
import (
"fmt"
"net"
"os"
"time"
"github.com/lucheng0127/virtuallan/pkg/packet"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)
func (c *Client) HandleConn(netToIface chan *packet.VLPkt) {
go func() {
func (c *Client) HandleConn(netToIface chan *packet.VLPkt) error {
g := new(errgroup.Group)
g.Go(func() error {
for {
pkt := <-netToIface
if pkt.Type != packet.P_RAW {
@@ -25,12 +28,12 @@ func (c *Client) HandleConn(netToIface chan *packet.VLPkt) {
_, err = c.Iface.Write(stream)
if err != nil {
log.Errorf("write to tap %s %s\n", c.Iface.Name(), err.Error())
continue
return fmt.Errorf("write to tap %s %s", c.Iface.Name(), err.Error())
}
}
}()
})
g.Go(func() error {
for {
var buf [65535]byte
@@ -49,10 +52,16 @@ func (c *Client) HandleConn(netToIface chan *packet.VLPkt) {
_, err = c.Conn.Write(stream)
if err != nil {
log.Errorf("send udp stream to %s %s\n", c.Conn.RemoteAddr().String(), err.Error())
os.Exit(1)
return fmt.Errorf("send udp stream to %s %s", c.Conn.RemoteAddr().String(), err.Error())
}
}
})
if err := g.Wait(); err != nil {
return err
}
return nil
}
func SendKeepalive(conn *net.UDPConn, addr string) error {

View File

@@ -1,12 +1,12 @@
package client
import (
"errors"
"fmt"
"net"
"os"
"os/signal"
"strings"
"sync"
"github.com/erikdubbelboer/gspt"
"github.com/lucheng0127/virtuallan/pkg/cipher"
@@ -14,6 +14,7 @@ import (
"github.com/lucheng0127/virtuallan/pkg/utils"
log "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix"
)
@@ -74,26 +75,28 @@ func (c *Client) Launch() error {
c.Conn = conn
// Use errgroup check goruntine error
g := new(errgroup.Group)
// Handle signal
sigChan := make(chan os.Signal, 8)
signal.Notify(sigChan, unix.SIGTERM, unix.SIGINT)
go c.HandleSignal(sigChan)
g.Go(func() error {
return c.HandleSignal(sigChan)
})
// Do auth
ipChan := make(chan string)
netToIface := make(chan *packet.VLPkt, 1024)
var wg sync.WaitGroup
wg.Add(3)
// Handle udp packet
go func() {
g.Go(func() error {
for {
var buf [65535]byte
n, _, err := conn.ReadFromUDP(buf[:])
if err != nil {
log.Error("read from conn ", err)
os.Exit(1)
return fmt.Errorf("read from conn %s", err.Error())
}
if n < 2 {
@@ -110,14 +113,11 @@ func (c *Client) Launch() error {
case packet.P_RESPONSE:
switch pkt.VLBody.(*packet.RspBody).Code {
case packet.RSP_AUTH_REQUIRED:
log.Error("auth failed")
os.Exit(1)
return errors.New("auth failed")
case packet.RSP_IP_NOT_MATCH:
log.Error("ip not match")
os.Exit(1)
return errors.New("ip not match")
case packet.RSP_USER_LOGGED:
log.Error("user already logged by other endpoint")
os.Exit(1)
return errors.New("user already logged by other endpoint")
default:
continue
}
@@ -131,24 +131,24 @@ func (c *Client) Launch() error {
continue
}
}
}()
})
// Auth
authPkt := packet.NewAuthPkt(c.user, c.password)
authStream, err := authPkt.Encode()
if err != nil {
log.Error("encode auth packet ", err)
os.Exit(1)
return fmt.Errorf("encode auth packet %s", err.Error())
}
_, err = conn.Write(authStream)
if err != nil {
log.Error("send auth packet ", err)
os.Exit(1)
return fmt.Errorf("send auth packet %s", err.Error())
}
authChan := make(chan string, 1)
go checkLoginTimeout(authChan)
g.Go(func() error {
return checkLoginTimeout(authChan)
})
// Waiting for dhcp ip
ipAddr := <-ipChan
@@ -186,14 +186,21 @@ func (c *Client) Launch() error {
// XXX: Sometime when client restart too fast will not reveice the first multicast pkt
// Monitor multicast for route bordcast
go packet.MonitorRouteMulticast(tapIface, strings.Split(c.IPAddr, "/")[0])
g.Go(func() error {
return packet.MonitorRouteMulticast(tapIface, strings.Split(c.IPAddr, "/")[0])
})
// Send keepalive
go c.DoKeepalive(10)
// Switch io between udp net and tap interface
go c.HandleConn(netToIface)
g.Go(func() error {
return c.HandleConn(netToIface)
})
if err := g.Wait(); err != nil {
return err
}
wg.Wait()
return nil
}

View File

@@ -1,16 +1,18 @@
package client
import (
"errors"
"fmt"
"net"
"os"
"os/signal"
"sync"
"github.com/lucheng0127/virtuallan/pkg/cipher"
"github.com/lucheng0127/virtuallan/pkg/packet"
"github.com/lucheng0127/virtuallan/pkg/utils"
log "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
"golang.org/x/sync/errgroup"
)
func Run(cCtx *cli.Context) error {
@@ -67,26 +69,28 @@ func (c *Client) Launch() error {
c.Conn = conn
// Use errgroup check goruntine error
g := new(errgroup.Group)
// Handle signal
sigChan := make(chan os.Signal, 8)
signal.Notify(sigChan, os.Interrupt)
go c.HandleSignal(sigChan)
g.Go(func() error {
return c.HandleSignal(sigChan)
})
// Do auth
ipChan := make(chan string)
netToIface := make(chan *packet.VLPkt, 1024)
var wg sync.WaitGroup
wg.Add(3)
// Handle udp packet
go func() {
g.Go(func() error {
for {
var buf [65535]byte
n, _, err := conn.ReadFromUDP(buf[:])
if err != nil {
log.Error("read from conn ", err)
os.Exit(1)
return fmt.Errorf("read from conn %s", err)
}
if n < 2 {
@@ -103,14 +107,11 @@ func (c *Client) Launch() error {
case packet.P_RESPONSE:
switch pkt.VLBody.(*packet.RspBody).Code {
case packet.RSP_AUTH_REQUIRED:
log.Error("auth failed")
os.Exit(1)
return errors.New("auth failed")
case packet.RSP_IP_NOT_MATCH:
log.Error("ip not match")
os.Exit(1)
return errors.New("ip not match")
case packet.RSP_USER_LOGGED:
log.Error("user already logged by other endpoint")
os.Exit(1)
return errors.New("user already logged by other endpoint")
default:
continue
}
@@ -124,24 +125,24 @@ func (c *Client) Launch() error {
continue
}
}
}()
})
// Auth
authPkt := packet.NewAuthPkt(c.user, c.password)
authStream, err := authPkt.Encode()
if err != nil {
log.Error("encode auth packet ", err)
os.Exit(1)
return fmt.Errorf("encode auth packet %s", err.Error())
}
_, err = conn.Write(authStream)
if err != nil {
log.Error("send auth packet ", err)
os.Exit(1)
return fmt.Errorf("send auth packet %s", err.Error())
}
authChan := make(chan string, 1)
go checkLoginTimeout(authChan)
g.Go(func() error {
return checkLoginTimeout(authChan)
})
// Waiting for dhcp ip
ipAddr := <-ipChan
@@ -159,8 +160,13 @@ func (c *Client) Launch() error {
go c.DoKeepalive(10)
// Switch io between udp net and tap interface
go c.HandleConn(netToIface)
g.Go(func() error {
return c.HandleConn(netToIface)
})
if err := g.Wait(); err != nil {
return err
}
wg.Wait()
return nil
}

View File

@@ -35,19 +35,17 @@ func MulticastStream(data []byte) error {
return nil
}
func MonitorRouteMulticast(iface *net.Interface, tapIP string) {
func MonitorRouteMulticast(iface *net.Interface, tapIP string) error {
// Monitor route multicast will run as a goruntine in client so log error but don't exit
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", MULTICAST_ADDR, MULTICAST_PORT))
if err != nil {
log.Errorf("parse multicast addr %s", err.Error())
return
return fmt.Errorf("parse multicast addr %s", err.Error())
}
// Listen multicast on tap interface
ln, err := net.ListenMulticastUDP("udp", iface, udpAddr)
if err != nil {
log.Errorf("listen multicast address %s", err.Error())
return
return fmt.Errorf("listen multicast address %s", err.Error())
}
// Read data from udp
@@ -82,7 +80,7 @@ func MonitorRouteMulticast(iface *net.Interface, tapIP string) {
// TODO: Implement in windows
// Sync routes, use flag replace, for unknow ip need delete
if err := utils.SyncRoutesForIface(iface.Name, tapIP, routes); err != nil {
log.Errorf("sync route for %s %s", iface.Name, err.Error())
return fmt.Errorf("sync route for %s %s", iface.Name, err.Error())
}
}
}