Add batch size for android example

This commit is contained in:
zhoushunjie
2022-12-28 02:25:57 +00:00
parent cfac517ef3
commit 01bf63e8a7
2 changed files with 65 additions and 78 deletions

View File

@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <jni.h> // NOLINT
#include "fastdeploy_jni/perf_jni.h" // NOLINT
#include "fastdeploy_jni/convert_jni.h" // NOLINT #include "fastdeploy_jni/convert_jni.h" // NOLINT
#include "fastdeploy_jni/perf_jni.h" // NOLINT
#include "fastdeploy_jni/runtime_option_jni.h" // NOLINT #include "fastdeploy_jni/runtime_option_jni.h" // NOLINT
#include "fastdeploy_jni/text/text_results_jni.h" // NOLINT #include "fastdeploy_jni/text/text_results_jni.h" // NOLINT
#include "fastdeploy_jni/text/uie/uie_utils_jni.h" // NOLINT #include "fastdeploy_jni/text/uie/uie_utils_jni.h" // NOLINT
#include <jni.h> // NOLINT
#ifdef ENABLE_TEXT #ifdef ENABLE_TEXT
#include "fastdeploy/text.h" // NOLINT #include "fastdeploy/text.h" // NOLINT
#endif #endif
@@ -32,15 +32,10 @@ extern "C" {
#endif #endif
JNIEXPORT jlong JNICALL JNIEXPORT jlong JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(JNIEnv *env, Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(
jobject thiz, JNIEnv* env, jobject thiz, jstring model_file, jstring params_file,
jstring model_file, jstring vocab_file, jfloat position_prob, jint max_length,
jstring params_file, jobjectArray schema, jint batch_size, jobject runtime_option,
jstring vocab_file,
jfloat position_prob,
jint max_length,
jobjectArray schema,
jobject runtime_option,
jint schema_language) { jint schema_language) {
#ifndef ENABLE_TEXT #ifndef ENABLE_TEXT
return 0; return 0;
@@ -51,17 +46,13 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(JNIEnv *env,
auto c_position_prob = static_cast<jfloat>(position_prob); auto c_position_prob = static_cast<jfloat>(position_prob);
auto c_max_length = static_cast<size_t>(max_length); auto c_max_length = static_cast<size_t>(max_length);
auto c_schema = fni::ConvertTo<std::vector<std::string>>(env, schema); auto c_schema = fni::ConvertTo<std::vector<std::string>>(env, schema);
auto c_batch_size = static_cast<int>(batch_size);
auto c_runtime_option = fni::NewCxxRuntimeOption(env, runtime_option); auto c_runtime_option = fni::NewCxxRuntimeOption(env, runtime_option);
auto c_schema_language = static_cast<text::SchemaLanguage>(schema_language); auto c_schema_language = static_cast<text::SchemaLanguage>(schema_language);
auto c_paddle_model_format = fastdeploy::ModelFormat::PADDLE; auto c_paddle_model_format = fastdeploy::ModelFormat::PADDLE;
auto c_model_ptr = new text::UIEModel(c_model_file, auto c_model_ptr = new text::UIEModel(
c_params_file, c_model_file, c_params_file, c_vocab_file, c_position_prob, c_max_length,
c_vocab_file, c_schema, c_batch_size, c_runtime_option, c_paddle_model_format,
c_position_prob,
c_max_length,
c_schema,
c_runtime_option,
c_paddle_model_format,
c_schema_language); c_schema_language);
INITIALIZED_OR_RETURN(c_model_ptr) INITIALIZED_OR_RETURN(c_model_ptr)
@@ -73,10 +64,8 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(JNIEnv *env,
} }
JNIEXPORT jobjectArray JNICALL JNIEXPORT jobjectArray JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env, Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(
jobject thiz, JNIEnv* env, jobject thiz, jlong cxx_context, jobjectArray texts) {
jlong cxx_context,
jobjectArray texts) {
#ifndef ENABLE_TEXT #ifndef ENABLE_TEXT
return NULL; return NULL;
#else #else
@@ -91,8 +80,8 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env,
} }
LOGD("c_texts: %s", fni::UIETextsStr(c_texts).c_str()); LOGD("c_texts: %s", fni::UIETextsStr(c_texts).c_str());
std::vector<std::unordered_map< std::vector<std::unordered_map<std::string, std::vector<text::UIEResult>>>
std::string, std::vector<text::UIEResult>>> c_results; c_results;
auto t = fni::GetCurrentTime(); auto t = fni::GetCurrentTime();
c_model_ptr->Predict(c_texts, &c_results); c_model_ptr->Predict(c_texts, &c_results);
@@ -107,49 +96,45 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env,
// Push results to HashMap array // Push results to HashMap array
const char* j_hashmap_put_signature = const char* j_hashmap_put_signature =
"(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;";
const jclass j_hashmap_clazz = env->FindClass( const jclass j_hashmap_clazz = env->FindClass("java/util/HashMap");
"java/util/HashMap"); const jclass j_uie_result_clazz =
const jclass j_uie_result_clazz = env->FindClass( env->FindClass("com/baidu/paddle/fastdeploy/text/UIEResult");
"com/baidu/paddle/fastdeploy/text/UIEResult");
// Get HashMap method id // Get HashMap method id
const jmethodID j_hashmap_init = env->GetMethodID( const jmethodID j_hashmap_init =
j_hashmap_clazz, "<init>", "()V"); env->GetMethodID(j_hashmap_clazz, "<init>", "()V");
const jmethodID j_hashmap_put = env->GetMethodID( const jmethodID j_hashmap_put =
j_hashmap_clazz,"put", j_hashmap_put_signature); env->GetMethodID(j_hashmap_clazz, "put", j_hashmap_put_signature);
const int c_uie_result_hashmap_size = c_results.size(); const int c_uie_result_hashmap_size = c_results.size();
jobjectArray j_hashmap_uie_result_arr = env->NewObjectArray( jobjectArray j_hashmap_uie_result_arr =
c_uie_result_hashmap_size, j_hashmap_clazz, NULL); env->NewObjectArray(c_uie_result_hashmap_size, j_hashmap_clazz, NULL);
for (int i = 0; i < c_uie_result_hashmap_size; ++i) { for (int i = 0; i < c_uie_result_hashmap_size; ++i) {
auto& curr_c_uie_result_map = c_results[i]; auto& curr_c_uie_result_map = c_results[i];
// Convert unordered_map<string, vector<UIEResult>> // Convert unordered_map<string, vector<UIEResult>>
// -> HashMap<String, UIEResult[]> // -> HashMap<String, UIEResult[]>
jobject curr_j_uie_result_hashmap = env->NewObject( jobject curr_j_uie_result_hashmap =
j_hashmap_clazz, j_hashmap_init); env->NewObject(j_hashmap_clazz, j_hashmap_init);
for (auto&& curr_c_uie_result : curr_c_uie_result_map) { for (auto&& curr_c_uie_result : curr_c_uie_result_map) {
const auto& curr_inner_c_uie_key = curr_c_uie_result.first; const auto& curr_inner_c_uie_key = curr_c_uie_result.first;
jstring curr_inner_j_uie_key = fni::ConvertTo<jstring>( jstring curr_inner_j_uie_key =
env, curr_inner_c_uie_key); // Key of HashMap fni::ConvertTo<jstring>(env, curr_inner_c_uie_key); // Key of HashMap
if (curr_c_uie_result.second.size() > 0) { if (curr_c_uie_result.second.size() > 0) {
// Value of HashMap: HashMap<String, UIEResult[]> // Value of HashMap: HashMap<String, UIEResult[]>
jobjectArray curr_inner_j_uie_result_values = jobjectArray curr_inner_j_uie_result_values = env->NewObjectArray(
env->NewObjectArray(curr_c_uie_result.second.size(), curr_c_uie_result.second.size(), j_uie_result_clazz, NULL);
j_uie_result_clazz,
NULL);
// Convert vector<UIEResult> -> Java UIEResult[] // Convert vector<UIEResult> -> Java UIEResult[]
for (int j = 0; j < curr_c_uie_result.second.size(); ++j) { for (int j = 0; j < curr_c_uie_result.second.size(); ++j) {
text::UIEResult* inner_c_uie_result = ( text::UIEResult* inner_c_uie_result =
&(curr_c_uie_result.second[j])); (&(curr_c_uie_result.second[j]));
jobject curr_inner_j_uie_result_obj = jobject curr_inner_j_uie_result_obj = fni::NewUIEJavaResultFromCxx(
fni::NewUIEJavaResultFromCxx(
env, reinterpret_cast<void*>(inner_c_uie_result)); env, reinterpret_cast<void*>(inner_c_uie_result));
env->SetObjectArrayElement(curr_inner_j_uie_result_values, j, env->SetObjectArrayElement(curr_inner_j_uie_result_values, j,
@@ -159,9 +144,9 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env,
// Set element of 'curr_j_uie_result_hashmap': // Set element of 'curr_j_uie_result_hashmap':
// HashMap<String, UIEResult[]> // HashMap<String, UIEResult[]>
env->CallObjectMethod( env->CallObjectMethod(curr_j_uie_result_hashmap, j_hashmap_put,
curr_j_uie_result_hashmap, j_hashmap_put, curr_inner_j_uie_key,
curr_inner_j_uie_key, curr_inner_j_uie_result_values); curr_inner_j_uie_result_values);
env->DeleteLocalRef(curr_inner_j_uie_key); env->DeleteLocalRef(curr_inner_j_uie_key);
env->DeleteLocalRef(curr_inner_j_uie_result_values); env->DeleteLocalRef(curr_inner_j_uie_result_values);
@@ -179,9 +164,8 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env,
} }
JNIEXPORT jboolean JNICALL JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_releaseNative(JNIEnv *env, Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_releaseNative(
jobject thiz, JNIEnv* env, jobject thiz, jlong cxx_context) {
jlong cxx_context) {
#ifndef ENABLE_TEXT #ifndef ENABLE_TEXT
return JNI_FALSE; return JNI_FALSE;
#else #else
@@ -199,8 +183,7 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_releaseNative(JNIEnv *env,
JNIEXPORT jboolean JNICALL JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaStringNative( Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaStringNative(
JNIEnv *env, jobject thiz, jlong cxx_context, JNIEnv* env, jobject thiz, jlong cxx_context, jobjectArray schema) {
jobjectArray schema) {
#ifndef ENABLE_TEXT #ifndef ENABLE_TEXT
return JNI_FALSE; return JNI_FALSE;
#else #else
@@ -221,8 +204,7 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaStringNative(
JNIEXPORT jboolean JNICALL JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative( Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative(
JNIEnv *env, jobject thiz, jlong cxx_context, JNIEnv* env, jobject thiz, jlong cxx_context, jobjectArray schema) {
jobjectArray schema) {
#ifndef ENABLE_TEXT #ifndef ENABLE_TEXT
return JNI_FALSE; return JNI_FALSE;
#else #else
@@ -243,8 +225,8 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative(
jobject curr_j_schema_node = env->GetObjectArrayElement(schema, i); jobject curr_j_schema_node = env->GetObjectArrayElement(schema, i);
text::SchemaNode curr_c_schema_node; text::SchemaNode curr_c_schema_node;
if (fni::AllocateUIECxxSchemaNodeFromJava( if (fni::AllocateUIECxxSchemaNodeFromJava(
env, curr_j_schema_node, reinterpret_cast<void *>( env, curr_j_schema_node,
&curr_c_schema_node))) { reinterpret_cast<void*>(&curr_c_schema_node))) {
c_schema.push_back(curr_c_schema_node); c_schema.push_back(curr_c_schema_node);
} }
env->DeleteLocalRef(curr_j_schema_node); env->DeleteLocalRef(curr_j_schema_node);
@@ -264,4 +246,3 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative(
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@@ -22,7 +22,7 @@ public class UIEModel {
String vocabFile, String vocabFile,
String[] schema) { String[] schema) {
init_(modelFile, paramsFile, vocabFile, 0.5f, 128, init_(modelFile, paramsFile, vocabFile, 0.5f, 128,
schema, new RuntimeOption(), SchemaLanguage.ZH); schema, 64, new RuntimeOption(), SchemaLanguage.ZH);
} }
// Constructor with custom runtime option // Constructor with custom runtime option
@@ -32,10 +32,11 @@ public class UIEModel {
float positionProb, float positionProb,
int maxLength, int maxLength,
String[] schema, String[] schema,
int batchSize,
RuntimeOption runtimeOption, RuntimeOption runtimeOption,
SchemaLanguage schemaLanguage) { SchemaLanguage schemaLanguage) {
init_(modelFile, paramsFile, vocabFile, positionProb, maxLength, init_(modelFile, paramsFile, vocabFile, positionProb, maxLength,
schema, runtimeOption, schemaLanguage); schema, batchSize, runtimeOption, schemaLanguage);
} }
// Call init manually with label file // Call init manually with label file
@@ -44,7 +45,7 @@ public class UIEModel {
String vocabFile, String vocabFile,
String[] schema) { String[] schema) {
return init_(modelFile, paramsFile, vocabFile, 0.5f, 128, return init_(modelFile, paramsFile, vocabFile, 0.5f, 128,
schema, new RuntimeOption(), SchemaLanguage.ZH); schema, 64, new RuntimeOption(), SchemaLanguage.ZH);
} }
public boolean init(String modelFile, public boolean init(String modelFile,
@@ -53,10 +54,11 @@ public class UIEModel {
float positionProb, float positionProb,
int maxLength, int maxLength,
String[] schema, String[] schema,
int batchSize,
RuntimeOption runtimeOption, RuntimeOption runtimeOption,
SchemaLanguage schemaLanguage) { SchemaLanguage schemaLanguage) {
return init_(modelFile, paramsFile, vocabFile, positionProb, maxLength, return init_(modelFile, paramsFile, vocabFile, positionProb, maxLength,
schema, runtimeOption, schemaLanguage); schema, batchSize, runtimeOption, schemaLanguage);
} }
public boolean release() { public boolean release() {
@@ -103,6 +105,7 @@ public class UIEModel {
float positionProb, float positionProb,
int maxLength, int maxLength,
String[] schema, String[] schema,
int batchSize,
RuntimeOption runtimeOption, RuntimeOption runtimeOption,
SchemaLanguage schemaLanguage) { SchemaLanguage schemaLanguage) {
if (!mInitialized) { if (!mInitialized) {
@@ -113,6 +116,7 @@ public class UIEModel {
positionProb, positionProb,
maxLength, maxLength,
schema, schema,
batchSize,
runtimeOption, runtimeOption,
schemaLanguage.ordinal() schemaLanguage.ordinal()
); );
@@ -130,6 +134,7 @@ public class UIEModel {
positionProb, positionProb,
maxLength, maxLength,
schema, schema,
batchSize,
runtimeOption, runtimeOption,
schemaLanguage.ordinal() schemaLanguage.ordinal()
); );
@@ -149,6 +154,7 @@ public class UIEModel {
float positionProb, float positionProb,
int maxLength, int maxLength,
String[] schema, String[] schema,
int batchSize,
RuntimeOption runtimeOption, RuntimeOption runtimeOption,
int schemaLanguage); int schemaLanguage);