diff --git a/TODO.MD b/TODO.MD index 8287f0be..7c8e709d 100644 --- a/TODO.MD +++ b/TODO.MD @@ -12,4 +12,5 @@ - [ ] 加入 TLS 以提高安全性 - [ ] 写个 CNI 网络插件,直接提供 VPN 功能 - [ ] 优化重连逻辑 +- [ ] 支持 service mesh diff --git a/tun/tun.go b/tun/tun.go index 2b3ef168..253a6988 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -27,27 +27,23 @@ type tunListener struct { // TunListener creates a listener for tun tunnel. func TunListener(cfg Config) (Listener, error) { - threads := 1 ln := &tunListener{ - conns: make(chan net.Conn, threads), + conns: make(chan net.Conn, 1), closed: make(chan struct{}), config: cfg, } - for i := 0; i < threads; i++ { - conn, ifce, err := createTun(cfg) - if err != nil { - return nil, err - } - ln.addr = conn.LocalAddr() - - addrs, _ := ifce.Addrs() - _ = os.Setenv("tunName", ifce.Name) - log.Debugf("[tun] %s: name: %s, mtu: %d, addrs: %s", - conn.LocalAddr(), ifce.Name, ifce.MTU, addrs) - - ln.conns <- conn + conn, ifce, err := createTun(cfg) + if err != nil { + return nil, err } + ln.addr = conn.LocalAddr() + + addrs, _ := ifce.Addrs() + _ = os.Setenv("tunName", ifce.Name) + log.Debugf("[tun] %s: name: %s, mtu: %d, addrs: %s", conn.LocalAddr(), ifce.Name, ifce.MTU, addrs) + + ln.conns <- conn return ln, nil } @@ -106,16 +102,16 @@ func (c *tunConn) RemoteAddr() net.Addr { return &net.IPAddr{} } -func (c *tunConn) SetDeadline(t time.Time) error { +func (c *tunConn) SetDeadline(time.Time) error { return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (c *tunConn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +func (c *tunConn) SetReadDeadline(time.Time) error { + return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("read deadline not supported")} } -func (c *tunConn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +func (c *tunConn) SetWriteDeadline(time.Time) error { + return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("write deadline not supported")} } // IPRoute is an IP routing entry. diff --git a/util/portforward.go b/util/portforward.go index bc2f65c0..c93367ff 100644 --- a/util/portforward.go +++ b/util/portforward.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "io/ioutil" + k8serrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/client-go/tools/portforward" "net" "net/http" @@ -23,9 +24,10 @@ import ( // PortForwarder knows how to listen for local connections and forward them to // a remote pod via an upgraded HTTP request. type PortForwarder struct { - addresses []listenAddress - ports []ForwardedPort - stopChan <-chan struct{} + addresses []listenAddress + ports []ForwardedPort + stopChan <-chan struct{} + innerStopChan chan struct{} dialer httpstream.Dialer streamConn httpstream.Connection @@ -156,13 +158,14 @@ func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string return nil, err } return &PortForwarder{ - dialer: dialer, - addresses: parsedAddresses, - ports: parsedPorts, - stopChan: stopChan, - Ready: readyChan, - out: out, - errOut: errOut, + dialer: dialer, + addresses: parsedAddresses, + ports: parsedPorts, + stopChan: stopChan, + innerStopChan: make(chan struct{}, 1), + Ready: readyChan, + out: out, + errOut: errOut, }, nil } @@ -212,8 +215,8 @@ func (pf *PortForwarder) forward() error { // wait for interrupt or conn closure select { case <-pf.stopChan: - //case <-pf.streamConn.CloseChan(): - // runtime.HandleError(errors.New("lost connection to pod")) + case <-pf.innerStopChan: + runtime.HandleError(errors.New("lost connection to pod")) } return nil @@ -420,7 +423,7 @@ func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) { func (pf *PortForwarder) tryToCreateStream(header *http.Header) (httpstream.Stream, error) { errorChan := make(chan error, 2) - var value atomic.Value + var resultChan atomic.Value time.AfterFunc(time.Second*1, func() { errorChan <- errors.New("timeout") }) @@ -428,14 +431,14 @@ func (pf *PortForwarder) tryToCreateStream(header *http.Header) (httpstream.Stre if pf.streamConn != nil { if stream, err := pf.streamConn.CreateStream(*header); err == nil && stream != nil { errorChan <- nil - value.Store(stream) + resultChan.Store(stream) return } } errorChan <- errors.New("") }() - if err := <-errorChan; err == nil { - return value.Load().(httpstream.Stream), nil + if err := <-errorChan; err == nil && resultChan.Load() != nil { + return resultChan.Load().(httpstream.Stream), nil } // close old connection in case of resource leak if pf.streamConn != nil { @@ -444,7 +447,12 @@ func (pf *PortForwarder) tryToCreateStream(header *http.Header) (httpstream.Stre var err error pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name) if err != nil { - runtime.HandleError(fmt.Errorf("error upgrading connection: %s", err)) + if k8serrors.IsNotFound(err) { + runtime.HandleError(fmt.Errorf("pod not found: %s", err)) + close(pf.innerStopChan) + } else { + runtime.HandleError(fmt.Errorf("error upgrading connection: %s", err)) + } return nil, err } header.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(pf.nextRequestID()))