diff --git a/runtime/dcmi/dcmi.go b/runtime/dcmi/dcmi.go index 7a9dbaa..0c9e483 100644 --- a/runtime/dcmi/dcmi.go +++ b/runtime/dcmi/dcmi.go @@ -10,8 +10,9 @@ import "C" import ( "fmt" "math" - "mindxcheckutils" "unsafe" + + "mindxcheckutils" ) const ( @@ -71,7 +72,7 @@ func GetCardList() (int32, []int32, error) { } var cardNum = int32(cNum) var cardIDList []int32 - for i := int32(0); i < cardNum && i < hiAIMaxCardNum; i++ { + for i := int32(0); i < cardNum; i++ { cardID := int32(ids[i]) if cardID < 0 { continue @@ -157,7 +158,7 @@ func (w *NpuWorker) FindDevice(visibleDevice int32) (int32, int32, error) { if err := C.dcmi_get_device_logicid_from_phyid(C.uint(visibleDevice), &dcmiLogicID); err != 0 { return 0, 0, fmt.Errorf("phy id can not be converted to logic id : %v", err) } - if uint(dcmiLogicID) > math.MaxInt32 { + if int32(dcmiLogicID) < 0 || int32(dcmiLogicID) >= hiAIMaxCardNum*hiAIMaxDeviceNum { return 0, 0, fmt.Errorf("logic id too large") } targetLogicID := int32(dcmiLogicID) diff --git a/runtime/dcmi/dcmi_api.go b/runtime/dcmi/dcmi_api.go index 92da7e5..0884bee 100644 --- a/runtime/dcmi/dcmi_api.go +++ b/runtime/dcmi/dcmi_api.go @@ -5,7 +5,6 @@ package dcmi import ( "fmt" - "math" "strconv" "strings" @@ -52,7 +51,7 @@ func CreateVDevice(w WorkerInterface, spec *specs.Spec) (VDeviceInfo, error) { } func extractVpuParam(spec *specs.Spec) (int32, string, error) { - visibleDevice, splitDevice, needSplit, visibleDeviceLine := int32(-1), "", false, "" + splitDevice, needSplit, visibleDeviceLine := "", false, "" allowSplit := map[string]string{ "vir01": "vir01", "vir02": "vir02", "vir04": "vir04", "vir08": "vir08", "vir16": "vir16", "vir04_3c": "vir04_3c", "vir02_1c": "vir02_1c", "vir04_4c_dvpp": "vir04_4c_dvpp", @@ -80,10 +79,11 @@ func extractVpuParam(spec *specs.Spec) (int32, string, error) { if !needSplit { return -1, "", nil } - if cardID, err := strconv.Atoi(visibleDeviceLine); err == nil && cardID >= 0 && cardID <= math.MaxInt32 { - visibleDevice = int32(cardID) - } else { - return -1, "", fmt.Errorf("cannot parse param : %v %v", err, visibleDeviceLine) + visibleDevice, err := strconv.Atoi(visibleDeviceLine) + if err != nil || visibleDevice < 0 || visibleDevice >= hiAIMaxCardNum*hiAIMaxDeviceNum { + return -1, "", fmt.Errorf("cannot parse param : %v %s", err, visibleDeviceLine) + } - return visibleDevice, splitDevice, nil + + return int32(visibleDevice), splitDevice, nil } diff --git a/runtime/main.go b/runtime/main.go index bde9049..7d02e5c 100644 --- a/runtime/main.go +++ b/runtime/main.go @@ -143,6 +143,7 @@ func addHook(spec *specs.Spec) error { for _, hook := range spec.Hooks.Prestart { if strings.Contains(hook.Path, hookCli) { needUpdate = false + break } } if needUpdate { @@ -158,6 +159,7 @@ func addHook(spec *specs.Spec) error { if len(words) == envLength && strings.TrimSpace(words[0]) == "ASCEND_RUNTIME_OPTIONS" { if strings.Contains(words[1], "VIRTUAL") { hasVirtualFlag = true + break } } }