diff --git a/csrc/fastdeploy/vision/common/result.cc b/csrc/fastdeploy/vision/common/result.cc index 06a4d463f..5175eda58 100644 --- a/csrc/fastdeploy/vision/common/result.cc +++ b/csrc/fastdeploy/vision/common/result.cc @@ -319,11 +319,11 @@ std::string OCRResult::Str() { out = out + "]"; if (rec_scores.size() > 0) { - out = out + "rec text: " + text[n] + " rec scores:" + + out = out + "rec text: " + text[n] + " rec score:" + std::to_string(rec_scores[n]) + " "; } - if (cls_label.size() > 0) { - out = out + "cls label: " + std::to_string(cls_label[n]) + + if (cls_labels.size() > 0) { + out = out + "cls label: " + std::to_string(cls_labels[n]) + " cls score: " + std::to_string(cls_scores[n]); } out = out + "\n"; @@ -334,9 +334,9 @@ std::string OCRResult::Str() { cls_scores.size() > 0) { std::string out; for (int i = 0; i < rec_scores.size(); i++) { - out = out + "rec text: " + text[i] + " rec scores:" + + out = out + "rec text: " + text[i] + " rec score:" + std::to_string(rec_scores[i]) + " "; - out = out + "cls label: " + std::to_string(cls_label[i]) + + out = out + "cls label: " + std::to_string(cls_labels[i]) + " cls score: " + std::to_string(cls_scores[i]); out = out + "\n"; } @@ -345,7 +345,7 @@ std::string OCRResult::Str() { cls_scores.size() > 0) { std::string out; for (int i = 0; i < cls_scores.size(); i++) { - out = out + "cls label: " + std::to_string(cls_label[i]) + + out = out + "cls label: " + std::to_string(cls_labels[i]) + " cls score: " + std::to_string(cls_scores[i]); out = out + "\n"; } @@ -354,7 +354,7 @@ std::string OCRResult::Str() { cls_scores.size() == 0) { std::string out; for (int i = 0; i < rec_scores.size(); i++) { - out = out + "rec text: " + text[i] + " rec scores:" + + out = out + "rec text: " + text[i] + " rec score:" + std::to_string(rec_scores[i]) + " "; out = out + "\n"; } diff --git a/csrc/fastdeploy/vision/common/result.h b/csrc/fastdeploy/vision/common/result.h index a59abb4de..a12735ce5 100644 --- a/csrc/fastdeploy/vision/common/result.h +++ b/csrc/fastdeploy/vision/common/result.h @@ -67,7 +67,7 @@ struct FASTDEPLOY_DECL OCRResult : public BaseResult { std::vector rec_scores; std::vector cls_scores; - std::vector cls_label; + std::vector cls_labels; ResultType type = ResultType::OCR; diff --git a/csrc/fastdeploy/vision/ocr/ppocr/classifier.cc b/csrc/fastdeploy/vision/ocr/ppocr/classifier.cc index 155a99a38..3cce19983 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/classifier.cc +++ b/csrc/fastdeploy/vision/ocr/ppocr/classifier.cc @@ -99,8 +99,8 @@ bool Classifier::Preprocess(Mat* mat, FDTensor* output) { } //后处理 -bool Classifier::Postprocess(FDTensor& infer_result, int& cls_labels, - float& cls_scores) { +bool Classifier::Postprocess(FDTensor& infer_result, + std::tuple* cls_result) { std::vector output_shape = infer_result.shape; FDASSERT(output_shape[0] == 1, "Only support batch =1 now."); @@ -112,14 +112,14 @@ bool Classifier::Postprocess(FDTensor& infer_result, int& cls_labels, float score = float(*std::max_element(&out_data[0], &out_data[output_shape[1]])); - cls_labels = label; - cls_scores = score; + std::get<0>(*cls_result) = label; + std::get<1>(*cls_result) = score; return true; } //预测 -bool Classifier::Predict(cv::Mat* img, int& cls_labels, float& cls_socres) { +bool Classifier::Predict(cv::Mat* img, std::tuple* cls_result) { Mat mat(*img); std::vector input_tensors(1); @@ -135,7 +135,7 @@ bool Classifier::Predict(cv::Mat* img, int& cls_labels, float& cls_socres) { return false; } - if (!Postprocess(output_tensors[0], cls_labels, cls_socres)) { + if (!Postprocess(output_tensors[0], cls_result)) { FDERROR << "Failed to post process." << std::endl; return false; } diff --git a/csrc/fastdeploy/vision/ocr/ppocr/classifier.h b/csrc/fastdeploy/vision/ocr/ppocr/classifier.h index 3ebc723a9..39fc102c3 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/classifier.h +++ b/csrc/fastdeploy/vision/ocr/ppocr/classifier.h @@ -35,7 +35,7 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel { std::string ModelName() const { return "ppocr/ocr_cls"; } // 模型预测接口,即用户调用的接口 - virtual bool Predict(cv::Mat* img, int& cls_labels, float& cls_socres); + virtual bool Predict(cv::Mat* img, std::tuple* result); // pre & post parameters float cls_thresh; @@ -56,7 +56,7 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel { // 后端推理结果后处理,输出给用户 // infer_result 为后端推理后的输出Tensor - bool Postprocess(FDTensor& infer_result, int& cls_labels, float& cls_scores); + bool Postprocess(FDTensor& infer_result, std::tuple* result); }; } // namespace ocr diff --git a/csrc/fastdeploy/vision/ocr/ppocr/dbdetector.cc b/csrc/fastdeploy/vision/ocr/ppocr/dbdetector.cc index 9f0822e9b..10d4951b3 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/dbdetector.cc +++ b/csrc/fastdeploy/vision/ocr/ppocr/dbdetector.cc @@ -63,8 +63,8 @@ bool DBDetector::Initialize() { return true; } -void OcrDetectorResizeImage(Mat* img, int max_size_len, float& ratio_h, - float& ratio_w) { +void OcrDetectorResizeImage(Mat* img, int max_size_len, float* ratio_h, + float* ratio_w) { int w = img->Width(); int h = img->Height(); @@ -86,8 +86,8 @@ void OcrDetectorResizeImage(Mat* img, int max_size_len, float& ratio_h, Resize::Run(img, resize_w, resize_h); - ratio_h = float(resize_h) / float(h); - ratio_w = float(resize_w) / float(w); + *ratio_h = float(resize_h) / float(h); + *ratio_w = float(resize_w) / float(w); } //预处理 @@ -95,7 +95,7 @@ bool DBDetector::Preprocess( Mat* mat, FDTensor* output, std::map>* im_info) { // Resize - OcrDetectorResizeImage(mat, max_side_len, ratio_h, ratio_w); + OcrDetectorResizeImage(mat, max_side_len, &ratio_h, &ratio_w); // Normalize Normalize::Run(mat, mean, scale, true); @@ -112,7 +112,7 @@ bool DBDetector::Preprocess( //后处理 bool DBDetector::Postprocess( - FDTensor& infer_result, std::vector>>* boxes, + FDTensor& infer_result, std::vector>* boxes_result, const std::map>& im_info) { std::vector output_shape = infer_result.shape; FDASSERT(output_shape[0] == 1, "Only support batch =1 now."); @@ -142,17 +142,33 @@ bool DBDetector::Postprocess( cv::dilate(bit_map, bit_map, dila_ele); } - post_processor_.BoxesFromBitmap(pred_map, boxes, bit_map, det_db_box_thresh, + // boxes_result 的value,传给boxes + + std::vector>> boxes; + + post_processor_.BoxesFromBitmap(pred_map, &boxes, bit_map, det_db_box_thresh, det_db_unclip_ratio, det_db_score_mode); - post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, im_info); + post_processor_.FilterTagDetRes(&boxes, ratio_h, ratio_w, im_info); + + // boxes to boxes_result + for (int i = 0; i < boxes.size(); i++) { + std::array new_box; + int k = 0; + for (auto& vec : boxes[i]) { + for (auto& e : vec) { + new_box[k++] = e; + } + } + boxes_result->push_back(new_box); + } return true; } //预测 -bool DBDetector::Predict( - cv::Mat* img, std::vector>>* boxes_result) { +bool DBDetector::Predict(cv::Mat* img, + std::vector>* boxes_result) { Mat mat(*img); std::vector input_tensors(1); diff --git a/csrc/fastdeploy/vision/ocr/ppocr/dbdetector.h b/csrc/fastdeploy/vision/ocr/ppocr/dbdetector.h index b845bf302..38350f7b7 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/dbdetector.h +++ b/csrc/fastdeploy/vision/ocr/ppocr/dbdetector.h @@ -35,7 +35,7 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel { // 模型预测接口,即用户调用的接口 virtual bool Predict(cv::Mat* im, - std::vector>>* boxes); + std::vector>* boxes_result); // pre&post process parameters int max_side_len; @@ -64,7 +64,7 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel { // 后端推理结果后处理,输出给用户 bool Postprocess(FDTensor& infer_result, - std::vector>>* boxes, + std::vector>* boxes_result, const std::map>& im_info); // OCR后处理类 diff --git a/csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v2.cc b/csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v2.cc index b14b99623..a6907525d 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v2.cc +++ b/csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v2.cc @@ -26,45 +26,31 @@ PPOCRSystemv2::PPOCRSystemv2(fastdeploy::vision::ocr::DBDetector* ocr_det, void PPOCRSystemv2::Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result) { - std::vector>> boxes; + std::vector> boxes; this->detector->Predict(img, &boxes); - // vector转array - for (int i = 0; i < boxes.size(); i++) { - std::array new_box; - int k = 0; - for (auto& vec : boxes[i]) { - for (auto& e : vec) { - new_box[k++] = e; - } - } - (result->boxes).push_back(new_box); - } + result->boxes = boxes; } void PPOCRSystemv2::Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result) { - std::string rec_texts = ""; - float rec_text_scores = 0; + std::tuple rec_result; - this->recognizer->rec_image_shape[1] = - 32; // OCRv2模型此处需要设置为32,其他与OCRv3一致 - this->recognizer->Predict(img, rec_texts, rec_text_scores); + this->recognizer->Predict(img, &rec_result); - result->text.push_back(rec_texts); - result->rec_scores.push_back(rec_text_scores); + result->text.push_back(std::get<0>(rec_result)); + result->rec_scores.push_back(std::get<1>(rec_result)); } void PPOCRSystemv2::Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result) { - int cls_label = 0; - float cls_scores = 0; + std::tuple cls_result; - this->classifier->Predict(img, cls_label, cls_scores); + this->classifier->Predict(img, &cls_result); - result->cls_label.push_back(cls_label); - result->cls_scores.push_back(cls_scores); + result->cls_labels.push_back(std::get<0>(cls_result)); + result->cls_scores.push_back(std::get<1>(cls_result)); } bool PPOCRSystemv2::Predict(cv::Mat* img, @@ -74,7 +60,7 @@ bool PPOCRSystemv2::Predict(cv::Mat* img, if (this->classifier->initialized != 0) { this->Classify(img, result); //摆正单张图像 - if ((result->cls_label)[0] % 2 == 1 && + if ((result->cls_labels)[0] % 2 == 1 && (result->cls_scores)[0] > this->classifier->cls_thresh) { cv::rotate(*img, *img, 1); } @@ -88,7 +74,6 @@ bool PPOCRSystemv2::Predict(cv::Mat* img, //从DET模型开始 //一张图,会输出多个“小图片”,送给后续模型 this->Detect(img, result); - std::cout << "Finish Det Prediction!" << std::endl; // crop image std::vector img_list; @@ -105,20 +90,18 @@ bool PPOCRSystemv2::Predict(cv::Mat* img, } for (int i = 0; i < img_list.size(); i++) { - if ((result->cls_label)[i] % 2 == 1 && + if ((result->cls_labels)[i] % 2 == 1 && (result->cls_scores)[i] > this->classifier->cls_thresh) { std::cout << "Rotate this image " << std::endl; cv::rotate(img_list[i], img_list[i], 1); } } - std::cout << "Finish Cls Prediction!" << std::endl; } // rec if (this->recognizer->initialized != 0) { for (int i = 0; i < img_list.size(); i++) { this->Recognize(&img_list[i], result); } - std::cout << "Finish Rec Prediction!" << std::endl; } } diff --git a/csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v3.cc b/csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v3.cc index 094136224..5a073d2ed 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v3.cc +++ b/csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v3.cc @@ -26,43 +26,31 @@ PPOCRSystemv3::PPOCRSystemv3(fastdeploy::vision::ocr::DBDetector* ocr_det, void PPOCRSystemv3::Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result) { - std::vector>> boxes; + std::vector> boxes_result; - this->detector->Predict(img, &boxes); + this->detector->Predict(img, &boxes_result); - // vector转array - for (int i = 0; i < boxes.size(); i++) { - std::array new_box; - int k = 0; - for (auto& vec : boxes[i]) { - for (auto& e : vec) { - new_box[k++] = e; - } - } - (result->boxes).push_back(new_box); - } + result->boxes = boxes_result; } void PPOCRSystemv3::Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result) { - std::string rec_texts = ""; - float rec_text_scores = 0; + std::tuple rec_result; - this->recognizer->Predict(img, rec_texts, rec_text_scores); + this->recognizer->Predict(img, &rec_result); - result->text.push_back(rec_texts); - result->rec_scores.push_back(rec_text_scores); + result->text.push_back(std::get<0>(rec_result)); + result->rec_scores.push_back(std::get<1>(rec_result)); } void PPOCRSystemv3::Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result) { - int cls_label = 0; - float cls_scores = 0; + std::tuple cls_result; - this->classifier->Predict(img, cls_label, cls_scores); + this->classifier->Predict(img, &cls_result); - result->cls_label.push_back(cls_label); - result->cls_scores.push_back(cls_scores); + result->cls_labels.push_back(std::get<0>(cls_result)); + result->cls_scores.push_back(std::get<1>(cls_result)); } bool PPOCRSystemv3::Predict(cv::Mat* img, @@ -72,7 +60,7 @@ bool PPOCRSystemv3::Predict(cv::Mat* img, if (this->classifier->initialized != 0) { this->Classify(img, result); //摆正单张图像 - if ((result->cls_label)[0] % 2 == 1 && + if ((result->cls_labels)[0] % 2 == 1 && (result->cls_scores)[0] > this->classifier->cls_thresh) { cv::rotate(*img, *img, 1); } @@ -86,7 +74,6 @@ bool PPOCRSystemv3::Predict(cv::Mat* img, //从DET模型开始 //一张图,会输出多个“小图片”,送给后续模型 this->Detect(img, result); - std::cout << "Finish Det Prediction!" << std::endl; // crop image std::vector img_list; @@ -103,20 +90,18 @@ bool PPOCRSystemv3::Predict(cv::Mat* img, } for (int i = 0; i < img_list.size(); i++) { - if ((result->cls_label)[i] % 2 == 1 && + if ((result->cls_labels)[i] % 2 == 1 && (result->cls_scores)[i] > this->classifier->cls_thresh) { std::cout << "Rotate this image " << std::endl; cv::rotate(img_list[i], img_list[i], 1); } } - std::cout << "Finish Cls Prediction!" << std::endl; } // rec if (this->recognizer->initialized != 0) { for (int i = 0; i < img_list.size(); i++) { this->Recognize(&img_list[i], result); } - std::cout << "Finish Rec Prediction!" << std::endl; } } diff --git a/csrc/fastdeploy/vision/ocr/ppocr/recognizer.cc b/csrc/fastdeploy/vision/ocr/ppocr/recognizer.cc index dd3e4bc90..22de29485 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/recognizer.cc +++ b/csrc/fastdeploy/vision/ocr/ppocr/recognizer.cc @@ -89,7 +89,7 @@ bool Recognizer::Initialize() { return true; } -void OcrRecognizerResizeImage(Mat* mat, float wh_ratio, +void OcrRecognizerResizeImage(Mat* mat, const float& wh_ratio, const std::vector& rec_image_shape) { int imgC, imgH, imgW; imgC = rec_image_shape[0]; @@ -135,8 +135,8 @@ bool Recognizer::Preprocess(Mat* mat, FDTensor* output, } //后处理 -bool Recognizer::Postprocess(FDTensor& infer_result, std::string& rec_texts, - float& rec_text_scores) { +bool Recognizer::Postprocess(FDTensor& infer_result, + std::tuple* rec_result) { std::vector output_shape = infer_result.shape; FDASSERT(output_shape[0] == 1, "Only support batch =1 now."); @@ -168,15 +168,15 @@ bool Recognizer::Postprocess(FDTensor& infer_result, std::string& rec_texts, score /= count; - rec_texts = str_res; - rec_text_scores = score; + std::get<0>(*rec_result) = str_res; + std::get<1>(*rec_result) = score; return true; } //预测 -bool Recognizer::Predict(cv::Mat* img, std::string& rec_texts, - float& rec_text_scores) { +bool Recognizer::Predict(cv::Mat* img, + std::tuple* rec_result) { Mat mat(*img); std::vector input_tensors(1); @@ -194,7 +194,7 @@ bool Recognizer::Predict(cv::Mat* img, std::string& rec_texts, return false; } - if (!Postprocess(output_tensors[0], rec_texts, rec_text_scores)) { + if (!Postprocess(output_tensors[0], rec_result)) { FDERROR << "Failed to post process." << std::endl; return false; } diff --git a/csrc/fastdeploy/vision/ocr/ppocr/recognizer.h b/csrc/fastdeploy/vision/ocr/ppocr/recognizer.h index 4829f66c9..2c6343f7f 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/recognizer.h +++ b/csrc/fastdeploy/vision/ocr/ppocr/recognizer.h @@ -36,8 +36,8 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel { std::string ModelName() const { return "ppocr/ocr_rec"; } // 模型预测接口,即用户调用的接口 - virtual bool Predict(cv::Mat* img, std::string& rec_texts, - float& rec_text_scores); + virtual bool Predict(cv::Mat* img, + std::tuple* rec_result); // pre & post parameters std::vector label_list; @@ -60,8 +60,8 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel { // 后端推理结果后处理,输出给用户 // infer_result 为后端推理后的输出Tensor - bool Postprocess(FDTensor& infer_result, std::string& rec_texts, - float& rec_text_scores); + bool Postprocess(FDTensor& infer_result, + std::tuple* rec_result); }; } // namespace ocr diff --git a/csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc b/csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc index beacdf4f7..d235cef67 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc +++ b/csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc @@ -326,8 +326,9 @@ void PostProcessor::BoxesFromBitmap( //方法根据识别结果获取目标框位置 void PostProcessor::FilterTagDetRes( - std::vector>> *boxes, float ratio_h, - float ratio_w, const std::map> &im_info) { + std::vector>> *boxes, const float ratio_h, + const float ratio_w, + const std::map> &im_info) { int oriimg_h = im_info.at("input_shape")[0]; int oriimg_w = im_info.at("input_shape")[1]; @@ -350,10 +351,6 @@ void PostProcessor::FilterTagDetRes( rect_height = int(sqrt(pow((*boxes)[n][0][0] - (*boxes)[n][3][0], 2) + pow((*boxes)[n][0][1] - (*boxes)[n][3][1], 2))); - //原始实现,小于4的跳过,只return大于4的 - // if (rect_width <= 4 || rect_height <= 4) continue; - // root_points.push_back((*boxes)[n]); - //小于4的删除掉. erase配合逆序遍历. if (rect_width <= 4 || rect_height <= 4) { boxes->erase(boxes->begin() + n); diff --git a/csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h b/csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h index 4883ed4c2..034ac3ec4 100644 --- a/csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h +++ b/csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h @@ -59,8 +59,8 @@ class PostProcessor { const std::string &det_db_score_mode); void FilterTagDetRes( - std::vector>> *boxes, float ratio_h, - float ratio_w, + std::vector>> *boxes, const float ratio_h, + const float ratio_w, const std::map> &im_info); private: diff --git a/csrc/fastdeploy/vision/vision_pybind.cc b/csrc/fastdeploy/vision/vision_pybind.cc index 76f956225..5d137ff07 100644 --- a/csrc/fastdeploy/vision/vision_pybind.cc +++ b/csrc/fastdeploy/vision/vision_pybind.cc @@ -48,8 +48,8 @@ void BindVision(pybind11::module& m) { .def_readwrite("boxes", &vision::OCRResult::boxes) .def_readwrite("text", &vision::OCRResult::text) .def_readwrite("score", &vision::OCRResult::rec_scores) - .def_readwrite("cls_score", &vision::OCRResult::cls_scores) - .def_readwrite("cls_label", &vision::OCRResult::cls_label) + .def_readwrite("cls_scores", &vision::OCRResult::cls_scores) + .def_readwrite("cls_labels", &vision::OCRResult::cls_labels) .def("__repr__", &vision::OCRResult::Str) .def("__str__", &vision::OCRResult::Str); pybind11::class_(m, "FaceDetectionResult") diff --git a/examples/vision/ocr/PPOCRSystemv2/python/infer.py b/examples/vision/ocr/PPOCRSystemv2/python/infer.py index a193208b2..b896c5413 100644 --- a/examples/vision/ocr/PPOCRSystemv2/python/infer.py +++ b/examples/vision/ocr/PPOCRSystemv2/python/infer.py @@ -105,7 +105,7 @@ rec_params_file = os.path.join(args.rec_model, "inference.pdiparams") rec_label_file = args.rec_label_file #默认 -det_model = fd.vision.ocr.DBDetector("") +det_model = fd.vision.ocr.DBDetector() cls_model = fd.vision.ocr.Classifier() rec_model = fd.vision.ocr.Recognizer() diff --git a/examples/vision/ocr/PPOCRSystemv3/python/infer.py b/examples/vision/ocr/PPOCRSystemv3/python/infer.py index 09998cbcb..3703afc9b 100644 --- a/examples/vision/ocr/PPOCRSystemv3/python/infer.py +++ b/examples/vision/ocr/PPOCRSystemv3/python/infer.py @@ -104,7 +104,7 @@ rec_params_file = os.path.join(args.rec_model, "inference.pdiparams") rec_label_file = args.rec_label_file #默认 -det_model = fd.vision.ocr.DBDetector("") +det_model = fd.vision.ocr.DBDetector() cls_model = fd.vision.ocr.Classifier() rec_model = fd.vision.ocr.Recognizer() @@ -137,6 +137,7 @@ im = cv2.imread(args.image) #预测并打印结果 result = ppocrsysv3.predict(im) + print(result) # 可视化结果