[Model] Update rkyolo pybind (#1294)

更新rkyolo pybind
This commit is contained in:
Zheng-Bicheng
2023-02-11 09:09:53 +08:00
committed by GitHub
parent 59a4ab343f
commit 6a3ac91057
3 changed files with 18 additions and 11 deletions

View File

@@ -80,6 +80,10 @@ class FASTDEPLOY_DECL RKYOLOPostprocessor {
obj_class_num_ = num; obj_class_num_ = num;
prob_box_size_ = obj_class_num_ + 5; prob_box_size_ = obj_class_num_ + 5;
} }
/// Get the number of class
int GetClassNum() {
return obj_class_num_;
}
private: private:
std::vector<int> anchors_ = {10, 13, 16, 30, 33, 23, 30, 61, 62, std::vector<int> anchors_ = {10, 13, 16, 30, 33, 23, 30, 61, 62,

View File

@@ -65,7 +65,9 @@ void BindRKYOLO(pybind11::module& m) {
.def_property("conf_threshold", &vision::detection::RKYOLOPostprocessor::GetConfThreshold, .def_property("conf_threshold", &vision::detection::RKYOLOPostprocessor::GetConfThreshold,
&vision::detection::RKYOLOPostprocessor::SetConfThreshold) &vision::detection::RKYOLOPostprocessor::SetConfThreshold)
.def_property("nms_threshold", &vision::detection::RKYOLOPostprocessor::GetNMSThreshold, .def_property("nms_threshold", &vision::detection::RKYOLOPostprocessor::GetNMSThreshold,
&vision::detection::RKYOLOPostprocessor::SetNMSThreshold); &vision::detection::RKYOLOPostprocessor::SetNMSThreshold)
.def_property("class_num", &vision::detection::RKYOLOPostprocessor::GetClassNum,
&vision::detection::RKYOLOPostprocessor::SetClassNum);
pybind11::class_<vision::detection::RKYOLOV5, FastDeployModel>(m, "RKYOLOV5") pybind11::class_<vision::detection::RKYOLOV5, FastDeployModel>(m, "RKYOLOV5")
.def(pybind11::init<std::string, .def(pybind11::init<std::string,

View File

@@ -108,11 +108,11 @@ class RKYOLOPostprocessor:
return self._postprocessor.nms_threshold return self._postprocessor.nms_threshold
@property @property
def multi_label(self): def class_num(self):
""" """
multi_label for postprocessing, set true for eval, default is True class_num for postprocessing, default is 80
""" """
return self._postprocessor.multi_label return self._postprocessor.class_num
@conf_threshold.setter @conf_threshold.setter
def conf_threshold(self, conf_threshold): def conf_threshold(self, conf_threshold):
@@ -126,13 +126,14 @@ class RKYOLOPostprocessor:
"The value to set `nms_threshold` must be type of float." "The value to set `nms_threshold` must be type of float."
self._postprocessor.nms_threshold = nms_threshold self._postprocessor.nms_threshold = nms_threshold
@multi_label.setter @class_num.setter
def multi_label(self, value): def class_num(self, class_num):
assert isinstance( """
value, class_num for postprocessing, default is 80
bool), "The value to set `multi_label` must be type of bool." """
self._postprocessor.multi_label = value assert isinstance(class_num, int), \
"The value to set `nms_threshold` must be type of float."
self._postprocessor.class_num = class_num
class RKYOLOV5(FastDeployModel): class RKYOLOV5(FastDeployModel):
def __init__(self, def __init__(self,