refactor: refactor print GRPC message (#386)

This commit is contained in:
naison
2024-12-06 19:29:11 +08:00
committed by GitHub
parent d9a978d330
commit 81f62eab31
19 changed files with 168 additions and 306 deletions

View File

@@ -2,13 +2,10 @@ package action
import (
"context"
"fmt"
"io"
log "github.com/sirupsen/logrus"
"github.com/spf13/pflag"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc"
@@ -44,20 +41,9 @@ func (svr *Server) Clone(req *rpc.CloneRequest, resp rpc.Daemon_CloneServer) (er
if err != nil {
return err
}
var msg *rpc.ConnectResponse
for {
msg, err = connResp.Recv()
if err == io.EOF {
break
} else if err == nil {
fmt.Fprint(out, msg.Message)
} else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled {
return nil
} else if code := status.Code(err); code == codes.AlreadyExists {
return fmt.Errorf("connect with cluster already established, disconnect required before proceeding")
} else {
return err
}
err = util.PrintGRPCStream[rpc.ConnectResponse](connResp, out)
if err != nil {
return err
}
util.InitLoggerForClient(config.Debug)
log.SetOutput(out)

View File

@@ -173,17 +173,9 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp
if err != nil {
return err
}
for {
recv, err := connResp.Recv()
if err == io.EOF {
break
} else if err != nil {
return err
}
err = resp.Send(recv)
if err != nil {
return err
}
err = util.CopyGRPCStream[rpc.ConnectResponse](connResp, resp)
if err != nil {
return err
}
if resp.Context().Err() != nil {

View File

@@ -190,17 +190,9 @@ func (svr *Server) redirectToSudoDaemon(req *rpc.ConnectRequest, resp rpc.Daemon
if err != nil {
return err
}
for {
recv, err := connResp.Recv()
if err == io.EOF {
break
} else if err != nil {
return err
}
err = resp.Send(recv)
if err != nil {
return err
}
err = util.CopyGRPCStream[rpc.ConnectResponse](connResp, resp)
if err != nil {
return err
}
if resp.Context().Err() != nil {

View File

@@ -118,18 +118,9 @@ func (svr *Server) Disconnect(req *rpc.DisconnectRequest, resp rpc.Daemon_Discon
if err != nil {
return err
}
var recv *rpc.DisconnectResponse
for {
recv, err = connResp.Recv()
if err == io.EOF {
break
} else if err != nil {
return err
}
err = resp.Send(recv)
if err != nil {
return err
}
err = util.CopyGRPCStream[rpc.DisconnectResponse](connResp, resp)
if err != nil {
return err
}
}

View File

@@ -52,7 +52,7 @@ func (svr *Server) Logs(req *rpc.LogRequest, resp rpc.Daemon_LogsServer) error {
continue
}
err = resp.Send(&rpc.LogResponse{Message: line.Text})
err = resp.Send(&rpc.LogResponse{Message: line.Text + "\n"})
if err != nil {
return err
}

View File

@@ -99,19 +99,15 @@ func (svr *Server) Proxy(req *rpc.ConnectRequest, resp rpc.Daemon_ProxyServer) (
if err != nil {
return err
}
var recv *rpc.DisconnectResponse
for {
recv, err = disconnectResp.Recv()
if err == io.EOF {
break
} else if err != nil {
log.Errorf("Receive from disconnect failed: %v", err)
return err
}
err = resp.Send(&rpc.ConnectResponse{Message: recv.Message})
if err != nil {
return err
}
err = util.CopyAndConvertGRPCStream[rpc.DisconnectResponse, rpc.ConnectResponse](
disconnectResp,
resp,
func(response *rpc.DisconnectResponse) *rpc.ConnectResponse {
return &rpc.ConnectResponse{Message: response.Message}
},
)
if err != nil {
return err
}
util.InitLoggerForClient(config.Debug)
log.SetOutput(out)
@@ -125,18 +121,9 @@ func (svr *Server) Proxy(req *rpc.ConnectRequest, resp rpc.Daemon_ProxyServer) (
if err != nil {
return err
}
var recv *rpc.ConnectResponse
for {
recv, err = connResp.Recv()
if err == io.EOF {
break
} else if err != nil {
return err
}
err = resp.Send(recv)
if err != nil {
return err
}
err = util.CopyGRPCStream[rpc.ConnectResponse](connResp, resp)
if err != nil {
return err
}
util.InitLoggerForClient(config.Debug)
log.SetOutput(out)

View File

@@ -18,6 +18,7 @@ import (
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon/elevate"
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc"
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
)
var daemonClient, sudoDaemonClient rpc.DaemonClient
@@ -67,11 +68,8 @@ func GetClient(isSudo bool) (cli rpc.DaemonClient) {
if err != nil {
return nil
}
for {
if _, err = quitStream.Recv(); err != nil {
return nil
}
}
err = util.PrintGRPCStream[rpc.QuitResponse](quitStream, nil)
return
}
if isSudo {

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
@@ -158,11 +157,15 @@ func (option *Options) Connect(ctx context.Context, sshConfig *pkgssh.SshConfig,
req.Workloads = nil
}
option.AddRollbackFunc(func() error {
_ = disconnect(ctx, daemonCli, &rpc.DisconnectRequest{
resp, err := daemonCli.Disconnect(ctx, &rpc.DisconnectRequest{
KubeconfigBytes: ptr.To(string(kubeConfigBytes)),
Namespace: ptr.To(ns),
SshJump: sshConfig.ToRPC(),
})
if err != nil {
return err
}
_ = util.PrintGRPCStream[rpc.DisconnectResponse](resp)
return nil
})
var resp rpc.Daemon_ConnectClient
@@ -171,15 +174,8 @@ func (option *Options) Connect(ctx context.Context, sshConfig *pkgssh.SshConfig,
log.Errorf("Connect to cluster error: %s", err.Error())
return err
}
for {
resp, err := resp.Recv()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
_, _ = fmt.Fprint(os.Stdout, resp.Message)
}
err = util.PrintGRPCStream[rpc.CloneResponse](resp)
return err
case ConnectModeContainer:
runConfig, err := option.CreateConnectContainer(portBindings)
@@ -339,15 +335,8 @@ func disconnect(ctx context.Context, daemonClient rpc.DaemonClient, req *rpc.Dis
if err != nil {
return err
}
for {
recv, err := resp.Recv()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
_, _ = fmt.Fprint(os.Stdout, recv.Message)
}
err = util.PrintGRPCStream[rpc.DisconnectResponse](resp)
return err
}
func (option *Options) CreateConnectContainer(portBindings nat.PortMap) (*RunConfig, error) {

View File

@@ -77,7 +77,7 @@ func (c *Config) watchServiceToAddHosts(ctx context.Context, serviceInterface v1
defer ticker.Stop()
immediate := make(chan struct{}, 1)
immediate <- struct{}{}
var ErrChanDone = errors.New("watch service chan done")
for ctx.Err() == nil {
err := func() error {
w, err := serviceInterface.Watch(ctx, v1.ListOptions{Watch: true})
@@ -91,7 +91,7 @@ func (c *Config) watchServiceToAddHosts(ctx context.Context, serviceInterface v1
return ctx.Err()
case event, ok := <-w.ResultChan():
if !ok {
return errors.New("watch service chan done")
return ErrChanDone
}
svc, ok := event.Object.(*v12.Service)
if !ok {
@@ -154,7 +154,7 @@ func (c *Config) watchServiceToAddHosts(ctx context.Context, serviceInterface v1
if ctx.Err() != nil {
return
}
if err != nil && !errors.Is(err, context.Canceled) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, ErrChanDone) {
log.Debugf("Failed to watch service to add route table: %v", err)
}
if utilnet.IsConnectionRefused(err) || apierrors.IsTooManyRequests(err) || apierrors.IsForbidden(err) {

83
pkg/util/grpc.go Normal file
View File

@@ -0,0 +1,83 @@
package util
import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type Printable interface {
GetMessage() string
}
func PrintGRPCStream[T any](clientStream grpc.ClientStream, writers ...io.Writer) error {
var out io.Writer = os.Stdout
for _, writer := range writers {
out = writer
break
}
for {
var t = new(T)
err := clientStream.RecvMsg(t)
if errors.Is(err, io.EOF) {
return nil
}
if status.Code(err) == codes.Canceled {
return nil
}
if err != nil {
return err
}
if out == nil {
continue
}
if p, ok := any(t).(Printable); ok {
_, _ = fmt.Fprintf(out, p.GetMessage())
} else {
buf, _ := json.Marshal(t)
_, _ = fmt.Fprintf(out, string(buf))
}
}
}
func CopyGRPCStream[T any](r grpc.ClientStream, w grpc.ServerStream) error {
for {
var t = new(T)
err := r.RecvMsg(t)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return err
}
err = w.SendMsg(t)
if err != nil {
return err
}
}
}
func CopyAndConvertGRPCStream[I any, O any](r grpc.ClientStream, w grpc.ServerStream, convert func(*I) *O) error {
for {
var i = new(I)
err := r.RecvMsg(i)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return err
}
o := convert(i)
err = w.SendMsg(o)
if err != nil {
return err
}
}
}