refactor: add prefix to log

This commit is contained in:
fengcaiwen
2025-12-22 11:03:45 +08:00
parent 1fe1e29f72
commit bbd5e121e5
15 changed files with 172 additions and 503 deletions

View File

@@ -91,7 +91,7 @@ type Rule struct {
PortMap map[int32]string PortMap map[int32]string
} }
func (a *Virtual) To(enableIPv6 bool, logger *log.Logger) ( func (a *Virtual) To(enableIPv6 bool, logger *log.Entry) (
listeners []types.Resource, listeners []types.Resource,
clusters []types.Resource, clusters []types.Resource,
routes []types.Resource, routes []types.Resource,

View File

@@ -11,7 +11,7 @@ import (
plog "github.com/wencaiwulue/kubevpn/v2/pkg/log" plog "github.com/wencaiwulue/kubevpn/v2/pkg/log"
) )
func Main(ctx context.Context, factory cmdutil.Factory, port uint, logger *log.Logger) error { func Main(ctx context.Context, factory cmdutil.Factory, port uint, logger *log.Entry) error {
snapshotCache := cache.NewSnapshotCache(false, cache.IDHash{}, logger) snapshotCache := cache.NewSnapshotCache(false, cache.IDHash{}, logger)
proc := NewProcessor(snapshotCache, logger) proc := NewProcessor(snapshotCache, logger)

View File

@@ -21,13 +21,13 @@ import (
type Processor struct { type Processor struct {
cache cache.SnapshotCache cache cache.SnapshotCache
logger *log.Logger logger *log.Entry
version int64 version int64
expireCache *utilcache.Expiring expireCache *utilcache.Expiring
} }
func NewProcessor(cache cache.SnapshotCache, log *log.Logger) *Processor { func NewProcessor(cache cache.SnapshotCache, log *log.Entry) *Processor {
return &Processor{ return &Processor{
cache: cache, cache: cache,
logger: log, logger: log,

View File

@@ -2,6 +2,7 @@ package core
import ( import (
"context" "context"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/channel"
@@ -10,6 +11,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"github.com/wencaiwulue/kubevpn/v2/pkg/config" "github.com/wencaiwulue/kubevpn/v2/pkg/config"
plog "github.com/wencaiwulue/kubevpn/v2/pkg/log"
"github.com/wencaiwulue/kubevpn/v2/pkg/util" "github.com/wencaiwulue/kubevpn/v2/pkg/util"
) )
@@ -47,7 +49,7 @@ func (h *gvisorLocalHandler) Run(ctx context.Context) {
readFromEndpointWriteToTun(ctx, endpoint, h.outbound) readFromEndpointWriteToTun(ctx, endpoint, h.outbound)
util.SafeClose(h.errChan) util.SafeClose(h.errChan)
}() }()
s := NewLocalStack(ctx, sniffer.NewWithPrefix(endpoint, "[gVISOR] ")) s := NewLocalStack(ctx, sniffer.NewWithPrefix(endpoint, fmt.Sprintf("[gVISOR]%s ", plog.GenStr(plog.GetFields(ctx)))))
defer s.Destroy() defer s.Destroy()
select { select {
case <-h.errChan: case <-h.errChan:

View File

@@ -2,6 +2,7 @@ package core
import ( import (
"context" "context"
"fmt"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
@@ -16,10 +17,11 @@ import (
) )
func readFromEndpointWriteToTun(ctx context.Context, endpoint *channel.Endpoint, out chan<- *Packet) { func readFromEndpointWriteToTun(ctx context.Context, endpoint *channel.Endpoint, out chan<- *Packet) {
prefix := fmt.Sprintf("[gVISOR]%s ", plog.GenStr(plog.GetFields(ctx)))
for ctx.Err() == nil { for ctx.Err() == nil {
pkt := endpoint.ReadContext(ctx) pkt := endpoint.ReadContext(ctx)
if pkt != nil { if pkt != nil {
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pkt.NetworkProtocolNumber, pkt) sniffer.LogPacket(prefix, sniffer.DirectionSend, pkt.NetworkProtocolNumber, pkt)
data := pkt.ToView().AsSlice() data := pkt.ToView().AsSlice()
buf := config.LPool.Get().([]byte)[:] buf := config.LPool.Get().([]byte)[:]
n := copy(buf[1:], data) n := copy(buf[1:], data)
@@ -30,6 +32,7 @@ func readFromEndpointWriteToTun(ctx context.Context, endpoint *channel.Endpoint,
} }
func readFromGvisorInboundWriteToEndpoint(ctx context.Context, in <-chan *Packet, endpoint *channel.Endpoint) { func readFromGvisorInboundWriteToEndpoint(ctx context.Context, in <-chan *Packet, endpoint *channel.Endpoint) {
prefix := fmt.Sprintf("[gVISOR]%s ", plog.GenStr(plog.GetFields(ctx)))
for ctx.Err() == nil { for ctx.Err() == nil {
var packet *Packet var packet *Packet
select { select {
@@ -60,7 +63,7 @@ func readFromGvisorInboundWriteToEndpoint(ctx context.Context, in <-chan *Packet
Payload: buffer.MakeWithData(packet.data[1:packet.length]), Payload: buffer.MakeWithData(packet.data[1:packet.length]),
}) })
config.LPool.Put(packet.data[:]) config.LPool.Put(packet.data[:])
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionRecv, protocol, pkt) sniffer.LogPacket(prefix, sniffer.DirectionRecv, protocol, pkt)
endpoint.InjectInbound(protocol, pkt) endpoint.InjectInbound(protocol, pkt)
pkt.DecRef() pkt.DecRef()
} }

View File

@@ -34,6 +34,12 @@ func TunHandler(node *Node, forward *Forwarder) Handler {
} }
func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) { func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) {
tunIfi, err := util.GetTunDeviceByConn(tun)
if err != nil {
plog.G(ctx).Errorf("Failed to get tun device: %v", err)
return
}
ctx = plog.WithField(ctx, tunIfi.Name, "")
if !h.forward.IsEmpty() { if !h.forward.IsEmpty() {
h.HandleClient(ctx, tun, h.forward) h.HandleClient(ctx, tun, h.forward)
} else { } else {

View File

@@ -162,7 +162,7 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
config.LPool.Put(buf[:]) config.LPool.Put(buf[:])
continue continue
} }
plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n) plog.G(plog.WithFields(context.Background(), plog.GetFields(ctx))).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
packet := NewPacket(buf[:], n+1, src, dst) packet := NewPacket(buf[:], n+1, src, dst)
if packet.src.Equal(packet.dst) { if packet.src.Equal(packet.dst) {
gvisorInbound <- packet gvisorInbound <- packet

View File

@@ -336,7 +336,7 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err
} }
}() }()
} }
out := plog.G(ctx).Out out := plog.G(ctx).Logger.Out
err = util.PortForwardPod( err = util.PortForwardPod(
c.config, c.config,
c.restclient, c.restclient,
@@ -345,7 +345,7 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err
portPair, portPair,
readyChan, readyChan,
childCtx.Done(), childCtx.Done(),
out, nil,
out, out,
) )
if *first { if *first {

View File

@@ -187,7 +187,7 @@ func (d *SyncOptions) DoSync(ctx context.Context, kubeconfigJsonBytes []byte, im
} }
} }
{ {
container, err := podcmd.FindOrDefaultContainerByName(&v1.Pod{Spec: v1.PodSpec{Containers: containers}}, d.TargetContainer, false, plog.G(ctx).Out) container, err := podcmd.FindOrDefaultContainerByName(&v1.Pod{Spec: v1.PodSpec{Containers: containers}}, d.TargetContainer, false, plog.G(ctx).Logger.Out)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -34,11 +34,59 @@ func WithoutLogger(ctx context.Context) context.Context {
return ctx return ctx
} }
// GetLogger retrieves the current logger from the context. If no logger is // getLogger retrieves the current logger from the context. If no logger is
// available, the default logger is returned. // available, the default logger is returned.
func GetLogger(ctx context.Context) *log.Logger { func getLogger(ctx context.Context) *log.Logger {
if logger := ctx.Value(loggerKey{}); logger != nil && logger.(*loggerValue).logger != nil { if logger := ctx.Value(loggerKey{}); logger != nil && logger.(*loggerValue).logger != nil {
return logger.(*loggerValue).logger return logger.(*loggerValue).logger
} }
return L return L
} }
type fieldsKey struct{}
// WithFields 将指定的字段添加到 context 中,这些字段会在后续从 context 获取 logger 时自动添加
func WithFields(ctx context.Context, fields map[string]any) context.Context {
existingFields := GetFields(ctx)
if existingFields == nil {
return context.WithValue(ctx, fieldsKey{}, fields)
}
// 合并字段,新字段会覆盖旧字段
mergedFields := make(map[string]any)
for k, v := range existingFields {
mergedFields[k] = v
}
for k, v := range fields {
mergedFields[k] = v
}
return context.WithValue(ctx, fieldsKey{}, mergedFields)
}
// WithField 将单个字段添加到 context 中
func WithField(ctx context.Context, key string, value any) context.Context {
return WithFields(ctx, map[string]any{key: value})
}
// GetFields 从 context 中获取所有已存储的字段
func GetFields(ctx context.Context) map[string]any {
if fields := ctx.Value(fieldsKey{}); fields != nil {
if f, ok := fields.(map[string]any); ok {
return f
}
}
return nil
}
// GetLogger 从 context 中获取 logger并自动添加 context 中存储的字段
func GetLogger(ctx context.Context) *log.Entry {
logger := getLogger(ctx)
fields := GetFields(ctx)
if len(fields) > 0 {
return logger.WithFields(fields)
}
return log.NewEntry(logger)
}

View File

@@ -3,24 +3,56 @@ package log
import ( import (
"context" "context"
"testing" "testing"
"time"
) )
func TestGetLoggerFromContext(t *testing.T) { func TestLog(t *testing.T) {
logger := InitLoggerForServer() ctx := context.Background()
ctx := WithLogger(context.Background(), logger) G(ctx).WithField("tun", "abc").Debug("debug")
cancel, cancelFunc := context.WithCancel(ctx) logger := G(ctx).WithField("tun", "abc").Logger
defer cancelFunc()
timeout, c := context.WithTimeout(cancel, time.Second*10)
defer c()
l := GetLogger(timeout)
if logger != l {
panic("not same")
}
cancel = WithoutLogger(cancel)
defaultLogger := GetLogger(cancel)
if defaultLogger != L {
panic("not same")
}
logger.Debug("debug") logger.Debug("debug")
logger.Warn("warn")
}
func TestWithFields(t *testing.T) {
ctx := WithField(context.Background(), "request_id", "12345")
ctx = WithField(ctx, "user_id", "user-abc")
logger := GetLogger(ctx)
logger.Info("this log will contains request_id and user_id")
ctx2 := WithFields(ctx, map[string]any{
"action": "login",
"ip": "192.168.1.1",
})
logger2 := GetLogger(ctx2)
logger2.Info("this log will contains four fields")
// 在不同方法中使用
processRequest(ctx2)
}
func processRequest(ctx context.Context) {
logger := GetLogger(ctx)
logger.Debug("request handling...")
logger.WithField("step", "validation").Info("please input validation")
}
func TestWithFieldsMerge(t *testing.T) {
ctx := WithFields(context.Background(), map[string]any{
"service": "api",
"version": "v1",
})
// merge fields
ctx = WithFields(ctx, map[string]any{
"endpoint": "/users",
"method": "GET",
})
ctx = WithField(ctx, "version", "v2")
logger := GetLogger(ctx)
logger.Info("should show all fieldsversion changed to v2")
} }

View File

@@ -6,6 +6,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sort"
"strings" "strings"
"time" "time"
@@ -65,6 +66,18 @@ type serverFormat struct {
// 2009/01/23 01:23:23 d.go:23: message // 2009/01/23 01:23:23 d.go:23: message
func (*serverFormat) Format(e *log.Entry) ([]byte, error) { func (*serverFormat) Format(e *log.Entry) ([]byte, error) {
// e.Caller maybe is nil, because pkg/handler/connect.go:252 // e.Caller maybe is nil, because pkg/handler/connect.go:252
if len(e.Data) > 0 {
return []byte(
fmt.Sprintf("%s %s %s:%d %s: %s\n",
GenStr(e.Data),
e.Time.Format("2006-01-02 15:04:05.000"),
filepath.Base(ptr.Deref(e.Caller, runtime.Frame{}).File),
ptr.Deref(e.Caller, runtime.Frame{}).Line,
e.Level.String(),
e.Message,
)), nil
}
return []byte( return []byte(
fmt.Sprintf("%s %s:%d %s: %s\n", fmt.Sprintf("%s %s:%d %s: %s\n",
e.Time.Format("2006-01-02 15:04:05.000"), e.Time.Format("2006-01-02 15:04:05.000"),
@@ -106,3 +119,38 @@ func (g ServerEmitter) Emit(depth int, level glog.Level, timestamp time.Time, fo
message, message,
) )
} }
func GenStr(allFields map[string]any) string {
fieldsStr := ""
keys := make([]string, len(allFields))
i := 0
for field := range allFields {
keys[i] = field
i++
}
sort.Strings(keys)
for _, key := range keys {
var valueStr string
value := allFields[key]
if stringer, ok := value.(fmt.Stringer); ok {
valueStr = stringer.String()
} else {
valueStr = fmt.Sprintf("%v", value)
}
if strings.Contains(valueStr, " ") {
valueStr = `"` + valueStr + `"`
}
if valueStr == "" {
fieldsStr += key + " "
} else {
fieldsStr += key + "=" + valueStr + " "
}
}
return fmt.Sprintf("[%s]", strings.TrimSpace(fieldsStr))
}

View File

@@ -113,9 +113,9 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
var openChannelError *gossh.OpenChannelError var openChannelError *gossh.OpenChannelError
// if ssh server not permitted ssh port-forward, do nothing until exit // if ssh server not permitted ssh port-forward, do nothing until exit
if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited { if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited {
plog.G(ctx).Debugf("Failed to open ssh port-forward to %s: %v", remote.String(), err) plog.G(ctx).Errorf("Prohibited to open ssh port-forward to %s: %v", remote.String(), err)
plog.G(ctx).Errorf("Failed to open ssh port-forward to %s: %v", remote.String(), err)
cancelFunc1() cancelFunc1()
return
} }
plog.G(ctx).Debugf("Failed to dial into remote %s: %v", remote.String(), err) plog.G(ctx).Debugf("Failed to dial into remote %s: %v", remote.String(), err)
return return

View File

@@ -1,470 +0,0 @@
package util
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"k8s.io/api/core/v1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/tools/portforward"
)
// PortForwarder knows how to listen for local connections and forward them to
// a remote pod via an upgraded HTTP request.
type PortForwarder struct {
addresses []listenAddress
ports []ForwardedPort
stopChan <-chan struct{}
// if failed to find socat, send error
// if pod is not found, send error
errChan chan error
dialer httpstream.Dialer
streamConn httpstream.Connection
listeners []io.Closer
Ready chan struct{}
requestIDLock sync.Mutex
requestID int
out io.Writer
errOut io.Writer
}
// ForwardedPort contains a Local:Remote port pairing.
type ForwardedPort struct {
Local uint16
Remote uint16
}
/*
valid port specifications:
5000
- forwards from localhost:5000 to pod:5000
8888:5000
- forwards from localhost:8888 to pod:5000
0:5000
:5000
- selects a random available local port,
forwards from localhost:<random port> to pod:5000
*/
func parsePorts(ports []string) ([]ForwardedPort, error) {
var forwards []ForwardedPort
for _, portString := range ports {
parts := strings.Split(portString, ":")
var localString, remoteString string
if len(parts) == 1 {
localString = parts[0]
remoteString = parts[0]
} else if len(parts) == 2 {
localString = parts[0]
if localString == "" {
// support :5000
localString = "0"
}
remoteString = parts[1]
} else {
return nil, fmt.Errorf("invalid port format '%s'", portString)
}
localPort, err := strconv.ParseUint(localString, 10, 16)
if err != nil {
return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err)
}
remotePort, err := strconv.ParseUint(remoteString, 10, 16)
if err != nil {
return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err)
}
if remotePort == 0 {
return nil, fmt.Errorf("remote port must be > 0")
}
forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
}
return forwards, nil
}
type listenAddress struct {
address string
protocol string
failureMode string
}
func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
var addresses []listenAddress
parsed := make(map[string]listenAddress)
for _, address := range addressesToParse {
if address == "localhost" {
if _, exists := parsed["127.0.0.1"]; !exists {
ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
parsed[ip.address] = ip
}
if _, exists := parsed["::1"]; !exists {
ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
parsed[ip.address] = ip
}
} else if net.ParseIP(address).To4() != nil {
parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
} else if net.ParseIP(address) != nil {
parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
} else {
return nil, fmt.Errorf("%s is not a valid IP", address)
}
}
addresses = make([]listenAddress, len(parsed))
id := 0
for _, v := range parsed {
addresses[id] = v
id++
}
// Sort addresses before returning to get a stable order
sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })
return addresses, nil
}
// NewOnAddresses creates a new PortForwarder with custom listen addresses.
func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
if len(addresses) == 0 {
return nil, errors.New("you must specify at least 1 address")
}
parsedAddresses, err := parseAddresses(addresses)
if err != nil {
return nil, err
}
if len(ports) == 0 {
return nil, errors.New("you must specify at least 1 port")
}
parsedPorts, err := parsePorts(ports)
if err != nil {
return nil, err
}
return &PortForwarder{
dialer: dialer,
addresses: parsedAddresses,
ports: parsedPorts,
stopChan: stopChan,
errChan: make(chan error, 1),
Ready: readyChan,
out: out,
errOut: errOut,
}, nil
}
// ForwardPorts formats and executes a port forwarding request. The connection will remain
// open until stopChan is closed.
func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
var err error
pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
if err != nil {
return fmt.Errorf("error upgrading connection: %s", err)
}
defer pf.streamConn.Close()
return pf.forward()
}
// forward dials the remote host specific in req, upgrades the request, starts
// listeners for each port specified in ports, and forwards local connections
// to the remote host via streams.
func (pf *PortForwarder) forward() error {
var err error
listenSuccess := false
for i := range pf.ports {
port := &pf.ports[i]
err = pf.listenOnPort(port)
switch {
case err == nil:
listenSuccess = true
default:
if pf.errOut != nil {
fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err)
}
}
}
if !listenSuccess {
return fmt.Errorf("unable to listen on any of the requested ports: %v", pf.ports)
}
if pf.Ready != nil {
close(pf.Ready)
}
// wait for interrupt or conn closure
select {
case <-pf.stopChan:
runtime.HandleError(errors.New("lost connection to pod"))
}
select {
case errs, ok := <-pf.errChan:
if ok {
return errs
}
return nil
default:
return nil
}
}
// listenOnPort delegates listener creation and waits for connections on requested bind addresses.
// An error is raised based on address groups (default and localhost) and their failure modes
func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
var errors []error
failCounters := make(map[string]int, 2)
successCounters := make(map[string]int, 2)
for _, addr := range pf.addresses {
err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address)
if err != nil {
errors = append(errors, err)
failCounters[addr.failureMode]++
} else {
successCounters[addr.failureMode]++
}
}
if successCounters["all"] == 0 && failCounters["all"] > 0 {
return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
}
if failCounters["any"] > 0 {
return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
}
return nil
}
// listenOnPortAndAddress delegates listener creation and waits for new connections
// in the background f
func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error {
listener, err := pf.getListener(protocol, address, port)
if err != nil {
return err
}
pf.listeners = append(pf.listeners, listener)
go pf.waitForConnection(listener, *port)
return nil
}
// getListener creates a listener on the interface targeted by the given hostname on the given port with
// the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6
func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
if err != nil {
return nil, fmt.Errorf("unable to create listener: Error %s", err)
}
listenerAddress := listener.Addr().String()
host, localPort, _ := net.SplitHostPort(listenerAddress)
localPortUInt, err := strconv.ParseUint(localPort, 10, 16)
if err != nil {
fmt.Fprintf(pf.out, "Failed to forward from %s:%d -> %d\n", hostname, localPortUInt, port.Remote)
return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host)
}
port.Local = uint16(localPortUInt)
if pf.out != nil {
fmt.Fprintf(pf.out, "Forwarding from %s -> %d\n", net.JoinHostPort(hostname, strconv.Itoa(int(localPortUInt))), port.Remote)
}
return listener, nil
}
// waitForConnection waits for new connections to listener and handles them in
// the background.
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
for {
conn, err := listener.Accept()
if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
}
return
}
go pf.handleConnection(conn, port)
}
}
func (pf *PortForwarder) nextRequestID() int {
pf.requestIDLock.Lock()
defer pf.requestIDLock.Unlock()
id := pf.requestID
pf.requestID++
return id
}
// handleConnection copies data between the local connection and the stream to
// the remote server.
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
defer conn.Close()
if pf.out != nil {
fmt.Fprintf(pf.out, "Handling connection for %d\n", port.Local)
}
requestID := pf.nextRequestID()
// create error stream
headers := http.Header{}
headers.Set(v1.StreamType, v1.StreamTypeError)
headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
var err error
errorStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
// we're not writing to this stream
errorStream.Close()
errorChan := make(chan error)
go func() {
message, err := io.ReadAll(errorStream)
switch {
case err != nil:
errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
case len(message) > 0:
errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
}
close(errorChan)
}()
// create data stream
headers.Set(v1.StreamType, v1.StreamTypeData)
dataStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
localError := make(chan struct{})
remoteDone := make(chan struct{})
go func() {
// Copy from the remote side to the local port.
if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
}
// inform the select below that the remote copy is done
close(remoteDone)
}()
go func() {
// inform server we're not sending any more data after copy unblocks
defer dataStream.Close()
// Copy from the local port to the remote side.
if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
// break out of the select below without waiting for the other copy to finish
close(localError)
}
}()
// wait for either a local->remote error or for copying from remote->local to finish
select {
case <-remoteDone:
case <-localError:
// wait for interrupt or conn closure
case <-pf.stopChan:
runtime.HandleError(errors.New("lost connection to pod"))
}
// always expect something on errorChan (it may be nil)
select {
case err = <-errorChan:
default:
}
if err != nil {
if strings.Contains(err.Error(), "failed to find socat") {
select {
case pf.errChan <- err:
default:
}
}
runtime.HandleError(err)
}
}
// Close stops all listeners of PortForwarder.
func (pf *PortForwarder) Close() {
// stop all listeners
for _, l := range pf.listeners {
if err := l.Close(); err != nil {
runtime.HandleError(fmt.Errorf("error closing listener: %v", err))
}
}
}
// GetPorts will return the ports that were forwarded; this can be used to
// retrieve the locally-bound port in cases where the input was port 0. This
// function will signal an error if the Ready channel is nil or if the
// listeners are not ready yet; this function will succeed after the Ready
// channel has been closed.
func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
if pf.Ready == nil {
return nil, fmt.Errorf("no ready channel provided")
}
select {
case <-pf.Ready:
return pf.ports, nil
default:
return nil, fmt.Errorf("listeners not ready")
}
}
func (pf *PortForwarder) tryToCreateStream(header *http.Header) (httpstream.Stream, error) {
errorChan := make(chan error, 2)
var resultChan atomic.Value
time.AfterFunc(time.Second*1, func() {
errorChan <- errors.New("timeout")
})
go func() {
if pf.streamConn != nil {
if stream, err := pf.streamConn.CreateStream(*header); err == nil && stream != nil {
errorChan <- nil
resultChan.Store(stream)
return
}
}
errorChan <- errors.New("")
}()
if err := <-errorChan; err == nil && resultChan.Load() != nil {
return resultChan.Load().(httpstream.Stream), nil
}
// close old connection in case of resource leak
if pf.streamConn != nil {
_ = pf.streamConn.Close()
}
var err error
pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
if err != nil {
if k8serrors.IsNotFound(err) {
runtime.HandleError(fmt.Errorf("pod not found: %s", err))
select {
case pf.errChan <- err:
default:
}
} else {
runtime.HandleError(fmt.Errorf("error upgrading connection: %s", err))
}
return nil, err
}
header.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(pf.nextRequestID()))
return pf.streamConn.CreateStream(*header)
}

View File

@@ -47,7 +47,7 @@ func RolloutStatus(ctx1 context.Context, f cmdutil.Factory, ns, workloads string
defer func() { defer func() {
if err != nil { if err != nil {
plog.G(ctx1).Errorf("Rollout status for %s failed: %s", workloads, err.Error()) plog.G(ctx1).Errorf("Rollout status for %s failed: %s", workloads, err.Error())
out := plog.GetLogger(ctx1).Out out := plog.GetLogger(ctx1).Logger.Out
streams := genericiooptions.IOStreams{ streams := genericiooptions.IOStreams{
In: os.Stdin, In: os.Stdin,
Out: out, Out: out,