diff --git a/pkg/controlplane/cache.go b/pkg/controlplane/cache.go index f7072537..8e2f60ee 100644 --- a/pkg/controlplane/cache.go +++ b/pkg/controlplane/cache.go @@ -91,7 +91,7 @@ type Rule struct { PortMap map[int32]string } -func (a *Virtual) To(enableIPv6 bool, logger *log.Logger) ( +func (a *Virtual) To(enableIPv6 bool, logger *log.Entry) ( listeners []types.Resource, clusters []types.Resource, routes []types.Resource, diff --git a/pkg/controlplane/main.go b/pkg/controlplane/main.go index aed4f9ed..ff0a1505 100644 --- a/pkg/controlplane/main.go +++ b/pkg/controlplane/main.go @@ -11,7 +11,7 @@ import ( plog "github.com/wencaiwulue/kubevpn/v2/pkg/log" ) -func Main(ctx context.Context, factory cmdutil.Factory, port uint, logger *log.Logger) error { +func Main(ctx context.Context, factory cmdutil.Factory, port uint, logger *log.Entry) error { snapshotCache := cache.NewSnapshotCache(false, cache.IDHash{}, logger) proc := NewProcessor(snapshotCache, logger) diff --git a/pkg/controlplane/processor.go b/pkg/controlplane/processor.go index 1353995c..d94dbe6c 100644 --- a/pkg/controlplane/processor.go +++ b/pkg/controlplane/processor.go @@ -21,13 +21,13 @@ import ( type Processor struct { cache cache.SnapshotCache - logger *log.Logger + logger *log.Entry version int64 expireCache *utilcache.Expiring } -func NewProcessor(cache cache.SnapshotCache, log *log.Logger) *Processor { +func NewProcessor(cache cache.SnapshotCache, log *log.Entry) *Processor { return &Processor{ cache: cache, logger: log, diff --git a/pkg/core/gvisorlocaltcphandler.go b/pkg/core/gvisorlocaltcphandler.go index ce8670d2..6cf0b4e7 100644 --- a/pkg/core/gvisorlocaltcphandler.go +++ b/pkg/core/gvisorlocaltcphandler.go @@ -2,6 +2,7 @@ package core import ( "context" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/channel" @@ -10,6 +11,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "github.com/wencaiwulue/kubevpn/v2/pkg/config" + plog "github.com/wencaiwulue/kubevpn/v2/pkg/log" "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) @@ -47,7 +49,7 @@ func (h *gvisorLocalHandler) Run(ctx context.Context) { readFromEndpointWriteToTun(ctx, endpoint, h.outbound) util.SafeClose(h.errChan) }() - s := NewLocalStack(ctx, sniffer.NewWithPrefix(endpoint, "[gVISOR] ")) + s := NewLocalStack(ctx, sniffer.NewWithPrefix(endpoint, fmt.Sprintf("[gVISOR]%s ", plog.GenStr(plog.GetFields(ctx))))) defer s.Destroy() select { case <-h.errChan: diff --git a/pkg/core/gvisorlocaltunendpoint.go b/pkg/core/gvisorlocaltunendpoint.go index eed15e03..78235782 100755 --- a/pkg/core/gvisorlocaltunendpoint.go +++ b/pkg/core/gvisorlocaltunendpoint.go @@ -2,6 +2,7 @@ package core import ( "context" + "fmt" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" @@ -16,10 +17,11 @@ import ( ) func readFromEndpointWriteToTun(ctx context.Context, endpoint *channel.Endpoint, out chan<- *Packet) { + prefix := fmt.Sprintf("[gVISOR]%s ", plog.GenStr(plog.GetFields(ctx))) for ctx.Err() == nil { pkt := endpoint.ReadContext(ctx) if pkt != nil { - sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pkt.NetworkProtocolNumber, pkt) + sniffer.LogPacket(prefix, sniffer.DirectionSend, pkt.NetworkProtocolNumber, pkt) data := pkt.ToView().AsSlice() buf := config.LPool.Get().([]byte)[:] n := copy(buf[1:], data) @@ -30,6 +32,7 @@ func readFromEndpointWriteToTun(ctx context.Context, endpoint *channel.Endpoint, } func readFromGvisorInboundWriteToEndpoint(ctx context.Context, in <-chan *Packet, endpoint *channel.Endpoint) { + prefix := fmt.Sprintf("[gVISOR]%s ", plog.GenStr(plog.GetFields(ctx))) for ctx.Err() == nil { var packet *Packet select { @@ -60,7 +63,7 @@ func readFromGvisorInboundWriteToEndpoint(ctx context.Context, in <-chan *Packet Payload: buffer.MakeWithData(packet.data[1:packet.length]), }) config.LPool.Put(packet.data[:]) - sniffer.LogPacket("[gVISOR] ", sniffer.DirectionRecv, protocol, pkt) + sniffer.LogPacket(prefix, sniffer.DirectionRecv, protocol, pkt) endpoint.InjectInbound(protocol, pkt) pkt.DecRef() } diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index 6db66b53..27745390 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -34,6 +34,12 @@ func TunHandler(node *Node, forward *Forwarder) Handler { } func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) { + tunIfi, err := util.GetTunDeviceByConn(tun) + if err != nil { + plog.G(ctx).Errorf("Failed to get tun device: %v", err) + return + } + ctx = plog.WithField(ctx, tunIfi.Name, "") if !h.forward.IsEmpty() { h.HandleClient(ctx, tun, h.forward) } else { diff --git a/pkg/core/tunhandlerclient.go b/pkg/core/tunhandlerclient.go index 1b221345..46dcfc67 100644 --- a/pkg/core/tunhandlerclient.go +++ b/pkg/core/tunhandlerclient.go @@ -162,7 +162,7 @@ func (d *ClientDevice) readFromTun(ctx context.Context) { config.LPool.Put(buf[:]) continue } - plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n) + plog.G(plog.WithFields(context.Background(), plog.GetFields(ctx))).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n) packet := NewPacket(buf[:], n+1, src, dst) if packet.src.Equal(packet.dst) { gvisorInbound <- packet diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index 152c2242..6487d5f3 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -336,7 +336,7 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err } }() } - out := plog.G(ctx).Out + out := plog.G(ctx).Logger.Out err = util.PortForwardPod( c.config, c.restclient, @@ -345,7 +345,7 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err portPair, readyChan, childCtx.Done(), - out, + nil, out, ) if *first { diff --git a/pkg/handler/sync.go b/pkg/handler/sync.go index 1844e80a..9e5feffb 100644 --- a/pkg/handler/sync.go +++ b/pkg/handler/sync.go @@ -187,7 +187,7 @@ func (d *SyncOptions) DoSync(ctx context.Context, kubeconfigJsonBytes []byte, im } } { - container, err := podcmd.FindOrDefaultContainerByName(&v1.Pod{Spec: v1.PodSpec{Containers: containers}}, d.TargetContainer, false, plog.G(ctx).Out) + container, err := podcmd.FindOrDefaultContainerByName(&v1.Pod{Spec: v1.PodSpec{Containers: containers}}, d.TargetContainer, false, plog.G(ctx).Logger.Out) if err != nil { return err } diff --git a/pkg/log/context.go b/pkg/log/context.go index 8f2ee049..96916412 100644 --- a/pkg/log/context.go +++ b/pkg/log/context.go @@ -34,11 +34,59 @@ func WithoutLogger(ctx context.Context) context.Context { return ctx } -// GetLogger retrieves the current logger from the context. If no logger is +// getLogger retrieves the current logger from the context. If no logger is // available, the default logger is returned. -func GetLogger(ctx context.Context) *log.Logger { +func getLogger(ctx context.Context) *log.Logger { if logger := ctx.Value(loggerKey{}); logger != nil && logger.(*loggerValue).logger != nil { return logger.(*loggerValue).logger } return L } + +type fieldsKey struct{} + +// WithFields 将指定的字段添加到 context 中,这些字段会在后续从 context 获取 logger 时自动添加 +func WithFields(ctx context.Context, fields map[string]any) context.Context { + existingFields := GetFields(ctx) + if existingFields == nil { + return context.WithValue(ctx, fieldsKey{}, fields) + } + + // 合并字段,新字段会覆盖旧字段 + mergedFields := make(map[string]any) + for k, v := range existingFields { + mergedFields[k] = v + } + for k, v := range fields { + mergedFields[k] = v + } + + return context.WithValue(ctx, fieldsKey{}, mergedFields) +} + +// WithField 将单个字段添加到 context 中 +func WithField(ctx context.Context, key string, value any) context.Context { + return WithFields(ctx, map[string]any{key: value}) +} + +// GetFields 从 context 中获取所有已存储的字段 +func GetFields(ctx context.Context) map[string]any { + if fields := ctx.Value(fieldsKey{}); fields != nil { + if f, ok := fields.(map[string]any); ok { + return f + } + } + return nil +} + +// GetLogger 从 context 中获取 logger,并自动添加 context 中存储的字段 +func GetLogger(ctx context.Context) *log.Entry { + logger := getLogger(ctx) + fields := GetFields(ctx) + + if len(fields) > 0 { + return logger.WithFields(fields) + } + + return log.NewEntry(logger) +} diff --git a/pkg/log/context_test.go b/pkg/log/context_test.go index a490a2c4..ccfc7f84 100644 --- a/pkg/log/context_test.go +++ b/pkg/log/context_test.go @@ -3,24 +3,56 @@ package log import ( "context" "testing" - "time" ) -func TestGetLoggerFromContext(t *testing.T) { - logger := InitLoggerForServer() - ctx := WithLogger(context.Background(), logger) - cancel, cancelFunc := context.WithCancel(ctx) - defer cancelFunc() - timeout, c := context.WithTimeout(cancel, time.Second*10) - defer c() - l := GetLogger(timeout) - if logger != l { - panic("not same") - } - cancel = WithoutLogger(cancel) - defaultLogger := GetLogger(cancel) - if defaultLogger != L { - panic("not same") - } +func TestLog(t *testing.T) { + ctx := context.Background() + G(ctx).WithField("tun", "abc").Debug("debug") + logger := G(ctx).WithField("tun", "abc").Logger logger.Debug("debug") + logger.Warn("warn") +} + +func TestWithFields(t *testing.T) { + ctx := WithField(context.Background(), "request_id", "12345") + ctx = WithField(ctx, "user_id", "user-abc") + + logger := GetLogger(ctx) + logger.Info("this log will contains request_id and user_id") + + ctx2 := WithFields(ctx, map[string]any{ + "action": "login", + "ip": "192.168.1.1", + }) + + logger2 := GetLogger(ctx2) + logger2.Info("this log will contains four fields") + + // 在不同方法中使用 + processRequest(ctx2) +} + +func processRequest(ctx context.Context) { + logger := GetLogger(ctx) + logger.Debug("request handling...") + + logger.WithField("step", "validation").Info("please input validation") +} + +func TestWithFieldsMerge(t *testing.T) { + ctx := WithFields(context.Background(), map[string]any{ + "service": "api", + "version": "v1", + }) + + // merge fields + ctx = WithFields(ctx, map[string]any{ + "endpoint": "/users", + "method": "GET", + }) + + ctx = WithField(ctx, "version", "v2") + + logger := GetLogger(ctx) + logger.Info("should show all fields,version changed to v2") } diff --git a/pkg/log/logger.go b/pkg/log/logger.go index e10a3849..b1f72120 100644 --- a/pkg/log/logger.go +++ b/pkg/log/logger.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "runtime" + "sort" "strings" "time" @@ -65,6 +66,18 @@ type serverFormat struct { // 2009/01/23 01:23:23 d.go:23: message func (*serverFormat) Format(e *log.Entry) ([]byte, error) { // e.Caller maybe is nil, because pkg/handler/connect.go:252 + if len(e.Data) > 0 { + return []byte( + fmt.Sprintf("%s %s %s:%d %s: %s\n", + GenStr(e.Data), + e.Time.Format("2006-01-02 15:04:05.000"), + filepath.Base(ptr.Deref(e.Caller, runtime.Frame{}).File), + ptr.Deref(e.Caller, runtime.Frame{}).Line, + e.Level.String(), + e.Message, + )), nil + } + return []byte( fmt.Sprintf("%s %s:%d %s: %s\n", e.Time.Format("2006-01-02 15:04:05.000"), @@ -106,3 +119,38 @@ func (g ServerEmitter) Emit(depth int, level glog.Level, timestamp time.Time, fo message, ) } + +func GenStr(allFields map[string]any) string { + fieldsStr := "" + + keys := make([]string, len(allFields)) + i := 0 + for field := range allFields { + keys[i] = field + i++ + } + + sort.Strings(keys) + + for _, key := range keys { + var valueStr string + value := allFields[key] + + if stringer, ok := value.(fmt.Stringer); ok { + valueStr = stringer.String() + } else { + valueStr = fmt.Sprintf("%v", value) + } + + if strings.Contains(valueStr, " ") { + valueStr = `"` + valueStr + `"` + } + if valueStr == "" { + fieldsStr += key + " " + } else { + fieldsStr += key + "=" + valueStr + " " + } + + } + return fmt.Sprintf("[%s]", strings.TrimSpace(fieldsStr)) +} diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index e0ca5f93..7b8ec40e 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -113,9 +113,9 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr var openChannelError *gossh.OpenChannelError // if ssh server not permitted ssh port-forward, do nothing until exit if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited { - plog.G(ctx).Debugf("Failed to open ssh port-forward to %s: %v", remote.String(), err) - plog.G(ctx).Errorf("Failed to open ssh port-forward to %s: %v", remote.String(), err) + plog.G(ctx).Errorf("Prohibited to open ssh port-forward to %s: %v", remote.String(), err) cancelFunc1() + return } plog.G(ctx).Debugf("Failed to dial into remote %s: %v", remote.String(), err) return diff --git a/pkg/util/portforward.go b/pkg/util/portforward.go deleted file mode 100644 index 055ac20f..00000000 --- a/pkg/util/portforward.go +++ /dev/null @@ -1,470 +0,0 @@ -package util - -import ( - "errors" - "fmt" - "io" - "net" - "net/http" - "sort" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "k8s.io/api/core/v1" - k8serrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/util/httpstream" - "k8s.io/apimachinery/pkg/util/runtime" - "k8s.io/client-go/tools/portforward" -) - -// PortForwarder knows how to listen for local connections and forward them to -// a remote pod via an upgraded HTTP request. -type PortForwarder struct { - addresses []listenAddress - ports []ForwardedPort - stopChan <-chan struct{} - // if failed to find socat, send error - // if pod is not found, send error - errChan chan error - - dialer httpstream.Dialer - streamConn httpstream.Connection - listeners []io.Closer - Ready chan struct{} - requestIDLock sync.Mutex - requestID int - out io.Writer - errOut io.Writer -} - -// ForwardedPort contains a Local:Remote port pairing. -type ForwardedPort struct { - Local uint16 - Remote uint16 -} - -/* -valid port specifications: - -5000 -- forwards from localhost:5000 to pod:5000 - -8888:5000 -- forwards from localhost:8888 to pod:5000 - -0:5000 -:5000 - - selects a random available local port, - forwards from localhost: to pod:5000 -*/ -func parsePorts(ports []string) ([]ForwardedPort, error) { - var forwards []ForwardedPort - for _, portString := range ports { - parts := strings.Split(portString, ":") - var localString, remoteString string - if len(parts) == 1 { - localString = parts[0] - remoteString = parts[0] - } else if len(parts) == 2 { - localString = parts[0] - if localString == "" { - // support :5000 - localString = "0" - } - remoteString = parts[1] - } else { - return nil, fmt.Errorf("invalid port format '%s'", portString) - } - - localPort, err := strconv.ParseUint(localString, 10, 16) - if err != nil { - return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err) - } - - remotePort, err := strconv.ParseUint(remoteString, 10, 16) - if err != nil { - return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err) - } - if remotePort == 0 { - return nil, fmt.Errorf("remote port must be > 0") - } - - forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)}) - } - - return forwards, nil -} - -type listenAddress struct { - address string - protocol string - failureMode string -} - -func parseAddresses(addressesToParse []string) ([]listenAddress, error) { - var addresses []listenAddress - parsed := make(map[string]listenAddress) - for _, address := range addressesToParse { - if address == "localhost" { - if _, exists := parsed["127.0.0.1"]; !exists { - ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"} - parsed[ip.address] = ip - } - if _, exists := parsed["::1"]; !exists { - ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"} - parsed[ip.address] = ip - } - } else if net.ParseIP(address).To4() != nil { - parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"} - } else if net.ParseIP(address) != nil { - parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"} - } else { - return nil, fmt.Errorf("%s is not a valid IP", address) - } - } - addresses = make([]listenAddress, len(parsed)) - id := 0 - for _, v := range parsed { - addresses[id] = v - id++ - } - // Sort addresses before returning to get a stable order - sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address }) - - return addresses, nil -} - -// NewOnAddresses creates a new PortForwarder with custom listen addresses. -func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) { - if len(addresses) == 0 { - return nil, errors.New("you must specify at least 1 address") - } - parsedAddresses, err := parseAddresses(addresses) - if err != nil { - return nil, err - } - if len(ports) == 0 { - return nil, errors.New("you must specify at least 1 port") - } - parsedPorts, err := parsePorts(ports) - if err != nil { - return nil, err - } - return &PortForwarder{ - dialer: dialer, - addresses: parsedAddresses, - ports: parsedPorts, - stopChan: stopChan, - errChan: make(chan error, 1), - Ready: readyChan, - out: out, - errOut: errOut, - }, nil -} - -// ForwardPorts formats and executes a port forwarding request. The connection will remain -// open until stopChan is closed. -func (pf *PortForwarder) ForwardPorts() error { - defer pf.Close() - - var err error - pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name) - if err != nil { - return fmt.Errorf("error upgrading connection: %s", err) - } - defer pf.streamConn.Close() - - return pf.forward() -} - -// forward dials the remote host specific in req, upgrades the request, starts -// listeners for each port specified in ports, and forwards local connections -// to the remote host via streams. -func (pf *PortForwarder) forward() error { - var err error - - listenSuccess := false - for i := range pf.ports { - port := &pf.ports[i] - err = pf.listenOnPort(port) - switch { - case err == nil: - listenSuccess = true - default: - if pf.errOut != nil { - fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err) - } - } - } - - if !listenSuccess { - return fmt.Errorf("unable to listen on any of the requested ports: %v", pf.ports) - } - - if pf.Ready != nil { - close(pf.Ready) - } - - // wait for interrupt or conn closure - select { - case <-pf.stopChan: - runtime.HandleError(errors.New("lost connection to pod")) - } - select { - case errs, ok := <-pf.errChan: - if ok { - return errs - } - return nil - default: - return nil - } -} - -// listenOnPort delegates listener creation and waits for connections on requested bind addresses. -// An error is raised based on address groups (default and localhost) and their failure modes -func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error { - var errors []error - failCounters := make(map[string]int, 2) - successCounters := make(map[string]int, 2) - for _, addr := range pf.addresses { - err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address) - if err != nil { - errors = append(errors, err) - failCounters[addr.failureMode]++ - } else { - successCounters[addr.failureMode]++ - } - } - if successCounters["all"] == 0 && failCounters["all"] > 0 { - return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors) - } - if failCounters["any"] > 0 { - return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors) - } - return nil -} - -// listenOnPortAndAddress delegates listener creation and waits for new connections -// in the background f -func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error { - listener, err := pf.getListener(protocol, address, port) - if err != nil { - return err - } - pf.listeners = append(pf.listeners, listener) - go pf.waitForConnection(listener, *port) - return nil -} - -// getListener creates a listener on the interface targeted by the given hostname on the given port with -// the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6 -func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) { - listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local)))) - if err != nil { - return nil, fmt.Errorf("unable to create listener: Error %s", err) - } - listenerAddress := listener.Addr().String() - host, localPort, _ := net.SplitHostPort(listenerAddress) - localPortUInt, err := strconv.ParseUint(localPort, 10, 16) - - if err != nil { - fmt.Fprintf(pf.out, "Failed to forward from %s:%d -> %d\n", hostname, localPortUInt, port.Remote) - return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host) - } - port.Local = uint16(localPortUInt) - if pf.out != nil { - fmt.Fprintf(pf.out, "Forwarding from %s -> %d\n", net.JoinHostPort(hostname, strconv.Itoa(int(localPortUInt))), port.Remote) - } - - return listener, nil -} - -// waitForConnection waits for new connections to listener and handles them in -// the background. -func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) { - for { - conn, err := listener.Accept() - if err != nil { - // TODO consider using something like https://github.com/hydrogen18/stoppableListener? - if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") { - runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err)) - } - return - } - go pf.handleConnection(conn, port) - } -} - -func (pf *PortForwarder) nextRequestID() int { - pf.requestIDLock.Lock() - defer pf.requestIDLock.Unlock() - id := pf.requestID - pf.requestID++ - return id -} - -// handleConnection copies data between the local connection and the stream to -// the remote server. -func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) { - defer conn.Close() - - if pf.out != nil { - fmt.Fprintf(pf.out, "Handling connection for %d\n", port.Local) - } - - requestID := pf.nextRequestID() - // create error stream - headers := http.Header{} - headers.Set(v1.StreamType, v1.StreamTypeError) - headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote)) - headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID)) - var err error - errorStream, err := pf.streamConn.CreateStream(headers) - if err != nil { - runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)) - return - } - // we're not writing to this stream - errorStream.Close() - - errorChan := make(chan error) - go func() { - message, err := io.ReadAll(errorStream) - switch { - case err != nil: - errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err) - case len(message) > 0: - errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message)) - } - close(errorChan) - }() - - // create data stream - headers.Set(v1.StreamType, v1.StreamTypeData) - dataStream, err := pf.streamConn.CreateStream(headers) - if err != nil { - runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)) - return - } - - localError := make(chan struct{}) - remoteDone := make(chan struct{}) - - go func() { - // Copy from the remote side to the local port. - if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { - runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err)) - } - - // inform the select below that the remote copy is done - close(remoteDone) - }() - - go func() { - // inform server we're not sending any more data after copy unblocks - defer dataStream.Close() - - // Copy from the local port to the remote side. - if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { - runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err)) - // break out of the select below without waiting for the other copy to finish - close(localError) - } - }() - - // wait for either a local->remote error or for copying from remote->local to finish - select { - case <-remoteDone: - case <-localError: - // wait for interrupt or conn closure - case <-pf.stopChan: - runtime.HandleError(errors.New("lost connection to pod")) - } - - // always expect something on errorChan (it may be nil) - select { - case err = <-errorChan: - default: - } - if err != nil { - if strings.Contains(err.Error(), "failed to find socat") { - select { - case pf.errChan <- err: - default: - } - } - runtime.HandleError(err) - } -} - -// Close stops all listeners of PortForwarder. -func (pf *PortForwarder) Close() { - // stop all listeners - for _, l := range pf.listeners { - if err := l.Close(); err != nil { - runtime.HandleError(fmt.Errorf("error closing listener: %v", err)) - } - } -} - -// GetPorts will return the ports that were forwarded; this can be used to -// retrieve the locally-bound port in cases where the input was port 0. This -// function will signal an error if the Ready channel is nil or if the -// listeners are not ready yet; this function will succeed after the Ready -// channel has been closed. -func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) { - if pf.Ready == nil { - return nil, fmt.Errorf("no ready channel provided") - } - select { - case <-pf.Ready: - return pf.ports, nil - default: - return nil, fmt.Errorf("listeners not ready") - } -} - -func (pf *PortForwarder) tryToCreateStream(header *http.Header) (httpstream.Stream, error) { - errorChan := make(chan error, 2) - var resultChan atomic.Value - time.AfterFunc(time.Second*1, func() { - errorChan <- errors.New("timeout") - }) - go func() { - if pf.streamConn != nil { - if stream, err := pf.streamConn.CreateStream(*header); err == nil && stream != nil { - errorChan <- nil - resultChan.Store(stream) - return - } - } - errorChan <- errors.New("") - }() - if err := <-errorChan; err == nil && resultChan.Load() != nil { - return resultChan.Load().(httpstream.Stream), nil - } - // close old connection in case of resource leak - if pf.streamConn != nil { - _ = pf.streamConn.Close() - } - var err error - pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name) - if err != nil { - if k8serrors.IsNotFound(err) { - runtime.HandleError(fmt.Errorf("pod not found: %s", err)) - select { - case pf.errChan <- err: - default: - } - } else { - runtime.HandleError(fmt.Errorf("error upgrading connection: %s", err)) - } - return nil, err - } - header.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(pf.nextRequestID())) - return pf.streamConn.CreateStream(*header) -} diff --git a/pkg/util/util.go b/pkg/util/util.go index bc199593..c8b7052b 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -47,7 +47,7 @@ func RolloutStatus(ctx1 context.Context, f cmdutil.Factory, ns, workloads string defer func() { if err != nil { plog.G(ctx1).Errorf("Rollout status for %s failed: %s", workloads, err.Error()) - out := plog.GetLogger(ctx1).Out + out := plog.GetLogger(ctx1).Logger.Out streams := genericiooptions.IOStreams{ In: os.Stdin, Out: out,