mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 20:32:52 +08:00
[Iluvatar GPU] fix ci error caused by rebuild_padding param and cuda graph (#4504)
This commit is contained in:
@@ -20,6 +20,8 @@ from dataclasses import dataclass
|
|||||||
import paddle
|
import paddle
|
||||||
import pynvml
|
import pynvml
|
||||||
|
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PaddleMemoryInfo:
|
class PaddleMemoryInfo:
|
||||||
@@ -46,8 +48,11 @@ class GPUMemoryChecker:
|
|||||||
self.device_id = device_id
|
self.device_id = device_id
|
||||||
self.print_debug_info = print_debug_info
|
self.print_debug_info = print_debug_info
|
||||||
|
|
||||||
pynvml.nvmlInit()
|
if current_platform.is_iluvatar():
|
||||||
self.gpu_memory_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id)
|
self.gpu_memory_handle = None
|
||||||
|
else:
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
self.gpu_memory_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
""" """
|
""" """
|
||||||
|
|||||||
@@ -732,7 +732,9 @@ def rebuild_padding(
|
|||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
output_padding_offset,
|
output_padding_offset,
|
||||||
|
first_token_out,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
enable_logprob,
|
||||||
)
|
)
|
||||||
elif current_platform.is_gcu():
|
elif current_platform.is_gcu():
|
||||||
from fastdeploy.model_executor.ops.gcu import rebuild_padding
|
from fastdeploy.model_executor.ops.gcu import rebuild_padding
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ class IluvatarModelRunner(GPUModelRunner):
|
|||||||
rank: int,
|
rank: int,
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
):
|
):
|
||||||
|
# Iluvatar does not support cudagraph
|
||||||
|
fd_config.graph_opt_config.use_cudagraph = False
|
||||||
super(IluvatarModelRunner, self).__init__(
|
super(IluvatarModelRunner, self).__init__(
|
||||||
fd_config=fd_config, device=device, device_id=device_id, rank=rank, local_rank=local_rank
|
fd_config=fd_config, device=device, device_id=device_id, rank=rank, local_rank=local_rank
|
||||||
)
|
)
|
||||||
@@ -40,6 +42,7 @@ class IluvatarModelRunner(GPUModelRunner):
|
|||||||
assert not self.cache_config.enable_prefix_caching, "Iluvatar does not support prefix caching"
|
assert not self.cache_config.enable_prefix_caching, "Iluvatar does not support prefix caching"
|
||||||
self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
|
self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
|
||||||
assert not self.mla_cache, "Iluvatar does not support MLA"
|
assert not self.mla_cache, "Iluvatar does not support MLA"
|
||||||
|
assert not self.use_cudagraph, "Iluvatar does not support cudagraph"
|
||||||
if self.enable_mm:
|
if self.enable_mm:
|
||||||
assert (
|
assert (
|
||||||
not self.cache_config.enable_chunked_prefill
|
not self.cache_config.enable_chunked_prefill
|
||||||
|
|||||||
@@ -1,42 +1,91 @@
|
|||||||
|
import functools
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
|
||||||
from fastdeploy import LLM, SamplingParams
|
from fastdeploy import LLM, SamplingParams
|
||||||
from fastdeploy.utils import set_random_seed
|
from fastdeploy.utils import set_random_seed
|
||||||
|
|
||||||
set_random_seed(123)
|
|
||||||
|
|
||||||
prompts = [
|
def timeout(seconds):
|
||||||
"Hello, my name is",
|
def decorator(func):
|
||||||
]
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
result = [None]
|
||||||
|
exception = [None]
|
||||||
|
|
||||||
# 采样参数
|
def target():
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16)
|
try:
|
||||||
|
result[0] = func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
exception[0] = e
|
||||||
|
|
||||||
# 加载模型
|
thread = threading.Thread(target=target)
|
||||||
llm = LLM(
|
thread.daemon = True
|
||||||
model="/data1/fastdeploy/ERNIE_300B_4L",
|
thread.start()
|
||||||
tensor_parallel_size=8,
|
thread.join(seconds)
|
||||||
max_model_len=8192,
|
|
||||||
quantization="wint8",
|
|
||||||
block_size=16,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 批量进行推理(llm内部基于资源情况进行请求排队、动态插入处理)
|
if thread.is_alive():
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
raise TimeoutError(f"Function timed out after {seconds} seconds")
|
||||||
|
|
||||||
assert outputs[0].outputs.token_ids == [
|
if exception[0]:
|
||||||
23768,
|
raise exception[0]
|
||||||
97000,
|
|
||||||
47814,
|
return result[0]
|
||||||
59335,
|
|
||||||
68170,
|
return wrapper
|
||||||
183,
|
|
||||||
49080,
|
return decorator
|
||||||
94717,
|
|
||||||
82966,
|
|
||||||
99140,
|
@timeout(60)
|
||||||
31615,
|
def offline_infer_check():
|
||||||
51497,
|
set_random_seed(123)
|
||||||
94851,
|
|
||||||
60764,
|
prompts = [
|
||||||
10889,
|
"Hello, my name is",
|
||||||
2,
|
]
|
||||||
], f"{outputs[0].outputs.token_ids}"
|
|
||||||
|
# 采样参数
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
llm = LLM(
|
||||||
|
model="/data1/fastdeploy/ERNIE_300B_4L",
|
||||||
|
tensor_parallel_size=8,
|
||||||
|
max_model_len=8192,
|
||||||
|
quantization="wint8",
|
||||||
|
block_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 批量进行推理(llm内部基于资源情况进行请求排队、动态插入处理)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
assert outputs[0].outputs.token_ids == [
|
||||||
|
23768,
|
||||||
|
97000,
|
||||||
|
47814,
|
||||||
|
59335,
|
||||||
|
68170,
|
||||||
|
183,
|
||||||
|
49080,
|
||||||
|
94717,
|
||||||
|
82966,
|
||||||
|
99140,
|
||||||
|
31615,
|
||||||
|
51497,
|
||||||
|
94851,
|
||||||
|
60764,
|
||||||
|
10889,
|
||||||
|
2,
|
||||||
|
], f"{outputs[0].outputs.token_ids}"
|
||||||
|
print("PASSED")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
result = offline_infer_check()
|
||||||
|
sys.exit(0)
|
||||||
|
except TimeoutError:
|
||||||
|
sys.exit(124)
|
||||||
|
except Exception:
|
||||||
|
sys.exit(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user