diff --git a/cmd/kubevpn/cmds/clone.go b/cmd/kubevpn/cmds/clone.go index 07e56925..2be9cd90 100644 --- a/cmd/kubevpn/cmds/clone.go +++ b/cmd/kubevpn/cmds/clone.go @@ -2,14 +2,11 @@ package cmds import ( "fmt" - "io" "os" pkgerr "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" cmdutil "k8s.io/kubectl/pkg/cmd/util" utilcomp "k8s.io/kubectl/pkg/util/completion" "k8s.io/kubectl/pkg/util/i18n" @@ -141,16 +138,9 @@ func CmdClone(f cmdutil.Factory) *cobra.Command { if err != nil { return err } - for { - recv, err := resp.Recv() - if err == io.EOF { - break - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprint(os.Stdout, recv.GetMessage()) + err = util.PrintGRPCStream[rpc.CloneResponse](resp) + if err != nil { + return err } util.Print(os.Stdout, config.Slogan) return nil diff --git a/cmd/kubevpn/cmds/connect.go b/cmd/kubevpn/cmds/connect.go index de37431b..6c0d0e67 100644 --- a/cmd/kubevpn/cmds/connect.go +++ b/cmd/kubevpn/cmds/connect.go @@ -3,13 +3,11 @@ package cmds import ( "context" "fmt" - "io" "os" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "google.golang.org/grpc" cmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/util/i18n" "k8s.io/kubectl/pkg/util/templates" @@ -98,38 +96,18 @@ func CmdConnect(f cmdutil.Factory) *cobra.Command { } // if is foreground, send to sudo daemon server cli := daemon.GetClient(false) + var resp grpc.ClientStream if lite { - resp, err := cli.ConnectFork(cmd.Context(), req) - if err != nil { - return err - } - for { - recv, err := resp.Recv() - if err == io.EOF { - break - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprint(os.Stdout, recv.GetMessage()) - } + resp, err = cli.ConnectFork(cmd.Context(), req) } else { - resp, err := cli.Connect(cmd.Context(), req) - if err != nil { - return err - } - for { - recv, err := resp.Recv() - if err == io.EOF { - break - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprint(os.Stdout, recv.GetMessage()) - } + resp, err = cli.Connect(cmd.Context(), req) + } + if err != nil { + return err + } + err = util.PrintGRPCStream[rpc.ConnectResponse](resp) + if err != nil { + return err } if !foreground { util.Print(os.Stdout, config.Slogan) @@ -144,15 +122,9 @@ func CmdConnect(f cmdutil.Factory) *cobra.Command { log.Errorf("Disconnect error: %v", err) return err } - for { - recv, err := disconnect.Recv() - if err == io.EOF { - break - } else if err != nil { - log.Errorf("Receive disconnect message error: %v", err) - return err - } - _, _ = fmt.Fprint(os.Stdout, recv.Message) + err = util.PrintGRPCStream[rpc.DisconnectResponse](disconnect) + if err != nil { + return err } _, _ = fmt.Fprint(os.Stdout, "Disconnect completed") } diff --git a/cmd/kubevpn/cmds/disconnect.go b/cmd/kubevpn/cmds/disconnect.go index 5ace8e7f..bb0ff089 100644 --- a/cmd/kubevpn/cmds/disconnect.go +++ b/cmd/kubevpn/cmds/disconnect.go @@ -2,13 +2,10 @@ package cmds import ( "fmt" - "io" "os" "strconv" "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" cmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/util/i18n" "k8s.io/kubectl/pkg/util/templates" @@ -69,18 +66,12 @@ func CmdDisconnect(f cmdutil.Factory) *cobra.Command { All: pointer.Bool(all), }, ) - var resp *rpc.DisconnectResponse - for { - resp, err = client.Recv() - if err == io.EOF { - break - } else if err == nil { - _, _ = fmt.Fprint(os.Stdout, resp.Message) - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - break - } else { - return err - } + if err != nil { + return err + } + err = util.PrintGRPCStream[rpc.DisconnectResponse](client) + if err != nil { + return err } _, _ = fmt.Fprint(os.Stdout, "Disconnect completed") return nil diff --git a/cmd/kubevpn/cmds/leave.go b/cmd/kubevpn/cmds/leave.go index 9cff8dc6..a4e0b2cf 100644 --- a/cmd/kubevpn/cmds/leave.go +++ b/cmd/kubevpn/cmds/leave.go @@ -1,19 +1,14 @@ package cmds import ( - "fmt" - "io" - "os" - "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" cmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/util/i18n" "k8s.io/kubectl/pkg/util/templates" "github.com/wencaiwulue/kubevpn/v2/pkg/daemon" "github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc" + "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) func CmdLeave(f cmdutil.Factory) *cobra.Command { @@ -43,17 +38,8 @@ func CmdLeave(f cmdutil.Factory) *cobra.Command { if err != nil { return err } - for { - recv, err := leave.Recv() - if err == io.EOF { - return nil - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprint(os.Stdout, recv.GetMessage()) - } + err = util.PrintGRPCStream[rpc.LeaveResponse](leave) + return err }, } return leaveCmd diff --git a/cmd/kubevpn/cmds/logs.go b/cmd/kubevpn/cmds/logs.go index 4ba32343..6a574366 100644 --- a/cmd/kubevpn/cmds/logs.go +++ b/cmd/kubevpn/cmds/logs.go @@ -1,13 +1,7 @@ package cmds import ( - "fmt" - "io" - "os" - "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" cmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/util/i18n" "k8s.io/kubectl/pkg/util/templates" @@ -41,20 +35,8 @@ func CmdLogs(f cmdutil.Factory) *cobra.Command { if err != nil { return err } - var resp *rpc.LogResponse - for { - resp, err = client.Recv() - if err == io.EOF { - break - } else if err == nil { - fmt.Fprintln(os.Stdout, resp.Message) - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else { - return err - } - } - return nil + err = util.PrintGRPCStream[rpc.LogResponse](client) + return err }, } cmd.Flags().BoolVarP(&req.Follow, "follow", "f", false, "Specify if the logs should be streamed.") diff --git a/cmd/kubevpn/cmds/proxy.go b/cmd/kubevpn/cmds/proxy.go index cd3a3f1f..0dbeb37d 100644 --- a/cmd/kubevpn/cmds/proxy.go +++ b/cmd/kubevpn/cmds/proxy.go @@ -3,13 +3,10 @@ package cmds import ( "context" "fmt" - "io" "os" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" cmdutil "k8s.io/kubectl/pkg/cmd/util" utilcomp "k8s.io/kubectl/pkg/util/completion" "k8s.io/kubectl/pkg/util/i18n" @@ -139,18 +136,9 @@ func CmdProxy(f cmdutil.Factory) *cobra.Command { if err != nil { return err } - var resp *rpc.ConnectResponse - for { - resp, err = client.Recv() - if err == io.EOF { - break - } else if err == nil { - _, _ = fmt.Fprint(os.Stdout, resp.Message) - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else { - return err - } + err = util.PrintGRPCStream[rpc.ConnectResponse](client) + if err != nil { + return err } util.Print(os.Stdout, config.Slogan) // hangup @@ -161,17 +149,12 @@ func CmdProxy(f cmdutil.Factory) *cobra.Command { stream, err := cli.Leave(context.Background(), &rpc.LeaveRequest{ Workloads: args, }) - var resp *rpc.LeaveResponse - for { - resp, err = stream.Recv() - if err == io.EOF { - return nil - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprint(os.Stdout, resp.Message) + if err != nil { + return err + } + err = util.PrintGRPCStream[rpc.LeaveResponse](stream) + if err != nil { + return err } } return nil diff --git a/cmd/kubevpn/cmds/quit.go b/cmd/kubevpn/cmds/quit.go index 868768b0..3941e169 100644 --- a/cmd/kubevpn/cmds/quit.go +++ b/cmd/kubevpn/cmds/quit.go @@ -3,12 +3,9 @@ package cmds import ( "context" "fmt" - "io" "os" "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" cmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/util/i18n" "k8s.io/kubectl/pkg/util/templates" @@ -49,18 +46,9 @@ func quit(ctx context.Context, isSudo bool) error { if err != nil { return err } - var resp *rpc.QuitResponse - for { - resp, err = client.Recv() - if err == io.EOF { - break - } else if err == nil { - _, _ = fmt.Fprint(os.Stdout, resp.Message) - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else { - return err - } + err = util.PrintGRPCStream[rpc.QuitResponse](client) + if err != nil { + return err } return nil } diff --git a/cmd/kubevpn/cmds/remove.go b/cmd/kubevpn/cmds/remove.go index aecb5263..b0542526 100644 --- a/cmd/kubevpn/cmds/remove.go +++ b/cmd/kubevpn/cmds/remove.go @@ -1,19 +1,14 @@ package cmds import ( - "fmt" - "io" - "os" - "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" cmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/util/i18n" "k8s.io/kubectl/pkg/util/templates" "github.com/wencaiwulue/kubevpn/v2/pkg/daemon" "github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc" + "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) func CmdRemove(f cmdutil.Factory) *cobra.Command { @@ -41,17 +36,8 @@ func CmdRemove(f cmdutil.Factory) *cobra.Command { if err != nil { return err } - for { - recv, err := leave.Recv() - if err == io.EOF { - return nil - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprint(os.Stdout, recv.GetMessage()) - } + err = util.PrintGRPCStream[rpc.RemoveResponse](leave) + return err }, } return cmd diff --git a/cmd/kubevpn/cmds/reset.go b/cmd/kubevpn/cmds/reset.go index 8dfd0a32..856168b9 100644 --- a/cmd/kubevpn/cmds/reset.go +++ b/cmd/kubevpn/cmds/reset.go @@ -1,14 +1,8 @@ package cmds import ( - "fmt" - "io" - "os" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" cmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/util/i18n" "k8s.io/kubectl/pkg/util/templates" @@ -71,7 +65,7 @@ func CmdReset(f cmdutil.Factory) *cobra.Command { if err != nil { log.Warnf("Failed to disconnect from cluter: %v", err) } else { - _ = printDisconnectResp(disconnect) + _ = util.PrintGRPCStream[rpc.DisconnectResponse](disconnect) } req := &rpc.ResetRequest{ @@ -83,7 +77,7 @@ func CmdReset(f cmdutil.Factory) *cobra.Command { if err != nil { return err } - err = printResetResp(resp) + err = util.PrintGRPCStream[rpc.ResetResponse](resp) return err }, } @@ -91,31 +85,3 @@ func CmdReset(f cmdutil.Factory) *cobra.Command { pkgssh.AddSshFlags(cmd.Flags(), sshConf) return cmd } - -func printResetResp(resp rpc.Daemon_ResetClient) error { - for { - recv, err := resp.Recv() - if err == io.EOF { - return nil - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprintf(os.Stdout, recv.GetMessage()) - } -} - -func printDisconnectResp(disconnect rpc.Daemon_DisconnectClient) error { - for { - recv, err := disconnect.Recv() - if err == io.EOF { - return nil - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprintf(os.Stdout, recv.GetMessage()) - } -} diff --git a/pkg/daemon/action/clone.go b/pkg/daemon/action/clone.go index 72ccb461..05665d52 100644 --- a/pkg/daemon/action/clone.go +++ b/pkg/daemon/action/clone.go @@ -2,13 +2,10 @@ package action import ( "context" - "fmt" "io" log "github.com/sirupsen/logrus" "github.com/spf13/pflag" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/wencaiwulue/kubevpn/v2/pkg/config" "github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc" @@ -44,20 +41,9 @@ func (svr *Server) Clone(req *rpc.CloneRequest, resp rpc.Daemon_CloneServer) (er if err != nil { return err } - var msg *rpc.ConnectResponse - for { - msg, err = connResp.Recv() - if err == io.EOF { - break - } else if err == nil { - fmt.Fprint(out, msg.Message) - } else if code := status.Code(err); code == codes.DeadlineExceeded || code == codes.Canceled { - return nil - } else if code := status.Code(err); code == codes.AlreadyExists { - return fmt.Errorf("connect with cluster already established, disconnect required before proceeding") - } else { - return err - } + err = util.PrintGRPCStream[rpc.ConnectResponse](connResp, out) + if err != nil { + return err } util.InitLoggerForClient(config.Debug) log.SetOutput(out) diff --git a/pkg/daemon/action/connect-fork.go b/pkg/daemon/action/connect-fork.go index 8a37864a..c419c990 100644 --- a/pkg/daemon/action/connect-fork.go +++ b/pkg/daemon/action/connect-fork.go @@ -173,17 +173,9 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp if err != nil { return err } - for { - recv, err := connResp.Recv() - if err == io.EOF { - break - } else if err != nil { - return err - } - err = resp.Send(recv) - if err != nil { - return err - } + err = util.CopyGRPCStream[rpc.ConnectResponse](connResp, resp) + if err != nil { + return err } if resp.Context().Err() != nil { diff --git a/pkg/daemon/action/connect.go b/pkg/daemon/action/connect.go index 0cdf3e22..9d90fce7 100644 --- a/pkg/daemon/action/connect.go +++ b/pkg/daemon/action/connect.go @@ -190,17 +190,9 @@ func (svr *Server) redirectToSudoDaemon(req *rpc.ConnectRequest, resp rpc.Daemon if err != nil { return err } - for { - recv, err := connResp.Recv() - if err == io.EOF { - break - } else if err != nil { - return err - } - err = resp.Send(recv) - if err != nil { - return err - } + err = util.CopyGRPCStream[rpc.ConnectResponse](connResp, resp) + if err != nil { + return err } if resp.Context().Err() != nil { diff --git a/pkg/daemon/action/disconnect.go b/pkg/daemon/action/disconnect.go index 68c5f58f..4f920080 100644 --- a/pkg/daemon/action/disconnect.go +++ b/pkg/daemon/action/disconnect.go @@ -118,18 +118,9 @@ func (svr *Server) Disconnect(req *rpc.DisconnectRequest, resp rpc.Daemon_Discon if err != nil { return err } - var recv *rpc.DisconnectResponse - for { - recv, err = connResp.Recv() - if err == io.EOF { - break - } else if err != nil { - return err - } - err = resp.Send(recv) - if err != nil { - return err - } + err = util.CopyGRPCStream[rpc.DisconnectResponse](connResp, resp) + if err != nil { + return err } } diff --git a/pkg/daemon/action/logs.go b/pkg/daemon/action/logs.go index 5784ea83..c252c3ad 100644 --- a/pkg/daemon/action/logs.go +++ b/pkg/daemon/action/logs.go @@ -52,7 +52,7 @@ func (svr *Server) Logs(req *rpc.LogRequest, resp rpc.Daemon_LogsServer) error { continue } - err = resp.Send(&rpc.LogResponse{Message: line.Text}) + err = resp.Send(&rpc.LogResponse{Message: line.Text + "\n"}) if err != nil { return err } diff --git a/pkg/daemon/action/proxy.go b/pkg/daemon/action/proxy.go index 4c797ec3..1cb01e7f 100644 --- a/pkg/daemon/action/proxy.go +++ b/pkg/daemon/action/proxy.go @@ -99,19 +99,15 @@ func (svr *Server) Proxy(req *rpc.ConnectRequest, resp rpc.Daemon_ProxyServer) ( if err != nil { return err } - var recv *rpc.DisconnectResponse - for { - recv, err = disconnectResp.Recv() - if err == io.EOF { - break - } else if err != nil { - log.Errorf("Receive from disconnect failed: %v", err) - return err - } - err = resp.Send(&rpc.ConnectResponse{Message: recv.Message}) - if err != nil { - return err - } + err = util.CopyAndConvertGRPCStream[rpc.DisconnectResponse, rpc.ConnectResponse]( + disconnectResp, + resp, + func(response *rpc.DisconnectResponse) *rpc.ConnectResponse { + return &rpc.ConnectResponse{Message: response.Message} + }, + ) + if err != nil { + return err } util.InitLoggerForClient(config.Debug) log.SetOutput(out) @@ -125,18 +121,9 @@ func (svr *Server) Proxy(req *rpc.ConnectRequest, resp rpc.Daemon_ProxyServer) ( if err != nil { return err } - var recv *rpc.ConnectResponse - for { - recv, err = connResp.Recv() - if err == io.EOF { - break - } else if err != nil { - return err - } - err = resp.Send(recv) - if err != nil { - return err - } + err = util.CopyGRPCStream[rpc.ConnectResponse](connResp, resp) + if err != nil { + return err } util.InitLoggerForClient(config.Debug) log.SetOutput(out) diff --git a/pkg/daemon/client.go b/pkg/daemon/client.go index 73e4e34b..c614e77e 100644 --- a/pkg/daemon/client.go +++ b/pkg/daemon/client.go @@ -18,6 +18,7 @@ import ( "github.com/wencaiwulue/kubevpn/v2/pkg/config" "github.com/wencaiwulue/kubevpn/v2/pkg/daemon/elevate" "github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc" + "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) var daemonClient, sudoDaemonClient rpc.DaemonClient @@ -67,11 +68,8 @@ func GetClient(isSudo bool) (cli rpc.DaemonClient) { if err != nil { return nil } - for { - if _, err = quitStream.Recv(); err != nil { - return nil - } - } + err = util.PrintGRPCStream[rpc.QuitResponse](quitStream, nil) + return } if isSudo { diff --git a/pkg/dev/options.go b/pkg/dev/options.go index 0da90c4e..9bd06cd0 100644 --- a/pkg/dev/options.go +++ b/pkg/dev/options.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "os" "strconv" "strings" @@ -158,11 +157,15 @@ func (option *Options) Connect(ctx context.Context, sshConfig *pkgssh.SshConfig, req.Workloads = nil } option.AddRollbackFunc(func() error { - _ = disconnect(ctx, daemonCli, &rpc.DisconnectRequest{ + resp, err := daemonCli.Disconnect(ctx, &rpc.DisconnectRequest{ KubeconfigBytes: ptr.To(string(kubeConfigBytes)), Namespace: ptr.To(ns), SshJump: sshConfig.ToRPC(), }) + if err != nil { + return err + } + _ = util.PrintGRPCStream[rpc.DisconnectResponse](resp) return nil }) var resp rpc.Daemon_ConnectClient @@ -171,15 +174,8 @@ func (option *Options) Connect(ctx context.Context, sshConfig *pkgssh.SshConfig, log.Errorf("Connect to cluster error: %s", err.Error()) return err } - for { - resp, err := resp.Recv() - if err == io.EOF { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprint(os.Stdout, resp.Message) - } + err = util.PrintGRPCStream[rpc.CloneResponse](resp) + return err case ConnectModeContainer: runConfig, err := option.CreateConnectContainer(portBindings) @@ -339,15 +335,8 @@ func disconnect(ctx context.Context, daemonClient rpc.DaemonClient, req *rpc.Dis if err != nil { return err } - for { - recv, err := resp.Recv() - if err == io.EOF { - return nil - } else if err != nil { - return err - } - _, _ = fmt.Fprint(os.Stdout, recv.Message) - } + err = util.PrintGRPCStream[rpc.DisconnectResponse](resp) + return err } func (option *Options) CreateConnectContainer(portBindings nat.PortMap) (*RunConfig, error) { diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index a79a832d..f3724b60 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -77,7 +77,7 @@ func (c *Config) watchServiceToAddHosts(ctx context.Context, serviceInterface v1 defer ticker.Stop() immediate := make(chan struct{}, 1) immediate <- struct{}{} - + var ErrChanDone = errors.New("watch service chan done") for ctx.Err() == nil { err := func() error { w, err := serviceInterface.Watch(ctx, v1.ListOptions{Watch: true}) @@ -91,7 +91,7 @@ func (c *Config) watchServiceToAddHosts(ctx context.Context, serviceInterface v1 return ctx.Err() case event, ok := <-w.ResultChan(): if !ok { - return errors.New("watch service chan done") + return ErrChanDone } svc, ok := event.Object.(*v12.Service) if !ok { @@ -154,7 +154,7 @@ func (c *Config) watchServiceToAddHosts(ctx context.Context, serviceInterface v1 if ctx.Err() != nil { return } - if err != nil && !errors.Is(err, context.Canceled) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, ErrChanDone) { log.Debugf("Failed to watch service to add route table: %v", err) } if utilnet.IsConnectionRefused(err) || apierrors.IsTooManyRequests(err) || apierrors.IsForbidden(err) { diff --git a/pkg/util/grpc.go b/pkg/util/grpc.go new file mode 100644 index 00000000..c899556b --- /dev/null +++ b/pkg/util/grpc.go @@ -0,0 +1,83 @@ +package util + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "os" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type Printable interface { + GetMessage() string +} + +func PrintGRPCStream[T any](clientStream grpc.ClientStream, writers ...io.Writer) error { + var out io.Writer = os.Stdout + for _, writer := range writers { + out = writer + break + } + + for { + var t = new(T) + err := clientStream.RecvMsg(t) + if errors.Is(err, io.EOF) { + return nil + } + if status.Code(err) == codes.Canceled { + return nil + } + if err != nil { + return err + } + if out == nil { + continue + } + if p, ok := any(t).(Printable); ok { + _, _ = fmt.Fprintf(out, p.GetMessage()) + } else { + buf, _ := json.Marshal(t) + _, _ = fmt.Fprintf(out, string(buf)) + } + } +} + +func CopyGRPCStream[T any](r grpc.ClientStream, w grpc.ServerStream) error { + for { + var t = new(T) + err := r.RecvMsg(t) + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return err + } + err = w.SendMsg(t) + if err != nil { + return err + } + } +} + +func CopyAndConvertGRPCStream[I any, O any](r grpc.ClientStream, w grpc.ServerStream, convert func(*I) *O) error { + for { + var i = new(I) + err := r.RecvMsg(i) + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return err + } + o := convert(i) + err = w.SendMsg(o) + if err != nil { + return err + } + } +}