[Iluvatar GPU] Optimze attention and moe performance (#3234)

This commit is contained in:
yzwu
2025-08-08 10:51:24 +08:00
committed by GitHub
parent 37569cca86
commit fbdd6b0663
24 changed files with 1130 additions and 1653 deletions

View File

@@ -16,22 +16,22 @@
import gc
import os
from typing import List, Optional
import time
import numpy as np
import paddle
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.gpu_worker import GpuWorker
from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner
from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase
from fastdeploy.worker.worker_process import PaddleDisWorkerProc
logger = get_logger("iluvatar_worker", "iluvatar_worker.log")
class IluvatarWorker(WorkerBase):
class IluvatarWorker(GpuWorker):
""" """
def __init__(
@@ -40,15 +40,16 @@ class IluvatarWorker(WorkerBase):
local_rank: int,
rank: int,
):
super().__init__(
super(IluvatarWorker, self).__init__(
fd_config=fd_config,
local_rank=local_rank,
rank=rank,
)
pass
def init_device(self):
"""Initialize device and Construct model runner"""
"""
Initialize device and construct model runner
"""
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
# Set evironment variable
self.device = f"iluvatar_gpu:{self.local_rank}"
@@ -70,12 +71,6 @@ class IluvatarWorker(WorkerBase):
local_rank=self.local_rank,
)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
return self.model_runner.exist_prefill()
def determine_available_memory(self) -> int:
"""
Profiles the peak memory usage of the model to determine how much
@@ -92,51 +87,86 @@ class IluvatarWorker(WorkerBase):
# 1. Record memory state before profile run
return int(float(os.getenv("FD_ILUVATAR_KVCACHE_MEM", "3")) * 1024**3)
def load_model(self) -> None:
""" """
self.model_runner.load_model()
def get_model(self) -> nn.Layer:
""" """
return self.model_runner.get_model()
class IluvatarPaddleDisWorkerProc(PaddleDisWorkerProc):
"""
Paddle Distributed wrapper for fastdeploy.worker.Worker,
for handling single-node multi-GPU tensor parallel.
The wrapper internally executes an event loop that continuously executes requests
in the task queue. Control flow is transmitted by IPC.
"""
def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0):
super(IluvatarPaddleDisWorkerProc, self).__init__(
fd_config=fd_config,
ranks=ranks,
local_rank=local_rank,
)
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
num_running_requests: int = None,
) -> Optional[ModelRunnerOutput]:
""" """
output = self.model_runner.execute_model(model_forward_batch, num_running_requests)
return output
def initialize_kv_cache(self) -> None:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
"""Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
if self.fd_config.parallel_config.do_profile:
# 1. Get available memory(bytes)
available_kv_cache_memory = self.worker.determine_available_memory()
logger.info(f"------- available_kv_cache_memory:{available_kv_cache_memory / 1024**3} GB --------")
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization
"""
# 1. Warm up model
# NOTE(gongshaotian): may be not need warm_up at this place
if self.model_runner.graph_opt_level >= 1:
self.model_runner.sot_warmup()
# 2. Calculate the appropriate number of blocks
model_block_memory_used = self.worker.cal_theortical_kvcache()
num_blocks_local = int(available_kv_cache_memory // model_block_memory_used)
# NOTE(liuzichang): Too many block will lead to illegal memory access
# We will develop dynamic limits in future.
if num_blocks_local > 40000:
logger.info(f"------- Reset num_blocks_local {num_blocks_local} to 40000")
num_blocks_local = min(40000, num_blocks_local)
logger.info(f"------- model_block_memory_used:{model_block_memory_used} --------")
logger.info(f"------- num_blocks_local:{num_blocks_local} --------")
# 2. Triger cuda grpah capture
self.model_runner.capture_model()
set_random_seed(self.fd_config.model_config.seed)
# NOTE(yuzhe.wu): Using the old version of the calculation num_blocks_global method,
# because the new version that adopting allreduce min will report a bad request error
# when running 300b model. The Relation commit:
# https://github.com/PaddlePaddle/FastDeploy/commit/2f74e93d7e87aa3ffec3fc6966bf11ab5363b956
def check_health(self) -> bool:
""" """
return True
# 3. Send IPCSignal
get_profile_block_num = np.zeros(shape=[self.ranks], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_local
def cal_theortical_kvcache(self) -> int:
""" """
return self.model_runner.cal_theortical_kvcache()
# Wait all worker send the signal
while np.any(self.get_profile_block_num_signal.value <= 0):
time.sleep(0.01)
num_blocks_global = self.get_profile_block_num_signal.value.min().item()
if num_blocks_global < 0:
logger.error(
"The total number of blocks cannot be less than zero."
"Please increase gpu_memory_utilization"
"Or decrease max_num_batched_tokens(max model length) "
)
raise ValueError(
"The total number of blocks cannot be less than zero."
"Please increase gpu_memory_utilization"
"Or decrease max_num_batched_tokens(max model length) "
)
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_global
else:
num_blocks_global = self.fd_config.parallel_config.total_block_num
# 4. init kv_cache with accurate num_blocks
logger.info(f"------- num_blocks_global:{num_blocks_global} --------")
self.worker.initialize_cache(num_gpu_blocks=num_blocks_global)