diff --git a/cli/src/main.c b/cli/src/main.c index 686cb2f..03c5a36 100644 --- a/cli/src/main.c +++ b/cli/src/main.c @@ -41,6 +41,12 @@ struct CmdArgs { int pid; }; +struct ParsedConfig { + char containerNsPath[BUF_SIZE]; + char cgroupPath[BUF_SIZE]; + int originNsFd; +}; + static inline bool IsCmdArgsValid(struct CmdArgs *args) { return (args->devices != NULL) && (args->rootfs != NULL) && (args->pid > 0); @@ -148,7 +154,7 @@ static int MountDevice(const char *rootfs, const int serialNumber) return 0; } -static int DoMount(const char *rootfs, const char *devicesList) +static int DoDeviceMounting(const char *rootfs, const char *devicesList) { static const char *sep = ","; char list[BUF_SIZE] = {0}; @@ -281,7 +287,7 @@ static int MountFiles(const char *rootfs, const char *file, unsigned long reMoun return 0; } -static int DoMountFiles(const char *rootfs) +static int DoCtrlDeviceMounting(const char *rootfs) { /* device */ unsigned long reMountRwFlag = MS_BIND | MS_REMOUNT | MS_RDONLY | MS_NOSUID | MS_NOEXEC; @@ -303,6 +309,25 @@ static int DoMountFiles(const char *rootfs) return 0; } +static int DoMounting(const struct CmdArgs *args) +{ + int ret; + + ret = DoDeviceMounting(args->rootfs, args->devices); + if (ret < 0) { + fprintf(stderr, "error: failed to do mounts\n"); + return -1; + } + + ret = DoCtrlDeviceMounting(args->rootfs); + if (ret < 0) { + fprintf(stderr, "error: failed to do mount files\n"); + return -1; + } + + return 0; +} + typedef char *(*ParseFileLine)(char *, const char *); int IsStrEqual(const char *s1, const char *s2) @@ -315,11 +340,17 @@ int StrHasPrefix(const char *str, const char *prefix) return (!strncmp(str, prefix, strlen(prefix))); } +static bool IsCgroupLineArgsValid(const char *rootDir, const char *mountPoint, const char* fsType, const char* substr) +{ + return ((rootDir != NULL && mountPoint != NULL && fsType != NULL && substr != NULL) && + (*rootDir != '\0' && *mountPoint != '\0' && *fsType != '\0' && *substr != '\0')); +} + char *GetCgroupMount(char *line, const char *subsys) { int i; - char *rootDir = NULL; + char *rootDir = NULL; for (i = 0; i < ROOT_GAP; ++i) { /* root is substr before gap, line is substr after gap */ rootDir = strsep(&line, " "); @@ -328,6 +359,7 @@ char *GetCgroupMount(char *line, const char *subsys) char *mountPoint = NULL; mountPoint = strsep(&line, " "); line = strchr(line, '-'); + char* fsType = NULL; for (i = 0; i < FSTYPE_GAP; ++i) { fsType = strsep(&line, " "); @@ -338,27 +370,23 @@ char *GetCgroupMount(char *line, const char *subsys) substr = strsep(&line, " "); } - if (rootDir == NULL || mountPoint == NULL || fsType == NULL || substr == NULL) { - return (NULL); - } - - if (*rootDir == '\0' || *mountPoint == '\0' || *fsType == '\0' || *substr == '\0') { - return (NULL); + if (!IsCgroupLineArgsValid(rootDir, mountPoint, fsType, substr)) { + return NULL; } if (strlen(rootDir) >= BUF_SIZE || StrHasPrefix(rootDir, "/..")) { - return (NULL); + return NULL; } if (strstr(substr, subsys) == NULL) { - return (NULL); + return NULL; } if (!IsStrEqual(fsType, "cgroup")) { - return (NULL); + return NULL; } - return (mountPoint); + return mountPoint; } char *GetCgroupRoot(char *line, const char *subSystem) @@ -548,89 +576,96 @@ int SetupCgroup(struct CmdArgs *args, const char *cgroupPath) return 0; } -static int SetupMounts(struct CmdArgs *args) +static int DoPrepare(const struct CmdArgs *args, struct ParsedConfig *config) { int ret; - char cgroupPath[BUF_SIZE] = {0}; - char containerNsPath[BUF_SIZE] = {0}; - ret = GetNsPath(args->pid, "mnt", containerNsPath, BUF_SIZE); + ret = GetNsPath(args->pid, "mnt", config->containerNsPath, BUF_SIZE); if (ret < 0) { fprintf(stderr, "error: failed to get container mnt ns path: pid(%d)\n", args->pid); - return ret; + return -1; + } + + ret = GetCgroupPath(args, config->cgroupPath, BUF_SIZE); + if (ret < 0) { + fprintf(stderr, "error: failed to get cgroup path\n"); + return -1; } char originNsPath[BUF_SIZE] = {0}; ret = GetSelfNsPath("mnt", originNsPath, BUF_SIZE); if (ret < 0) { fprintf(stderr, "error: failed to get self ns path\n"); - return ret; + return -1; } - int originNsFd = open((const char *)originNsPath, O_RDONLY); - if (originNsFd < 0) { + config->originNsFd = open((const char *)originNsPath, O_RDONLY); + if (config->originNsFd < 0) { fprintf(stderr, "error: failed to get self ns fd: %s\n", originNsPath); return -1; } - ret = GetCgroupPath(args, cgroupPath, BUF_SIZE); + return 0; +} + +static int SetupMounts(struct CmdArgs *args) +{ + int ret; + struct ParsedConfig config; + + ret = DoPrepare(args, &config); if (ret < 0) { - fprintf(stderr, "error: failed to get cgroup path\n"); + fprintf(stderr, "error: failed to prepare nesessary config\n"); return -1; } // enter container's mount namespace - ret = EnterNsByPath((const char *) containerNsPath, CLONE_NEWNS); + ret = EnterNsByPath((const char *)config.containerNsPath, CLONE_NEWNS); if (ret < 0) { - fprintf(stderr, "error: failed to set to container ns: %s\n", containerNsPath); - close(originNsFd); + fprintf(stderr, "error: failed to set to container ns: %s\n", config.containerNsPath); + close(config.originNsFd); return -1; } - ret = DoMount(args->rootfs, args->devices); + ret = DoMounting(args); if (ret < 0) { - fprintf(stderr, "error: failed to do mounts\n"); - close(originNsFd); + fprintf(stderr, "error: failed to do mounting\n"); + close(config.originNsFd); return -1; } - ret = DoMountFiles(args->rootfs); - if (ret < 0) { - fprintf(stderr, "error: failed to do mount files\n"); - close(originNsFd); - return -1; - } - - ret = SetupCgroup(args, (const char *)cgroupPath); + ret = SetupCgroup(args, (const char *)config.cgroupPath); if (ret < 0) { fprintf(stderr, "error: failed to set up cgroup\n"); - close(originNsFd); + close(config.originNsFd); return -1; } // back to original namespace - ret = EnterNsByFd(originNsFd, CLONE_NEWNS); + ret = EnterNsByFd(config.originNsFd, CLONE_NEWNS); if (ret < 0) { fprintf(stderr, "error: failed to set ns back\n"); - close(originNsFd); + close(config.originNsFd); return -1; } - close(originNsFd); + close(config.originNsFd); return 0; } #ifdef gtest -int _main(int argc, char **argv) { +int _main(int argc, char **argv) +{ #else -int main(int argc, char **argv) { +int main(int argc, char **argv) +{ #endif int c; int optionIndex; struct CmdArgs args = { - .devices = NULL, - .rootfs = NULL, - .pid = -1 + .devices = NULL, + .rootfs = NULL, + .pid = -1 }; while ((c = getopt_long(argc, argv, "d:p:r", g_opts, &optionIndex)) != -1) { diff --git a/cli/test/testcase/gtest_mytestcase.cpp b/cli/test/testcase/gtest_mytestcase.cpp index fe65da8..119140e 100644 --- a/cli/test/testcase/gtest_mytestcase.cpp +++ b/cli/test/testcase/gtest_mytestcase.cpp @@ -1,12 +1,13 @@ -// Demo.cpp : Defines the entry point for the console application. -// +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. + * Description: 测试集 +*/ #include #include #include "gtest/gtest.h" #include "mockcpp/mockcpp.hpp" using namespace std; -//建议这样引用,避免下面用关键字时需要加前缀 testing:: using namespace testing; extern "C" int IsStrEqual(const char *s1, const char *s2); @@ -16,11 +17,10 @@ extern "C" int EnterNsByFd(int fd, int nsType); int stub_setns(int fd, int nstype) { - return 0; + return 0; } -class Test_Fhho : public Test -{ +class Test_Fhho : public Test { protected: static void SetUpTestCase() { @@ -40,7 +40,7 @@ protected: cout << "TestSuite测试用例事件:在每个testcase之后执行" << endl; } }; - + TEST_F(Test_Fhho, ClassEQ1) { EXPECT_EQ(1, IsStrEqual("", "")); @@ -48,13 +48,13 @@ TEST_F(Test_Fhho, ClassEQ1) #if 0 TEST_F(Test_Fhho, ClassEQ2) -{ - int pid = 1; - char* nsType = "mnt"; - char buf[100] = {0x0}; - int bufSize = 100; - int ret = GetNsPath(pid, nsType, buf, 100); - EXPECT_EQ(1, ret); +{ + int pid = 1; + char* nsType = "mnt"; + char buf[100] = {0x0}; + int bufSize = 100; + int ret = GetNsPath(pid, nsType, buf, 100); + EXPECT_EQ(1, ret); } TEST_F(Test_Fhho, ClassEQ3) diff --git a/cli/test/testcase/main.cpp b/cli/test/testcase/main.cpp index aff34d8..2dd952f 100644 --- a/cli/test/testcase/main.cpp +++ b/cli/test/testcase/main.cpp @@ -1,40 +1,21 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. + * Description: 测试框架主函数 +*/ #include #include #include "gtest/gtest.h" #include "mockcpp/mockcpp.hpp" -//#include "gtest_testcase.cpp" -//#include "mockcpp_testcase.cpp" - using namespace std; -//建议这样引用,避免下面用关键字时需要加前缀 testing:: using namespace testing; - int main(int argc, char* argv[], char* evn[]) { - //std::vector g_func1 = GET_FUNC_CTOR_LIST(); - //全局事件:设置执行全局事件 - //ddGlobalTestEnvironment(new FooEnvironment); - - //输出 用例列表,用例不执行了~ - //testing::GTEST_FLAG(list_tests) = " "; - //设置过滤功能后,参数化功能失效~~~~//执行列出来的测试套的用例 - //testing::GTEST_FLAG(filter) = "EXEPath.*";//"FooTest.*:TestCase.*:TestSuite.*:TestCaseTest.*:IsPrimeParamTest.*"; - - //测试套排序,下面两种情况不能同时使用,否则排序就无作用 - //GTEST_FLAG(list_order) = "Test_Fhho;UT_DEMO;TestSuitName;FuncFoo;TestSuitEvent"; - //测试套模糊匹配排序,注:只以开头进行精确匹配,遇到 * 后模糊匹配 - //如UT_*;IT_*,先执行所有UT_开头的用例再执行IT_开头的用例 - /*GTEST_FLAG(dark_list_order) = "UT_*;\ - IT_*";*/ - // Returns 0 if all tests passed, or 1 other wise. - int ret = Init_UT(argc, argv, true); - if (1 == ret) - { + int ret = Init_UT(argc, argv, true); + if (1 == ret) { printf("有用例错误,请按任意键继续。。。"); - //getchar(); } - return ret; + return ret; } diff --git a/hook/main.go b/hook/main.go index 01a0ebf..8b9b606 100644 --- a/hook/main.go +++ b/hook/main.go @@ -1,3 +1,7 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. + * Description: ascend-docker-hook工具,配置容器挂载Ascend NPU设备 + */ package main import ( @@ -21,12 +25,15 @@ const ( ascendVisibleDevices = "ASCEND_VISIBLE_DEVICES" ascendDockerCli = "ascend-docker-cli" defaultAscendDockerCli = "/usr/local/bin/ascend-docker-cli" + + borderNum = 2 + kvPairSize = 2 ) type containerConfig struct { Pid int Rootfs string - Env []string + Env []string } func removeDuplication(devices []int) []int { @@ -52,7 +59,7 @@ func parseDevices(visibleDevices string) ([]int, error) { d = strings.TrimSpace(d) if strings.Contains(d, "-") { borders := strings.Split(d, "-") - if len(borders) < 2 { + if len(borders) < borderNum { return nil, fmt.Errorf("invalid device range: %s", d) } @@ -130,7 +137,7 @@ func getContainerConfig() (*containerConfig, error) { ret := &containerConfig{ Pid: state.Pid, Rootfs: ociSpec.Root.Path, - Env: ociSpec.Process.Env, + Env: ociSpec.Process.Env, } return ret, nil @@ -139,7 +146,7 @@ func getContainerConfig() (*containerConfig, error) { func getValueByKey(data []string, key string) string { for _, s := range data { p := strings.SplitN(s, "=", 2) - if len(p) != 2 { + if len(p) != kvPairSize { log.Panicln("environment error") } @@ -171,7 +178,7 @@ func doPrestartHook() error { if err != nil { _, err = os.Stat(defaultAscendDockerCli) if err != nil { - return fmt.Errorf("could not found ascend docker cli\n") + return fmt.Errorf("could not found ascend docker cli") } cliPath = defaultAscendDockerCli @@ -183,7 +190,7 @@ func doPrestartHook() error { "--rootfs", containerConfig.Rootfs) if err := syscall.Exec(cliPath, args, os.Environ()); err != nil { - return fmt.Errorf("failed to exec ascend-docker-cli %v: %w\n", args, err) + return fmt.Errorf("failed to exec ascend-docker-cli %v: %w", args, err) } return nil diff --git a/hook/main_test.go b/hook/main_test.go index 4957840..a5e982d 100644 --- a/hook/main_test.go +++ b/hook/main_test.go @@ -1,3 +1,7 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. + * Description: hook main 函数单元测试 +*/ package main import ( @@ -7,9 +11,10 @@ import ( func TestRemoveDuplication(t *testing.T) { originList := []int {1,2,2,4,5,5,5,6,8,8} + targetList := []int {1,2,4,5,6,8} resultList := removeDuplication(originList) - if !reflect.DeepEqual(resultList, []int {1,2,4,5,6,8}) { + if !reflect.DeepEqual(resultList, targetList) { t.Fail() } } diff --git a/runtime/main.go b/runtime/main.go index bfa452c..5330647 100644 --- a/runtime/main.go +++ b/runtime/main.go @@ -1,3 +1,7 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. + * Description: ascend-docker-runtime工具,配置容器挂载Ascend NPU设备 + */ package main import ( @@ -30,7 +34,7 @@ func getArgs() (*args, error) { for i, param := range os.Args { if param == "--bundle" || param == "-b" { if len(os.Args)-i <= 1 { - return nil, fmt.Errorf("bundle option needs an argument\n") + return nil, fmt.Errorf("bundle option needs an argument") } args.bundleDirPath = os.Args[i+1] } else if param == "create" { @@ -46,12 +50,12 @@ func execRunc() error { if err != nil { runcPath, err = exec.LookPath("runc") if err != nil { - return fmt.Errorf("failed to find the path of runc: %w\n", err) + return fmt.Errorf("failed to find the path of runc: %w", err) } } if err = syscall.Exec(runcPath, append([]string{runcPath}, os.Args[1:]...), os.Environ()); err != nil { - return fmt.Errorf("failed to exec runc: %w\n", err) + return fmt.Errorf("failed to exec runc: %w", err) } return nil @@ -62,7 +66,7 @@ func addHook(spec *specs.Spec) error { if err != nil { path = hookDefaultFilePath if _, err = os.Stat(path); err != nil { - return fmt.Errorf("cannot find the hook\n") + return fmt.Errorf("cannot find the hook") } } @@ -93,32 +97,32 @@ func modifySpecFile(path string) error { jsonFile, err := os.OpenFile(path, os.O_RDWR, stat.Mode()) if err != nil { - return fmt.Errorf("cannot open oci spec file %s: %w\n", path, err) + return fmt.Errorf("cannot open oci spec file %s: %w", path, err) } defer jsonFile.Close() jsonContent, err := ioutil.ReadAll(jsonFile) if err != nil { - return fmt.Errorf("failed to read oci spec file %s: %w\n", path, err) + return fmt.Errorf("failed to read oci spec file %s: %w", path, err) } var spec specs.Spec if err := json.Unmarshal(jsonContent, &spec); err != nil { - return fmt.Errorf("failed to unmarshal oci spec file %s: %w\n", path, err) + return fmt.Errorf("failed to unmarshal oci spec file %s: %w", path, err) } if err := addHook(&spec); err != nil { - return fmt.Errorf("failed to inject hook: %w\n", err) + return fmt.Errorf("failed to inject hook: %w", err) } jsonOutput, err := json.Marshal(spec) if err != nil { - return fmt.Errorf("failed to marshal OCI spec file: %w\n", err) + return fmt.Errorf("failed to marshal OCI spec file: %w", err) } if _, err := jsonFile.WriteAt(jsonOutput, 0); err != nil { - return fmt.Errorf("failed to write OCI spec file: %w\n", err) + return fmt.Errorf("failed to write OCI spec file: %w", err) } return nil @@ -127,7 +131,7 @@ func modifySpecFile(path string) error { func doProcess() error { args, err := getArgs() if err != nil { - return fmt.Errorf("failed to get args: %w\n", err) + return fmt.Errorf("failed to get args: %w", err) } if args.cmd != "create" { @@ -137,14 +141,14 @@ func doProcess() error { if args.bundleDirPath == "" { args.bundleDirPath, err = os.Getwd() if err != nil { - return fmt.Errorf("failed to get current working dir: %w\n", err) + return fmt.Errorf("failed to get current working dir: %w", err) } } specFilePath := args.bundleDirPath + "/config.json" if err := modifySpecFile(specFilePath); err != nil { - return fmt.Errorf("failed to modify spec file %s: %w\n", specFilePath, err) + return fmt.Errorf("failed to modify spec file %s: %w", specFilePath, err) } return execRunc()