optimize slice usage

This commit is contained in:
p_caiwfeng
2022-04-06 15:49:48 +08:00
parent 9fae2e8e0e
commit 7e53c1659a
5 changed files with 44 additions and 520 deletions

View File

@@ -7,6 +7,7 @@ import (
config2 "github.com/wencaiwulue/kubevpn/config"
"github.com/wencaiwulue/kubevpn/pkg"
"github.com/wencaiwulue/kubevpn/util"
"net/http"
)
var config pkg.Route
@@ -24,6 +25,7 @@ var ServerCmd = &cobra.Command{
Long: `serve`,
PreRun: func(*cobra.Command, []string) {
util.InitLogger(config2.Debug)
go func() { log.Info(http.ListenAndServe("localhost:6060", nil)) }()
},
Run: func(cmd *cobra.Command, args []string) {
if err := pkg.Start(context.TODO(), config); err != nil {

View File

@@ -2,6 +2,7 @@ package core
import (
"context"
"errors"
log "github.com/sirupsen/logrus"
"github.com/wencaiwulue/kubevpn/config"
"net"
@@ -17,7 +18,7 @@ func UDPOverTCPTunnelConnector() Connector {
func (c *fakeUDPTunnelConnector) ConnectContext(ctx context.Context, conn net.Conn) (net.Conn, error) {
defer conn.SetDeadline(time.Time{})
return newFakeUDPTunnelConnOverTCP(conn)
return newFakeUDPTunnelConnOverTCP(ctx, conn)
}
type fakeUdpHandler struct {
@@ -54,8 +55,11 @@ var Server8422, _ = net.ResolveUDPAddr("udp", "127.0.0.1:8422")
func (h *fakeUdpHandler) tunnelServerUDP(tcpConn net.Conn, udpConn *net.UDPConn) (err error) {
errChan := make(chan error, 2)
go func() {
b := LPool.Get().([]byte)
defer LPool.Put(b)
for {
dgram, err := ReadDatagramPacket(tcpConn)
dgram, err := readDatagramPacket(tcpConn, b[:])
if err != nil {
log.Debugf("[udp-tun] %s -> 0 : %v", tcpConn.RemoteAddr(), err)
errChan <- err
@@ -78,7 +82,7 @@ func (h *fakeUdpHandler) tunnelServerUDP(tcpConn net.Conn, udpConn *net.UDPConn)
defer MPool.Put(b)
for {
n, err := udpConn.Read(b)
n, err := udpConn.Read(b[:])
if err != nil {
log.Debugf("[udp-tun] %s : %s", tcpConn.RemoteAddr(), err)
errChan <- err
@@ -86,7 +90,7 @@ func (h *fakeUdpHandler) tunnelServerUDP(tcpConn net.Conn, udpConn *net.UDPConn)
}
// pipe from peer to tunnel
dgram := NewDatagramPacket(b[:n])
dgram := newDatagramPacket(b[:n])
if err = dgram.Write(tcpConn); err != nil {
log.Debugf("[tcpserver] udp-tun %s <- %s : %s", tcpConn.RemoteAddr(), dgram.Addr(), err)
errChan <- err
@@ -104,27 +108,30 @@ func (h *fakeUdpHandler) tunnelServerUDP(tcpConn net.Conn, udpConn *net.UDPConn)
type fakeUDPTunnelConn struct {
// tcp connection
net.Conn
ctx context.Context
}
func newFakeUDPTunnelConnOverTCP(conn net.Conn) (net.Conn, error) {
return &fakeUDPTunnelConn{Conn: conn}, nil
func newFakeUDPTunnelConnOverTCP(ctx context.Context, conn net.Conn) (net.Conn, error) {
return &fakeUDPTunnelConn{ctx: ctx, Conn: conn}, nil
}
func (c *fakeUDPTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
dgram, err := ReadDatagramPacket(c.Conn)
if err != nil {
log.Debug(err)
return
func (c *fakeUDPTunnelConn) ReadFrom(b []byte) (int, net.Addr, error) {
select {
case <-c.ctx.Done():
return 0, nil, errors.New("closed connection")
default:
dgram, err := readDatagramPacket(c.Conn, b)
if err != nil {
return 0, nil, err
}
return int(dgram.DataLength), dgram.Addr(), nil
}
n = copy(b, dgram.Data)
addr = dgram.Addr()
return
}
func (c *fakeUDPTunnelConn) WriteTo(b []byte, _ net.Addr) (n int, err error) {
dgram := NewDatagramPacket(b)
if err = dgram.Write(c.Conn); err != nil {
return
func (c *fakeUDPTunnelConn) WriteTo(b []byte, _ net.Addr) (int, error) {
dgram := newDatagramPacket(b)
if err := dgram.Write(c.Conn); err != nil {
return 0, err
}
return len(b), nil
}

View File

@@ -130,11 +130,12 @@ func (h *tunHandler) transportTun(ctx context.Context, tun net.Conn, conn net.Pa
_ = conn.Close()
}()
go func() {
b := SPool.Get().([]byte)
defer SPool.Put(b)
for ctx.Err() == nil {
err := func() error {
b := SPool.Get().([]byte)
defer SPool.Put(b)
n, err := tun.Read(b)
n, err := tun.Read(b[:])
if err != nil {
select {
case h.chExit <- struct{}{}:
@@ -196,12 +197,12 @@ func (h *tunHandler) transportTun(ctx context.Context, tun net.Conn, conn net.Pa
}()
go func() {
b := LPool.Get().([]byte)
defer LPool.Put(b)
for ctx.Err() == nil {
err := func() error {
b := SPool.Get().([]byte)
defer SPool.Put(b)
n, addr, err := conn.ReadFrom(b)
n, addr, err := conn.ReadFrom(b[:])
if err != nil && err != shadowaead.ErrShortPacket {
return err
}

View File

@@ -1,7 +1,6 @@
package core
import (
"bytes"
"encoding/binary"
"fmt"
"io"
@@ -20,7 +19,7 @@ func (addr *datagramPacket) String() string {
return fmt.Sprintf("DataLength: %d, Data: %v\n", addr.DataLength, addr.Data)
}
func NewDatagramPacket(data []byte) (r *datagramPacket) {
func newDatagramPacket(data []byte) (r *datagramPacket) {
return &datagramPacket{
DataLength: uint16(len(data)),
Data: data,
@@ -31,9 +30,7 @@ func (addr *datagramPacket) Addr() net.Addr {
return Server8422
}
func ReadDatagramPacket(r io.Reader) (*datagramPacket, error) {
b := LPool.Get().([]byte)
defer LPool.Put(b)
func readDatagramPacket(r io.Reader, b []byte) (*datagramPacket, error) {
_, err := io.ReadFull(r, b[:2])
if err != nil {
return nil, err
@@ -48,13 +45,13 @@ func ReadDatagramPacket(r io.Reader) (*datagramPacket, error) {
}
func (addr *datagramPacket) Write(w io.Writer) error {
buf := bytes.Buffer{}
i := make([]byte, 2)
binary.BigEndian.PutUint16(i, uint16(len(addr.Data)))
buf.Write(i)
if _, err := buf.Write(addr.Data); err != nil {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b[:], uint16(len(addr.Data)))
if _, err := w.Write(b); err != nil {
return err
}
_, err := buf.WriteTo(w)
return err
if _, err := w.Write(addr.Data); err != nil {
return err
}
return nil
}

View File

@@ -1,483 +0,0 @@
package util
import (
"errors"
"fmt"
containerderrors "github.com/containerd/containerd/errdefs"
"io"
"io/ioutil"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/client-go/tools/portforward"
"net"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/runtime"
)
// PortForwarder knows how to listen for local connections and forward them to
// a remote pod via an upgraded HTTP request.
type PortForwarder struct {
addresses []listenAddress
ports []ForwardedPort
stopChan <-chan struct{}
innerStopChan chan struct{}
dialer httpstream.Dialer
streamConn httpstream.Connection
listeners []io.Closer
Ready chan struct{}
requestIDLock sync.Mutex
requestID int
out io.Writer
errOut io.Writer
}
// ForwardedPort contains a Local:Remote port pairing.
type ForwardedPort struct {
Local uint16
Remote uint16
}
/*
valid port specifications:
5000
- forwards from localhost:5000 to pod:5000
8888:5000
- forwards from localhost:8888 to pod:5000
0:5000
:5000
- selects a random available local port,
forwards from localhost:<random port> to pod:5000
*/
func parsePorts(ports []string) ([]ForwardedPort, error) {
var forwards []ForwardedPort
for _, portString := range ports {
parts := strings.Split(portString, ":")
var localString, remoteString string
if len(parts) == 1 {
localString = parts[0]
remoteString = parts[0]
} else if len(parts) == 2 {
localString = parts[0]
if localString == "" {
// support :5000
localString = "0"
}
remoteString = parts[1]
} else {
return nil, fmt.Errorf("invalid port format '%s'", portString)
}
localPort, err := strconv.ParseUint(localString, 10, 16)
if err != nil {
return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err)
}
remotePort, err := strconv.ParseUint(remoteString, 10, 16)
if err != nil {
return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err)
}
if remotePort == 0 {
return nil, fmt.Errorf("remote port must be > 0")
}
forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
}
return forwards, nil
}
type listenAddress struct {
address string
protocol string
failureMode string
}
func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
var addresses []listenAddress
parsed := make(map[string]listenAddress)
for _, address := range addressesToParse {
if address == "localhost" {
if _, exists := parsed["127.0.0.1"]; !exists {
ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
parsed[ip.address] = ip
}
if _, exists := parsed["::1"]; !exists {
ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
parsed[ip.address] = ip
}
} else if net.ParseIP(address).To4() != nil {
parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
} else if net.ParseIP(address) != nil {
parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
} else {
return nil, fmt.Errorf("%s is not a valid IP", address)
}
}
addresses = make([]listenAddress, len(parsed))
id := 0
for _, v := range parsed {
addresses[id] = v
id++
}
// Sort addresses before returning to get a stable order
sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })
return addresses, nil
}
// New creates a new PortForwarder with localhost listen addresses.
func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
return NewOnAddresses(dialer, []string{"localhost"}, ports, stopChan, readyChan, out, errOut)
}
// NewOnAddresses creates a new PortForwarder with custom listen addresses.
func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
if len(addresses) == 0 {
return nil, errors.New("you must specify at least 1 address")
}
parsedAddresses, err := parseAddresses(addresses)
if err != nil {
return nil, err
}
if len(ports) == 0 {
return nil, errors.New("you must specify at least 1 port")
}
parsedPorts, err := parsePorts(ports)
if err != nil {
return nil, err
}
return &PortForwarder{
dialer: dialer,
addresses: parsedAddresses,
ports: parsedPorts,
stopChan: stopChan,
innerStopChan: make(chan struct{}, 1),
Ready: readyChan,
out: out,
errOut: errOut,
}, nil
}
// ForwardPorts formats and executes a port forwarding request. The connection will remain
// open until stopChan is closed.
func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
var err error
pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
if err != nil {
return fmt.Errorf("error upgrading connection: %s", err)
}
defer pf.streamConn.Close()
return pf.forward()
}
// forward dials the remote host specific in req, upgrades the request, starts
// listeners for each port specified in ports, and forwards local connections
// to the remote host via streams.
func (pf *PortForwarder) forward() error {
var err error
listenSuccess := false
for i := range pf.ports {
port := &pf.ports[i]
err = pf.listenOnPort(port)
switch {
case err == nil:
listenSuccess = true
default:
if pf.errOut != nil {
fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err)
}
}
}
if !listenSuccess {
return fmt.Errorf("unable to listen on any of the requested ports: %v", pf.ports)
}
if pf.Ready != nil {
close(pf.Ready)
}
// wait for interrupt or conn closure
select {
case <-pf.stopChan:
case <-pf.innerStopChan:
runtime.HandleError(errors.New("lost connection to pod"))
}
return nil
}
// listenOnPort delegates listener creation and waits for connections on requested bind addresses.
// An error is raised based on address groups (default and localhost) and their failure modes
func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
var errors []error
failCounters := make(map[string]int, 2)
successCounters := make(map[string]int, 2)
for _, addr := range pf.addresses {
err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address)
if err != nil {
errors = append(errors, err)
failCounters[addr.failureMode]++
} else {
successCounters[addr.failureMode]++
}
}
if successCounters["all"] == 0 && failCounters["all"] > 0 {
return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
}
if failCounters["any"] > 0 {
return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
}
return nil
}
// listenOnPortAndAddress delegates listener creation and waits for new connections
// in the background f
func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error {
listener, err := pf.getListener(protocol, address, port)
if err != nil {
return err
}
pf.listeners = append(pf.listeners, listener)
go pf.waitForConnection(listener, *port)
return nil
}
// getListener creates a listener on the interface targeted by the given hostname on the given port with
// the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6
func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
if err != nil {
return nil, fmt.Errorf("unable to create listener: Error %s", err)
}
listenerAddress := listener.Addr().String()
host, localPort, _ := net.SplitHostPort(listenerAddress)
localPortUInt, err := strconv.ParseUint(localPort, 10, 16)
if err != nil {
fmt.Fprintf(pf.out, "Failed to forward from %s:%d -> %d\n", hostname, localPortUInt, port.Remote)
return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host)
}
port.Local = uint16(localPortUInt)
if pf.out != nil {
fmt.Fprintf(pf.out, "Forwarding from %s -> %d\n", net.JoinHostPort(hostname, strconv.Itoa(int(localPortUInt))), port.Remote)
}
return listener, nil
}
// waitForConnection waits for new connections to listener and handles them in
// the background.
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
for {
conn, err := listener.Accept()
if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
}
return
}
go pf.handleConnection(conn, port)
}
}
func (pf *PortForwarder) nextRequestID() int {
pf.requestIDLock.Lock()
defer pf.requestIDLock.Unlock()
id := pf.requestID
pf.requestID++
return id
}
// handleConnection copies data between the local connection and the stream to
// the remote server.
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
defer conn.Close()
if pf.out != nil {
fmt.Fprintf(pf.out, "Handling connection for %d\n", port.Local)
}
defaultRetry := 5
firstCreateStream:
requestID := pf.nextRequestID()
// create error stream
headers := http.Header{}
headers.Set(v1.StreamType, v1.StreamTypeError)
headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
var err error
errorStream, err := pf.tryToCreateStream(&headers)
if err != nil {
runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
// we're not writing to this stream
errorStream.Close()
errorChan := make(chan error)
go func() {
message, err := ioutil.ReadAll(errorStream)
switch {
case err != nil:
errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
case len(message) > 0:
errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
}
close(errorChan)
}()
// create data stream
headers.Set(v1.StreamType, v1.StreamTypeData)
dataStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
defaultRetry--
if defaultRetry > 0 {
goto firstCreateStream
}
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
localError := make(chan struct{})
remoteDone := make(chan struct{})
go func() {
// Copy from the remote side to the local port.
if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
}
// inform the select below that the remote copy is done
close(remoteDone)
}()
go func() {
// inform server we're not sending any more data after copy unblocks
defer dataStream.Close()
// Copy from the local port to the remote side.
if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
// break out of the select below without waiting for the other copy to finish
close(localError)
}
}()
// wait for either a local->remote error or for copying from remote->local to finish
select {
case <-remoteDone:
case <-localError:
}
// always expect something on errorChan (it may be nil)
err = <-errorChan
if err != nil {
// docker
if IsContainerNotFoundError(err) {
close(pf.innerStopChan)
} else
// containerd
if containerderrors.IsNotFound(err) {
close(pf.innerStopChan)
} else
// others
if strings.Contains(err.Error(), "no such container") ||
strings.Contains(err.Error(), "does not exist") {
close(pf.innerStopChan)
}
runtime.HandleError(err)
}
}
// Close stops all listeners of PortForwarder.
func (pf *PortForwarder) Close() {
// stop all listeners
for _, l := range pf.listeners {
if err := l.Close(); err != nil {
runtime.HandleError(fmt.Errorf("error closing listener: %v", err))
}
}
}
// GetPorts will return the ports that were forwarded; this can be used to
// retrieve the locally-bound port in cases where the input was port 0. This
// function will signal an error if the Ready channel is nil or if the
// listeners are not ready yet; this function will succeed after the Ready
// channel has been closed.
func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
if pf.Ready == nil {
return nil, fmt.Errorf("no Ready channel provided")
}
select {
case <-pf.Ready:
return pf.ports, nil
default:
return nil, fmt.Errorf("listeners not ready")
}
}
func (pf *PortForwarder) tryToCreateStream(header *http.Header) (httpstream.Stream, error) {
errorChan := make(chan error, 2)
var resultChan atomic.Value
time.AfterFunc(time.Second*1, func() {
errorChan <- errors.New("timeout")
})
go func() {
if pf.streamConn != nil {
if stream, err := pf.streamConn.CreateStream(*header); err == nil && stream != nil {
errorChan <- nil
resultChan.Store(stream)
return
}
}
errorChan <- errors.New("")
}()
if err := <-errorChan; err == nil && resultChan.Load() != nil {
return resultChan.Load().(httpstream.Stream), nil
}
// close old connection in case of resource leak
if pf.streamConn != nil {
_ = pf.streamConn.Close()
}
var err error
pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
if err != nil {
if k8serrors.IsNotFound(err) {
runtime.HandleError(fmt.Errorf("pod not found: %s", err))
close(pf.innerStopChan)
} else {
runtime.HandleError(fmt.Errorf("error upgrading connection: %s", err))
}
return nil, err
}
header.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(pf.nextRequestID()))
return pf.streamConn.CreateStream(*header)
}
// containerNotFoundErrorRegx is the regexp of container not found error message.
var containerNotFoundErrorRegx = regexp.MustCompile(`No such container: [0-9a-z]+`)
// IsContainerNotFoundError checks whether the error is container not found error.
func IsContainerNotFoundError(err error) bool {
return containerNotFoundErrorRegx.MatchString(err.Error())
}