mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
@@ -80,6 +80,10 @@ class FASTDEPLOY_DECL RKYOLOPostprocessor {
|
|||||||
obj_class_num_ = num;
|
obj_class_num_ = num;
|
||||||
prob_box_size_ = obj_class_num_ + 5;
|
prob_box_size_ = obj_class_num_ + 5;
|
||||||
}
|
}
|
||||||
|
/// Get the number of class
|
||||||
|
int GetClassNum() {
|
||||||
|
return obj_class_num_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> anchors_ = {10, 13, 16, 30, 33, 23, 30, 61, 62,
|
std::vector<int> anchors_ = {10, 13, 16, 30, 33, 23, 30, 61, 62,
|
||||||
|
@@ -65,7 +65,9 @@ void BindRKYOLO(pybind11::module& m) {
|
|||||||
.def_property("conf_threshold", &vision::detection::RKYOLOPostprocessor::GetConfThreshold,
|
.def_property("conf_threshold", &vision::detection::RKYOLOPostprocessor::GetConfThreshold,
|
||||||
&vision::detection::RKYOLOPostprocessor::SetConfThreshold)
|
&vision::detection::RKYOLOPostprocessor::SetConfThreshold)
|
||||||
.def_property("nms_threshold", &vision::detection::RKYOLOPostprocessor::GetNMSThreshold,
|
.def_property("nms_threshold", &vision::detection::RKYOLOPostprocessor::GetNMSThreshold,
|
||||||
&vision::detection::RKYOLOPostprocessor::SetNMSThreshold);
|
&vision::detection::RKYOLOPostprocessor::SetNMSThreshold)
|
||||||
|
.def_property("class_num", &vision::detection::RKYOLOPostprocessor::GetClassNum,
|
||||||
|
&vision::detection::RKYOLOPostprocessor::SetClassNum);
|
||||||
|
|
||||||
pybind11::class_<vision::detection::RKYOLOV5, FastDeployModel>(m, "RKYOLOV5")
|
pybind11::class_<vision::detection::RKYOLOV5, FastDeployModel>(m, "RKYOLOV5")
|
||||||
.def(pybind11::init<std::string,
|
.def(pybind11::init<std::string,
|
||||||
|
@@ -108,11 +108,11 @@ class RKYOLOPostprocessor:
|
|||||||
return self._postprocessor.nms_threshold
|
return self._postprocessor.nms_threshold
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multi_label(self):
|
def class_num(self):
|
||||||
"""
|
"""
|
||||||
multi_label for postprocessing, set true for eval, default is True
|
class_num for postprocessing, default is 80
|
||||||
"""
|
"""
|
||||||
return self._postprocessor.multi_label
|
return self._postprocessor.class_num
|
||||||
|
|
||||||
@conf_threshold.setter
|
@conf_threshold.setter
|
||||||
def conf_threshold(self, conf_threshold):
|
def conf_threshold(self, conf_threshold):
|
||||||
@@ -126,13 +126,14 @@ class RKYOLOPostprocessor:
|
|||||||
"The value to set `nms_threshold` must be type of float."
|
"The value to set `nms_threshold` must be type of float."
|
||||||
self._postprocessor.nms_threshold = nms_threshold
|
self._postprocessor.nms_threshold = nms_threshold
|
||||||
|
|
||||||
@multi_label.setter
|
@class_num.setter
|
||||||
def multi_label(self, value):
|
def class_num(self, class_num):
|
||||||
assert isinstance(
|
"""
|
||||||
value,
|
class_num for postprocessing, default is 80
|
||||||
bool), "The value to set `multi_label` must be type of bool."
|
"""
|
||||||
self._postprocessor.multi_label = value
|
assert isinstance(class_num, int), \
|
||||||
|
"The value to set `nms_threshold` must be type of float."
|
||||||
|
self._postprocessor.class_num = class_num
|
||||||
|
|
||||||
class RKYOLOV5(FastDeployModel):
|
class RKYOLOV5(FastDeployModel):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
Reference in New Issue
Block a user