diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 73b9d8cb2..2ec78cf3b 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -97,6 +97,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))), # Whether to use new get_output and save_output method (0 or 1) "FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))), + # Whether to enable model cache feature + "FD_ENABLE_MODEL_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_CACHE", "0"))), } diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 589b4b838..92832ee27 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -1053,7 +1053,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2] self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size] - if layer.fd_config.load_config.load_choices == "default_v1": + if self.quant_config.is_checkpoint_bf16: layer.up_gate_proj_weight = layer.create_parameter( shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2], dtype=layer.weight_dtype, @@ -1138,7 +1138,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): def process_weights_after_loading(self, layer): """ """ - if not layer.fd_config.load_config.load_choices == "default_v1": + if not self.quant_config.is_checkpoint_bf16: return weight_id_map = {"gate_up": 0, "down": 1} if ( diff --git a/fastdeploy/model_executor/layers/quantization/mix_quant.py b/fastdeploy/model_executor/layers/quantization/mix_quant.py index b36b71938..45f029b11 100644 --- a/fastdeploy/model_executor/layers/quantization/mix_quant.py +++ b/fastdeploy/model_executor/layers/quantization/mix_quant.py @@ -82,7 +82,7 @@ class MixQuantConfig(QuantConfigBase): .from_config( { "is_permuted": self.is_permuted, - "self.is_checkpoint_bf16": self.is_checkpoint_bf16, + "is_checkpoint_bf16": self.is_checkpoint_bf16, "hadamard_block_size": self.hadamard_block_size, } ) @@ -94,7 +94,7 @@ class MixQuantConfig(QuantConfigBase): .from_config( { "is_permuted": self.is_permuted, - "self.is_checkpoint_bf16": self.is_checkpoint_bf16, + "is_checkpoint_bf16": self.is_checkpoint_bf16, "hadamard_block_size": self.hadamard_block_size, } ) @@ -112,6 +112,6 @@ class MixQuantConfig(QuantConfigBase): else: return ( get_quantization_config(self.dense_quant_type) - .from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16}) + .from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16}) .get_quant_method(layer) ) diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 79c84d701..ac77f15f3 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -45,6 +45,7 @@ class WeightOnlyConfig(QuantConfigBase): def __init__( self, algo: str, + is_checkpoint_bf16: bool = False, ) -> None: super().__init__() self.algo = algo @@ -56,6 +57,7 @@ class WeightOnlyConfig(QuantConfigBase): self.quant_max_bound = 0 self.quant_min_bound = 0 self.quant_round_type = 0 + self.is_checkpoint_bf16 = is_checkpoint_bf16 def name(self) -> str: return "weight_only" @@ -63,7 +65,8 @@ class WeightOnlyConfig(QuantConfigBase): @classmethod def from_config(cls, config: dict) -> "WeightOnlyConfig": algo = config["algo"] - return cls(algo) + is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False) + return cls(algo, is_checkpoint_bf16) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: if current_platform.is_xpu(): @@ -154,12 +157,13 @@ class WINT8Config(WeightOnlyConfig): weight only int8 config """ - def __init__(self) -> None: - super().__init__("weight_only_int8") + def __init__(self, is_checkpoint_bf16: bool = False) -> None: + super().__init__("weight_only_int8", is_checkpoint_bf16) @classmethod def from_config(cls, config: dict) -> "WINT8Config": - return cls() + is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False) + return cls(is_checkpoint_bf16) def name(self) -> str: return "wint8" @@ -172,12 +176,14 @@ class WINT4Config(WeightOnlyConfig): def __init__( self, + is_checkpoint_bf16: bool = False, ) -> None: - super().__init__("weight_only_int4") + super().__init__("weight_only_int4", is_checkpoint_bf16) @classmethod def from_config(cls, config: dict) -> "WINT4Config": - return cls() + is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False) + return cls(is_checkpoint_bf16) def name(self) -> str: return "wint4" @@ -196,7 +202,7 @@ class WeightOnlyLinearMethod(QuantMethodBase): self.quant_config = quant_config def create_weights(self, layer, **extra_weight_attrs): - if layer.fd_config.load_config.load_choices == "default_v1": + if self.quant_config.is_checkpoint_bf16: layer.weight = layer.create_parameter( shape=layer.weight_shape, dtype=layer.weight_dtype, @@ -259,7 +265,7 @@ class WeightOnlyLinearMethod(QuantMethodBase): ) def process_weights_after_loading(self, layer) -> None: - if not layer.fd_config.load_config.load_choices == "default_v1": + if not self.quant_config.is_checkpoint_bf16: return quanted_weight_tensor, weight_scale_tensor = weight_quantize( layer.weight, diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index be0d76a33..f1f6ee289 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -14,9 +14,15 @@ # limitations under the License. """ +import contextlib +import hashlib +import inspect import json import os +import pickle import time +from functools import wraps +from pathlib import Path import paddle import paddle.distributed as dist @@ -28,22 +34,139 @@ from paddleformers.utils.safetensors import fast_safe_open from safetensors import safe_open from tqdm import tqdm +from fastdeploy import envs from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.linear import KVBatchLinear from fastdeploy.model_executor.models.tp_utils import ( check_tensor_parallel_prerequisites, ) +from fastdeploy.model_executor.utils import switch_config_context from fastdeploy.platforms import current_platform -def measure_time(func): - def wrapper(*args, **kwargs): - time_before_load = time.time() - result = func(*args, **kwargs) - time_after_load = time.time() - logger.info(f"Model loading took {time_after_load - time_before_load} seconds") - return result +def pdparams_weight_iterator(paddle_file_list: list[str]): + for pdparams_file in tqdm( + paddle_file_list, + desc="Loading pdparams checkpoint shards", + ): + state_dict = paddle.load(pdparams_file) + yield from state_dict.items() + del state_dict - return wrapper + +def load_weights_form_cache(model, weights_iterator): + params_dict = dict(model.named_parameters()) + for loaded_weight_name, loaded_weight in weights_iterator: + param = params_dict[loaded_weight_name] + param.copy_(loaded_weight, False) + if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False): + model.lm_head.load_state_dict({model.lm_head.weight_key: loaded_weight}) + for _, model_sublayer in model.named_sublayers(): + if isinstance(model_sublayer, KVBatchLinear): + model_sublayer.process_weights_after_loading() + + +def get_weight_iterator(model_path: str): + _, files_list, use_safetensors = get_all_weights_file(model_path) + if use_safetensors: + weights_iterator = fast_weights_iterator(files_list) + else: + weights_iterator = pdparams_weight_iterator(files_list) + return weights_iterator + + +def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"): + weight_cache_context = contextlib.nullcontext() + weight_cache_dir = None + enable_cache = False + if envs.FD_ENABLE_MODEL_CACHE: + model_weight_cache_path = os.path.join(fd_config.model_config.model, weight_cache_path) + # model_type + quantization + tp_size + ep_size + weight_cache_key = "_".join( + [ + fd_config.model_config.model_type, + fd_config.quant_config.name(), + str(fd_config.parallel_config.tensor_parallel_size), + str(fd_config.parallel_config.expert_parallel_size), + ] + ) + # only support tp now + hash_key = hashlib.md5(pickle.dumps(weight_cache_key)).hexdigest() + weight_cache_dir = os.path.join(model_weight_cache_path, hash_key) + if os.path.exists(weight_cache_dir): + logger.info( + f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it." + ) + enable_cache = True + weight_cache_context = switch_config_context(fd_config.quant_config, "is_checkpoint_bf16", False) + + return enable_cache, weight_cache_dir, weight_cache_context + + +def save_model(model_arg_name="model", config_arg_name="fd_config"): + @measure_time("Model saving") + def _save_model(model_dict, weight_cache_dir): + # Note: ProcessGroupNCCL do not support deepcopy protocol, we made modifications here. + paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self + paddle.distributed.communication.group.Group.to_json = lambda self: repr(self) + paddle.save(model_dict, weight_cache_dir) + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + fd_config = bound_args.arguments.get(config_arg_name, None) + model = bound_args.arguments.get(model_arg_name, None) + enable_cache, weight_cache_dir, _ = is_weight_cache_enabled(fd_config) + assert fd_config is not None, "fd_config cannot be None" + assert model is not None, "model cannot be None" + if enable_cache: + tp_weight_cache_dir = os.path.join( + weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}" + ) + context = switch_config_context(fd_config.model_config, "model", tp_weight_cache_dir) + else: + context = contextlib.nullcontext() + + with context: + result = func(*args, **kwargs) + if envs.FD_ENABLE_MODEL_CACHE and weight_cache_dir is not None and not os.path.exists(weight_cache_dir): + assert fd_config.quant_config is not None and getattr( + fd_config.quant_config, "is_checkpoint_bf16", False + ), "Save cache only for dynamic quantization" + tp_weight_cache_dir = os.path.join( + weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}" + ) + logger.info(f"Saving model to {tp_weight_cache_dir}") + os.makedirs( + tp_weight_cache_dir, + exist_ok=True, + ) + _save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams")) + else: + logger.info("Weights are already cached, skip saving") + return result + + return wrapper + + return decorator + + +def measure_time(prefix: str = "Model loading"): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + time_before = time.time() + result = func(*args, **kwargs) + time_after = time.time() + logger.info(f"{prefix} took {time_after - time_before:.3f} seconds") + return result + + return wrapper + + return decorator def load_reordered_experts(model_path: str, key_name: str): @@ -232,33 +355,38 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete """ state_dict = {} - _, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}")) + _, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}")) weights_iterator = safetensors_weights_iterator(safetensor_files) for name, weight in weights_iterator: state_dict[name] = weight return state_dict -def get_all_safetensors(model_path: str): +def get_all_weights_file(model_path: str): """ get_all_safetensors """ - safe_model_path = os.path.join(model_path, "model.safetensors") - if os.path.exists(safe_model_path): - safetensor_list = [safe_model_path] - with safe_open(safe_model_path, framework="np", device="cpu") as f: - key_name_list = f.keys() - return key_name_list, safetensor_list + model_path = Path(model_path) + use_safetensors = True + if any(model_path.glob("*.pdparams")): + key_name_list = [] + files_list = [str(file) for file in model_path.glob("*.pdparams")] + use_safetensors = False else: - with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f: - weight_map = json.load(f)["weight_map"] - weight_files_in_index = set() - for weight_name in weight_map: - weight_files_in_index.add(os.path.join(model_path, weight_map[weight_name])) - key_name_list = list(set(weight_map.keys())) - safetensor_list = list(weight_files_in_index) - safetensor_list.sort() - return key_name_list, safetensor_list + safe_model_path = model_path / "model.safetensors" + if safe_model_path.exists(): + files_list = [str(safe_model_path)] + with safe_open(safe_model_path, framework="np", device="cpu") as f: + key_name_list = f.keys() + return key_name_list, files_list, use_safetensors + else: + index_file = model_path / "model.safetensors.index.json" + with index_file.open("r") as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map} + key_name_list = list(weight_map.keys()) + files_list = sorted(weight_files_in_index) + return key_name_list, files_list, use_safetensors def load_tp_checkpoint_v1( @@ -271,7 +399,7 @@ def load_tp_checkpoint_v1( load_tp_checkpoint_v1 """ - safetensor_keys, safetensor_files = get_all_safetensors(model_path) + safetensor_keys, safetensor_files, _ = get_all_weights_file(model_path) if use_fastsafetensor: weights_iterator = fastsafetensors_weights_iterator(safetensor_files) diff --git a/fastdeploy/model_executor/model_loader/default_loader.py b/fastdeploy/model_executor/model_loader/default_loader.py index e1ee0ce1f..6ebb14253 100644 --- a/fastdeploy/model_executor/model_loader/default_loader.py +++ b/fastdeploy/model_executor/model_loader/default_loader.py @@ -51,7 +51,7 @@ class DefaultModelLoader(BaseModelLoader): paddle.device.cuda.empty_cache() paddle.device.synchronize() - @measure_time + @measure_time() def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None: model_class = ModelRegistry.get_pretrain_cls(architectures) diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 51e80e7b0..f6ecb43f7 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -20,9 +20,11 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig, LoadConfig, ModelConfig from fastdeploy.model_executor.load_weight_utils import ( - fast_weights_iterator, - get_all_safetensors, + get_weight_iterator, + is_weight_cache_enabled, + load_weights_form_cache, measure_time, + save_model, ) from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader from fastdeploy.model_executor.models.model_base import ModelRegistry @@ -44,11 +46,14 @@ class DefaultModelLoaderV1(BaseModelLoader): paddle.device.cuda.empty_cache() paddle.device.synchronize() - @measure_time - def load_weights(self, model, fd_config: FDConfig) -> None: - _, safetensor_files = get_all_safetensors(fd_config.model_config.model) - weights_iterator = fast_weights_iterator(safetensor_files) - model.load_weights(weights_iterator) + @save_model() + @measure_time() + def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None: + weights_iterator = get_weight_iterator(fd_config.model_config.model) + if enable_cache: + load_weights_form_cache(model, weights_iterator) + else: + model.load_weights(weights_iterator) self.clean_memory_fragments() def load_model(self, fd_config: FDConfig) -> nn.Layer: @@ -61,14 +66,15 @@ class DefaultModelLoaderV1(BaseModelLoader): architectures = architectures + "RL" - with context: - model_cls = ModelRegistry.get_class(architectures) - model = model_cls(fd_config) + enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config) + with weight_cache_context: + with context: + model_cls = ModelRegistry.get_class(architectures) + model = model_cls(fd_config) model.eval() - # RL model not need set_state_dict if fd_config.load_config.dynamic_load_weight: return model - self.load_weights(model, fd_config) + self.load_weights(model, fd_config, enable_cache) return model diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 8e90fb80f..264dc3d76 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -199,3 +199,14 @@ def temporary_dtype(dtype: str): yield finally: paddle.set_default_dtype(orig_dtype) + + +@contextmanager +def switch_config_context(config_obj, config_attr_name, value): + """switch_config_context""" + origin_value = getattr(config_obj, config_attr_name) + setattr(config_obj, config_attr_name, value) + try: + yield + finally: + setattr(config_obj, config_attr_name, origin_value) diff --git a/scripts/offline_w4a8.py b/scripts/offline_w4a8.py index edfb7868a..596a48bf3 100644 --- a/scripts/offline_w4a8.py +++ b/scripts/offline_w4a8.py @@ -15,7 +15,7 @@ from safetensors.numpy import save_file as safe_save_file from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.load_weight_utils import ( - get_all_safetensors, + get_all_weights_file, safetensors_weights_iterator, ) from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute @@ -158,7 +158,7 @@ def main(): Ernie4_5Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i] break tokenizer = Ernie4_5Tokenizer.from_pretrained(args.model_name_or_path) - _, safetensor_files = get_all_safetensors(args.model_name_or_path) + _, safetensor_files, _ = get_all_weights_file(args.model_name_or_path) weights_iterator = safetensors_weights_iterator(safetensor_files) state_dict = {} save_state_dict = {} diff --git a/tests/model_loader/test_common_model.py b/tests/model_loader/test_common_model.py index 95ff318e1..335792f40 100644 --- a/tests/model_loader/test_common_model.py +++ b/tests/model_loader/test_common_model.py @@ -56,7 +56,11 @@ model_param_map = { "backend": "triton", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}, }, - {"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}}, + { + "quant_type": "block_wise_fp8", + "backend": "deepgemm", + "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17", "FD_USE_DEEP_GEMM": "1"}, + }, ], }, "DeepSeek-V3-0324": { diff --git a/tests/model_loader/test_model_cache.py b/tests/model_loader/test_model_cache.py new file mode 100644 index 000000000..8b1504efa --- /dev/null +++ b/tests/model_loader/test_model_cache.py @@ -0,0 +1,128 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +import pytest + +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from tests.model_loader.utils import ( + check_tokens_id_and_text_close, + form_model_get_output_topp0, + get_paddle_model_path, + run_with_timeout, +) + +FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313)) +FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) + +prompts = ["解释下“温故而知新", "Hello, how are you?"] + + +model_param_map = { + "ernie-4_5-21b-a3b-bf16-paddle": { + "tensor_parallel_size": 2, + "quantizations": [ + { + "quant_type": "wint4", + "env": {"FD_ENABLE_MODEL_CACHE": "1"}, + } + ], + } +} + + +params = [] +for model, cfg in model_param_map.items(): + for q in cfg["quantizations"]: + if isinstance(q, dict): + quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {}) + else: + quant, backend, env = q, "default", {} + params.append( + pytest.param( + model, + cfg.get("tensor_parallel_size", 1), + cfg.get("max_model_len", 1024), + quant, + cfg.get("max_tokens", 32), + env, + marks=[pytest.mark.core_model], + id=f"{model}.{quant}.{backend}", + ) + ) + + +@pytest.mark.parametrize( + "model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env", + params, +) +def test_model_cache( + fd_runner, + model_name_or_path: str, + tensor_parallel_size: int, + max_model_len: int, + max_tokens: int, + quantization: str, + env, + monkeypatch, +) -> None: + model_path = get_paddle_model_path(model_name_or_path) + + fd_outputs_v1 = run_with_timeout( + target=form_model_get_output_topp0, + args=( + fd_runner, + model_path, + tensor_parallel_size, + max_model_len, + max_tokens, + quantization, + "default_v1", + FD_ENGINE_QUEUE_PORT, + prompts, + FD_CACHE_QUEUE_PORT, + ), + ) + + if env: + for k, v in env.items(): + monkeypatch.setenv(k, v) + + fd_outputs_v1_with_cache = run_with_timeout( + target=form_model_get_output_topp0, + args=( + fd_runner, + model_path, + tensor_parallel_size, + max_model_len, + max_tokens, + quantization, + "default_v1", + FD_ENGINE_QUEUE_PORT, + prompts, + FD_CACHE_QUEUE_PORT, + ), + ) + check_tokens_id_and_text_close( + outputs_0_lst=fd_outputs_v1, + outputs_1_lst=fd_outputs_v1_with_cache, + name_0="default_v1 laoder", + name_1="default_v1 loader using cache", + ) diff --git a/tests/model_loader/utils.py b/tests/model_loader/utils.py index 0705af8ef..67113bb0b 100644 --- a/tests/model_loader/utils.py +++ b/tests/model_loader/utils.py @@ -63,9 +63,14 @@ def run_with_timeout(target, args, timeout=60 * 5): print_logs() raise RuntimeError("Worker process hung and was terminated") try: - return result_queue.get(timeout=60) + result = result_queue.get(timeout=60) except Exception as e: raise RuntimeError(f"Failed to get result from worker: {e}") + finally: + result_queue.close() + result_queue.join_thread() + + return result def form_model_get_output_topp0(