Add batch size argument for uie

This commit is contained in:
zhoushunjie
2022-12-27 15:22:09 +00:00
parent df940b750f
commit 60f8f0e11b
6 changed files with 114 additions and 76 deletions

View File

@@ -49,17 +49,17 @@ int main(int argc, char* argv[]) {
backend_type = std::atoi(argv[3]);
}
switch (backend_type) {
case 0:
option.UsePaddleInferBackend();
break;
case 1:
option.UseOrtBackend();
break;
case 2:
option.UseOpenVINOBackend();
break;
default:
break;
case 0:
option.UsePaddleInferBackend();
break;
case 1:
option.UseOrtBackend();
break;
case 2:
option.UseOpenVINOBackend();
break;
default:
break;
}
std::string model_dir(argv[1]);
std::string model_path = model_dir + sep + "inference.pdmodel";
@@ -68,9 +68,9 @@ int main(int argc, char* argv[]) {
using fastdeploy::text::SchemaNode;
using fastdeploy::text::UIEResult;
auto predictor =
fastdeploy::text::UIEModel(model_path, param_path, vocab_path, 0.5, 128,
{"时间", "选手", "赛事名称"}, option);
auto predictor = fastdeploy::text::UIEModel(
model_path, param_path, vocab_path, 0.5, 128,
{"时间", "选手", "赛事名称"}, /* batch_size = */ 1, option);
std::cout << "After init predictor" << std::endl;
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>> results;
// Named Entity Recognition

View File

@@ -129,6 +129,7 @@ if __name__ == "__main__":
position_prob=0.5,
max_length=args.max_length,
schema=schema,
batch_size=args.batch_size,
runtime_option=runtime_option,
schema_language=SchemaLanguage.ZH)
@@ -181,7 +182,8 @@ if __name__ == "__main__":
schema = {"评价维度": ["观点词", "情感倾向[正向,负向]"]}
print(f"The extraction schema: {schema}")
uie.set_schema(schema)
results = uie.predict(["店面干净,很清静"], return_dict=True)
results = uie.predict(
["店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队"], return_dict=True)
pprint(results)
print()