Files
kubevpn/pkg/util/util.go
2024-07-19 22:07:35 +08:00

419 lines
11 KiB
Go

package util
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"os"
osexec "os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
rbacv1 "k8s.io/api/rbac/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/fields"
k8sruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/kubernetes/scheme"
"k8s.io/client-go/tools/cache"
watchtools "k8s.io/client-go/tools/watch"
"k8s.io/client-go/util/retry"
cmdutil "k8s.io/kubectl/pkg/cmd/util"
"k8s.io/kubectl/pkg/polymorphichelpers"
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
"github.com/wencaiwulue/kubevpn/v2/pkg/driver"
)
func GetAvailableUDPPortOrDie() (int, error) {
address, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:0", "localhost"))
if err != nil {
return 0, err
}
listener, err := net.ListenUDP("udp", address)
if err != nil {
return 0, err
}
defer listener.Close()
return listener.LocalAddr().(*net.UDPAddr).Port, nil
}
func GetAvailableTCPPortOrDie() (int, error) {
address, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:0", "localhost"))
if err != nil {
return 0, err
}
listener, err := net.ListenTCP("tcp", address)
if err != nil {
return 0, err
}
defer listener.Close()
return listener.Addr().(*net.TCPAddr).Port, nil
}
func IsWindows() bool {
return runtime.GOOS == "windows"
}
func BytesToInt(b []byte) uint32 {
buffer := bytes.NewBuffer(b)
var u uint32
if err := binary.Read(buffer, binary.BigEndian, &u); err != nil {
log.Warn(err)
}
return u
}
func RolloutStatus(ctx1 context.Context, factory cmdutil.Factory, namespace, workloads string, timeout time.Duration) (err error) {
log.Infof("rollout status for %s", workloads)
defer func() {
if err != nil {
log.Errorf("rollout status for %s failed: %s", workloads, err.Error())
} else {
log.Infof("rollout status for %s successfully", workloads)
}
}()
client, _ := factory.DynamicClient()
r := factory.NewBuilder().
WithScheme(scheme.Scheme, scheme.Scheme.PrioritizedVersionsAllGroups()...).
NamespaceParam(namespace).DefaultNamespace().
ResourceTypeOrNameArgs(true, workloads).
SingleResourceType().
Latest().
Do()
err = r.Err()
if err != nil {
return err
}
infos, err := r.Infos()
if err != nil {
return err
}
if len(infos) != 1 {
return fmt.Errorf("rollout status is only supported on individual resources and resource collections - %d resources were found", len(infos))
}
info := infos[0]
mapping := info.ResourceMapping()
statusViewer, err := polymorphichelpers.StatusViewerFn(mapping)
if err != nil {
return err
}
fieldSelector := fields.OneTermEqualSelector("metadata.name", info.Name).String()
lw := &cache.ListWatch{
ListFunc: func(options metav1.ListOptions) (k8sruntime.Object, error) {
options.FieldSelector = fieldSelector
return client.Resource(info.Mapping.Resource).Namespace(info.Namespace).List(ctx1, options)
},
WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) {
options.FieldSelector = fieldSelector
return client.Resource(info.Mapping.Resource).Namespace(info.Namespace).Watch(ctx1, options)
},
}
// if the rollout isn't done yet, keep watching deployment status
ctx, cancel := watchtools.ContextWithOptionalTimeout(ctx1, timeout)
defer cancel()
return func() error {
_, err = watchtools.UntilWithSync(ctx, lw, &unstructured.Unstructured{}, nil, func(e watch.Event) (bool, error) {
switch t := e.Type; t {
case watch.Added, watch.Modified:
status, done, err := statusViewer.Status(e.Object.(k8sruntime.Unstructured), 0)
if err != nil {
return false, err
}
log.Info(strings.TrimSpace(status))
// Quit waiting if the rollout is done
if done {
return true, nil
}
return false, nil
case watch.Deleted:
// We need to abort to avoid cases of recreation and not to silently watch the wrong (new) object
return true, fmt.Errorf("object has been deleted")
default:
return true, fmt.Errorf("internal error: unexpected event %#v", e)
}
})
return err
}()
}
type proxyWriter struct {
*bytes.Buffer
checker func(log string)
}
func (w *proxyWriter) Write(b []byte) (int, error) {
write, err := w.Buffer.Write(b)
if w.checker != nil {
w.checker(w.Buffer.String())
}
return write, err
}
func RunWithRollingOutWithChecker(cmd *osexec.Cmd, checker func(log string)) (string, string, error) {
stdoutBuf := &proxyWriter{Buffer: bytes.NewBuffer(make([]byte, 0)), checker: checker}
stderrBuf := &proxyWriter{Buffer: bytes.NewBuffer(make([]byte, 0)), checker: checker}
stdoutPipe, _ := cmd.StdoutPipe()
stderrPipe, _ := cmd.StderrPipe()
stdout := io.MultiWriter(os.Stdout, stdoutBuf)
stderr := io.MultiWriter(os.Stderr, stderrBuf)
go func() {
_, _ = io.Copy(stdout, stdoutPipe)
}()
go func() {
_, _ = io.Copy(stderr, stderrPipe)
}()
if err := cmd.Start(); err != nil {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
return stdoutBuf.String(), stderrBuf.String(), err
}
if err := cmd.Wait(); err != nil {
return "", "", err
}
var err error
if !cmd.ProcessState.Success() {
err = errors.New("exit code is not 0")
}
stdoutStr := strings.TrimSpace(stdoutBuf.String())
stderrStr := strings.TrimSpace(stderrBuf.String())
return stdoutStr, stderrStr, err
}
func WaitPortToBeFree(ctx context.Context, port int) error {
log.Infoln(fmt.Sprintf("wait port %v to be free...", port))
ticker := time.NewTicker(time.Second * 2)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("wait port %v to be free timeout", port)
case <-ticker.C:
if !IsPortListening(port) {
log.Infoln(fmt.Sprintf("port %v are free", port))
return nil
}
}
}
}
func IsPortListening(port int) bool {
listener, err := net.Listen("tcp4", net.JoinHostPort("localhost", strconv.Itoa(port)))
if err != nil {
return true
} else {
_ = listener.Close()
return false
}
}
func CanI(clientset *kubernetes.Clientset, sa, ns string, resource *rbacv1.PolicyRule) (allowed bool, err error) {
var roleBindingList *rbacv1.RoleBindingList
roleBindingList, err = clientset.RbacV1().RoleBindings(ns).List(context.Background(), metav1.ListOptions{})
if err != nil {
return false, err
}
for _, item := range roleBindingList.Items {
for _, subject := range item.Subjects {
if subject.Name == sa && subject.Kind == "ServiceAccount" {
var role *rbacv1.Role
role, err = clientset.RbacV1().Roles(ns).Get(context.Background(), item.RoleRef.Name, metav1.GetOptions{})
if err != nil {
return false, err
}
for _, rule := range role.Rules {
if sets.New[string](rule.Resources...).HasAll(resource.Resources...) && sets.New[string](rule.Verbs...).HasAll(resource.Verbs...) {
if len(rule.ResourceNames) == 0 || sets.New[string](rule.ResourceNames...).HasAll(resource.ResourceNames...) {
return true, nil
}
}
}
}
}
}
var clusterRoleBindingList *rbacv1.ClusterRoleBindingList
clusterRoleBindingList, err = clientset.RbacV1().ClusterRoleBindings().List(context.Background(), metav1.ListOptions{})
for _, item := range clusterRoleBindingList.Items {
for _, subject := range item.Subjects {
if subject.Name == sa && subject.Kind == "ServiceAccount" {
var role *rbacv1.ClusterRole
role, err = clientset.RbacV1().ClusterRoles().Get(context.Background(), item.RoleRef.Name, metav1.GetOptions{})
if err != nil {
return false, err
}
for _, rule := range role.Rules {
if sets.New[string](rule.Resources...).HasAll(resource.Resources...) && sets.New[string](rule.Verbs...).HasAll(resource.Verbs...) {
if len(rule.ResourceNames) == 0 || sets.New[string](rule.ResourceNames...).HasAll(resource.ResourceNames...) {
return true, nil
}
}
}
}
}
}
return false, nil
}
func DoReq(request *http.Request) (body []byte, err error) {
cert, ok := os.LookupEnv(config.TLSCertKey)
if !ok {
return nil, fmt.Errorf("can not get %s from env", config.TLSCertKey)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM([]byte(cert))
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
},
},
}
var resp *http.Response
resp, err = client.Do(request)
if err != nil {
return nil, fmt.Errorf("err: %v", err)
}
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("can not read body, err: %v", err)
}
if resp.StatusCode == http.StatusOK {
return body, nil
}
return body, fmt.Errorf("http status is %d", resp.StatusCode)
}
func GetTlsDomain(namespace string) string {
return config.ConfigMapPodTrafficManager + "." + namespace + "." + "svc"
}
func Deduplicate(cidr []*net.IPNet) (result []*net.IPNet) {
var set = sets.New[string]()
for _, ipNet := range cidr {
if ipNet == nil {
continue
}
if !set.Has(ipNet.String()) {
result = append(result, ipNet)
}
set.Insert(ipNet.String())
}
return
}
func CleanExtensionLib() {
if !IsWindows() {
return
}
path, err := os.Executable()
if err != nil {
return
}
filename := filepath.Join(filepath.Dir(path), "wintun.dll")
_ = retry.OnError(
retry.DefaultRetry,
func(error) bool {
_, err = os.Lstat(filename)
return !errors.Is(err, os.ErrNotExist)
},
func() error {
err = driver.UninstallWireGuardTunDriver()
return fmt.Errorf("%v", err)
},
)
_, err = os.Lstat(filename)
if errors.Is(err, os.ErrNotExist) {
return
}
MoveToTemp()
}
func Print(writer io.Writer, slogan string) {
length := len(slogan) + 4 + 4
var sb strings.Builder
sb.WriteString("+" + strings.Repeat("-", length) + "+")
sb.WriteByte('\n')
sb.WriteString("|")
sb.WriteString(strings.Repeat(" ", 4))
sb.WriteString(slogan)
sb.WriteString(strings.Repeat(" ", 4))
sb.WriteString("|")
sb.WriteByte('\n')
sb.WriteString("+" + strings.Repeat("-", length) + "+")
sb.WriteByte('\n')
_, _ = writer.Write([]byte(sb.String()))
}
func StartupPProf(port int) {
_ = http.ListenAndServe(fmt.Sprintf("localhost:%d", port), nil)
}
func MoveToTemp() {
path, err := os.Executable()
if err != nil {
return
}
filename := filepath.Join(filepath.Dir(path), "wintun.dll")
_, err = os.Lstat(filename)
if errors.Is(err, os.ErrNotExist) {
return
}
var temp *os.File
if temp, err = os.CreateTemp("", ""); err != nil {
return
}
if err = temp.Close(); err != nil {
return
}
if err = os.Remove(temp.Name()); err != nil {
return
}
if err = os.Rename(filename, temp.Name()); err != nil {
log.Debugln(err)
}
}
func Merge[K comparable, V any](fromMap, ToMap map[K]V) map[K]V {
if fromMap == nil {
return ToMap
}
for k, v := range ToMap {
fromMap[k] = v
}
return fromMap
}