mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 15:56:49 +08:00
[NewFeature]Support dp multi api server && Fix some bug in mixed ep && merge develop (#3598)
* [Feature] update ep * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix queue ports idx * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * Update engine.py * fix ci * fix some bug in mixed ep * add server fix and op fix * rm some log * fix code style * ltd fix * fix * fix * fix some bug * fix bug * fix bug * fix style * Update config.py * Update splitwise_connector.py * Update cache_messager.py * Update __init__.py * merge and fix * Update engine.py * Update common_engine.py * Update run_ci_xpu.sh * Update ernie_processor.py * Update ernie_processor.py --------- Co-authored-by: ltd0924 <ltd0924@sina.com> Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
This commit is contained in:
@@ -152,19 +152,7 @@ class PaddleDisWorkerProc:
|
||||
# TODO(gongshaotian): Use worker factory to get worker
|
||||
self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks)
|
||||
|
||||
# Initialize task queue
|
||||
task_address = (
|
||||
self.parallel_config.pod_ip,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
)
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
self.task_queue = TaskQueue(
|
||||
address=task_address,
|
||||
is_server=False,
|
||||
num_client=self.parallel_config.tensor_parallel_size,
|
||||
client_id=self.parallel_config.tensor_parallel_rank,
|
||||
local_data_parallel_id=self.parallel_config.data_parallel_rank,
|
||||
)
|
||||
|
||||
def init_health_status(self) -> None:
|
||||
"""
|
||||
@@ -193,15 +181,16 @@ class PaddleDisWorkerProc:
|
||||
self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
|
||||
|
||||
# init worker_healthy_live_signal
|
||||
workers_alive = np.zeros(shape=[array_size], dtype=np.int32)
|
||||
workers_alive = np.zeros(shape=[min(array_size, self.parallel_config.tensor_parallel_size)], dtype=np.int32)
|
||||
self.worker_healthy_live_signal = IPCSignal(
|
||||
name="worker_healthy_live_signal",
|
||||
array=workers_alive,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
|
||||
|
||||
# init model_weights_status
|
||||
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
|
||||
@@ -209,27 +198,27 @@ class PaddleDisWorkerProc:
|
||||
name="model_weights_status",
|
||||
array=workers_model_weights,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
# init exist_task_signal
|
||||
workers_exist_task = np.zeros([self.parallel_config.data_parallel_size], dtype=np.int32)
|
||||
workers_exist_task = np.zeros([1], dtype=np.int32)
|
||||
self.exist_task_signal = IPCSignal(
|
||||
name="exist_task_signal",
|
||||
array=workers_exist_task,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
# init exist_swapped_task_signal
|
||||
workers_swapped_task = np.zeros(shape=[self.parallel_config.data_parallel_size], dtype=np.int32)
|
||||
workers_swapped_task = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.exist_swapped_task_signal = IPCSignal(
|
||||
name="exist_swapped_task_signal",
|
||||
array=workers_swapped_task,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -239,9 +228,10 @@ class PaddleDisWorkerProc:
|
||||
name="exist_prefill_task_signal",
|
||||
array=exist_prefill_task_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
logger.info("gaoziyuan test init_health_status")
|
||||
|
||||
def event_loop_normal(self) -> None:
|
||||
"""Main event loop for Paddle Distrubuted Workers.
|
||||
@@ -411,6 +401,21 @@ class PaddleDisWorkerProc:
|
||||
"""Initialize device and Construct model runner"""
|
||||
self.worker.init_device()
|
||||
|
||||
def start_task_queue_service(self):
|
||||
# Initialize task queue
|
||||
task_address = (
|
||||
self.parallel_config.pod_ip,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
)
|
||||
logger.info(f"connect task queue address {task_address}")
|
||||
self.task_queue = TaskQueue(
|
||||
address=task_address,
|
||||
is_server=False,
|
||||
num_client=self.parallel_config.tensor_parallel_size,
|
||||
client_id=self.parallel_config.tensor_parallel_rank,
|
||||
local_data_parallel_id=self.parallel_config.expert_parallel_rank,
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load weights and create model"""
|
||||
|
||||
@@ -444,7 +449,7 @@ def parse_args():
|
||||
parser.add_argument("--total_block_num", type=int, default=2000)
|
||||
parser.add_argument("--block_size", type=int, default=64)
|
||||
parser.add_argument("--pod_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--engine_worker_queue_port", type=int, default=9923)
|
||||
parser.add_argument("--engine_worker_queue_port", type=str, default="9923")
|
||||
parser.add_argument("--max_model_len", type=int, default=3072, help="max model len")
|
||||
parser.add_argument("--device_ids", type=str, default="0", help="cuda visible devices")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="input dtype")
|
||||
@@ -619,10 +624,16 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
|
||||
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
|
||||
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
parallel_config.local_data_parallel_id = expert_parallel_rank % max_chips_per_node
|
||||
|
||||
parallel_config.expert_parallel_rank = expert_parallel_rank
|
||||
parallel_config.num_experts_per_rank = num_experts_per_rank
|
||||
parallel_config.num_experts_start_offset = num_experts_start_offset
|
||||
|
||||
parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
|
||||
parallel_config.local_data_parallel_id
|
||||
]
|
||||
parallel_config.set_tp_group()
|
||||
|
||||
load_config = LoadConfig(vars(args))
|
||||
@@ -640,6 +651,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
logger.info(f"parallel_config.use_ep {parallel_config.use_ep}")
|
||||
logger.info(f"parallel_config.tensor_parallel_size {parallel_config.tensor_parallel_size}")
|
||||
logger.info(f"parallel_config.tensor_parallel_rank {parallel_config.tensor_parallel_rank}")
|
||||
logger.info(f"parallel_config.engine_worker_queue_port {parallel_config.engine_worker_queue_port}")
|
||||
|
||||
if getattr(model_config, "num_hidden_layers", None) is None:
|
||||
raise ValueError("num_hidden_layers is None")
|
||||
@@ -705,6 +717,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
graph_opt_config=graph_opt_config,
|
||||
early_stop_config=early_stop_config,
|
||||
cache_config=cache_config,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
ips=args.ips,
|
||||
)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
@@ -746,6 +759,8 @@ def run_worker_proc() -> None:
|
||||
# Initialize health status
|
||||
worker_proc.init_health_status()
|
||||
|
||||
worker_proc.start_task_queue_service()
|
||||
|
||||
# Start event loop
|
||||
worker_proc.event_loop_normal()
|
||||
|
||||
|
Reference in New Issue
Block a user