Add batch size argument for uie

This commit is contained in:
zhoushunjie
2022-12-27 15:22:09 +00:00
parent df940b750f
commit 60f8f0e11b
6 changed files with 114 additions and 76 deletions

View File

@@ -35,24 +35,29 @@ void BindUIE(pybind11::module& m) {
py::class_<text::UIEModel, FastDeployModel>(m, "UIEModel")
.def(py::init<std::string, std::string, std::string, float, size_t,
std::vector<std::string>, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
std::vector<std::string>, int, RuntimeOption, ModelFormat,
text::SchemaLanguage>(),
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
py::arg("batch_size"),
py::arg("custom_option") = fastdeploy::RuntimeOption(),
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
py::arg("schema_language") = text::SchemaLanguage::ZH)
.def(
py::init<std::string, std::string, std::string, float, size_t,
std::vector<text::SchemaNode>, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
py::arg("custom_option") = fastdeploy::RuntimeOption(),
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
py::arg("schema_language") = text::SchemaLanguage::ZH)
.def(py::init<std::string, std::string, std::string, float, size_t,
text::SchemaNode, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
std::vector<text::SchemaNode>, int, RuntimeOption,
ModelFormat, text::SchemaLanguage>(),
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
py::arg("batch_size"),
py::arg("custom_option") = fastdeploy::RuntimeOption(),
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
py::arg("schema_language") = text::SchemaLanguage::ZH)
.def(py::init<std::string, std::string, std::string, float, size_t,
text::SchemaNode, int, RuntimeOption, ModelFormat,
text::SchemaLanguage>(),
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
py::arg("batch_size"),
py::arg("custom_option") = fastdeploy::RuntimeOption(),
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
py::arg("schema_language") = text::SchemaLanguage::ZH)
@@ -60,23 +65,25 @@ void BindUIE(pybind11::module& m) {
static_cast<void (text::UIEModel::*)(
const std::vector<std::string>&)>(&text::UIEModel::SetSchema),
py::arg("schema"))
.def("set_schema", static_cast<void (text::UIEModel::*)(
const std::vector<text::SchemaNode>&)>(
&text::UIEModel::SetSchema),
.def("set_schema",
static_cast<void (text::UIEModel::*)(
const std::vector<text::SchemaNode>&)>(
&text::UIEModel::SetSchema),
py::arg("schema"))
.def("set_schema",
static_cast<void (text::UIEModel::*)(const text::SchemaNode&)>(
&text::UIEModel::SetSchema),
py::arg("schema"))
.def("predict",
[](text::UIEModel& self, const std::vector<std::string>& texts) {
std::vector<
std::unordered_map<std::string, std::vector<text::UIEResult>>>
results;
self.Predict(texts, &results);
return results;
},
py::arg("text"));
.def(
"predict",
[](text::UIEModel& self, const std::vector<std::string>& texts) {
std::vector<
std::unordered_map<std::string, std::vector<text::UIEResult>>>
results;
self.Predict(texts, &results);
return results;
},
py::arg("text"));
}
} // namespace fastdeploy