diff --git a/examples/vision/detection/paddledetection/python/infer_ppyoloe.py b/examples/vision/detection/paddledetection/python/infer_ppyoloe.py index 1d3260b53..10163d83b 100644 --- a/examples/vision/detection/paddledetection/python/infer_ppyoloe.py +++ b/examples/vision/detection/paddledetection/python/infer_ppyoloe.py @@ -53,6 +53,8 @@ config_file = os.path.join(model_dir, "infer_cfg.yml") runtime_option = build_option(args) model = fd.vision.detection.PPYOLOE( model_file, params_file, config_file, runtime_option=runtime_option) +print(model._model.input_info_of_runtime(0)) +print(model._model.output_info_of_runtime(0)) # 预测图片检测结果 if args.image is None: diff --git a/python/fastdeploy/model.py b/python/fastdeploy/model.py index 0277e5b8f..59833f775 100644 --- a/python/fastdeploy/model.py +++ b/python/fastdeploy/model.py @@ -27,23 +27,23 @@ class FastDeployModel: def model_name(self): return self._model.model_name() - def num_inputs(self): - return self._model.num_inputs() + def num_inputs_of_runtime(self): + return self._model.num_inputs_of_runtime() - def num_outputs(self): - return self._model.num_outputs() + def num_outputs_of_runtime(self): + return self._model.num_outputs_of_runtime() - def get_input_info(self, index): - assert index < self.num_inputs( + def input_info_of_runtime(self, index): + assert index < self.num_inputs_of_runtime( ), "The index:{} must be less than number of inputs:{}.".format( - index, self.num_inputs()) - return self._model.get_input_info(index) + index, self.num_inputs_of_runtime()) + return self._model.input_info_of_runtime(index) - def get_output_info(self, index): - assert index < self.num_outputs( + def output_info_of_runtime(self, index): + assert index < self.num_outputs_of_runtime( ), "The index:{} must be less than number of outputs:{}.".format( - index, self.num_outputs()) - return self._model.get_output_info(index) + index, self.num_outputs_of_runtime()) + return self._model.output_info_of_runtime(index) def enable_record_time_of_runtime(self): self._model.enable_record_time_of_runtime()