mirror of
https://github.com/containers/gvisor-tap-vsock.git
synced 2025-10-08 10:10:12 +08:00

this commit fixes a potential race condition that prevented the tests to succeed when running in a github workflow. Basically the thread id was not actually available before writing it on the file, resulting in a thread id equals to 0 written in it. So, when the tests were trying to retrieve the thread id to use it to send the WM_QUIT signal, they failed. This patch adds a check on the thread id before writing it on the file. Now, if the thread id is 0, it keeps calling winquit to retrieve it. If, after 10 secs, there is no success it returns an error. Signed-off-by: lstocchi <lstocchi@redhat.com>
219 lines
4.6 KiB
Go
219 lines
4.6 KiB
Go
//go:build windows
|
|
// +build windows
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/containers/gvisor-tap-vsock/pkg/sshclient"
|
|
"github.com/containers/gvisor-tap-vsock/pkg/types"
|
|
"github.com/containers/gvisor-tap-vsock/pkg/utils"
|
|
"github.com/containers/winquit/pkg/winquit"
|
|
"github.com/sirupsen/logrus"
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/sys/windows/svc/eventlog"
|
|
)
|
|
|
|
const (
|
|
ERR_BAD_ARGS = 0x000A
|
|
WM_QUIT = 0x12
|
|
)
|
|
|
|
var (
|
|
stateDir string
|
|
debug bool
|
|
)
|
|
|
|
func main() {
|
|
args := os.Args
|
|
if len(args) > 1 {
|
|
switch args[1] {
|
|
case "-version":
|
|
version := types.NewVersion("win-sshproxy")
|
|
fmt.Println(version.String())
|
|
os.Exit(0)
|
|
case "-debug":
|
|
debug = true
|
|
args = args[2:]
|
|
default:
|
|
args = args[1:]
|
|
}
|
|
}
|
|
|
|
if len(args) < 5 || (len(args)-2)%3 != 0 {
|
|
alert("Usage: " + filepath.Base(os.Args[0]) + "(-debug) [name] [statedir] ([source] [dest] [identity])... \n\nThis facilty proxies windows pipes and unix sockets over ssh using the specified identity.")
|
|
os.Exit(ERR_BAD_ARGS)
|
|
}
|
|
|
|
log, err := setupLogging(args[0])
|
|
if err != nil {
|
|
os.Exit(1)
|
|
}
|
|
defer log.Close()
|
|
|
|
stateDir = args[1]
|
|
|
|
var sources, dests, identities []string
|
|
for i := 2; i < len(args)-2; i += 3 {
|
|
sources = append(sources, args[i])
|
|
dests = append(dests, args[i+1])
|
|
identities = append(identities, args[i+2])
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
group, ctx := errgroup.WithContext(ctx)
|
|
|
|
quit := make(chan bool, 1)
|
|
// Wait for a WM_QUIT message to exit
|
|
winquit.NotifyOnQuit(quit)
|
|
go func() {
|
|
<-quit
|
|
cancel()
|
|
}()
|
|
|
|
// Save thread for legacy callers which use it to post a quit
|
|
if _, err := saveThreadId(); err != nil {
|
|
logrus.Errorf("Error saving thread id: " + err.Error())
|
|
}
|
|
|
|
logrus.Debug("Setting up proxies")
|
|
setupProxies(ctx, group, sources, dests, identities)
|
|
|
|
// Wait for cmopletion (cancellation) or error
|
|
if err := group.Wait(); err != nil {
|
|
logrus.Errorf("Error occured in execution group: " + err.Error())
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func setupLogging(name string) (*eventlog.Log, error) {
|
|
// Reuse the Built-in .NET Runtime Source so that we do not
|
|
// have to provide a messaage table and modify the system
|
|
// event configuration
|
|
log, err := eventlog.Open(".NET Runtime")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logrus.AddHook(NewEventHook(log, name))
|
|
if debug {
|
|
logrus.SetLevel(logrus.DebugLevel)
|
|
} else {
|
|
logrus.SetLevel(logrus.InfoLevel)
|
|
}
|
|
|
|
return log, nil
|
|
}
|
|
|
|
func setupProxies(ctx context.Context, g *errgroup.Group, sources []string, dests []string, identities []string) error {
|
|
for i := 0; i < len(sources); i++ {
|
|
var (
|
|
src *url.URL
|
|
dest *url.URL
|
|
err error
|
|
)
|
|
if strings.Contains(sources[i], "://") {
|
|
src, err = url.Parse(sources[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
src = &url.URL{
|
|
Scheme: "unix",
|
|
Path: sources[i],
|
|
}
|
|
}
|
|
|
|
dest, err = url.Parse(dests[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
j := i
|
|
g.Go(func() error {
|
|
forward, err := sshclient.CreateSSHForward(ctx, src, dest, identities[j], nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
go func() {
|
|
<-ctx.Done()
|
|
// Abort pending accepts
|
|
forward.Close()
|
|
}()
|
|
loop:
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
break loop
|
|
default:
|
|
// proceed
|
|
}
|
|
err := forward.AcceptAndTunnel(ctx)
|
|
if err != nil {
|
|
logrus.Debugf("Error occurred handling ssh forwarded connection: %q", err)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func saveThreadId() (uint32, error) {
|
|
path := filepath.Join(stateDir, "win-sshproxy.tid")
|
|
file, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer file.Close()
|
|
|
|
tid, err := getThreadId()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
fmt.Fprintf(file, "%d:%d\n", os.Getpid(), tid)
|
|
return tid, nil
|
|
}
|
|
|
|
func getThreadId() (uint32, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
getTid := func() (uint32, error) {
|
|
tid := winquit.GetCurrentMessageLoopThreadId()
|
|
if tid != 0 {
|
|
return tid, nil
|
|
}
|
|
return 0, fmt.Errorf("failed to get thread ID")
|
|
}
|
|
|
|
return utils.Retry(ctx, getTid, "Waiting for message loop thread id")
|
|
}
|
|
|
|
// Creates an "error" style pop-up window
|
|
func alert(caption string) int {
|
|
// Error box style
|
|
format := 0x10
|
|
|
|
user32 := syscall.NewLazyDLL("user32.dll")
|
|
captionPtr, _ := syscall.UTF16PtrFromString(caption)
|
|
titlePtr, _ := syscall.UTF16PtrFromString("winpath")
|
|
ret, _, _ := user32.NewProc("MessageBoxW").Call(
|
|
uintptr(0),
|
|
uintptr(unsafe.Pointer(captionPtr)),
|
|
uintptr(unsafe.Pointer(titlePtr)),
|
|
uintptr(format))
|
|
|
|
return int(ret)
|
|
}
|