diff --git a/java/android/fastdeploy/src/main/cpp/fastdeploy_jni/text/uie/uie_model_jni.cc b/java/android/fastdeploy/src/main/cpp/fastdeploy_jni/text/uie/uie_model_jni.cc index ce97d8dd7..d8b69da29 100644 --- a/java/android/fastdeploy/src/main/cpp/fastdeploy_jni/text/uie/uie_model_jni.cc +++ b/java/android/fastdeploy/src/main/cpp/fastdeploy_jni/text/uie/uie_model_jni.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include // NOLINT -#include "fastdeploy_jni/perf_jni.h" // NOLINT -#include "fastdeploy_jni/convert_jni.h" // NOLINT -#include "fastdeploy_jni/runtime_option_jni.h" // NOLINT -#include "fastdeploy_jni/text/text_results_jni.h" // NOLINT -#include "fastdeploy_jni/text/uie/uie_utils_jni.h" // NOLINT +#include "fastdeploy_jni/convert_jni.h" // NOLINT +#include "fastdeploy_jni/perf_jni.h" // NOLINT +#include "fastdeploy_jni/runtime_option_jni.h" // NOLINT +#include "fastdeploy_jni/text/text_results_jni.h" // NOLINT +#include "fastdeploy_jni/text/uie/uie_utils_jni.h" // NOLINT +#include // NOLINT #ifdef ENABLE_TEXT #include "fastdeploy/text.h" // NOLINT #endif @@ -32,16 +32,11 @@ extern "C" { #endif JNIEXPORT jlong JNICALL -Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(JNIEnv *env, - jobject thiz, - jstring model_file, - jstring params_file, - jstring vocab_file, - jfloat position_prob, - jint max_length, - jobjectArray schema, - jobject runtime_option, - jint schema_language) { +Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative( + JNIEnv* env, jobject thiz, jstring model_file, jstring params_file, + jstring vocab_file, jfloat position_prob, jint max_length, + jobjectArray schema, jint batch_size, jobject runtime_option, + jint schema_language) { #ifndef ENABLE_TEXT return 0; #else @@ -51,18 +46,14 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(JNIEnv *env, auto c_position_prob = static_cast(position_prob); auto c_max_length = static_cast(max_length); auto c_schema = fni::ConvertTo>(env, schema); + auto c_batch_size = static_cast(batch_size); auto c_runtime_option = fni::NewCxxRuntimeOption(env, runtime_option); auto c_schema_language = static_cast(schema_language); auto c_paddle_model_format = fastdeploy::ModelFormat::PADDLE; - auto c_model_ptr = new text::UIEModel(c_model_file, - c_params_file, - c_vocab_file, - c_position_prob, - c_max_length, - c_schema, - c_runtime_option, - c_paddle_model_format, - c_schema_language); + auto c_model_ptr = new text::UIEModel( + c_model_file, c_params_file, c_vocab_file, c_position_prob, c_max_length, + c_schema, c_batch_size, c_runtime_option, c_paddle_model_format, + c_schema_language); INITIALIZED_OR_RETURN(c_model_ptr) #ifdef ENABLE_RUNTIME_PERF @@ -73,17 +64,15 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(JNIEnv *env, } JNIEXPORT jobjectArray JNICALL -Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env, - jobject thiz, - jlong cxx_context, - jobjectArray texts) { +Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative( + JNIEnv* env, jobject thiz, jlong cxx_context, jobjectArray texts) { #ifndef ENABLE_TEXT return NULL; #else if (cxx_context == 0) { return NULL; } - auto c_model_ptr = reinterpret_cast(cxx_context); + auto c_model_ptr = reinterpret_cast(cxx_context); auto c_texts = fni::ConvertTo>(env, texts); if (c_texts.empty()) { LOGE("c_texts is empty!"); @@ -91,8 +80,8 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env, } LOGD("c_texts: %s", fni::UIETextsStr(c_texts).c_str()); - std::vector>> c_results; + std::vector>> + c_results; auto t = fni::GetCurrentTime(); c_model_ptr->Predict(c_texts, &c_results); @@ -107,50 +96,46 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env, // Push results to HashMap array const char* j_hashmap_put_signature = "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; - const jclass j_hashmap_clazz = env->FindClass( - "java/util/HashMap"); - const jclass j_uie_result_clazz = env->FindClass( - "com/baidu/paddle/fastdeploy/text/UIEResult"); + const jclass j_hashmap_clazz = env->FindClass("java/util/HashMap"); + const jclass j_uie_result_clazz = + env->FindClass("com/baidu/paddle/fastdeploy/text/UIEResult"); // Get HashMap method id - const jmethodID j_hashmap_init = env->GetMethodID( - j_hashmap_clazz, "", "()V"); - const jmethodID j_hashmap_put = env->GetMethodID( - j_hashmap_clazz,"put", j_hashmap_put_signature); + const jmethodID j_hashmap_init = + env->GetMethodID(j_hashmap_clazz, "", "()V"); + const jmethodID j_hashmap_put = + env->GetMethodID(j_hashmap_clazz, "put", j_hashmap_put_signature); const int c_uie_result_hashmap_size = c_results.size(); - jobjectArray j_hashmap_uie_result_arr = env->NewObjectArray( - c_uie_result_hashmap_size, j_hashmap_clazz, NULL); + jobjectArray j_hashmap_uie_result_arr = + env->NewObjectArray(c_uie_result_hashmap_size, j_hashmap_clazz, NULL); for (int i = 0; i < c_uie_result_hashmap_size; ++i) { auto& curr_c_uie_result_map = c_results[i]; // Convert unordered_map> // -> HashMap - jobject curr_j_uie_result_hashmap = env->NewObject( - j_hashmap_clazz, j_hashmap_init); + jobject curr_j_uie_result_hashmap = + env->NewObject(j_hashmap_clazz, j_hashmap_init); - for (auto&& curr_c_uie_result: curr_c_uie_result_map) { + for (auto&& curr_c_uie_result : curr_c_uie_result_map) { const auto& curr_inner_c_uie_key = curr_c_uie_result.first; - jstring curr_inner_j_uie_key = fni::ConvertTo( - env, curr_inner_c_uie_key); // Key of HashMap + jstring curr_inner_j_uie_key = + fni::ConvertTo(env, curr_inner_c_uie_key); // Key of HashMap if (curr_c_uie_result.second.size() > 0) { // Value of HashMap: HashMap - jobjectArray curr_inner_j_uie_result_values = - env->NewObjectArray(curr_c_uie_result.second.size(), - j_uie_result_clazz, - NULL); + jobjectArray curr_inner_j_uie_result_values = env->NewObjectArray( + curr_c_uie_result.second.size(), j_uie_result_clazz, NULL); // Convert vector -> Java UIEResult[] for (int j = 0; j < curr_c_uie_result.second.size(); ++j) { - text::UIEResult* inner_c_uie_result = ( - &(curr_c_uie_result.second[j])); + text::UIEResult* inner_c_uie_result = + (&(curr_c_uie_result.second[j])); - jobject curr_inner_j_uie_result_obj = - fni::NewUIEJavaResultFromCxx( - env, reinterpret_cast(inner_c_uie_result)); + jobject curr_inner_j_uie_result_obj = fni::NewUIEJavaResultFromCxx( + env, reinterpret_cast(inner_c_uie_result)); env->SetObjectArrayElement(curr_inner_j_uie_result_values, j, curr_inner_j_uie_result_obj); @@ -159,14 +144,14 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env, // Set element of 'curr_j_uie_result_hashmap': // HashMap - env->CallObjectMethod( - curr_j_uie_result_hashmap, j_hashmap_put, - curr_inner_j_uie_key, curr_inner_j_uie_result_values); + env->CallObjectMethod(curr_j_uie_result_hashmap, j_hashmap_put, + curr_inner_j_uie_key, + curr_inner_j_uie_result_values); env->DeleteLocalRef(curr_inner_j_uie_key); env->DeleteLocalRef(curr_inner_j_uie_result_values); - } // end if - } // end for + } // end if + } // end for // Set current HashMap to HashMap[i] env->SetObjectArrayElement(j_hashmap_uie_result_arr, i, @@ -179,16 +164,15 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env, } JNIEXPORT jboolean JNICALL -Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_releaseNative(JNIEnv *env, - jobject thiz, - jlong cxx_context) { +Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_releaseNative( + JNIEnv* env, jobject thiz, jlong cxx_context) { #ifndef ENABLE_TEXT return JNI_FALSE; #else if (cxx_context == 0) { return JNI_FALSE; } - auto c_model_ptr = reinterpret_cast(cxx_context); + auto c_model_ptr = reinterpret_cast(cxx_context); PERF_TIME_OF_RUNTIME(c_model_ptr, -1) delete c_model_ptr; @@ -199,15 +183,14 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_releaseNative(JNIEnv *env, JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaStringNative( - JNIEnv *env, jobject thiz, jlong cxx_context, - jobjectArray schema) { + JNIEnv* env, jobject thiz, jlong cxx_context, jobjectArray schema) { #ifndef ENABLE_TEXT return JNI_FALSE; #else if (cxx_context == 0) { return JNI_FALSE; } - auto c_model_ptr = reinterpret_cast(cxx_context); + auto c_model_ptr = reinterpret_cast(cxx_context); auto c_schema = fni::ConvertTo>(env, schema); if (c_schema.empty()) { LOGE("c_schema is empty!"); @@ -221,8 +204,7 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaStringNative( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative( - JNIEnv *env, jobject thiz, jlong cxx_context, - jobjectArray schema) { + JNIEnv* env, jobject thiz, jlong cxx_context, jobjectArray schema) { #ifndef ENABLE_TEXT return JNI_FALSE; #else @@ -236,15 +218,15 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative( if (cxx_context == 0) { return JNI_FALSE; } - auto c_model_ptr = reinterpret_cast(cxx_context); + auto c_model_ptr = reinterpret_cast(cxx_context); std::vector c_schema; for (int i = 0; i < j_schema_size; ++i) { jobject curr_j_schema_node = env->GetObjectArrayElement(schema, i); text::SchemaNode curr_c_schema_node; if (fni::AllocateUIECxxSchemaNodeFromJava( - env, curr_j_schema_node, reinterpret_cast( - &curr_c_schema_node))) { + env, curr_j_schema_node, + reinterpret_cast(&curr_c_schema_node))) { c_schema.push_back(curr_c_schema_node); } env->DeleteLocalRef(curr_j_schema_node); @@ -264,4 +246,3 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative( #ifdef __cplusplus } #endif - diff --git a/java/android/fastdeploy/src/main/java/com/baidu/paddle/fastdeploy/text/uie/UIEModel.java b/java/android/fastdeploy/src/main/java/com/baidu/paddle/fastdeploy/text/uie/UIEModel.java index 8857fe289..c6eb326e9 100644 --- a/java/android/fastdeploy/src/main/java/com/baidu/paddle/fastdeploy/text/uie/UIEModel.java +++ b/java/android/fastdeploy/src/main/java/com/baidu/paddle/fastdeploy/text/uie/UIEModel.java @@ -22,7 +22,7 @@ public class UIEModel { String vocabFile, String[] schema) { init_(modelFile, paramsFile, vocabFile, 0.5f, 128, - schema, new RuntimeOption(), SchemaLanguage.ZH); + schema, 64, new RuntimeOption(), SchemaLanguage.ZH); } // Constructor with custom runtime option @@ -32,10 +32,11 @@ public class UIEModel { float positionProb, int maxLength, String[] schema, + int batchSize, RuntimeOption runtimeOption, SchemaLanguage schemaLanguage) { init_(modelFile, paramsFile, vocabFile, positionProb, maxLength, - schema, runtimeOption, schemaLanguage); + schema, batchSize, runtimeOption, schemaLanguage); } // Call init manually with label file @@ -44,7 +45,7 @@ public class UIEModel { String vocabFile, String[] schema) { return init_(modelFile, paramsFile, vocabFile, 0.5f, 128, - schema, new RuntimeOption(), SchemaLanguage.ZH); + schema, 64, new RuntimeOption(), SchemaLanguage.ZH); } public boolean init(String modelFile, @@ -53,10 +54,11 @@ public class UIEModel { float positionProb, int maxLength, String[] schema, + int batchSize, RuntimeOption runtimeOption, SchemaLanguage schemaLanguage) { return init_(modelFile, paramsFile, vocabFile, positionProb, maxLength, - schema, runtimeOption, schemaLanguage); + schema, batchSize, runtimeOption, schemaLanguage); } public boolean release() { @@ -103,6 +105,7 @@ public class UIEModel { float positionProb, int maxLength, String[] schema, + int batchSize, RuntimeOption runtimeOption, SchemaLanguage schemaLanguage) { if (!mInitialized) { @@ -113,6 +116,7 @@ public class UIEModel { positionProb, maxLength, schema, + batchSize, runtimeOption, schemaLanguage.ordinal() ); @@ -130,6 +134,7 @@ public class UIEModel { positionProb, maxLength, schema, + batchSize, runtimeOption, schemaLanguage.ordinal() ); @@ -149,6 +154,7 @@ public class UIEModel { float positionProb, int maxLength, String[] schema, + int batchSize, RuntimeOption runtimeOption, int schemaLanguage);