mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
Add batch size argument for uie
This commit is contained in:
@@ -68,9 +68,9 @@ int main(int argc, char* argv[]) {
|
|||||||
using fastdeploy::text::SchemaNode;
|
using fastdeploy::text::SchemaNode;
|
||||||
using fastdeploy::text::UIEResult;
|
using fastdeploy::text::UIEResult;
|
||||||
|
|
||||||
auto predictor =
|
auto predictor = fastdeploy::text::UIEModel(
|
||||||
fastdeploy::text::UIEModel(model_path, param_path, vocab_path, 0.5, 128,
|
model_path, param_path, vocab_path, 0.5, 128,
|
||||||
{"时间", "选手", "赛事名称"}, option);
|
{"时间", "选手", "赛事名称"}, /* batch_size = */ 1, option);
|
||||||
std::cout << "After init predictor" << std::endl;
|
std::cout << "After init predictor" << std::endl;
|
||||||
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>> results;
|
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>> results;
|
||||||
// Named Entity Recognition
|
// Named Entity Recognition
|
||||||
|
@@ -129,6 +129,7 @@ if __name__ == "__main__":
|
|||||||
position_prob=0.5,
|
position_prob=0.5,
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
|
batch_size=args.batch_size,
|
||||||
runtime_option=runtime_option,
|
runtime_option=runtime_option,
|
||||||
schema_language=SchemaLanguage.ZH)
|
schema_language=SchemaLanguage.ZH)
|
||||||
|
|
||||||
@@ -181,7 +182,8 @@ if __name__ == "__main__":
|
|||||||
schema = {"评价维度": ["观点词", "情感倾向[正向,负向]"]}
|
schema = {"评价维度": ["观点词", "情感倾向[正向,负向]"]}
|
||||||
print(f"The extraction schema: {schema}")
|
print(f"The extraction schema: {schema}")
|
||||||
uie.set_schema(schema)
|
uie.set_schema(schema)
|
||||||
results = uie.predict(["店面干净,很清静"], return_dict=True)
|
results = uie.predict(
|
||||||
|
["店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队"], return_dict=True)
|
||||||
pprint(results)
|
pprint(results)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
@@ -13,6 +13,8 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/text/uie/model.h"
|
#include "fastdeploy/text/uie/model.h"
|
||||||
|
#include "fastdeploy/function/concat.h"
|
||||||
|
#include "fastdeploy/function/split.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <codecvt>
|
#include <codecvt>
|
||||||
#include <locale>
|
#include <locale>
|
||||||
@@ -42,8 +44,7 @@ static std::string DBC2SBC(const std::string& content) {
|
|||||||
result.append(content.data() + content_utf8_len, content_char_width);
|
result.append(content.data() + content_utf8_len, content_char_width);
|
||||||
} else {
|
} else {
|
||||||
char dst_char[5] = {0};
|
char dst_char[5] = {0};
|
||||||
uint32_t utf8_uint32 =
|
uint32_t utf8_uint32 = fast_tokenizer::utils::UnicodeToUTF8(content_char);
|
||||||
fast_tokenizer::utils::UnicodeToUTF8(content_char);
|
|
||||||
uint32_t utf8_char_count =
|
uint32_t utf8_char_count =
|
||||||
fast_tokenizer::utils::UnicodeToUTF8Char(utf8_uint32, dst_char);
|
fast_tokenizer::utils::UnicodeToUTF8Char(utf8_uint32, dst_char);
|
||||||
result.append(dst_char, utf8_char_count);
|
result.append(dst_char, utf8_char_count);
|
||||||
@@ -164,12 +165,12 @@ UIEModel::UIEModel(const std::string& model_file,
|
|||||||
const std::string& params_file,
|
const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const std::vector<std::string>& schema,
|
size_t max_length, const std::vector<std::string>& schema,
|
||||||
|
int batch_size,
|
||||||
const fastdeploy::RuntimeOption& custom_option,
|
const fastdeploy::RuntimeOption& custom_option,
|
||||||
const fastdeploy::ModelFormat& model_format,
|
const fastdeploy::ModelFormat& model_format,
|
||||||
SchemaLanguage schema_language)
|
SchemaLanguage schema_language)
|
||||||
: max_length_(max_length),
|
: max_length_(max_length), position_prob_(position_prob),
|
||||||
position_prob_(position_prob),
|
schema_language_(schema_language), batch_size_(batch_size),
|
||||||
schema_language_(schema_language),
|
|
||||||
tokenizer_(vocab_file) {
|
tokenizer_(vocab_file) {
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
@@ -185,12 +186,12 @@ UIEModel::UIEModel(const std::string& model_file,
|
|||||||
const std::string& params_file,
|
const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const std::vector<SchemaNode>& schema,
|
size_t max_length, const std::vector<SchemaNode>& schema,
|
||||||
|
int batch_size,
|
||||||
const fastdeploy::RuntimeOption& custom_option,
|
const fastdeploy::RuntimeOption& custom_option,
|
||||||
const fastdeploy::ModelFormat& model_format,
|
const fastdeploy::ModelFormat& model_format,
|
||||||
SchemaLanguage schema_language)
|
SchemaLanguage schema_language)
|
||||||
: max_length_(max_length),
|
: max_length_(max_length), position_prob_(position_prob),
|
||||||
position_prob_(position_prob),
|
schema_language_(schema_language), batch_size_(batch_size),
|
||||||
schema_language_(schema_language),
|
|
||||||
tokenizer_(vocab_file) {
|
tokenizer_(vocab_file) {
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
@@ -205,13 +206,12 @@ UIEModel::UIEModel(const std::string& model_file,
|
|||||||
UIEModel::UIEModel(const std::string& model_file,
|
UIEModel::UIEModel(const std::string& model_file,
|
||||||
const std::string& params_file,
|
const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const SchemaNode& schema,
|
size_t max_length, const SchemaNode& schema, int batch_size,
|
||||||
const fastdeploy::RuntimeOption& custom_option,
|
const fastdeploy::RuntimeOption& custom_option,
|
||||||
const fastdeploy::ModelFormat& model_format,
|
const fastdeploy::ModelFormat& model_format,
|
||||||
SchemaLanguage schema_language)
|
SchemaLanguage schema_language)
|
||||||
: max_length_(max_length),
|
: max_length_(max_length), position_prob_(position_prob),
|
||||||
position_prob_(position_prob),
|
schema_language_(schema_language), batch_size_(batch_size),
|
||||||
schema_language_(schema_language),
|
|
||||||
tokenizer_(vocab_file) {
|
tokenizer_(vocab_file) {
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
@@ -230,7 +230,8 @@ bool UIEModel::Initialize() {
|
|||||||
|
|
||||||
void UIEModel::SetValidBackend() {
|
void UIEModel::SetValidBackend() {
|
||||||
// TODO(zhoushunjie): Add lite backend in future
|
// TODO(zhoushunjie): Add lite backend in future
|
||||||
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO, Backend::PDINFER, Backend::LITE};
|
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO, Backend::PDINFER,
|
||||||
|
Backend::LITE};
|
||||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -253,8 +254,8 @@ void UIEModel::AutoSplitter(const std::vector<std::string>& texts,
|
|||||||
size_t cnt_org = 0;
|
size_t cnt_org = 0;
|
||||||
size_t cnt_short = 0;
|
size_t cnt_short = 0;
|
||||||
for (auto& text : texts) {
|
for (auto& text : texts) {
|
||||||
auto text_len = fast_tokenizer::utils::GetUnicodeLenFromUTF8(
|
auto text_len = fast_tokenizer::utils::GetUnicodeLenFromUTF8(text.c_str(),
|
||||||
text.c_str(), text.length());
|
text.length());
|
||||||
if (text_len <= max_length) {
|
if (text_len <= max_length) {
|
||||||
short_texts->push_back(text);
|
short_texts->push_back(text);
|
||||||
if (input_mapping->size() <= cnt_org) {
|
if (input_mapping->size() <= cnt_org) {
|
||||||
@@ -264,8 +265,7 @@ void UIEModel::AutoSplitter(const std::vector<std::string>& texts,
|
|||||||
}
|
}
|
||||||
cnt_short += 1;
|
cnt_short += 1;
|
||||||
} else {
|
} else {
|
||||||
fast_tokenizer::pretokenizers::CharToBytesOffsetConverter converter(
|
fast_tokenizer::pretokenizers::CharToBytesOffsetConverter converter(text);
|
||||||
text);
|
|
||||||
for (size_t start = 0; start < text_len; start += max_length) {
|
for (size_t start = 0; start < text_len; start += max_length) {
|
||||||
size_t end = start + max_length;
|
size_t end = start + max_length;
|
||||||
if (end > text_len) {
|
if (end > text_len) {
|
||||||
@@ -742,13 +742,37 @@ void UIEModel::Predict(
|
|||||||
std::vector<fast_tokenizer::core::Encoding> encodings;
|
std::vector<fast_tokenizer::core::Encoding> encodings;
|
||||||
Preprocess(short_input_texts, short_prompts, &encodings, &inputs);
|
Preprocess(short_input_texts, short_prompts, &encodings, &inputs);
|
||||||
|
|
||||||
|
std::vector<std::vector<FDTensor>> inputs_vec(NumInputsOfRuntime());
|
||||||
|
int encoding_size = encodings.size();
|
||||||
|
std::vector<int> num_or_sections;
|
||||||
|
for (int i = 0; i < encoding_size; ++i) {
|
||||||
|
int actual_batch_size = (std::min)(batch_size_, encoding_size - i);
|
||||||
|
num_or_sections.push_back(actual_batch_size);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < NumInputsOfRuntime(); ++i) {
|
||||||
|
function::Split(inputs[i], num_or_sections, &inputs_vec[i]);
|
||||||
|
}
|
||||||
|
|
||||||
// 3. Infer
|
// 3. Infer
|
||||||
std::vector<fastdeploy::FDTensor> outputs(NumOutputsOfRuntime());
|
std::vector<fastdeploy::FDTensor> outputs(NumOutputsOfRuntime());
|
||||||
if (!Infer(inputs, &outputs)) {
|
std::vector<fastdeploy::FDTensor> outputs0, outputs1;
|
||||||
|
|
||||||
|
for (int i = 0; i < inputs_vec[0].size(); ++i) {
|
||||||
|
std::vector<fastdeploy::FDTensor> curr_inputs(NumInputsOfRuntime());
|
||||||
|
std::vector<fastdeploy::FDTensor> curr_outputs(NumOutputsOfRuntime());
|
||||||
|
for (int j = 0; j < NumInputsOfRuntime(); ++j) {
|
||||||
|
curr_inputs[j] = std::move(inputs_vec[j][i]);
|
||||||
|
curr_inputs[j].name = inputs[j].name;
|
||||||
|
}
|
||||||
|
if (!Infer(curr_inputs, &curr_outputs)) {
|
||||||
FDERROR << "Failed to inference while using model:" << ModelName()
|
FDERROR << "Failed to inference while using model:" << ModelName()
|
||||||
<< "." << std::endl;
|
<< "." << std::endl;
|
||||||
}
|
}
|
||||||
|
outputs0.push_back(curr_outputs[0]);
|
||||||
|
outputs1.push_back(curr_outputs[1]);
|
||||||
|
}
|
||||||
|
function::Concat(outputs0, &outputs[0]);
|
||||||
|
function::Concat(outputs1, &outputs[1]);
|
||||||
// 4. Convert FDTensor to UIEResult
|
// 4. Convert FDTensor to UIEResult
|
||||||
Postprocess(outputs, encodings, short_input_texts, short_prompts,
|
Postprocess(outputs, encodings, short_input_texts, short_prompts,
|
||||||
input_mapping_with_short_text, &results_list);
|
input_mapping_with_short_text, &results_list);
|
||||||
|
@@ -14,14 +14,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "fast_tokenizer/tokenizers/ernie_fast_tokenizer.h"
|
||||||
|
#include "fastdeploy/fastdeploy_model.h"
|
||||||
|
#include "fastdeploy/utils/unique_ptr.h"
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "fastdeploy/fastdeploy_model.h"
|
|
||||||
#include "fastdeploy/utils/unique_ptr.h"
|
|
||||||
#include "fast_tokenizer/tokenizers/ernie_fast_tokenizer.h"
|
|
||||||
|
|
||||||
using namespace paddlenlp;
|
using namespace paddlenlp;
|
||||||
|
|
||||||
@@ -99,6 +99,7 @@ struct FASTDEPLOY_DECL UIEModel : public FastDeployModel {
|
|||||||
UIEModel(const std::string& model_file, const std::string& params_file,
|
UIEModel(const std::string& model_file, const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const std::vector<std::string>& schema,
|
size_t max_length, const std::vector<std::string>& schema,
|
||||||
|
int batch_size,
|
||||||
const fastdeploy::RuntimeOption& custom_option =
|
const fastdeploy::RuntimeOption& custom_option =
|
||||||
fastdeploy::RuntimeOption(),
|
fastdeploy::RuntimeOption(),
|
||||||
const fastdeploy::ModelFormat& model_format =
|
const fastdeploy::ModelFormat& model_format =
|
||||||
@@ -106,7 +107,7 @@ struct FASTDEPLOY_DECL UIEModel : public FastDeployModel {
|
|||||||
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||||
UIEModel(const std::string& model_file, const std::string& params_file,
|
UIEModel(const std::string& model_file, const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const SchemaNode& schema,
|
size_t max_length, const SchemaNode& schema, int batch_size,
|
||||||
const fastdeploy::RuntimeOption& custom_option =
|
const fastdeploy::RuntimeOption& custom_option =
|
||||||
fastdeploy::RuntimeOption(),
|
fastdeploy::RuntimeOption(),
|
||||||
const fastdeploy::ModelFormat& model_format =
|
const fastdeploy::ModelFormat& model_format =
|
||||||
@@ -115,6 +116,7 @@ struct FASTDEPLOY_DECL UIEModel : public FastDeployModel {
|
|||||||
UIEModel(const std::string& model_file, const std::string& params_file,
|
UIEModel(const std::string& model_file, const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const std::vector<SchemaNode>& schema,
|
size_t max_length, const std::vector<SchemaNode>& schema,
|
||||||
|
int batch_size,
|
||||||
const fastdeploy::RuntimeOption& custom_option =
|
const fastdeploy::RuntimeOption& custom_option =
|
||||||
fastdeploy::RuntimeOption(),
|
fastdeploy::RuntimeOption(),
|
||||||
const fastdeploy::ModelFormat& model_format =
|
const fastdeploy::ModelFormat& model_format =
|
||||||
@@ -154,8 +156,8 @@ struct FASTDEPLOY_DECL UIEModel : public FastDeployModel {
|
|||||||
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>*
|
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>*
|
||||||
results,
|
results,
|
||||||
std::vector<std::vector<UIEResult*>>* new_relations);
|
std::vector<std::vector<UIEResult*>>* new_relations);
|
||||||
void Predict(
|
void
|
||||||
const std::vector<std::string>& texts,
|
Predict(const std::vector<std::string>& texts,
|
||||||
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>*
|
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>*
|
||||||
results);
|
results);
|
||||||
|
|
||||||
@@ -190,8 +192,8 @@ struct FASTDEPLOY_DECL UIEModel : public FastDeployModel {
|
|||||||
const SPAN_SET& span_set,
|
const SPAN_SET& span_set,
|
||||||
const std::vector<fast_tokenizer::core::Offset>& offset_mapping,
|
const std::vector<fast_tokenizer::core::Offset>& offset_mapping,
|
||||||
std::vector<SpanIdx>* span_idxs, std::vector<float>* probs) const;
|
std::vector<SpanIdx>* span_idxs, std::vector<float>* probs) const;
|
||||||
void ConvertSpanToUIEResult(
|
void
|
||||||
const std::vector<std::string>& texts,
|
ConvertSpanToUIEResult(const std::vector<std::string>& texts,
|
||||||
const std::vector<std::string>& prompts,
|
const std::vector<std::string>& prompts,
|
||||||
const std::vector<std::vector<SpanIdx>>& span_idxs,
|
const std::vector<std::vector<SpanIdx>>& span_idxs,
|
||||||
const std::vector<std::vector<float>>& probs,
|
const std::vector<std::vector<float>>& probs,
|
||||||
@@ -199,6 +201,7 @@ struct FASTDEPLOY_DECL UIEModel : public FastDeployModel {
|
|||||||
std::unique_ptr<Schema> schema_;
|
std::unique_ptr<Schema> schema_;
|
||||||
size_t max_length_;
|
size_t max_length_;
|
||||||
float position_prob_;
|
float position_prob_;
|
||||||
|
int batch_size_;
|
||||||
SchemaLanguage schema_language_;
|
SchemaLanguage schema_language_;
|
||||||
fast_tokenizer::tokenizers_impl::ErnieFastTokenizer tokenizer_;
|
fast_tokenizer::tokenizers_impl::ErnieFastTokenizer tokenizer_;
|
||||||
};
|
};
|
||||||
|
@@ -35,24 +35,29 @@ void BindUIE(pybind11::module& m) {
|
|||||||
|
|
||||||
py::class_<text::UIEModel, FastDeployModel>(m, "UIEModel")
|
py::class_<text::UIEModel, FastDeployModel>(m, "UIEModel")
|
||||||
.def(py::init<std::string, std::string, std::string, float, size_t,
|
.def(py::init<std::string, std::string, std::string, float, size_t,
|
||||||
std::vector<std::string>, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
|
std::vector<std::string>, int, RuntimeOption, ModelFormat,
|
||||||
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
text::SchemaLanguage>(),
|
||||||
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
|
||||||
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
|
||||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
|
||||||
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
|
||||||
.def(
|
|
||||||
py::init<std::string, std::string, std::string, float, size_t,
|
|
||||||
std::vector<text::SchemaNode>, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
|
|
||||||
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
||||||
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
||||||
|
py::arg("batch_size"),
|
||||||
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
||||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
||||||
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
||||||
.def(py::init<std::string, std::string, std::string, float, size_t,
|
.def(py::init<std::string, std::string, std::string, float, size_t,
|
||||||
text::SchemaNode, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
|
std::vector<text::SchemaNode>, int, RuntimeOption,
|
||||||
|
ModelFormat, text::SchemaLanguage>(),
|
||||||
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
||||||
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
||||||
|
py::arg("batch_size"),
|
||||||
|
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
||||||
|
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
||||||
|
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
||||||
|
.def(py::init<std::string, std::string, std::string, float, size_t,
|
||||||
|
text::SchemaNode, int, RuntimeOption, ModelFormat,
|
||||||
|
text::SchemaLanguage>(),
|
||||||
|
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
||||||
|
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
||||||
|
py::arg("batch_size"),
|
||||||
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
||||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
||||||
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
||||||
@@ -60,7 +65,8 @@ void BindUIE(pybind11::module& m) {
|
|||||||
static_cast<void (text::UIEModel::*)(
|
static_cast<void (text::UIEModel::*)(
|
||||||
const std::vector<std::string>&)>(&text::UIEModel::SetSchema),
|
const std::vector<std::string>&)>(&text::UIEModel::SetSchema),
|
||||||
py::arg("schema"))
|
py::arg("schema"))
|
||||||
.def("set_schema", static_cast<void (text::UIEModel::*)(
|
.def("set_schema",
|
||||||
|
static_cast<void (text::UIEModel::*)(
|
||||||
const std::vector<text::SchemaNode>&)>(
|
const std::vector<text::SchemaNode>&)>(
|
||||||
&text::UIEModel::SetSchema),
|
&text::UIEModel::SetSchema),
|
||||||
py::arg("schema"))
|
py::arg("schema"))
|
||||||
@@ -68,7 +74,8 @@ void BindUIE(pybind11::module& m) {
|
|||||||
static_cast<void (text::UIEModel::*)(const text::SchemaNode&)>(
|
static_cast<void (text::UIEModel::*)(const text::SchemaNode&)>(
|
||||||
&text::UIEModel::SetSchema),
|
&text::UIEModel::SetSchema),
|
||||||
py::arg("schema"))
|
py::arg("schema"))
|
||||||
.def("predict",
|
.def(
|
||||||
|
"predict",
|
||||||
[](text::UIEModel& self, const std::vector<std::string>& texts) {
|
[](text::UIEModel& self, const std::vector<std::string>& texts) {
|
||||||
std::vector<
|
std::vector<
|
||||||
std::unordered_map<std::string, std::vector<text::UIEResult>>>
|
std::unordered_map<std::string, std::vector<text::UIEResult>>>
|
||||||
|
@@ -50,6 +50,7 @@ class UIEModel(FastDeployModel):
|
|||||||
position_prob=0.5,
|
position_prob=0.5,
|
||||||
max_length=128,
|
max_length=128,
|
||||||
schema=[],
|
schema=[],
|
||||||
|
batch_size=64,
|
||||||
runtime_option=RuntimeOption(),
|
runtime_option=RuntimeOption(),
|
||||||
model_format=ModelFormat.PADDLE,
|
model_format=ModelFormat.PADDLE,
|
||||||
schema_language=SchemaLanguage.ZH):
|
schema_language=SchemaLanguage.ZH):
|
||||||
@@ -63,9 +64,10 @@ class UIEModel(FastDeployModel):
|
|||||||
else:
|
else:
|
||||||
assert "The type of schema should be list or dict."
|
assert "The type of schema should be list or dict."
|
||||||
schema_language = C.text.SchemaLanguage(schema_language)
|
schema_language = C.text.SchemaLanguage(schema_language)
|
||||||
self._model = C.text.UIEModel(
|
self._model = C.text.UIEModel(model_file, params_file, vocab_file,
|
||||||
model_file, params_file, vocab_file, position_prob, max_length,
|
position_prob, max_length, schema,
|
||||||
schema, runtime_option._option, model_format, schema_language)
|
batch_size, runtime_option._option,
|
||||||
|
model_format, schema_language)
|
||||||
assert self.initialized, "UIEModel initialize failed."
|
assert self.initialized, "UIEModel initialize failed."
|
||||||
|
|
||||||
def set_schema(self, schema):
|
def set_schema(self, schema):
|
||||||
|
Reference in New Issue
Block a user