mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-12-24 11:51:13 +08:00
refactor: add prefix to log
This commit is contained in:
@@ -91,7 +91,7 @@ type Rule struct {
|
||||
PortMap map[int32]string
|
||||
}
|
||||
|
||||
func (a *Virtual) To(enableIPv6 bool, logger *log.Logger) (
|
||||
func (a *Virtual) To(enableIPv6 bool, logger *log.Entry) (
|
||||
listeners []types.Resource,
|
||||
clusters []types.Resource,
|
||||
routes []types.Resource,
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
plog "github.com/wencaiwulue/kubevpn/v2/pkg/log"
|
||||
)
|
||||
|
||||
func Main(ctx context.Context, factory cmdutil.Factory, port uint, logger *log.Logger) error {
|
||||
func Main(ctx context.Context, factory cmdutil.Factory, port uint, logger *log.Entry) error {
|
||||
snapshotCache := cache.NewSnapshotCache(false, cache.IDHash{}, logger)
|
||||
proc := NewProcessor(snapshotCache, logger)
|
||||
|
||||
|
||||
@@ -21,13 +21,13 @@ import (
|
||||
|
||||
type Processor struct {
|
||||
cache cache.SnapshotCache
|
||||
logger *log.Logger
|
||||
logger *log.Entry
|
||||
version int64
|
||||
|
||||
expireCache *utilcache.Expiring
|
||||
}
|
||||
|
||||
func NewProcessor(cache cache.SnapshotCache, log *log.Logger) *Processor {
|
||||
func NewProcessor(cache cache.SnapshotCache, log *log.Entry) *Processor {
|
||||
return &Processor{
|
||||
cache: cache,
|
||||
logger: log,
|
||||
|
||||
@@ -2,6 +2,7 @@ package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
||||
plog "github.com/wencaiwulue/kubevpn/v2/pkg/log"
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||
)
|
||||
|
||||
@@ -47,7 +49,7 @@ func (h *gvisorLocalHandler) Run(ctx context.Context) {
|
||||
readFromEndpointWriteToTun(ctx, endpoint, h.outbound)
|
||||
util.SafeClose(h.errChan)
|
||||
}()
|
||||
s := NewLocalStack(ctx, sniffer.NewWithPrefix(endpoint, "[gVISOR] "))
|
||||
s := NewLocalStack(ctx, sniffer.NewWithPrefix(endpoint, fmt.Sprintf("[gVISOR]%s ", plog.GenStr(plog.GetFields(ctx)))))
|
||||
defer s.Destroy()
|
||||
select {
|
||||
case <-h.errChan:
|
||||
|
||||
@@ -2,6 +2,7 @@ package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
@@ -16,10 +17,11 @@ import (
|
||||
)
|
||||
|
||||
func readFromEndpointWriteToTun(ctx context.Context, endpoint *channel.Endpoint, out chan<- *Packet) {
|
||||
prefix := fmt.Sprintf("[gVISOR]%s ", plog.GenStr(plog.GetFields(ctx)))
|
||||
for ctx.Err() == nil {
|
||||
pkt := endpoint.ReadContext(ctx)
|
||||
if pkt != nil {
|
||||
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pkt.NetworkProtocolNumber, pkt)
|
||||
sniffer.LogPacket(prefix, sniffer.DirectionSend, pkt.NetworkProtocolNumber, pkt)
|
||||
data := pkt.ToView().AsSlice()
|
||||
buf := config.LPool.Get().([]byte)[:]
|
||||
n := copy(buf[1:], data)
|
||||
@@ -30,6 +32,7 @@ func readFromEndpointWriteToTun(ctx context.Context, endpoint *channel.Endpoint,
|
||||
}
|
||||
|
||||
func readFromGvisorInboundWriteToEndpoint(ctx context.Context, in <-chan *Packet, endpoint *channel.Endpoint) {
|
||||
prefix := fmt.Sprintf("[gVISOR]%s ", plog.GenStr(plog.GetFields(ctx)))
|
||||
for ctx.Err() == nil {
|
||||
var packet *Packet
|
||||
select {
|
||||
@@ -60,7 +63,7 @@ func readFromGvisorInboundWriteToEndpoint(ctx context.Context, in <-chan *Packet
|
||||
Payload: buffer.MakeWithData(packet.data[1:packet.length]),
|
||||
})
|
||||
config.LPool.Put(packet.data[:])
|
||||
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionRecv, protocol, pkt)
|
||||
sniffer.LogPacket(prefix, sniffer.DirectionRecv, protocol, pkt)
|
||||
endpoint.InjectInbound(protocol, pkt)
|
||||
pkt.DecRef()
|
||||
}
|
||||
|
||||
@@ -34,6 +34,12 @@ func TunHandler(node *Node, forward *Forwarder) Handler {
|
||||
}
|
||||
|
||||
func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) {
|
||||
tunIfi, err := util.GetTunDeviceByConn(tun)
|
||||
if err != nil {
|
||||
plog.G(ctx).Errorf("Failed to get tun device: %v", err)
|
||||
return
|
||||
}
|
||||
ctx = plog.WithField(ctx, tunIfi.Name, "")
|
||||
if !h.forward.IsEmpty() {
|
||||
h.HandleClient(ctx, tun, h.forward)
|
||||
} else {
|
||||
|
||||
@@ -162,7 +162,7 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
|
||||
config.LPool.Put(buf[:])
|
||||
continue
|
||||
}
|
||||
plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
|
||||
plog.G(plog.WithFields(context.Background(), plog.GetFields(ctx))).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
|
||||
packet := NewPacket(buf[:], n+1, src, dst)
|
||||
if packet.src.Equal(packet.dst) {
|
||||
gvisorInbound <- packet
|
||||
|
||||
@@ -336,7 +336,7 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err
|
||||
}
|
||||
}()
|
||||
}
|
||||
out := plog.G(ctx).Out
|
||||
out := plog.G(ctx).Logger.Out
|
||||
err = util.PortForwardPod(
|
||||
c.config,
|
||||
c.restclient,
|
||||
@@ -345,7 +345,7 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err
|
||||
portPair,
|
||||
readyChan,
|
||||
childCtx.Done(),
|
||||
out,
|
||||
nil,
|
||||
out,
|
||||
)
|
||||
if *first {
|
||||
|
||||
@@ -187,7 +187,7 @@ func (d *SyncOptions) DoSync(ctx context.Context, kubeconfigJsonBytes []byte, im
|
||||
}
|
||||
}
|
||||
{
|
||||
container, err := podcmd.FindOrDefaultContainerByName(&v1.Pod{Spec: v1.PodSpec{Containers: containers}}, d.TargetContainer, false, plog.G(ctx).Out)
|
||||
container, err := podcmd.FindOrDefaultContainerByName(&v1.Pod{Spec: v1.PodSpec{Containers: containers}}, d.TargetContainer, false, plog.G(ctx).Logger.Out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -34,11 +34,59 @@ func WithoutLogger(ctx context.Context) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// GetLogger retrieves the current logger from the context. If no logger is
|
||||
// getLogger retrieves the current logger from the context. If no logger is
|
||||
// available, the default logger is returned.
|
||||
func GetLogger(ctx context.Context) *log.Logger {
|
||||
func getLogger(ctx context.Context) *log.Logger {
|
||||
if logger := ctx.Value(loggerKey{}); logger != nil && logger.(*loggerValue).logger != nil {
|
||||
return logger.(*loggerValue).logger
|
||||
}
|
||||
return L
|
||||
}
|
||||
|
||||
type fieldsKey struct{}
|
||||
|
||||
// WithFields 将指定的字段添加到 context 中,这些字段会在后续从 context 获取 logger 时自动添加
|
||||
func WithFields(ctx context.Context, fields map[string]any) context.Context {
|
||||
existingFields := GetFields(ctx)
|
||||
if existingFields == nil {
|
||||
return context.WithValue(ctx, fieldsKey{}, fields)
|
||||
}
|
||||
|
||||
// 合并字段,新字段会覆盖旧字段
|
||||
mergedFields := make(map[string]any)
|
||||
for k, v := range existingFields {
|
||||
mergedFields[k] = v
|
||||
}
|
||||
for k, v := range fields {
|
||||
mergedFields[k] = v
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, fieldsKey{}, mergedFields)
|
||||
}
|
||||
|
||||
// WithField 将单个字段添加到 context 中
|
||||
func WithField(ctx context.Context, key string, value any) context.Context {
|
||||
return WithFields(ctx, map[string]any{key: value})
|
||||
}
|
||||
|
||||
// GetFields 从 context 中获取所有已存储的字段
|
||||
func GetFields(ctx context.Context) map[string]any {
|
||||
if fields := ctx.Value(fieldsKey{}); fields != nil {
|
||||
if f, ok := fields.(map[string]any); ok {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLogger 从 context 中获取 logger,并自动添加 context 中存储的字段
|
||||
func GetLogger(ctx context.Context) *log.Entry {
|
||||
logger := getLogger(ctx)
|
||||
fields := GetFields(ctx)
|
||||
|
||||
if len(fields) > 0 {
|
||||
return logger.WithFields(fields)
|
||||
}
|
||||
|
||||
return log.NewEntry(logger)
|
||||
}
|
||||
|
||||
@@ -3,24 +3,56 @@ package log
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetLoggerFromContext(t *testing.T) {
|
||||
logger := InitLoggerForServer()
|
||||
ctx := WithLogger(context.Background(), logger)
|
||||
cancel, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
timeout, c := context.WithTimeout(cancel, time.Second*10)
|
||||
defer c()
|
||||
l := GetLogger(timeout)
|
||||
if logger != l {
|
||||
panic("not same")
|
||||
}
|
||||
cancel = WithoutLogger(cancel)
|
||||
defaultLogger := GetLogger(cancel)
|
||||
if defaultLogger != L {
|
||||
panic("not same")
|
||||
}
|
||||
func TestLog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
G(ctx).WithField("tun", "abc").Debug("debug")
|
||||
logger := G(ctx).WithField("tun", "abc").Logger
|
||||
logger.Debug("debug")
|
||||
logger.Warn("warn")
|
||||
}
|
||||
|
||||
func TestWithFields(t *testing.T) {
|
||||
ctx := WithField(context.Background(), "request_id", "12345")
|
||||
ctx = WithField(ctx, "user_id", "user-abc")
|
||||
|
||||
logger := GetLogger(ctx)
|
||||
logger.Info("this log will contains request_id and user_id")
|
||||
|
||||
ctx2 := WithFields(ctx, map[string]any{
|
||||
"action": "login",
|
||||
"ip": "192.168.1.1",
|
||||
})
|
||||
|
||||
logger2 := GetLogger(ctx2)
|
||||
logger2.Info("this log will contains four fields")
|
||||
|
||||
// 在不同方法中使用
|
||||
processRequest(ctx2)
|
||||
}
|
||||
|
||||
func processRequest(ctx context.Context) {
|
||||
logger := GetLogger(ctx)
|
||||
logger.Debug("request handling...")
|
||||
|
||||
logger.WithField("step", "validation").Info("please input validation")
|
||||
}
|
||||
|
||||
func TestWithFieldsMerge(t *testing.T) {
|
||||
ctx := WithFields(context.Background(), map[string]any{
|
||||
"service": "api",
|
||||
"version": "v1",
|
||||
})
|
||||
|
||||
// merge fields
|
||||
ctx = WithFields(ctx, map[string]any{
|
||||
"endpoint": "/users",
|
||||
"method": "GET",
|
||||
})
|
||||
|
||||
ctx = WithField(ctx, "version", "v2")
|
||||
|
||||
logger := GetLogger(ctx)
|
||||
logger.Info("should show all fields,version changed to v2")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -65,6 +66,18 @@ type serverFormat struct {
|
||||
// 2009/01/23 01:23:23 d.go:23: message
|
||||
func (*serverFormat) Format(e *log.Entry) ([]byte, error) {
|
||||
// e.Caller maybe is nil, because pkg/handler/connect.go:252
|
||||
if len(e.Data) > 0 {
|
||||
return []byte(
|
||||
fmt.Sprintf("%s %s %s:%d %s: %s\n",
|
||||
GenStr(e.Data),
|
||||
e.Time.Format("2006-01-02 15:04:05.000"),
|
||||
filepath.Base(ptr.Deref(e.Caller, runtime.Frame{}).File),
|
||||
ptr.Deref(e.Caller, runtime.Frame{}).Line,
|
||||
e.Level.String(),
|
||||
e.Message,
|
||||
)), nil
|
||||
}
|
||||
|
||||
return []byte(
|
||||
fmt.Sprintf("%s %s:%d %s: %s\n",
|
||||
e.Time.Format("2006-01-02 15:04:05.000"),
|
||||
@@ -106,3 +119,38 @@ func (g ServerEmitter) Emit(depth int, level glog.Level, timestamp time.Time, fo
|
||||
message,
|
||||
)
|
||||
}
|
||||
|
||||
func GenStr(allFields map[string]any) string {
|
||||
fieldsStr := ""
|
||||
|
||||
keys := make([]string, len(allFields))
|
||||
i := 0
|
||||
for field := range allFields {
|
||||
keys[i] = field
|
||||
i++
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
var valueStr string
|
||||
value := allFields[key]
|
||||
|
||||
if stringer, ok := value.(fmt.Stringer); ok {
|
||||
valueStr = stringer.String()
|
||||
} else {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
}
|
||||
|
||||
if strings.Contains(valueStr, " ") {
|
||||
valueStr = `"` + valueStr + `"`
|
||||
}
|
||||
if valueStr == "" {
|
||||
fieldsStr += key + " "
|
||||
} else {
|
||||
fieldsStr += key + "=" + valueStr + " "
|
||||
}
|
||||
|
||||
}
|
||||
return fmt.Sprintf("[%s]", strings.TrimSpace(fieldsStr))
|
||||
}
|
||||
|
||||
@@ -113,9 +113,9 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
var openChannelError *gossh.OpenChannelError
|
||||
// if ssh server not permitted ssh port-forward, do nothing until exit
|
||||
if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited {
|
||||
plog.G(ctx).Debugf("Failed to open ssh port-forward to %s: %v", remote.String(), err)
|
||||
plog.G(ctx).Errorf("Failed to open ssh port-forward to %s: %v", remote.String(), err)
|
||||
plog.G(ctx).Errorf("Prohibited to open ssh port-forward to %s: %v", remote.String(), err)
|
||||
cancelFunc1()
|
||||
return
|
||||
}
|
||||
plog.G(ctx).Debugf("Failed to dial into remote %s: %v", remote.String(), err)
|
||||
return
|
||||
|
||||
@@ -1,470 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"k8s.io/api/core/v1"
|
||||
k8serrors "k8s.io/apimachinery/pkg/api/errors"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/apimachinery/pkg/util/runtime"
|
||||
"k8s.io/client-go/tools/portforward"
|
||||
)
|
||||
|
||||
// PortForwarder knows how to listen for local connections and forward them to
|
||||
// a remote pod via an upgraded HTTP request.
|
||||
type PortForwarder struct {
|
||||
addresses []listenAddress
|
||||
ports []ForwardedPort
|
||||
stopChan <-chan struct{}
|
||||
// if failed to find socat, send error
|
||||
// if pod is not found, send error
|
||||
errChan chan error
|
||||
|
||||
dialer httpstream.Dialer
|
||||
streamConn httpstream.Connection
|
||||
listeners []io.Closer
|
||||
Ready chan struct{}
|
||||
requestIDLock sync.Mutex
|
||||
requestID int
|
||||
out io.Writer
|
||||
errOut io.Writer
|
||||
}
|
||||
|
||||
// ForwardedPort contains a Local:Remote port pairing.
|
||||
type ForwardedPort struct {
|
||||
Local uint16
|
||||
Remote uint16
|
||||
}
|
||||
|
||||
/*
|
||||
valid port specifications:
|
||||
|
||||
5000
|
||||
- forwards from localhost:5000 to pod:5000
|
||||
|
||||
8888:5000
|
||||
- forwards from localhost:8888 to pod:5000
|
||||
|
||||
0:5000
|
||||
:5000
|
||||
- selects a random available local port,
|
||||
forwards from localhost:<random port> to pod:5000
|
||||
*/
|
||||
func parsePorts(ports []string) ([]ForwardedPort, error) {
|
||||
var forwards []ForwardedPort
|
||||
for _, portString := range ports {
|
||||
parts := strings.Split(portString, ":")
|
||||
var localString, remoteString string
|
||||
if len(parts) == 1 {
|
||||
localString = parts[0]
|
||||
remoteString = parts[0]
|
||||
} else if len(parts) == 2 {
|
||||
localString = parts[0]
|
||||
if localString == "" {
|
||||
// support :5000
|
||||
localString = "0"
|
||||
}
|
||||
remoteString = parts[1]
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid port format '%s'", portString)
|
||||
}
|
||||
|
||||
localPort, err := strconv.ParseUint(localString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err)
|
||||
}
|
||||
|
||||
remotePort, err := strconv.ParseUint(remoteString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err)
|
||||
}
|
||||
if remotePort == 0 {
|
||||
return nil, fmt.Errorf("remote port must be > 0")
|
||||
}
|
||||
|
||||
forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
|
||||
}
|
||||
|
||||
return forwards, nil
|
||||
}
|
||||
|
||||
type listenAddress struct {
|
||||
address string
|
||||
protocol string
|
||||
failureMode string
|
||||
}
|
||||
|
||||
func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
|
||||
var addresses []listenAddress
|
||||
parsed := make(map[string]listenAddress)
|
||||
for _, address := range addressesToParse {
|
||||
if address == "localhost" {
|
||||
if _, exists := parsed["127.0.0.1"]; !exists {
|
||||
ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
|
||||
parsed[ip.address] = ip
|
||||
}
|
||||
if _, exists := parsed["::1"]; !exists {
|
||||
ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
|
||||
parsed[ip.address] = ip
|
||||
}
|
||||
} else if net.ParseIP(address).To4() != nil {
|
||||
parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
|
||||
} else if net.ParseIP(address) != nil {
|
||||
parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
|
||||
} else {
|
||||
return nil, fmt.Errorf("%s is not a valid IP", address)
|
||||
}
|
||||
}
|
||||
addresses = make([]listenAddress, len(parsed))
|
||||
id := 0
|
||||
for _, v := range parsed {
|
||||
addresses[id] = v
|
||||
id++
|
||||
}
|
||||
// Sort addresses before returning to get a stable order
|
||||
sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// NewOnAddresses creates a new PortForwarder with custom listen addresses.
|
||||
func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
|
||||
if len(addresses) == 0 {
|
||||
return nil, errors.New("you must specify at least 1 address")
|
||||
}
|
||||
parsedAddresses, err := parseAddresses(addresses)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
return nil, errors.New("you must specify at least 1 port")
|
||||
}
|
||||
parsedPorts, err := parsePorts(ports)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PortForwarder{
|
||||
dialer: dialer,
|
||||
addresses: parsedAddresses,
|
||||
ports: parsedPorts,
|
||||
stopChan: stopChan,
|
||||
errChan: make(chan error, 1),
|
||||
Ready: readyChan,
|
||||
out: out,
|
||||
errOut: errOut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardPorts formats and executes a port forwarding request. The connection will remain
|
||||
// open until stopChan is closed.
|
||||
func (pf *PortForwarder) ForwardPorts() error {
|
||||
defer pf.Close()
|
||||
|
||||
var err error
|
||||
pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error upgrading connection: %s", err)
|
||||
}
|
||||
defer pf.streamConn.Close()
|
||||
|
||||
return pf.forward()
|
||||
}
|
||||
|
||||
// forward dials the remote host specific in req, upgrades the request, starts
|
||||
// listeners for each port specified in ports, and forwards local connections
|
||||
// to the remote host via streams.
|
||||
func (pf *PortForwarder) forward() error {
|
||||
var err error
|
||||
|
||||
listenSuccess := false
|
||||
for i := range pf.ports {
|
||||
port := &pf.ports[i]
|
||||
err = pf.listenOnPort(port)
|
||||
switch {
|
||||
case err == nil:
|
||||
listenSuccess = true
|
||||
default:
|
||||
if pf.errOut != nil {
|
||||
fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !listenSuccess {
|
||||
return fmt.Errorf("unable to listen on any of the requested ports: %v", pf.ports)
|
||||
}
|
||||
|
||||
if pf.Ready != nil {
|
||||
close(pf.Ready)
|
||||
}
|
||||
|
||||
// wait for interrupt or conn closure
|
||||
select {
|
||||
case <-pf.stopChan:
|
||||
runtime.HandleError(errors.New("lost connection to pod"))
|
||||
}
|
||||
select {
|
||||
case errs, ok := <-pf.errChan:
|
||||
if ok {
|
||||
return errs
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// listenOnPort delegates listener creation and waits for connections on requested bind addresses.
|
||||
// An error is raised based on address groups (default and localhost) and their failure modes
|
||||
func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
|
||||
var errors []error
|
||||
failCounters := make(map[string]int, 2)
|
||||
successCounters := make(map[string]int, 2)
|
||||
for _, addr := range pf.addresses {
|
||||
err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
failCounters[addr.failureMode]++
|
||||
} else {
|
||||
successCounters[addr.failureMode]++
|
||||
}
|
||||
}
|
||||
if successCounters["all"] == 0 && failCounters["all"] > 0 {
|
||||
return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
|
||||
}
|
||||
if failCounters["any"] > 0 {
|
||||
return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// listenOnPortAndAddress delegates listener creation and waits for new connections
|
||||
// in the background f
|
||||
func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error {
|
||||
listener, err := pf.getListener(protocol, address, port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pf.listeners = append(pf.listeners, listener)
|
||||
go pf.waitForConnection(listener, *port)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getListener creates a listener on the interface targeted by the given hostname on the given port with
|
||||
// the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6
|
||||
func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
|
||||
listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create listener: Error %s", err)
|
||||
}
|
||||
listenerAddress := listener.Addr().String()
|
||||
host, localPort, _ := net.SplitHostPort(listenerAddress)
|
||||
localPortUInt, err := strconv.ParseUint(localPort, 10, 16)
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintf(pf.out, "Failed to forward from %s:%d -> %d\n", hostname, localPortUInt, port.Remote)
|
||||
return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host)
|
||||
}
|
||||
port.Local = uint16(localPortUInt)
|
||||
if pf.out != nil {
|
||||
fmt.Fprintf(pf.out, "Forwarding from %s -> %d\n", net.JoinHostPort(hostname, strconv.Itoa(int(localPortUInt))), port.Remote)
|
||||
}
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// waitForConnection waits for new connections to listener and handles them in
|
||||
// the background.
|
||||
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
|
||||
runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
|
||||
}
|
||||
return
|
||||
}
|
||||
go pf.handleConnection(conn, port)
|
||||
}
|
||||
}
|
||||
|
||||
func (pf *PortForwarder) nextRequestID() int {
|
||||
pf.requestIDLock.Lock()
|
||||
defer pf.requestIDLock.Unlock()
|
||||
id := pf.requestID
|
||||
pf.requestID++
|
||||
return id
|
||||
}
|
||||
|
||||
// handleConnection copies data between the local connection and the stream to
|
||||
// the remote server.
|
||||
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
|
||||
defer conn.Close()
|
||||
|
||||
if pf.out != nil {
|
||||
fmt.Fprintf(pf.out, "Handling connection for %d\n", port.Local)
|
||||
}
|
||||
|
||||
requestID := pf.nextRequestID()
|
||||
// create error stream
|
||||
headers := http.Header{}
|
||||
headers.Set(v1.StreamType, v1.StreamTypeError)
|
||||
headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
|
||||
headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
|
||||
var err error
|
||||
errorStream, err := pf.streamConn.CreateStream(headers)
|
||||
if err != nil {
|
||||
runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
|
||||
return
|
||||
}
|
||||
// we're not writing to this stream
|
||||
errorStream.Close()
|
||||
|
||||
errorChan := make(chan error)
|
||||
go func() {
|
||||
message, err := io.ReadAll(errorStream)
|
||||
switch {
|
||||
case err != nil:
|
||||
errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
case len(message) > 0:
|
||||
errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
|
||||
}
|
||||
close(errorChan)
|
||||
}()
|
||||
|
||||
// create data stream
|
||||
headers.Set(v1.StreamType, v1.StreamTypeData)
|
||||
dataStream, err := pf.streamConn.CreateStream(headers)
|
||||
if err != nil {
|
||||
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
|
||||
return
|
||||
}
|
||||
|
||||
localError := make(chan struct{})
|
||||
remoteDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
// Copy from the remote side to the local port.
|
||||
if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
|
||||
}
|
||||
|
||||
// inform the select below that the remote copy is done
|
||||
close(remoteDone)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// inform server we're not sending any more data after copy unblocks
|
||||
defer dataStream.Close()
|
||||
|
||||
// Copy from the local port to the remote side.
|
||||
if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
|
||||
// break out of the select below without waiting for the other copy to finish
|
||||
close(localError)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for either a local->remote error or for copying from remote->local to finish
|
||||
select {
|
||||
case <-remoteDone:
|
||||
case <-localError:
|
||||
// wait for interrupt or conn closure
|
||||
case <-pf.stopChan:
|
||||
runtime.HandleError(errors.New("lost connection to pod"))
|
||||
}
|
||||
|
||||
// always expect something on errorChan (it may be nil)
|
||||
select {
|
||||
case err = <-errorChan:
|
||||
default:
|
||||
}
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "failed to find socat") {
|
||||
select {
|
||||
case pf.errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
runtime.HandleError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops all listeners of PortForwarder.
|
||||
func (pf *PortForwarder) Close() {
|
||||
// stop all listeners
|
||||
for _, l := range pf.listeners {
|
||||
if err := l.Close(); err != nil {
|
||||
runtime.HandleError(fmt.Errorf("error closing listener: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetPorts will return the ports that were forwarded; this can be used to
|
||||
// retrieve the locally-bound port in cases where the input was port 0. This
|
||||
// function will signal an error if the Ready channel is nil or if the
|
||||
// listeners are not ready yet; this function will succeed after the Ready
|
||||
// channel has been closed.
|
||||
func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
|
||||
if pf.Ready == nil {
|
||||
return nil, fmt.Errorf("no ready channel provided")
|
||||
}
|
||||
select {
|
||||
case <-pf.Ready:
|
||||
return pf.ports, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("listeners not ready")
|
||||
}
|
||||
}
|
||||
|
||||
func (pf *PortForwarder) tryToCreateStream(header *http.Header) (httpstream.Stream, error) {
|
||||
errorChan := make(chan error, 2)
|
||||
var resultChan atomic.Value
|
||||
time.AfterFunc(time.Second*1, func() {
|
||||
errorChan <- errors.New("timeout")
|
||||
})
|
||||
go func() {
|
||||
if pf.streamConn != nil {
|
||||
if stream, err := pf.streamConn.CreateStream(*header); err == nil && stream != nil {
|
||||
errorChan <- nil
|
||||
resultChan.Store(stream)
|
||||
return
|
||||
}
|
||||
}
|
||||
errorChan <- errors.New("")
|
||||
}()
|
||||
if err := <-errorChan; err == nil && resultChan.Load() != nil {
|
||||
return resultChan.Load().(httpstream.Stream), nil
|
||||
}
|
||||
// close old connection in case of resource leak
|
||||
if pf.streamConn != nil {
|
||||
_ = pf.streamConn.Close()
|
||||
}
|
||||
var err error
|
||||
pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
|
||||
if err != nil {
|
||||
if k8serrors.IsNotFound(err) {
|
||||
runtime.HandleError(fmt.Errorf("pod not found: %s", err))
|
||||
select {
|
||||
case pf.errChan <- err:
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
runtime.HandleError(fmt.Errorf("error upgrading connection: %s", err))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
header.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(pf.nextRequestID()))
|
||||
return pf.streamConn.CreateStream(*header)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func RolloutStatus(ctx1 context.Context, f cmdutil.Factory, ns, workloads string
|
||||
defer func() {
|
||||
if err != nil {
|
||||
plog.G(ctx1).Errorf("Rollout status for %s failed: %s", workloads, err.Error())
|
||||
out := plog.GetLogger(ctx1).Out
|
||||
out := plog.GetLogger(ctx1).Logger.Out
|
||||
streams := genericiooptions.IOStreams{
|
||||
In: os.Stdin,
|
||||
Out: out,
|
||||
|
||||
Reference in New Issue
Block a user