diff --git a/runtime/main.go b/runtime/main.go index 436738d..8244d63 100644 --- a/runtime/main.go +++ b/runtime/main.go @@ -28,6 +28,7 @@ const ( hookDefaultFilePath = "/usr/local/bin/ascend-docker-hook" dockerRuncFile = "docker-runc" runcFile = "runc" + envLength = 2 ) var ( @@ -108,6 +109,19 @@ func addHook(spec *specs.Spec) error { }) } + hasVirtualFlag := false + for _, line := range spec.Process.Env { + words := strings.Split(line, "=") + if len(words) == envLength && strings.TrimSpace(words[0]) == "ASCEND_RUNTIME_OPTIONS" { + if strings.Contains(words[1], "VIRTUAL") { + hasVirtualFlag = true + } + } + } + if hasVirtualFlag { + return nil + } + vdevice, err := dcmi.CreateVDevice(&dcmi.NpuWorker{}, spec) if err != nil { @@ -126,12 +140,11 @@ func updateEnvAndPostHook(spec *specs.Spec, vdevice dcmi.VDeviceInfo) { needAddVirtualFlag := true for _, line := range spec.Process.Env { words := strings.Split(line, "=") - const LENGTH int = 2 - if len(words) == LENGTH && strings.TrimSpace(words[0]) == "ASCEND_VISIBLE_DEVICES" { + if len(words) == envLength && strings.TrimSpace(words[0]) == "ASCEND_VISIBLE_DEVICES" { newEnv = append(newEnv, fmt.Sprintf("ASCEND_VISIBLE_DEVICES=%d", vdevice.VdeviceID)) continue } - if len(words) == LENGTH && strings.TrimSpace(words[0]) == "ASCEND_RUNTIME_OPTIONS" { + if len(words) == envLength && strings.TrimSpace(words[0]) == "ASCEND_RUNTIME_OPTIONS" { needAddVirtualFlag = false if strings.Contains(words[1], "VIRTUAL") { newEnv = append(newEnv, line)