diff --git a/fastdeploy/model_executor/graph_optimization/utils.py b/fastdeploy/model_executor/graph_optimization/utils.py index 4aa7729da..ee157041e 100644 --- a/fastdeploy/model_executor/graph_optimization/utils.py +++ b/fastdeploy/model_executor/graph_optimization/utils.py @@ -20,6 +20,8 @@ from dataclasses import dataclass import paddle import pynvml +from fastdeploy.platforms import current_platform + @dataclass class PaddleMemoryInfo: @@ -46,8 +48,11 @@ class GPUMemoryChecker: self.device_id = device_id self.print_debug_info = print_debug_info - pynvml.nvmlInit() - self.gpu_memory_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id) + if current_platform.is_iluvatar(): + self.gpu_memory_handle = None + else: + pynvml.nvmlInit() + self.gpu_memory_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id) def __del__(self): """ """ diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 6482b357d..bddb12b49 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -732,7 +732,9 @@ def rebuild_padding( seq_lens_decoder, seq_lens_encoder, output_padding_offset, + first_token_out, max_input_length, + enable_logprob, ) elif current_platform.is_gcu(): from fastdeploy.model_executor.ops.gcu import rebuild_padding diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index e8ef1b69c..15dc8472c 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -31,6 +31,8 @@ class IluvatarModelRunner(GPUModelRunner): rank: int, local_rank: int, ): + # Iluvatar does not support cudagraph + fd_config.graph_opt_config.use_cudagraph = False super(IluvatarModelRunner, self).__init__( fd_config=fd_config, device=device, device_id=device_id, rank=rank, local_rank=local_rank ) @@ -40,6 +42,7 @@ class IluvatarModelRunner(GPUModelRunner): assert not self.cache_config.enable_prefix_caching, "Iluvatar does not support prefix caching" self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" assert not self.mla_cache, "Iluvatar does not support MLA" + assert not self.use_cudagraph, "Iluvatar does not support cudagraph" if self.enable_mm: assert ( not self.cache_config.enable_chunked_prefill diff --git a/tests/ci_use/iluvatar_UT/run_ernie300B_4layer.py b/tests/ci_use/iluvatar_UT/run_ernie300B_4layer.py index 40f0efd29..e619eaf0e 100644 --- a/tests/ci_use/iluvatar_UT/run_ernie300B_4layer.py +++ b/tests/ci_use/iluvatar_UT/run_ernie300B_4layer.py @@ -1,42 +1,91 @@ +import functools +import sys +import threading + from fastdeploy import LLM, SamplingParams from fastdeploy.utils import set_random_seed -set_random_seed(123) -prompts = [ - "Hello, my name is", -] +def timeout(seconds): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = [None] + exception = [None] -# 采样参数 -sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16) + def target(): + try: + result[0] = func(*args, **kwargs) + except Exception as e: + exception[0] = e -# 加载模型 -llm = LLM( - model="/data1/fastdeploy/ERNIE_300B_4L", - tensor_parallel_size=8, - max_model_len=8192, - quantization="wint8", - block_size=16, -) + thread = threading.Thread(target=target) + thread.daemon = True + thread.start() + thread.join(seconds) -# 批量进行推理(llm内部基于资源情况进行请求排队、动态插入处理) -outputs = llm.generate(prompts, sampling_params) + if thread.is_alive(): + raise TimeoutError(f"Function timed out after {seconds} seconds") -assert outputs[0].outputs.token_ids == [ - 23768, - 97000, - 47814, - 59335, - 68170, - 183, - 49080, - 94717, - 82966, - 99140, - 31615, - 51497, - 94851, - 60764, - 10889, - 2, -], f"{outputs[0].outputs.token_ids}" + if exception[0]: + raise exception[0] + + return result[0] + + return wrapper + + return decorator + + +@timeout(60) +def offline_infer_check(): + set_random_seed(123) + + prompts = [ + "Hello, my name is", + ] + + # 采样参数 + sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16) + + # 加载模型 + llm = LLM( + model="/data1/fastdeploy/ERNIE_300B_4L", + tensor_parallel_size=8, + max_model_len=8192, + quantization="wint8", + block_size=16, + ) + + # 批量进行推理(llm内部基于资源情况进行请求排队、动态插入处理) + outputs = llm.generate(prompts, sampling_params) + + assert outputs[0].outputs.token_ids == [ + 23768, + 97000, + 47814, + 59335, + 68170, + 183, + 49080, + 94717, + 82966, + 99140, + 31615, + 51497, + 94851, + 60764, + 10889, + 2, + ], f"{outputs[0].outputs.token_ids}" + print("PASSED") + + +if __name__ == "__main__": + try: + result = offline_infer_check() + sys.exit(0) + except TimeoutError: + sys.exit(124) + except Exception: + sys.exit(1)