Files
gvisor-tap-vsock/cmd/gvproxy/main.go
openshift-merge-bot[bot] 45ea6bbb97 Merge pull request #429 from lstocchi/i425
Add --services flag to start API without using --listen flag
2025-01-30 16:01:54 +00:00

586 lines
15 KiB
Go

package main
import (
"bufio"
"context"
"flag"
"fmt"
"net"
"net/http"
"net/http/pprof"
"net/url"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/containers/gvisor-tap-vsock/pkg/net/stdio"
"github.com/containers/gvisor-tap-vsock/pkg/sshclient"
"github.com/containers/gvisor-tap-vsock/pkg/transport"
"github.com/containers/gvisor-tap-vsock/pkg/types"
"github.com/containers/gvisor-tap-vsock/pkg/virtualnetwork"
"github.com/containers/winquit/pkg/winquit"
"github.com/dustin/go-humanize"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)
var (
debug bool
mtu int
endpoints arrayFlags
vpnkitSocket string
qemuSocket string
bessSocket string
stdioSocket string
vfkitSocket string
forwardSocket arrayFlags
forwardDest arrayFlags
forwardUser arrayFlags
forwardIdentify arrayFlags
sshPort int
pidFile string
exitCode int
logFile string
servicesEndpoint string
)
const (
gatewayIP = "192.168.127.1"
sshHostPort = "192.168.127.2:22"
hostIP = "192.168.127.254"
host = "host"
gateway = "gateway"
)
func main() {
version := types.NewVersion("gvproxy")
version.AddFlag()
flag.Var(&endpoints, "listen", "control endpoint")
flag.BoolVar(&debug, "debug", false, "Print debug info")
flag.IntVar(&mtu, "mtu", 1500, "Set the MTU")
flag.IntVar(&sshPort, "ssh-port", 2222, "Port to access the guest virtual machine. Must be between 1024 and 65535")
flag.StringVar(&vpnkitSocket, "listen-vpnkit", "", "VPNKit socket to be used by Hyperkit")
flag.StringVar(&qemuSocket, "listen-qemu", "", "Socket to be used by Qemu")
flag.StringVar(&bessSocket, "listen-bess", "", "unixpacket socket to be used by Bess-compatible applications")
flag.StringVar(&stdioSocket, "listen-stdio", "", "accept stdio pipe")
flag.StringVar(&vfkitSocket, "listen-vfkit", "", "unixgram socket to be used by vfkit-compatible applications")
flag.Var(&forwardSocket, "forward-sock", "Forwards a unix socket to the guest virtual machine over SSH")
flag.Var(&forwardDest, "forward-dest", "Forwards a unix socket to the guest virtual machine over SSH")
flag.Var(&forwardUser, "forward-user", "SSH user to use for unix socket forward")
flag.Var(&forwardIdentify, "forward-identity", "Path to SSH identity key for forwarding")
flag.StringVar(&pidFile, "pid-file", "", "Generate a file with the PID in it")
flag.StringVar(&logFile, "log-file", "", "Output log messages (logrus) to a given file path")
flag.StringVar(&servicesEndpoint, "services", "", "Exposes the same HTTP API as the --listen flag, without the /connect endpoint")
flag.Parse()
if version.ShowVersion() {
fmt.Println(version.String())
os.Exit(0)
}
// If the user provides a log-file, we re-direct log messages
// from logrus to the file
if logFile != "" {
lf, err := os.Create(logFile)
if err != nil {
fmt.Printf("unable to open log file %s, exiting...\n", logFile)
os.Exit(1)
}
defer func() {
if err := lf.Close(); err != nil {
fmt.Printf("unable to close log-file: %q\n", err)
}
}()
log.SetOutput(lf)
// If debug is set, lets seed the log file with some basic information
// about the environment and how it was called
log.Debugf("gvproxy version: %q", version.String())
log.Debugf("os: %q arch: %q", runtime.GOOS, runtime.GOARCH)
log.Debugf("command line: %q", os.Args)
}
log.Info(version.String())
ctx, cancel := context.WithCancel(context.Background())
// Make this the last defer statement in the stack
defer os.Exit(exitCode)
groupErrs, ctx := errgroup.WithContext(ctx)
// Setup signal channel for catching user signals
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
if debug {
log.SetLevel(log.DebugLevel)
}
// Intercept WM_QUIT/WM_CLOSE events if on Windows as SIGTERM (noop on other OSs)
winquit.SimulateSigTermOnQuit(sigChan)
// Make sure the qemu socket provided is valid syntax
if len(qemuSocket) > 0 {
uri, err := url.Parse(qemuSocket)
if err != nil || uri == nil {
exitWithError(errors.Wrapf(err, "invalid value for listen-qemu"))
}
if _, err := os.Stat(uri.Path); err == nil && uri.Scheme == "unix" {
exitWithError(errors.Errorf("%q already exists", uri.Path))
}
}
if len(bessSocket) > 0 {
uri, err := url.Parse(bessSocket)
if err != nil || uri == nil {
exitWithError(errors.Wrapf(err, "invalid value for listen-bess"))
}
if uri.Scheme != "unixpacket" {
exitWithError(errors.New("listen-bess must be unixpacket:// address"))
}
if _, err := os.Stat(uri.Path); err == nil {
exitWithError(errors.Errorf("%q already exists", uri.Path))
}
}
if len(vfkitSocket) > 0 {
uri, err := url.Parse(vfkitSocket)
if err != nil || uri == nil {
exitWithError(errors.Wrapf(err, "invalid value for listen-vfkit"))
}
if uri.Scheme != "unixgram" {
exitWithError(errors.New("listen-vfkit must be unixgram:// address"))
}
if _, err := os.Stat(uri.Path); err == nil {
exitWithError(errors.Errorf("%q already exists", uri.Path))
}
}
if vpnkitSocket != "" && qemuSocket != "" {
exitWithError(errors.New("cannot use qemu and vpnkit protocol at the same time"))
}
if vpnkitSocket != "" && bessSocket != "" {
exitWithError(errors.New("cannot use bess and vpnkit protocol at the same time"))
}
if qemuSocket != "" && bessSocket != "" {
exitWithError(errors.New("cannot use qemu and bess protocol at the same time"))
}
// If the given port is not between the privileged ports
// and the oft considered maximum port, return an error.
if sshPort != -1 && sshPort < 1024 || sshPort > 65535 {
exitWithError(errors.New("ssh-port value must be between 1024 and 65535"))
}
protocol := types.HyperKitProtocol
if qemuSocket != "" {
protocol = types.QemuProtocol
}
if bessSocket != "" {
protocol = types.BessProtocol
}
if vfkitSocket != "" {
protocol = types.VfkitProtocol
}
if c := len(forwardSocket); c != len(forwardDest) || c != len(forwardUser) || c != len(forwardIdentify) {
exitWithError(errors.New("--forward-sock, --forward-dest, --forward-user, and --forward-identity must all be specified together, " +
"the same number of times, or not at all"))
}
for i := 0; i < len(forwardSocket); i++ {
_, err := os.Stat(forwardIdentify[i])
if err != nil {
exitWithError(errors.Wrapf(err, "Identity file %s can't be loaded", forwardIdentify[i]))
}
}
// Create a PID file if requested
if len(pidFile) > 0 {
f, err := os.Create(pidFile)
if err != nil {
exitWithError(err)
}
// Remove the pid-file when exiting
defer func() {
if err := os.Remove(pidFile); err != nil {
log.Error(err)
}
}()
pid := os.Getpid()
if _, err := f.WriteString(strconv.Itoa(pid)); err != nil {
exitWithError(err)
}
}
config := types.Configuration{
Debug: debug,
CaptureFile: captureFile(),
MTU: mtu,
Subnet: "192.168.127.0/24",
GatewayIP: gatewayIP,
GatewayMacAddress: "5a:94:ef:e4:0c:dd",
DHCPStaticLeases: map[string]string{
"192.168.127.2": "5a:94:ef:e4:0c:ee",
},
DNS: []types.Zone{
{
Name: "containers.internal.",
Records: []types.Record{
{
Name: gateway,
IP: net.ParseIP(gatewayIP),
},
{
Name: host,
IP: net.ParseIP(hostIP),
},
},
},
{
Name: "docker.internal.",
Records: []types.Record{
{
Name: gateway,
IP: net.ParseIP(gatewayIP),
},
{
Name: host,
IP: net.ParseIP(hostIP),
},
},
},
},
DNSSearchDomains: searchDomains(),
Forwards: getForwardsMap(sshPort, sshHostPort),
NAT: map[string]string{
hostIP: "127.0.0.1",
},
GatewayVirtualIPs: []string{hostIP},
VpnKitUUIDMacAddresses: map[string]string{
"c3d68012-0208-11ea-9fd7-f2189899ab08": "5a:94:ef:e4:0c:ee",
},
Protocol: protocol,
}
groupErrs.Go(func() error {
return run(ctx, groupErrs, &config, endpoints, servicesEndpoint)
})
// Wait for something to happen
groupErrs.Go(func() error {
select {
// Catch signals so exits are graceful and defers can run
case <-sigChan:
cancel()
return errors.New("signal caught")
case <-ctx.Done():
return nil
}
})
// Wait for all of the go funcs to finish up
if err := groupErrs.Wait(); err != nil {
log.Errorf("gvproxy exiting: %v", err)
exitCode = 1
}
}
func getForwardsMap(sshPort int, sshHostPort string) map[string]string {
if sshPort == -1 {
return map[string]string{}
}
return map[string]string{
fmt.Sprintf("127.0.0.1:%d", sshPort): sshHostPort,
}
}
type arrayFlags []string
func (i *arrayFlags) String() string {
return "my string representation"
}
func (i *arrayFlags) Set(value string) error {
*i = append(*i, value)
return nil
}
func captureFile() string {
if !debug {
return ""
}
return "capture.pcap"
}
func run(ctx context.Context, g *errgroup.Group, configuration *types.Configuration, endpoints []string, servicesEndpoint string) error {
vn, err := virtualnetwork.New(configuration)
if err != nil {
return err
}
log.Info("waiting for clients...")
for _, endpoint := range endpoints {
log.Infof("listening %s", endpoint)
ln, err := transport.Listen(endpoint)
if err != nil {
return errors.Wrap(err, "cannot listen")
}
httpServe(ctx, g, ln, withProfiler(vn))
}
if servicesEndpoint != "" {
log.Infof("enabling services API. Listening %s", servicesEndpoint)
ln, err := transport.Listen(servicesEndpoint)
if err != nil {
return errors.Wrap(err, "cannot listen")
}
httpServe(ctx, g, ln, vn.ServicesMux())
}
ln, err := vn.Listen("tcp", fmt.Sprintf("%s:80", gatewayIP))
if err != nil {
return err
}
mux := http.NewServeMux()
mux.Handle("/services/forwarder/all", vn.Mux())
mux.Handle("/services/forwarder/expose", vn.Mux())
mux.Handle("/services/forwarder/unexpose", vn.Mux())
httpServe(ctx, g, ln, mux)
if debug {
g.Go(func() error {
debugLog:
for {
select {
case <-time.After(5 * time.Second):
log.Debugf("%v sent to the VM, %v received from the VM\n", humanize.Bytes(vn.BytesSent()), humanize.Bytes(vn.BytesReceived()))
case <-ctx.Done():
break debugLog
}
}
return nil
})
}
if vpnkitSocket != "" {
vpnkitListener, err := transport.Listen(vpnkitSocket)
if err != nil {
return errors.Wrap(err, "vpnkit listen error")
}
g.Go(func() error {
vpnloop:
for {
select {
case <-ctx.Done():
break vpnloop
default:
// pass through
}
conn, err := vpnkitListener.Accept()
if err != nil {
log.Errorf("vpnkit accept error: %s", err)
continue
}
g.Go(func() error {
return vn.AcceptVpnKit(conn)
})
}
return nil
})
}
if qemuSocket != "" {
qemuListener, err := transport.Listen(qemuSocket)
if err != nil {
return errors.Wrap(err, "qemu listen error")
}
g.Go(func() error {
<-ctx.Done()
if err := qemuListener.Close(); err != nil {
log.Errorf("error closing %s: %q", qemuSocket, err)
}
return os.Remove(qemuSocket)
})
g.Go(func() error {
conn, err := qemuListener.Accept()
if err != nil {
return errors.Wrap(err, "qemu accept error")
}
return vn.AcceptQemu(ctx, conn)
})
}
if bessSocket != "" {
bessListener, err := transport.Listen(bessSocket)
if err != nil {
return errors.Wrap(err, "bess listen error")
}
g.Go(func() error {
<-ctx.Done()
if err := bessListener.Close(); err != nil {
log.Errorf("error closing %s: %q", bessSocket, err)
}
return os.Remove(bessSocket)
})
g.Go(func() error {
conn, err := bessListener.Accept()
if err != nil {
return errors.Wrap(err, "bess accept error")
}
return vn.AcceptBess(ctx, conn)
})
}
if vfkitSocket != "" {
conn, err := transport.ListenUnixgram(vfkitSocket)
if err != nil {
return errors.Wrap(err, "vfkit listen error")
}
g.Go(func() error {
<-ctx.Done()
if err := conn.Close(); err != nil {
log.Errorf("error closing %s: %q", vfkitSocket, err)
}
return os.Remove(vfkitSocket)
})
g.Go(func() error {
vfkitConn, err := transport.AcceptVfkit(conn)
if err != nil {
return errors.Wrap(err, "vfkit accept error")
}
return vn.AcceptVfkit(ctx, vfkitConn)
})
}
if stdioSocket != "" {
g.Go(func() error {
conn := stdio.GetStdioConn()
return vn.AcceptStdio(ctx, conn)
})
}
for i := 0; i < len(forwardSocket); i++ {
var (
src *url.URL
err error
)
if strings.Contains(forwardSocket[i], "://") {
src, err = url.Parse(forwardSocket[i])
if err != nil {
return err
}
} else {
src = &url.URL{
Scheme: "unix",
Path: forwardSocket[i],
}
}
dest := &url.URL{
Scheme: "ssh",
User: url.User(forwardUser[i]),
Host: sshHostPort,
Path: forwardDest[i],
}
j := i
g.Go(func() error {
defer os.Remove(forwardSocket[j])
forward, err := sshclient.CreateSSHForward(ctx, src, dest, forwardIdentify[j], vn)
if err != nil {
return err
}
go func() {
<-ctx.Done()
// Abort pending accepts
forward.Close()
}()
loop:
for {
select {
case <-ctx.Done():
break loop
default:
// proceed
}
err := forward.AcceptAndTunnel(ctx)
if err != nil {
log.Debugf("Error occurred handling ssh forwarded connection: %q", err)
}
}
return nil
})
}
return nil
}
func httpServe(ctx context.Context, g *errgroup.Group, ln net.Listener, mux http.Handler) {
g.Go(func() error {
<-ctx.Done()
return ln.Close()
})
g.Go(func() error {
s := &http.Server{
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
err := s.Serve(ln)
if err != nil {
if err != http.ErrServerClosed {
return err
}
return err
}
return nil
})
}
func withProfiler(vn *virtualnetwork.VirtualNetwork) http.Handler {
mux := vn.Mux()
if debug {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
}
return mux
}
func exitWithError(err error) {
log.Error(err)
os.Exit(1)
}
func searchDomains() []string {
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" {
f, err := os.Open("/etc/resolv.conf")
if err != nil {
log.Errorf("open file error: %v", err)
return nil
}
defer f.Close()
sc := bufio.NewScanner(f)
searchPrefix := "search "
for sc.Scan() {
if strings.HasPrefix(sc.Text(), searchPrefix) {
searchDomains := strings.Split(strings.TrimPrefix(sc.Text(), searchPrefix), " ")
log.Debugf("Using search domains: %v", searchDomains)
return searchDomains
}
}
if err := sc.Err(); err != nil {
log.Errorf("scan file error: %v", err)
return nil
}
}
return nil
}