mirror of
				https://github.com/Ascend/ascend-docker-runtime.git
				synced 2025-10-31 23:32:34 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			172 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			172 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /**
 | ||
|  * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.
 | ||
|  * Description: ascend-docker-runtime工具,配置容器挂载Ascend NPU设备
 | ||
|  */
 | ||
| package main
 | ||
| 
 | ||
| import (
 | ||
| 	"encoding/json"
 | ||
| 	"fmt"
 | ||
| 	"io/ioutil"
 | ||
| 	"log"
 | ||
| 	"os"
 | ||
| 	"os/exec"
 | ||
| 	"strings"
 | ||
| 	"syscall"
 | ||
| 
 | ||
| 	"github.com/opencontainers/runtime-spec/specs-go"
 | ||
| )
 | ||
| 
 | ||
| const (
 | ||
| 	loggingPrefix       = "ascend-docker-runtime"
 | ||
| 	hookCli             = "ascend-docker-hook"
 | ||
| 	hookDefaultFilePath = "/usr/bin/ascend-docker-hook"
 | ||
| 	dockerRuncFile      = "docker-runc"
 | ||
| 	runcFile            = "runc"
 | ||
| )
 | ||
| 
 | ||
| var (
 | ||
| 	hookCliPath         = hookCli
 | ||
| 	hookDefaultFile     = hookDefaultFilePath
 | ||
| 	dockerRuncName      = dockerRuncFile
 | ||
| 	runcName            = runcFile
 | ||
| )
 | ||
| 
 | ||
| type args struct {
 | ||
| 	bundleDirPath string
 | ||
| 	cmd           string
 | ||
| }
 | ||
| 
 | ||
| func getArgs() (*args, error) {
 | ||
| 	args := &args{}
 | ||
| 
 | ||
| 	for i, param := range os.Args {
 | ||
| 		if param == "--bundle" || param == "-b" {
 | ||
| 			if len(os.Args)-i <= 1 {
 | ||
| 				return nil, fmt.Errorf("bundle option needs an argument")
 | ||
| 			}
 | ||
| 			args.bundleDirPath = os.Args[i+1]
 | ||
| 		} else if param == "create" {
 | ||
| 			args.cmd = param
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	return args, nil
 | ||
| }
 | ||
| 
 | ||
| var execRunc = func() error {
 | ||
| 	runcPath, err := exec.LookPath(dockerRuncName)
 | ||
| 	if err != nil {
 | ||
| 		runcPath, err = exec.LookPath(runcName)
 | ||
| 		if err != nil {
 | ||
| 			return fmt.Errorf("failed to find the path of runc: %w", err)
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	if err = syscall.Exec(runcPath, append([]string{runcPath}, os.Args[1:]...), os.Environ()); err != nil {
 | ||
| 		return fmt.Errorf("failed to exec runc: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| func addHook(spec *specs.Spec) error {
 | ||
| 	path, err := exec.LookPath(hookCliPath)
 | ||
| 	if err != nil {
 | ||
| 		path = hookDefaultFile
 | ||
| 		if _, err = os.Stat(path); err != nil {
 | ||
| 			return fmt.Errorf("cannot find the hook")
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	if spec.Hooks == nil {
 | ||
| 		spec.Hooks = &specs.Hooks{}
 | ||
| 	} else if len(spec.Hooks.Prestart) != 0 {
 | ||
| 		for _, hook := range spec.Hooks.Prestart {
 | ||
| 			if !strings.Contains(hook.Path, hookCli) {
 | ||
| 				continue
 | ||
| 			}
 | ||
| 			return nil
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	spec.Hooks.Prestart = append(spec.Hooks.Prestart, specs.Hook{
 | ||
| 		Path: path,
 | ||
| 		Args: []string{path},
 | ||
| 	})
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| func modifySpecFile(path string) error {
 | ||
| 	stat, err := os.Stat(path)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("spec file doesnt exist %s: %w", path, err)
 | ||
| 	}
 | ||
| 
 | ||
| 	jsonFile, err := os.OpenFile(path, os.O_RDWR, stat.Mode())
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("cannot open oci spec file %s: %w", path, err)
 | ||
| 	}
 | ||
| 
 | ||
| 	defer jsonFile.Close()
 | ||
| 
 | ||
| 	jsonContent, err := ioutil.ReadAll(jsonFile)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("failed to read oci spec file %s: %w", path, err)
 | ||
| 	}
 | ||
| 
 | ||
| 	var spec specs.Spec
 | ||
| 	if err := json.Unmarshal(jsonContent, &spec); err != nil {
 | ||
| 		return fmt.Errorf("failed to unmarshal oci spec file %s: %w", path, err)
 | ||
| 	}
 | ||
| 
 | ||
| 	if err := addHook(&spec); err != nil {
 | ||
| 		return fmt.Errorf("failed to inject hook: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	jsonOutput, err := json.Marshal(spec)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("failed to marshal OCI spec file: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	if _, err := jsonFile.WriteAt(jsonOutput, 0); err != nil {
 | ||
| 		return fmt.Errorf("failed to write OCI spec file: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| func doProcess() error {
 | ||
| 	args, err := getArgs()
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("failed to get args: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	if args.cmd != "create" {
 | ||
| 		return execRunc()
 | ||
| 	}
 | ||
| 
 | ||
| 	if args.bundleDirPath == "" {
 | ||
| 		args.bundleDirPath, err = os.Getwd()
 | ||
| 		if err != nil {
 | ||
| 			return fmt.Errorf("failed to get current working dir: %w", err)
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	specFilePath := args.bundleDirPath + "/config.json"
 | ||
| 
 | ||
| 	if err := modifySpecFile(specFilePath); err != nil {
 | ||
| 		return fmt.Errorf("failed to modify spec file %s: %w", specFilePath, err)
 | ||
| 	}
 | ||
| 
 | ||
| 	return execRunc()
 | ||
| }
 | ||
| 
 | ||
| func main() {
 | ||
| 	log.SetPrefix(loggingPrefix)
 | ||
| 	if err := doProcess(); err != nil {
 | ||
| 		log.Fatal(err)
 | ||
| 	}
 | ||
| }
 | 
