Add batch size for android example

This commit is contained in:
zhoushunjie
2022-12-28 02:25:57 +00:00
parent cfac517ef3
commit 01bf63e8a7
2 changed files with 65 additions and 78 deletions

View File

@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <jni.h> // NOLINT
#include "fastdeploy_jni/perf_jni.h" // NOLINT
#include "fastdeploy_jni/convert_jni.h" // NOLINT
#include "fastdeploy_jni/runtime_option_jni.h" // NOLINT
#include "fastdeploy_jni/text/text_results_jni.h" // NOLINT
#include "fastdeploy_jni/text/uie/uie_utils_jni.h" // NOLINT
#include "fastdeploy_jni/convert_jni.h" // NOLINT
#include "fastdeploy_jni/perf_jni.h" // NOLINT
#include "fastdeploy_jni/runtime_option_jni.h" // NOLINT
#include "fastdeploy_jni/text/text_results_jni.h" // NOLINT
#include "fastdeploy_jni/text/uie/uie_utils_jni.h" // NOLINT
#include <jni.h> // NOLINT
#ifdef ENABLE_TEXT
#include "fastdeploy/text.h" // NOLINT
#endif
@@ -32,16 +32,11 @@ extern "C" {
#endif
JNIEXPORT jlong JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(JNIEnv *env,
jobject thiz,
jstring model_file,
jstring params_file,
jstring vocab_file,
jfloat position_prob,
jint max_length,
jobjectArray schema,
jobject runtime_option,
jint schema_language) {
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(
JNIEnv* env, jobject thiz, jstring model_file, jstring params_file,
jstring vocab_file, jfloat position_prob, jint max_length,
jobjectArray schema, jint batch_size, jobject runtime_option,
jint schema_language) {
#ifndef ENABLE_TEXT
return 0;
#else
@@ -51,18 +46,14 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(JNIEnv *env,
auto c_position_prob = static_cast<jfloat>(position_prob);
auto c_max_length = static_cast<size_t>(max_length);
auto c_schema = fni::ConvertTo<std::vector<std::string>>(env, schema);
auto c_batch_size = static_cast<int>(batch_size);
auto c_runtime_option = fni::NewCxxRuntimeOption(env, runtime_option);
auto c_schema_language = static_cast<text::SchemaLanguage>(schema_language);
auto c_paddle_model_format = fastdeploy::ModelFormat::PADDLE;
auto c_model_ptr = new text::UIEModel(c_model_file,
c_params_file,
c_vocab_file,
c_position_prob,
c_max_length,
c_schema,
c_runtime_option,
c_paddle_model_format,
c_schema_language);
auto c_model_ptr = new text::UIEModel(
c_model_file, c_params_file, c_vocab_file, c_position_prob, c_max_length,
c_schema, c_batch_size, c_runtime_option, c_paddle_model_format,
c_schema_language);
INITIALIZED_OR_RETURN(c_model_ptr)
#ifdef ENABLE_RUNTIME_PERF
@@ -73,17 +64,15 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_bindNative(JNIEnv *env,
}
JNIEXPORT jobjectArray JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env,
jobject thiz,
jlong cxx_context,
jobjectArray texts) {
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(
JNIEnv* env, jobject thiz, jlong cxx_context, jobjectArray texts) {
#ifndef ENABLE_TEXT
return NULL;
#else
if (cxx_context == 0) {
return NULL;
}
auto c_model_ptr = reinterpret_cast<text::UIEModel *>(cxx_context);
auto c_model_ptr = reinterpret_cast<text::UIEModel*>(cxx_context);
auto c_texts = fni::ConvertTo<std::vector<std::string>>(env, texts);
if (c_texts.empty()) {
LOGE("c_texts is empty!");
@@ -91,8 +80,8 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env,
}
LOGD("c_texts: %s", fni::UIETextsStr(c_texts).c_str());
std::vector<std::unordered_map<
std::string, std::vector<text::UIEResult>>> c_results;
std::vector<std::unordered_map<std::string, std::vector<text::UIEResult>>>
c_results;
auto t = fni::GetCurrentTime();
c_model_ptr->Predict(c_texts, &c_results);
@@ -107,50 +96,46 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env,
// Push results to HashMap array
const char* j_hashmap_put_signature =
"(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;";
const jclass j_hashmap_clazz = env->FindClass(
"java/util/HashMap");
const jclass j_uie_result_clazz = env->FindClass(
"com/baidu/paddle/fastdeploy/text/UIEResult");
const jclass j_hashmap_clazz = env->FindClass("java/util/HashMap");
const jclass j_uie_result_clazz =
env->FindClass("com/baidu/paddle/fastdeploy/text/UIEResult");
// Get HashMap method id
const jmethodID j_hashmap_init = env->GetMethodID(
j_hashmap_clazz, "<init>", "()V");
const jmethodID j_hashmap_put = env->GetMethodID(
j_hashmap_clazz,"put", j_hashmap_put_signature);
const jmethodID j_hashmap_init =
env->GetMethodID(j_hashmap_clazz, "<init>", "()V");
const jmethodID j_hashmap_put =
env->GetMethodID(j_hashmap_clazz, "put", j_hashmap_put_signature);
const int c_uie_result_hashmap_size = c_results.size();
jobjectArray j_hashmap_uie_result_arr = env->NewObjectArray(
c_uie_result_hashmap_size, j_hashmap_clazz, NULL);
jobjectArray j_hashmap_uie_result_arr =
env->NewObjectArray(c_uie_result_hashmap_size, j_hashmap_clazz, NULL);
for (int i = 0; i < c_uie_result_hashmap_size; ++i) {
auto& curr_c_uie_result_map = c_results[i];
// Convert unordered_map<string, vector<UIEResult>>
// -> HashMap<String, UIEResult[]>
jobject curr_j_uie_result_hashmap = env->NewObject(
j_hashmap_clazz, j_hashmap_init);
jobject curr_j_uie_result_hashmap =
env->NewObject(j_hashmap_clazz, j_hashmap_init);
for (auto&& curr_c_uie_result: curr_c_uie_result_map) {
for (auto&& curr_c_uie_result : curr_c_uie_result_map) {
const auto& curr_inner_c_uie_key = curr_c_uie_result.first;
jstring curr_inner_j_uie_key = fni::ConvertTo<jstring>(
env, curr_inner_c_uie_key); // Key of HashMap
jstring curr_inner_j_uie_key =
fni::ConvertTo<jstring>(env, curr_inner_c_uie_key); // Key of HashMap
if (curr_c_uie_result.second.size() > 0) {
// Value of HashMap: HashMap<String, UIEResult[]>
jobjectArray curr_inner_j_uie_result_values =
env->NewObjectArray(curr_c_uie_result.second.size(),
j_uie_result_clazz,
NULL);
jobjectArray curr_inner_j_uie_result_values = env->NewObjectArray(
curr_c_uie_result.second.size(), j_uie_result_clazz, NULL);
// Convert vector<UIEResult> -> Java UIEResult[]
for (int j = 0; j < curr_c_uie_result.second.size(); ++j) {
text::UIEResult* inner_c_uie_result = (
&(curr_c_uie_result.second[j]));
text::UIEResult* inner_c_uie_result =
(&(curr_c_uie_result.second[j]));
jobject curr_inner_j_uie_result_obj =
fni::NewUIEJavaResultFromCxx(
env, reinterpret_cast<void *>(inner_c_uie_result));
jobject curr_inner_j_uie_result_obj = fni::NewUIEJavaResultFromCxx(
env, reinterpret_cast<void*>(inner_c_uie_result));
env->SetObjectArrayElement(curr_inner_j_uie_result_values, j,
curr_inner_j_uie_result_obj);
@@ -159,14 +144,14 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env,
// Set element of 'curr_j_uie_result_hashmap':
// HashMap<String, UIEResult[]>
env->CallObjectMethod(
curr_j_uie_result_hashmap, j_hashmap_put,
curr_inner_j_uie_key, curr_inner_j_uie_result_values);
env->CallObjectMethod(curr_j_uie_result_hashmap, j_hashmap_put,
curr_inner_j_uie_key,
curr_inner_j_uie_result_values);
env->DeleteLocalRef(curr_inner_j_uie_key);
env->DeleteLocalRef(curr_inner_j_uie_result_values);
} // end if
} // end for
} // end if
} // end for
// Set current HashMap<String, UIEResult[]> to HashMap[i]
env->SetObjectArrayElement(j_hashmap_uie_result_arr, i,
@@ -179,16 +164,15 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_predictNative(JNIEnv *env,
}
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_releaseNative(JNIEnv *env,
jobject thiz,
jlong cxx_context) {
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_releaseNative(
JNIEnv* env, jobject thiz, jlong cxx_context) {
#ifndef ENABLE_TEXT
return JNI_FALSE;
#else
if (cxx_context == 0) {
return JNI_FALSE;
}
auto c_model_ptr = reinterpret_cast<text::UIEModel *>(cxx_context);
auto c_model_ptr = reinterpret_cast<text::UIEModel*>(cxx_context);
PERF_TIME_OF_RUNTIME(c_model_ptr, -1)
delete c_model_ptr;
@@ -199,15 +183,14 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_releaseNative(JNIEnv *env,
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaStringNative(
JNIEnv *env, jobject thiz, jlong cxx_context,
jobjectArray schema) {
JNIEnv* env, jobject thiz, jlong cxx_context, jobjectArray schema) {
#ifndef ENABLE_TEXT
return JNI_FALSE;
#else
if (cxx_context == 0) {
return JNI_FALSE;
}
auto c_model_ptr = reinterpret_cast<text::UIEModel *>(cxx_context);
auto c_model_ptr = reinterpret_cast<text::UIEModel*>(cxx_context);
auto c_schema = fni::ConvertTo<std::vector<std::string>>(env, schema);
if (c_schema.empty()) {
LOGE("c_schema is empty!");
@@ -221,8 +204,7 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaStringNative(
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative(
JNIEnv *env, jobject thiz, jlong cxx_context,
jobjectArray schema) {
JNIEnv* env, jobject thiz, jlong cxx_context, jobjectArray schema) {
#ifndef ENABLE_TEXT
return JNI_FALSE;
#else
@@ -236,15 +218,15 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative(
if (cxx_context == 0) {
return JNI_FALSE;
}
auto c_model_ptr = reinterpret_cast<text::UIEModel *>(cxx_context);
auto c_model_ptr = reinterpret_cast<text::UIEModel*>(cxx_context);
std::vector<text::SchemaNode> c_schema;
for (int i = 0; i < j_schema_size; ++i) {
jobject curr_j_schema_node = env->GetObjectArrayElement(schema, i);
text::SchemaNode curr_c_schema_node;
if (fni::AllocateUIECxxSchemaNodeFromJava(
env, curr_j_schema_node, reinterpret_cast<void *>(
&curr_c_schema_node))) {
env, curr_j_schema_node,
reinterpret_cast<void*>(&curr_c_schema_node))) {
c_schema.push_back(curr_c_schema_node);
}
env->DeleteLocalRef(curr_j_schema_node);
@@ -264,4 +246,3 @@ Java_com_baidu_paddle_fastdeploy_text_uie_UIEModel_setSchemaNodeNative(
#ifdef __cplusplus
}
#endif

View File

@@ -22,7 +22,7 @@ public class UIEModel {
String vocabFile,
String[] schema) {
init_(modelFile, paramsFile, vocabFile, 0.5f, 128,
schema, new RuntimeOption(), SchemaLanguage.ZH);
schema, 64, new RuntimeOption(), SchemaLanguage.ZH);
}
// Constructor with custom runtime option
@@ -32,10 +32,11 @@ public class UIEModel {
float positionProb,
int maxLength,
String[] schema,
int batchSize,
RuntimeOption runtimeOption,
SchemaLanguage schemaLanguage) {
init_(modelFile, paramsFile, vocabFile, positionProb, maxLength,
schema, runtimeOption, schemaLanguage);
schema, batchSize, runtimeOption, schemaLanguage);
}
// Call init manually with label file
@@ -44,7 +45,7 @@ public class UIEModel {
String vocabFile,
String[] schema) {
return init_(modelFile, paramsFile, vocabFile, 0.5f, 128,
schema, new RuntimeOption(), SchemaLanguage.ZH);
schema, 64, new RuntimeOption(), SchemaLanguage.ZH);
}
public boolean init(String modelFile,
@@ -53,10 +54,11 @@ public class UIEModel {
float positionProb,
int maxLength,
String[] schema,
int batchSize,
RuntimeOption runtimeOption,
SchemaLanguage schemaLanguage) {
return init_(modelFile, paramsFile, vocabFile, positionProb, maxLength,
schema, runtimeOption, schemaLanguage);
schema, batchSize, runtimeOption, schemaLanguage);
}
public boolean release() {
@@ -103,6 +105,7 @@ public class UIEModel {
float positionProb,
int maxLength,
String[] schema,
int batchSize,
RuntimeOption runtimeOption,
SchemaLanguage schemaLanguage) {
if (!mInitialized) {
@@ -113,6 +116,7 @@ public class UIEModel {
positionProb,
maxLength,
schema,
batchSize,
runtimeOption,
schemaLanguage.ordinal()
);
@@ -130,6 +134,7 @@ public class UIEModel {
positionProb,
maxLength,
schema,
batchSize,
runtimeOption,
schemaLanguage.ordinal()
);
@@ -149,6 +154,7 @@ public class UIEModel {
float positionProb,
int maxLength,
String[] schema,
int batchSize,
RuntimeOption runtimeOption,
int schemaLanguage);