Use errgroup to return error of goruntine

This commit is contained in:
lucheng
2024-07-12 09:51:43 +08:00
parent 9d2038ff6f
commit c4db99d32a
7 changed files with 103 additions and 79 deletions

1
go.mod
View File

@@ -12,6 +12,7 @@ require (
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/urfave/cli/v2 v2.27.1 github.com/urfave/cli/v2 v2.27.1
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/sync v0.7.0
golang.org/x/sys v0.19.0 golang.org/x/sys v0.19.0
golang.org/x/term v0.19.0 golang.org/x/term v0.19.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1

2
go.sum
View File

@@ -88,6 +88,8 @@ golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -2,6 +2,7 @@ package client
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
@@ -87,13 +88,12 @@ func GetLoginInfo() (string, string, error) {
return strings.TrimSpace(user), strings.TrimSpace(passwd), nil return strings.TrimSpace(user), strings.TrimSpace(passwd), nil
} }
func checkLoginTimeout(c chan string) { func checkLoginTimeout(c chan string) error {
select { select {
case <-c: case <-c:
return return nil
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
log.Error("login timeout") return errors.New("login timeout")
os.Exit(1)
} }
} }
func (c *Client) Close() error { func (c *Client) Close() error {
@@ -112,15 +112,16 @@ func (c *Client) Close() error {
return nil return nil
} }
func (c *Client) HandleSignal(sigChan chan os.Signal) { func (c *Client) HandleSignal(sigChan chan os.Signal) error {
sig := <-sigChan sig := <-sigChan
log.Infof("received signal: %v, send fin pkt to close conn\n", sig) log.Infof("received signal: %v, send fin pkt to close conn\n", sig)
if err := c.Close(); err != nil { if err := c.Close(); err != nil {
log.Errorf("send fin pkt %s", err.Error()) return fmt.Errorf("send fin pkt %s", err.Error())
} }
os.Exit(0) os.Exit(0)
return nil
} }
func (c *Client) SetLogLevel() { func (c *Client) SetLogLevel() {

View File

@@ -1,16 +1,19 @@
package client package client
import ( import (
"fmt"
"net" "net"
"os"
"time" "time"
"github.com/lucheng0127/virtuallan/pkg/packet" "github.com/lucheng0127/virtuallan/pkg/packet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
) )
func (c *Client) HandleConn(netToIface chan *packet.VLPkt) { func (c *Client) HandleConn(netToIface chan *packet.VLPkt) error {
go func() { g := new(errgroup.Group)
g.Go(func() error {
for { for {
pkt := <-netToIface pkt := <-netToIface
if pkt.Type != packet.P_RAW { if pkt.Type != packet.P_RAW {
@@ -25,12 +28,12 @@ func (c *Client) HandleConn(netToIface chan *packet.VLPkt) {
_, err = c.Iface.Write(stream) _, err = c.Iface.Write(stream)
if err != nil { if err != nil {
log.Errorf("write to tap %s %s\n", c.Iface.Name(), err.Error()) return fmt.Errorf("write to tap %s %s", c.Iface.Name(), err.Error())
continue
} }
} }
}() })
g.Go(func() error {
for { for {
var buf [65535]byte var buf [65535]byte
@@ -49,10 +52,16 @@ func (c *Client) HandleConn(netToIface chan *packet.VLPkt) {
_, err = c.Conn.Write(stream) _, err = c.Conn.Write(stream)
if err != nil { if err != nil {
log.Errorf("send udp stream to %s %s\n", c.Conn.RemoteAddr().String(), err.Error()) return fmt.Errorf("send udp stream to %s %s", c.Conn.RemoteAddr().String(), err.Error())
os.Exit(1)
} }
} }
})
if err := g.Wait(); err != nil {
return err
}
return nil
} }
func SendKeepalive(conn *net.UDPConn, addr string) error { func SendKeepalive(conn *net.UDPConn, addr string) error {

View File

@@ -1,12 +1,12 @@
package client package client
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
"sync"
"github.com/erikdubbelboer/gspt" "github.com/erikdubbelboer/gspt"
"github.com/lucheng0127/virtuallan/pkg/cipher" "github.com/lucheng0127/virtuallan/pkg/cipher"
@@ -14,6 +14,7 @@ import (
"github.com/lucheng0127/virtuallan/pkg/utils" "github.com/lucheng0127/virtuallan/pkg/utils"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -74,26 +75,28 @@ func (c *Client) Launch() error {
c.Conn = conn c.Conn = conn
// Use errgroup check goruntine error
g := new(errgroup.Group)
// Handle signal // Handle signal
sigChan := make(chan os.Signal, 8) sigChan := make(chan os.Signal, 8)
signal.Notify(sigChan, unix.SIGTERM, unix.SIGINT) signal.Notify(sigChan, unix.SIGTERM, unix.SIGINT)
go c.HandleSignal(sigChan) g.Go(func() error {
return c.HandleSignal(sigChan)
})
// Do auth // Do auth
ipChan := make(chan string) ipChan := make(chan string)
netToIface := make(chan *packet.VLPkt, 1024) netToIface := make(chan *packet.VLPkt, 1024)
var wg sync.WaitGroup
wg.Add(3)
// Handle udp packet // Handle udp packet
go func() { g.Go(func() error {
for { for {
var buf [65535]byte var buf [65535]byte
n, _, err := conn.ReadFromUDP(buf[:]) n, _, err := conn.ReadFromUDP(buf[:])
if err != nil { if err != nil {
log.Error("read from conn ", err) return fmt.Errorf("read from conn %s", err.Error())
os.Exit(1)
} }
if n < 2 { if n < 2 {
@@ -110,14 +113,11 @@ func (c *Client) Launch() error {
case packet.P_RESPONSE: case packet.P_RESPONSE:
switch pkt.VLBody.(*packet.RspBody).Code { switch pkt.VLBody.(*packet.RspBody).Code {
case packet.RSP_AUTH_REQUIRED: case packet.RSP_AUTH_REQUIRED:
log.Error("auth failed") return errors.New("auth failed")
os.Exit(1)
case packet.RSP_IP_NOT_MATCH: case packet.RSP_IP_NOT_MATCH:
log.Error("ip not match") return errors.New("ip not match")
os.Exit(1)
case packet.RSP_USER_LOGGED: case packet.RSP_USER_LOGGED:
log.Error("user already logged by other endpoint") return errors.New("user already logged by other endpoint")
os.Exit(1)
default: default:
continue continue
} }
@@ -131,24 +131,24 @@ func (c *Client) Launch() error {
continue continue
} }
} }
}() })
// Auth // Auth
authPkt := packet.NewAuthPkt(c.user, c.password) authPkt := packet.NewAuthPkt(c.user, c.password)
authStream, err := authPkt.Encode() authStream, err := authPkt.Encode()
if err != nil { if err != nil {
log.Error("encode auth packet ", err) return fmt.Errorf("encode auth packet %s", err.Error())
os.Exit(1)
} }
_, err = conn.Write(authStream) _, err = conn.Write(authStream)
if err != nil { if err != nil {
log.Error("send auth packet ", err) return fmt.Errorf("send auth packet %s", err.Error())
os.Exit(1)
} }
authChan := make(chan string, 1) authChan := make(chan string, 1)
go checkLoginTimeout(authChan) g.Go(func() error {
return checkLoginTimeout(authChan)
})
// Waiting for dhcp ip // Waiting for dhcp ip
ipAddr := <-ipChan ipAddr := <-ipChan
@@ -186,14 +186,21 @@ func (c *Client) Launch() error {
// XXX: Sometime when client restart too fast will not reveice the first multicast pkt // XXX: Sometime when client restart too fast will not reveice the first multicast pkt
// Monitor multicast for route bordcast // Monitor multicast for route bordcast
go packet.MonitorRouteMulticast(tapIface, strings.Split(c.IPAddr, "/")[0]) g.Go(func() error {
return packet.MonitorRouteMulticast(tapIface, strings.Split(c.IPAddr, "/")[0])
})
// Send keepalive // Send keepalive
go c.DoKeepalive(10) go c.DoKeepalive(10)
// Switch io between udp net and tap interface // Switch io between udp net and tap interface
go c.HandleConn(netToIface) g.Go(func() error {
return c.HandleConn(netToIface)
})
if err := g.Wait(); err != nil {
return err
}
wg.Wait()
return nil return nil
} }

View File

@@ -1,16 +1,18 @@
package client package client
import ( import (
"errors"
"fmt"
"net" "net"
"os" "os"
"os/signal" "os/signal"
"sync"
"github.com/lucheng0127/virtuallan/pkg/cipher" "github.com/lucheng0127/virtuallan/pkg/cipher"
"github.com/lucheng0127/virtuallan/pkg/packet" "github.com/lucheng0127/virtuallan/pkg/packet"
"github.com/lucheng0127/virtuallan/pkg/utils" "github.com/lucheng0127/virtuallan/pkg/utils"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"golang.org/x/sync/errgroup"
) )
func Run(cCtx *cli.Context) error { func Run(cCtx *cli.Context) error {
@@ -67,26 +69,28 @@ func (c *Client) Launch() error {
c.Conn = conn c.Conn = conn
// Use errgroup check goruntine error
g := new(errgroup.Group)
// Handle signal // Handle signal
sigChan := make(chan os.Signal, 8) sigChan := make(chan os.Signal, 8)
signal.Notify(sigChan, os.Interrupt) signal.Notify(sigChan, os.Interrupt)
go c.HandleSignal(sigChan) g.Go(func() error {
return c.HandleSignal(sigChan)
})
// Do auth // Do auth
ipChan := make(chan string) ipChan := make(chan string)
netToIface := make(chan *packet.VLPkt, 1024) netToIface := make(chan *packet.VLPkt, 1024)
var wg sync.WaitGroup
wg.Add(3)
// Handle udp packet // Handle udp packet
go func() { g.Go(func() error {
for { for {
var buf [65535]byte var buf [65535]byte
n, _, err := conn.ReadFromUDP(buf[:]) n, _, err := conn.ReadFromUDP(buf[:])
if err != nil { if err != nil {
log.Error("read from conn ", err) return fmt.Errorf("read from conn %s", err)
os.Exit(1)
} }
if n < 2 { if n < 2 {
@@ -103,14 +107,11 @@ func (c *Client) Launch() error {
case packet.P_RESPONSE: case packet.P_RESPONSE:
switch pkt.VLBody.(*packet.RspBody).Code { switch pkt.VLBody.(*packet.RspBody).Code {
case packet.RSP_AUTH_REQUIRED: case packet.RSP_AUTH_REQUIRED:
log.Error("auth failed") return errors.New("auth failed")
os.Exit(1)
case packet.RSP_IP_NOT_MATCH: case packet.RSP_IP_NOT_MATCH:
log.Error("ip not match") return errors.New("ip not match")
os.Exit(1)
case packet.RSP_USER_LOGGED: case packet.RSP_USER_LOGGED:
log.Error("user already logged by other endpoint") return errors.New("user already logged by other endpoint")
os.Exit(1)
default: default:
continue continue
} }
@@ -124,24 +125,24 @@ func (c *Client) Launch() error {
continue continue
} }
} }
}() })
// Auth // Auth
authPkt := packet.NewAuthPkt(c.user, c.password) authPkt := packet.NewAuthPkt(c.user, c.password)
authStream, err := authPkt.Encode() authStream, err := authPkt.Encode()
if err != nil { if err != nil {
log.Error("encode auth packet ", err) return fmt.Errorf("encode auth packet %s", err.Error())
os.Exit(1)
} }
_, err = conn.Write(authStream) _, err = conn.Write(authStream)
if err != nil { if err != nil {
log.Error("send auth packet ", err) return fmt.Errorf("send auth packet %s", err.Error())
os.Exit(1)
} }
authChan := make(chan string, 1) authChan := make(chan string, 1)
go checkLoginTimeout(authChan) g.Go(func() error {
return checkLoginTimeout(authChan)
})
// Waiting for dhcp ip // Waiting for dhcp ip
ipAddr := <-ipChan ipAddr := <-ipChan
@@ -159,8 +160,13 @@ func (c *Client) Launch() error {
go c.DoKeepalive(10) go c.DoKeepalive(10)
// Switch io between udp net and tap interface // Switch io between udp net and tap interface
go c.HandleConn(netToIface) g.Go(func() error {
return c.HandleConn(netToIface)
})
if err := g.Wait(); err != nil {
return err
}
wg.Wait()
return nil return nil
} }

View File

@@ -35,19 +35,17 @@ func MulticastStream(data []byte) error {
return nil return nil
} }
func MonitorRouteMulticast(iface *net.Interface, tapIP string) { func MonitorRouteMulticast(iface *net.Interface, tapIP string) error {
// Monitor route multicast will run as a goruntine in client so log error but don't exit // Monitor route multicast will run as a goruntine in client so log error but don't exit
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", MULTICAST_ADDR, MULTICAST_PORT)) udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", MULTICAST_ADDR, MULTICAST_PORT))
if err != nil { if err != nil {
log.Errorf("parse multicast addr %s", err.Error()) return fmt.Errorf("parse multicast addr %s", err.Error())
return
} }
// Listen multicast on tap interface // Listen multicast on tap interface
ln, err := net.ListenMulticastUDP("udp", iface, udpAddr) ln, err := net.ListenMulticastUDP("udp", iface, udpAddr)
if err != nil { if err != nil {
log.Errorf("listen multicast address %s", err.Error()) return fmt.Errorf("listen multicast address %s", err.Error())
return
} }
// Read data from udp // Read data from udp
@@ -82,7 +80,7 @@ func MonitorRouteMulticast(iface *net.Interface, tapIP string) {
// TODO: Implement in windows // TODO: Implement in windows
// Sync routes, use flag replace, for unknow ip need delete // Sync routes, use flag replace, for unknow ip need delete
if err := utils.SyncRoutesForIface(iface.Name, tapIP, routes); err != nil { if err := utils.SyncRoutesForIface(iface.Name, tapIP, routes); err != nil {
log.Errorf("sync route for %s %s", iface.Name, err.Error()) return fmt.Errorf("sync route for %s %s", iface.Name, err.Error())
} }
} }
} }