mirror of
https://github.com/Ascend/ascend-docker-runtime.git
synced 2025-10-06 15:37:13 +08:00
200 lines
4.3 KiB
Go
200 lines
4.3 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"os/exec"
|
|
"path"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
|
|
"github.com/opencontainers/runtime-spec/specs-go"
|
|
)
|
|
|
|
const (
|
|
loggingPrefix = "ascend-docker-hook"
|
|
ascendVisibleDevices = "ASCEND_VISIBLE_DEVICES"
|
|
ascendDockerCli = "ascend-docker-cli"
|
|
defaultAscendDockerCli = "/usr/local/bin/ascend-docker-cli"
|
|
)
|
|
|
|
type containerConfig struct {
|
|
Pid int
|
|
Rootfs string
|
|
Env []string
|
|
}
|
|
|
|
func removeDuplication(devices []int) []int {
|
|
list := make([]int, 0, len(devices))
|
|
prev := -1
|
|
|
|
for _, device := range devices {
|
|
if device == prev {
|
|
continue
|
|
}
|
|
|
|
list = append(list, device)
|
|
prev = device
|
|
}
|
|
|
|
return list
|
|
}
|
|
|
|
func parseDevices(visibleDevices string) ([]int, error) {
|
|
devices := make([]int, 0)
|
|
|
|
for _, d := range strings.Split(visibleDevices, ",") {
|
|
d = strings.TrimSpace(d)
|
|
if strings.Contains(d, "-") {
|
|
borders := strings.Split(d, "-")
|
|
if len(borders) < 2 {
|
|
return nil, fmt.Errorf("invalid device range: %s", d)
|
|
}
|
|
|
|
borders[0] = strings.TrimSpace(borders[0])
|
|
borders[1] = strings.TrimSpace(borders[1])
|
|
|
|
left, err := strconv.Atoi(borders[0])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid left boarder range parameter: %s", borders[0])
|
|
}
|
|
|
|
right, err := strconv.Atoi(borders[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid right boarder range parameter: %s", borders[1])
|
|
}
|
|
|
|
if left > right {
|
|
return nil, fmt.Errorf("left boarder (%d) should not be larger than the right one(%d)", left, right)
|
|
}
|
|
|
|
for n := left; n <= right; n++ {
|
|
devices = append(devices, n)
|
|
}
|
|
} else {
|
|
n, err := strconv.Atoi(d)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid single device parameter: %s", d)
|
|
}
|
|
|
|
devices = append(devices, n)
|
|
}
|
|
}
|
|
|
|
sort.Slice(devices, func(i, j int) bool { return i < j })
|
|
return removeDuplication(devices), nil
|
|
}
|
|
|
|
func parseOciSpecFile(file string) (*specs.Spec, error) {
|
|
f, err := os.Open(file)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open the OCI config file: %s", file)
|
|
}
|
|
defer f.Close()
|
|
|
|
spec := new(specs.Spec)
|
|
if err := json.NewDecoder(f).Decode(spec); err != nil {
|
|
return nil, fmt.Errorf("failed to parse OCI config file: %s, caused by: %w", file, err)
|
|
}
|
|
|
|
if spec.Process == nil {
|
|
return nil, fmt.Errorf("invalid OCI spec for empty process")
|
|
}
|
|
|
|
if spec.Root == nil {
|
|
return nil, fmt.Errorf("invalid OCI spec for empty root")
|
|
}
|
|
|
|
return spec, nil
|
|
}
|
|
|
|
func getContainerConfig() (*containerConfig, error) {
|
|
state := new(specs.State)
|
|
decoder := json.NewDecoder(os.Stdin)
|
|
|
|
if err := decoder.Decode(state); err != nil {
|
|
return nil, fmt.Errorf("failed to parse the container's state")
|
|
}
|
|
|
|
configPath := path.Join(state.Bundle, "config.json")
|
|
ociSpec, err := parseOciSpecFile(configPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse OCI spec: %w", err)
|
|
}
|
|
|
|
ret := &containerConfig{
|
|
Pid: state.Pid,
|
|
Rootfs: ociSpec.Root.Path,
|
|
Env: ociSpec.Process.Env,
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func getValueByKey(data []string, key string) string {
|
|
for _, s := range data {
|
|
p := strings.SplitN(s, "=", 2)
|
|
if len(p) != 2 {
|
|
log.Panicln("environment error")
|
|
}
|
|
|
|
if p[0] == key {
|
|
return p[1]
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func doPrestartHook() error {
|
|
containerConfig, err := getContainerConfig()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get container config: %w", err)
|
|
}
|
|
|
|
visibleDevices := getValueByKey(containerConfig.Env, ascendVisibleDevices)
|
|
if visibleDevices == "" {
|
|
return fmt.Errorf("failed to get device setting")
|
|
}
|
|
|
|
devices, err := parseDevices(visibleDevices)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse device setting: %w", err)
|
|
}
|
|
|
|
cliPath, err := exec.LookPath(ascendDockerCli)
|
|
if err != nil {
|
|
_, err = os.Stat(defaultAscendDockerCli)
|
|
if err != nil {
|
|
return fmt.Errorf("could not found ascend docker cli\n")
|
|
}
|
|
|
|
cliPath = defaultAscendDockerCli
|
|
}
|
|
|
|
args := append([]string{cliPath},
|
|
"--devices", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(devices)), ","), "[]"),
|
|
"--pid", fmt.Sprintf("%d", containerConfig.Pid),
|
|
"--rootfs", containerConfig.Rootfs)
|
|
|
|
if err := syscall.Exec(cliPath, args, os.Environ()); err != nil {
|
|
return fmt.Errorf("failed to exec ascend-docker-cli %v: %w\n", args, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
log.SetPrefix(loggingPrefix)
|
|
flag.Parse()
|
|
|
|
if err := doPrestartHook(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|