mirror of
https://github.com/containers/gvisor-tap-vsock.git
synced 2025-09-26 21:01:42 +08:00
586 lines
15 KiB
Go
586 lines
15 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/pprof"
|
|
"net/url"
|
|
"os"
|
|
"os/signal"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/containers/gvisor-tap-vsock/pkg/net/stdio"
|
|
"github.com/containers/gvisor-tap-vsock/pkg/sshclient"
|
|
"github.com/containers/gvisor-tap-vsock/pkg/transport"
|
|
"github.com/containers/gvisor-tap-vsock/pkg/types"
|
|
"github.com/containers/gvisor-tap-vsock/pkg/virtualnetwork"
|
|
"github.com/containers/winquit/pkg/winquit"
|
|
"github.com/dustin/go-humanize"
|
|
"github.com/pkg/errors"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
var (
|
|
debug bool
|
|
mtu int
|
|
endpoints arrayFlags
|
|
vpnkitSocket string
|
|
qemuSocket string
|
|
bessSocket string
|
|
stdioSocket string
|
|
vfkitSocket string
|
|
forwardSocket arrayFlags
|
|
forwardDest arrayFlags
|
|
forwardUser arrayFlags
|
|
forwardIdentify arrayFlags
|
|
sshPort int
|
|
pidFile string
|
|
exitCode int
|
|
logFile string
|
|
servicesEndpoint string
|
|
)
|
|
|
|
const (
|
|
gatewayIP = "192.168.127.1"
|
|
sshHostPort = "192.168.127.2:22"
|
|
hostIP = "192.168.127.254"
|
|
host = "host"
|
|
gateway = "gateway"
|
|
)
|
|
|
|
func main() {
|
|
version := types.NewVersion("gvproxy")
|
|
version.AddFlag()
|
|
flag.Var(&endpoints, "listen", "control endpoint")
|
|
flag.BoolVar(&debug, "debug", false, "Print debug info")
|
|
flag.IntVar(&mtu, "mtu", 1500, "Set the MTU")
|
|
flag.IntVar(&sshPort, "ssh-port", 2222, "Port to access the guest virtual machine. Must be between 1024 and 65535")
|
|
flag.StringVar(&vpnkitSocket, "listen-vpnkit", "", "VPNKit socket to be used by Hyperkit")
|
|
flag.StringVar(&qemuSocket, "listen-qemu", "", "Socket to be used by Qemu")
|
|
flag.StringVar(&bessSocket, "listen-bess", "", "unixpacket socket to be used by Bess-compatible applications")
|
|
flag.StringVar(&stdioSocket, "listen-stdio", "", "accept stdio pipe")
|
|
flag.StringVar(&vfkitSocket, "listen-vfkit", "", "unixgram socket to be used by vfkit-compatible applications")
|
|
flag.Var(&forwardSocket, "forward-sock", "Forwards a unix socket to the guest virtual machine over SSH")
|
|
flag.Var(&forwardDest, "forward-dest", "Forwards a unix socket to the guest virtual machine over SSH")
|
|
flag.Var(&forwardUser, "forward-user", "SSH user to use for unix socket forward")
|
|
flag.Var(&forwardIdentify, "forward-identity", "Path to SSH identity key for forwarding")
|
|
flag.StringVar(&pidFile, "pid-file", "", "Generate a file with the PID in it")
|
|
flag.StringVar(&logFile, "log-file", "", "Output log messages (logrus) to a given file path")
|
|
flag.StringVar(&servicesEndpoint, "services", "", "Exposes the same HTTP API as the --listen flag, without the /connect endpoint")
|
|
flag.Parse()
|
|
|
|
if version.ShowVersion() {
|
|
fmt.Println(version.String())
|
|
os.Exit(0)
|
|
}
|
|
|
|
// If the user provides a log-file, we re-direct log messages
|
|
// from logrus to the file
|
|
if logFile != "" {
|
|
lf, err := os.Create(logFile)
|
|
if err != nil {
|
|
fmt.Printf("unable to open log file %s, exiting...\n", logFile)
|
|
os.Exit(1)
|
|
}
|
|
defer func() {
|
|
if err := lf.Close(); err != nil {
|
|
fmt.Printf("unable to close log-file: %q\n", err)
|
|
}
|
|
}()
|
|
log.SetOutput(lf)
|
|
|
|
// If debug is set, lets seed the log file with some basic information
|
|
// about the environment and how it was called
|
|
log.Debugf("gvproxy version: %q", version.String())
|
|
log.Debugf("os: %q arch: %q", runtime.GOOS, runtime.GOARCH)
|
|
log.Debugf("command line: %q", os.Args)
|
|
}
|
|
|
|
log.Info(version.String())
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
// Make this the last defer statement in the stack
|
|
defer os.Exit(exitCode)
|
|
|
|
groupErrs, ctx := errgroup.WithContext(ctx)
|
|
// Setup signal channel for catching user signals
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
|
|
|
if debug {
|
|
log.SetLevel(log.DebugLevel)
|
|
}
|
|
|
|
// Intercept WM_QUIT/WM_CLOSE events if on Windows as SIGTERM (noop on other OSs)
|
|
winquit.SimulateSigTermOnQuit(sigChan)
|
|
|
|
// Make sure the qemu socket provided is valid syntax
|
|
if len(qemuSocket) > 0 {
|
|
uri, err := url.Parse(qemuSocket)
|
|
if err != nil || uri == nil {
|
|
exitWithError(errors.Wrapf(err, "invalid value for listen-qemu"))
|
|
}
|
|
if _, err := os.Stat(uri.Path); err == nil && uri.Scheme == "unix" {
|
|
exitWithError(errors.Errorf("%q already exists", uri.Path))
|
|
}
|
|
}
|
|
if len(bessSocket) > 0 {
|
|
uri, err := url.Parse(bessSocket)
|
|
if err != nil || uri == nil {
|
|
exitWithError(errors.Wrapf(err, "invalid value for listen-bess"))
|
|
}
|
|
if uri.Scheme != "unixpacket" {
|
|
exitWithError(errors.New("listen-bess must be unixpacket:// address"))
|
|
}
|
|
if _, err := os.Stat(uri.Path); err == nil {
|
|
exitWithError(errors.Errorf("%q already exists", uri.Path))
|
|
}
|
|
}
|
|
if len(vfkitSocket) > 0 {
|
|
uri, err := url.Parse(vfkitSocket)
|
|
if err != nil || uri == nil {
|
|
exitWithError(errors.Wrapf(err, "invalid value for listen-vfkit"))
|
|
}
|
|
if uri.Scheme != "unixgram" {
|
|
exitWithError(errors.New("listen-vfkit must be unixgram:// address"))
|
|
}
|
|
if _, err := os.Stat(uri.Path); err == nil {
|
|
exitWithError(errors.Errorf("%q already exists", uri.Path))
|
|
}
|
|
}
|
|
|
|
if vpnkitSocket != "" && qemuSocket != "" {
|
|
exitWithError(errors.New("cannot use qemu and vpnkit protocol at the same time"))
|
|
}
|
|
if vpnkitSocket != "" && bessSocket != "" {
|
|
exitWithError(errors.New("cannot use bess and vpnkit protocol at the same time"))
|
|
}
|
|
if qemuSocket != "" && bessSocket != "" {
|
|
exitWithError(errors.New("cannot use qemu and bess protocol at the same time"))
|
|
}
|
|
|
|
// If the given port is not between the privileged ports
|
|
// and the oft considered maximum port, return an error.
|
|
if sshPort != -1 && sshPort < 1024 || sshPort > 65535 {
|
|
exitWithError(errors.New("ssh-port value must be between 1024 and 65535"))
|
|
}
|
|
protocol := types.HyperKitProtocol
|
|
if qemuSocket != "" {
|
|
protocol = types.QemuProtocol
|
|
}
|
|
if bessSocket != "" {
|
|
protocol = types.BessProtocol
|
|
}
|
|
if vfkitSocket != "" {
|
|
protocol = types.VfkitProtocol
|
|
}
|
|
|
|
if c := len(forwardSocket); c != len(forwardDest) || c != len(forwardUser) || c != len(forwardIdentify) {
|
|
exitWithError(errors.New("--forward-sock, --forward-dest, --forward-user, and --forward-identity must all be specified together, " +
|
|
"the same number of times, or not at all"))
|
|
}
|
|
|
|
for i := 0; i < len(forwardSocket); i++ {
|
|
_, err := os.Stat(forwardIdentify[i])
|
|
if err != nil {
|
|
exitWithError(errors.Wrapf(err, "Identity file %s can't be loaded", forwardIdentify[i]))
|
|
}
|
|
}
|
|
|
|
// Create a PID file if requested
|
|
if len(pidFile) > 0 {
|
|
f, err := os.Create(pidFile)
|
|
if err != nil {
|
|
exitWithError(err)
|
|
}
|
|
// Remove the pid-file when exiting
|
|
defer func() {
|
|
if err := os.Remove(pidFile); err != nil {
|
|
log.Error(err)
|
|
}
|
|
}()
|
|
pid := os.Getpid()
|
|
if _, err := f.WriteString(strconv.Itoa(pid)); err != nil {
|
|
exitWithError(err)
|
|
}
|
|
}
|
|
|
|
config := types.Configuration{
|
|
Debug: debug,
|
|
CaptureFile: captureFile(),
|
|
MTU: mtu,
|
|
Subnet: "192.168.127.0/24",
|
|
GatewayIP: gatewayIP,
|
|
GatewayMacAddress: "5a:94:ef:e4:0c:dd",
|
|
DHCPStaticLeases: map[string]string{
|
|
"192.168.127.2": "5a:94:ef:e4:0c:ee",
|
|
},
|
|
DNS: []types.Zone{
|
|
{
|
|
Name: "containers.internal.",
|
|
Records: []types.Record{
|
|
{
|
|
Name: gateway,
|
|
IP: net.ParseIP(gatewayIP),
|
|
},
|
|
{
|
|
Name: host,
|
|
IP: net.ParseIP(hostIP),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "docker.internal.",
|
|
Records: []types.Record{
|
|
{
|
|
Name: gateway,
|
|
IP: net.ParseIP(gatewayIP),
|
|
},
|
|
{
|
|
Name: host,
|
|
IP: net.ParseIP(hostIP),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
DNSSearchDomains: searchDomains(),
|
|
Forwards: getForwardsMap(sshPort, sshHostPort),
|
|
NAT: map[string]string{
|
|
hostIP: "127.0.0.1",
|
|
},
|
|
GatewayVirtualIPs: []string{hostIP},
|
|
VpnKitUUIDMacAddresses: map[string]string{
|
|
"c3d68012-0208-11ea-9fd7-f2189899ab08": "5a:94:ef:e4:0c:ee",
|
|
},
|
|
Protocol: protocol,
|
|
}
|
|
|
|
groupErrs.Go(func() error {
|
|
return run(ctx, groupErrs, &config, endpoints, servicesEndpoint)
|
|
})
|
|
|
|
// Wait for something to happen
|
|
groupErrs.Go(func() error {
|
|
select {
|
|
// Catch signals so exits are graceful and defers can run
|
|
case <-sigChan:
|
|
cancel()
|
|
return errors.New("signal caught")
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
})
|
|
// Wait for all of the go funcs to finish up
|
|
if err := groupErrs.Wait(); err != nil {
|
|
log.Errorf("gvproxy exiting: %v", err)
|
|
exitCode = 1
|
|
}
|
|
}
|
|
|
|
func getForwardsMap(sshPort int, sshHostPort string) map[string]string {
|
|
if sshPort == -1 {
|
|
return map[string]string{}
|
|
}
|
|
return map[string]string{
|
|
fmt.Sprintf("127.0.0.1:%d", sshPort): sshHostPort,
|
|
}
|
|
}
|
|
|
|
type arrayFlags []string
|
|
|
|
func (i *arrayFlags) String() string {
|
|
return "my string representation"
|
|
}
|
|
|
|
func (i *arrayFlags) Set(value string) error {
|
|
*i = append(*i, value)
|
|
return nil
|
|
}
|
|
|
|
func captureFile() string {
|
|
if !debug {
|
|
return ""
|
|
}
|
|
return "capture.pcap"
|
|
}
|
|
|
|
func run(ctx context.Context, g *errgroup.Group, configuration *types.Configuration, endpoints []string, servicesEndpoint string) error {
|
|
vn, err := virtualnetwork.New(configuration)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Info("waiting for clients...")
|
|
|
|
for _, endpoint := range endpoints {
|
|
log.Infof("listening %s", endpoint)
|
|
ln, err := transport.Listen(endpoint)
|
|
if err != nil {
|
|
return errors.Wrap(err, "cannot listen")
|
|
}
|
|
httpServe(ctx, g, ln, withProfiler(vn))
|
|
}
|
|
|
|
if servicesEndpoint != "" {
|
|
log.Infof("enabling services API. Listening %s", servicesEndpoint)
|
|
ln, err := transport.Listen(servicesEndpoint)
|
|
if err != nil {
|
|
return errors.Wrap(err, "cannot listen")
|
|
}
|
|
httpServe(ctx, g, ln, vn.ServicesMux())
|
|
}
|
|
|
|
ln, err := vn.Listen("tcp", fmt.Sprintf("%s:80", gatewayIP))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mux := http.NewServeMux()
|
|
mux.Handle("/services/forwarder/all", vn.Mux())
|
|
mux.Handle("/services/forwarder/expose", vn.Mux())
|
|
mux.Handle("/services/forwarder/unexpose", vn.Mux())
|
|
httpServe(ctx, g, ln, mux)
|
|
|
|
if debug {
|
|
g.Go(func() error {
|
|
debugLog:
|
|
for {
|
|
select {
|
|
case <-time.After(5 * time.Second):
|
|
log.Debugf("%v sent to the VM, %v received from the VM\n", humanize.Bytes(vn.BytesSent()), humanize.Bytes(vn.BytesReceived()))
|
|
case <-ctx.Done():
|
|
break debugLog
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
if vpnkitSocket != "" {
|
|
vpnkitListener, err := transport.Listen(vpnkitSocket)
|
|
if err != nil {
|
|
return errors.Wrap(err, "vpnkit listen error")
|
|
}
|
|
g.Go(func() error {
|
|
vpnloop:
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
break vpnloop
|
|
default:
|
|
// pass through
|
|
}
|
|
conn, err := vpnkitListener.Accept()
|
|
if err != nil {
|
|
log.Errorf("vpnkit accept error: %s", err)
|
|
continue
|
|
}
|
|
g.Go(func() error {
|
|
return vn.AcceptVpnKit(conn)
|
|
})
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
if qemuSocket != "" {
|
|
qemuListener, err := transport.Listen(qemuSocket)
|
|
if err != nil {
|
|
return errors.Wrap(err, "qemu listen error")
|
|
}
|
|
|
|
g.Go(func() error {
|
|
<-ctx.Done()
|
|
if err := qemuListener.Close(); err != nil {
|
|
log.Errorf("error closing %s: %q", qemuSocket, err)
|
|
}
|
|
return os.Remove(qemuSocket)
|
|
})
|
|
|
|
g.Go(func() error {
|
|
conn, err := qemuListener.Accept()
|
|
if err != nil {
|
|
return errors.Wrap(err, "qemu accept error")
|
|
}
|
|
return vn.AcceptQemu(ctx, conn)
|
|
})
|
|
}
|
|
|
|
if bessSocket != "" {
|
|
bessListener, err := transport.Listen(bessSocket)
|
|
if err != nil {
|
|
return errors.Wrap(err, "bess listen error")
|
|
}
|
|
|
|
g.Go(func() error {
|
|
<-ctx.Done()
|
|
if err := bessListener.Close(); err != nil {
|
|
log.Errorf("error closing %s: %q", bessSocket, err)
|
|
}
|
|
return os.Remove(bessSocket)
|
|
})
|
|
|
|
g.Go(func() error {
|
|
conn, err := bessListener.Accept()
|
|
if err != nil {
|
|
return errors.Wrap(err, "bess accept error")
|
|
|
|
}
|
|
return vn.AcceptBess(ctx, conn)
|
|
})
|
|
}
|
|
|
|
if vfkitSocket != "" {
|
|
conn, err := transport.ListenUnixgram(vfkitSocket)
|
|
if err != nil {
|
|
return errors.Wrap(err, "vfkit listen error")
|
|
}
|
|
|
|
g.Go(func() error {
|
|
<-ctx.Done()
|
|
if err := conn.Close(); err != nil {
|
|
log.Errorf("error closing %s: %q", vfkitSocket, err)
|
|
}
|
|
return os.Remove(vfkitSocket)
|
|
})
|
|
|
|
g.Go(func() error {
|
|
vfkitConn, err := transport.AcceptVfkit(conn)
|
|
if err != nil {
|
|
return errors.Wrap(err, "vfkit accept error")
|
|
}
|
|
return vn.AcceptVfkit(ctx, vfkitConn)
|
|
})
|
|
}
|
|
|
|
if stdioSocket != "" {
|
|
g.Go(func() error {
|
|
conn := stdio.GetStdioConn()
|
|
return vn.AcceptStdio(ctx, conn)
|
|
})
|
|
}
|
|
|
|
for i := 0; i < len(forwardSocket); i++ {
|
|
var (
|
|
src *url.URL
|
|
err error
|
|
)
|
|
if strings.Contains(forwardSocket[i], "://") {
|
|
src, err = url.Parse(forwardSocket[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
src = &url.URL{
|
|
Scheme: "unix",
|
|
Path: forwardSocket[i],
|
|
}
|
|
}
|
|
|
|
dest := &url.URL{
|
|
Scheme: "ssh",
|
|
User: url.User(forwardUser[i]),
|
|
Host: sshHostPort,
|
|
Path: forwardDest[i],
|
|
}
|
|
j := i
|
|
g.Go(func() error {
|
|
defer os.Remove(forwardSocket[j])
|
|
forward, err := sshclient.CreateSSHForward(ctx, src, dest, forwardIdentify[j], vn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
go func() {
|
|
<-ctx.Done()
|
|
// Abort pending accepts
|
|
forward.Close()
|
|
}()
|
|
loop:
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
break loop
|
|
default:
|
|
// proceed
|
|
}
|
|
err := forward.AcceptAndTunnel(ctx)
|
|
if err != nil {
|
|
log.Debugf("Error occurred handling ssh forwarded connection: %q", err)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func httpServe(ctx context.Context, g *errgroup.Group, ln net.Listener, mux http.Handler) {
|
|
g.Go(func() error {
|
|
<-ctx.Done()
|
|
return ln.Close()
|
|
})
|
|
g.Go(func() error {
|
|
s := &http.Server{
|
|
Handler: mux,
|
|
ReadTimeout: 10 * time.Second,
|
|
WriteTimeout: 10 * time.Second,
|
|
}
|
|
err := s.Serve(ln)
|
|
if err != nil {
|
|
if err != http.ErrServerClosed {
|
|
return err
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func withProfiler(vn *virtualnetwork.VirtualNetwork) http.Handler {
|
|
mux := vn.Mux()
|
|
if debug {
|
|
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
|
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
|
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
|
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
|
}
|
|
return mux
|
|
}
|
|
|
|
func exitWithError(err error) {
|
|
log.Error(err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
func searchDomains() []string {
|
|
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" {
|
|
f, err := os.Open("/etc/resolv.conf")
|
|
if err != nil {
|
|
log.Errorf("open file error: %v", err)
|
|
return nil
|
|
}
|
|
defer f.Close()
|
|
sc := bufio.NewScanner(f)
|
|
searchPrefix := "search "
|
|
for sc.Scan() {
|
|
if strings.HasPrefix(sc.Text(), searchPrefix) {
|
|
searchDomains := strings.Split(strings.TrimPrefix(sc.Text(), searchPrefix), " ")
|
|
log.Debugf("Using search domains: %v", searchDomains)
|
|
return searchDomains
|
|
}
|
|
}
|
|
if err := sc.Err(); err != nil {
|
|
log.Errorf("scan file error: %v", err)
|
|
return nil
|
|
}
|
|
}
|
|
return nil
|
|
}
|