Fix wrong api in fastdeploy model

This commit is contained in:
jiangjiajun
2022-11-22 06:25:05 +00:00
parent 1a9a2df782
commit 5d37f739e8
2 changed files with 14 additions and 12 deletions

View File

@@ -53,6 +53,8 @@ config_file = os.path.join(model_dir, "infer_cfg.yml")
runtime_option = build_option(args) runtime_option = build_option(args)
model = fd.vision.detection.PPYOLOE( model = fd.vision.detection.PPYOLOE(
model_file, params_file, config_file, runtime_option=runtime_option) model_file, params_file, config_file, runtime_option=runtime_option)
print(model._model.input_info_of_runtime(0))
print(model._model.output_info_of_runtime(0))
# 预测图片检测结果 # 预测图片检测结果
if args.image is None: if args.image is None:

View File

@@ -27,23 +27,23 @@ class FastDeployModel:
def model_name(self): def model_name(self):
return self._model.model_name() return self._model.model_name()
def num_inputs(self): def num_inputs_of_runtime(self):
return self._model.num_inputs() return self._model.num_inputs_of_runtime()
def num_outputs(self): def num_outputs_of_runtime(self):
return self._model.num_outputs() return self._model.num_outputs_of_runtime()
def get_input_info(self, index): def input_info_of_runtime(self, index):
assert index < self.num_inputs( assert index < self.num_inputs_of_runtime(
), "The index:{} must be less than number of inputs:{}.".format( ), "The index:{} must be less than number of inputs:{}.".format(
index, self.num_inputs()) index, self.num_inputs_of_runtime())
return self._model.get_input_info(index) return self._model.input_info_of_runtime(index)
def get_output_info(self, index): def output_info_of_runtime(self, index):
assert index < self.num_outputs( assert index < self.num_outputs_of_runtime(
), "The index:{} must be less than number of outputs:{}.".format( ), "The index:{} must be less than number of outputs:{}.".format(
index, self.num_outputs()) index, self.num_outputs_of_runtime())
return self._model.get_output_info(index) return self._model.output_info_of_runtime(index)
def enable_record_time_of_runtime(self): def enable_record_time_of_runtime(self):
self._model.enable_record_time_of_runtime() self._model.enable_record_time_of_runtime()