mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[NewFeture]add ep rollout model init and update/clear ep buffer (#4039)
* fix gid * merge * fix test * fix bug * fix * fix ci
This commit is contained in:
@@ -33,6 +33,11 @@
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 3; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 6: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 6; \
|
||||
__VA_ARGS__ \
|
||||
|
@@ -338,20 +338,26 @@ class ParallelConfig:
|
||||
else:
|
||||
self.pd_disaggregation_mode = "None"
|
||||
|
||||
def set_tp_group(self):
|
||||
def set_communicate_group(self):
|
||||
# different tp group id
|
||||
# prevent different tp_groups using the same group_id
|
||||
tp_gid_offset = envs.FD_TP_GROUP_GID_OFFSET
|
||||
dist.collective._set_custom_gid(self.data_parallel_rank + tp_gid_offset)
|
||||
|
||||
self.tp_group = dist.new_group(
|
||||
range(
|
||||
self.data_parallel_rank * self.tensor_parallel_size,
|
||||
(self.data_parallel_rank + 1) * self.tensor_parallel_size,
|
||||
)
|
||||
)
|
||||
dist.collective._set_custom_gid(None)
|
||||
|
||||
# same ep group id
|
||||
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
|
||||
self.ep_group = dist.new_group(range(self.expert_parallel_size))
|
||||
if self.enable_expert_parallel:
|
||||
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
|
||||
self.ep_group = dist.new_group(range(self.expert_parallel_size))
|
||||
dist.collective._set_custom_gid(None)
|
||||
|
||||
logger.info(
|
||||
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
|
||||
)
|
||||
@@ -833,6 +839,7 @@ class LoadConfig:
|
||||
load_strategy: Specifies the weight loading method when enabled:
|
||||
- 'ipc': Real-time IPC streaming with automatic resharding
|
||||
- 'ipc_snapshot': Load from disk snapshot of IPC weights
|
||||
- 'meta': Only model meta messages
|
||||
- None: No dynamic loading
|
||||
"""
|
||||
|
||||
@@ -843,7 +850,7 @@ class LoadConfig:
|
||||
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
|
||||
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
|
||||
self.dynamic_load_weight: bool = False
|
||||
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None
|
||||
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal"]] = "normal"
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
@@ -1201,12 +1208,10 @@ class FDConfig:
|
||||
|
||||
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if num_ranks > self.max_chips_per_node:
|
||||
if num_ranks > self.max_chips_per_node and self.load_config.load_strategy != "meta":
|
||||
self.worker_num_per_node = self.max_chips_per_node
|
||||
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
||||
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
||||
|
||||
# assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
||||
else:
|
||||
self.worker_num_per_node = num_ranks
|
||||
|
||||
|
@@ -135,7 +135,7 @@ class EngineArgs:
|
||||
"""
|
||||
dynamic load weight
|
||||
"""
|
||||
load_strategy: str = "ipc_snapshot"
|
||||
load_strategy: str = "normal"
|
||||
"""
|
||||
dynamic load weight strategy
|
||||
"""
|
||||
|
@@ -20,168 +20,139 @@ import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
try:
|
||||
from paddle.distributed.communication import deep_ep
|
||||
except:
|
||||
logger.warning("import deep_ep Failed!")
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
from paddle.distributed.communication import deep_ep
|
||||
except:
|
||||
logger.warning("import deep_ep Failed!")
|
||||
from typing import Optional
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.config import MoEPhase
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.utils import singleton
|
||||
|
||||
|
||||
class DeepEPEngineBase:
|
||||
class DeepEPBufferManager:
|
||||
_engine: Optional["DeepEPEngine"] = None
|
||||
|
||||
@classmethod
|
||||
def set_engine(cls, engine: "DeepEPEngine"):
|
||||
cls._engine = engine
|
||||
|
||||
@classmethod
|
||||
def clear_buffer(cls):
|
||||
if cls._engine:
|
||||
cls._engine.clear_deep_ep_buffer()
|
||||
|
||||
@classmethod
|
||||
def recreate_buffer(cls):
|
||||
if cls._engine:
|
||||
cls._engine.create_deep_ep_buffer()
|
||||
|
||||
|
||||
class DeepEPBuffer:
|
||||
"""
|
||||
A wrapper class for DeepEP engine.
|
||||
Encapsulates DeepEP buffer creation, management and cleanup.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden: int,
|
||||
group,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
ep_rank: int,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
async_finish: bool = False,
|
||||
group=None,
|
||||
):
|
||||
"""
|
||||
Initialize the DeepEP engine.
|
||||
Args:
|
||||
group: The MPI group object.
|
||||
ep_size: The number of ranks.
|
||||
rank_id: The rank id.
|
||||
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
|
||||
hidden: The hidden dimension of the model.
|
||||
num_experts: The number of experts.
|
||||
"""
|
||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
self.hidden = hidden
|
||||
self.group = group
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
self.ep_size = ep_size
|
||||
self.rank_id = ep_rank
|
||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
self.splitwise_role = splitwise_role
|
||||
self.moe_phase = moe_phase
|
||||
self.async_finish = async_finish
|
||||
# TODO(@wufeisheng): Support configurable EP size
|
||||
if group is None:
|
||||
group = paddle.distributed.new_group(range(ep_size))
|
||||
self.group = group
|
||||
self.num_local_experts = num_experts // ep_size
|
||||
self.deepep_engine = None
|
||||
self.init_deepep_engine()
|
||||
|
||||
@abstractmethod
|
||||
def init_deepep_engine(self):
|
||||
raise NotImplementedError
|
||||
self.deepep_buffer = None
|
||||
self.num_nvl_bytes = 0
|
||||
self.num_rdma_bytes = 0
|
||||
|
||||
# Precompute buffer sizes
|
||||
self._compute_buffer_sizes()
|
||||
|
||||
@singleton
|
||||
class DeepEPEngine(DeepEPEngineBase):
|
||||
"""
|
||||
A wrapper class for DeepEP engine.
|
||||
"""
|
||||
def _compute_buffer_sizes(self, param_bytes: int = 2):
|
||||
hidden_bytes = self.hidden_size * param_bytes # bf16 or fp16
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
ep_rank: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
async_finish: bool = False,
|
||||
group=None,
|
||||
):
|
||||
"""
|
||||
Initialize the DeepEP engine.
|
||||
Args:
|
||||
group: The MPI group object.
|
||||
ep_size: The number of ranks.
|
||||
rank_id: The rank id.
|
||||
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
|
||||
hidden: The hidden dimension of the model.
|
||||
num_experts: The number of experts.
|
||||
"""
|
||||
super().__init__(
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
hidden,
|
||||
num_experts,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
splitwise_role,
|
||||
moe_phase,
|
||||
async_finish,
|
||||
group,
|
||||
)
|
||||
for config in (
|
||||
deep_ep.Buffer.get_dispatch_config(self.group.world_size),
|
||||
deep_ep.Buffer.get_combine_config(self.group.world_size),
|
||||
):
|
||||
self.num_nvl_bytes = max(
|
||||
config.get_nvl_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_nvl_bytes
|
||||
)
|
||||
self.num_rdma_bytes = max(
|
||||
config.get_rdma_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_rdma_bytes
|
||||
)
|
||||
|
||||
def init_deepep_engine(self):
|
||||
from paddle.base.core import Config
|
||||
if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode":
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
)
|
||||
self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes)
|
||||
|
||||
self.ep_config = Config(24, 6, 256)
|
||||
logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}")
|
||||
|
||||
# In mixed EP mode on a single node, we dynamically switch between
|
||||
# high throughput and low latency modes.
|
||||
def create_buffer(self):
|
||||
"""Create or recreate buffer based on role and phase."""
|
||||
if self.deepep_buffer is not None:
|
||||
self.clear_buffer()
|
||||
|
||||
if self.splitwise_role == "mixed":
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
logger.info("Initializing mixed mode buffer (low latency).")
|
||||
self.deepep_buffer = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(2e9),
|
||||
int(6e9),
|
||||
self.num_nvl_bytes,
|
||||
self.num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=24,
|
||||
)
|
||||
# In disaggregated mode on multiple nodes, we either use
|
||||
# high throughput mode or low latency mode.
|
||||
self.deepep_buffer.set_num_sms(14) # TODO: tune in future
|
||||
else:
|
||||
if self.moe_phase.phase == "decode":
|
||||
logger.info("Initializing Low Latency Buffer")
|
||||
self.get_low_latency_buffer()
|
||||
self._create_low_latency_buffer()
|
||||
elif self.moe_phase.phase == "prefill":
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
logger.info("Initializing High Throughput Buffer for prefill phase.")
|
||||
self.deepep_buffer = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(5e8),
|
||||
self.num_nvl_bytes,
|
||||
0,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown generation phase {self.moe_phase}")
|
||||
raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}")
|
||||
|
||||
def get_low_latency_buffer(self):
|
||||
"""
|
||||
Get the DeepEP buffer.
|
||||
Args:
|
||||
group: The MPI group object.
|
||||
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
|
||||
hidden: The hidden dimension of the model.
|
||||
"""
|
||||
# NOTES: the low-latency mode will consume much more space than the normal mode
|
||||
# So we recommend that `num_max_dispatch_tokens_per_rank`
|
||||
# (the actual batch size in the decoding engine) should be less than 256
|
||||
logger.info("DeepEP buffer created successfully.")
|
||||
|
||||
def _create_low_latency_buffer(self):
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden,
|
||||
self.hidden_size,
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
)
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
|
||||
if (
|
||||
self.deepep_engine is None
|
||||
or self.deepep_engine.group != self.group
|
||||
or not self.deepep_engine.low_latency_mode
|
||||
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
|
||||
self.deepep_buffer is None
|
||||
or self.deepep_buffer.group != self.group
|
||||
or not self.deepep_buffer.low_latency_mode
|
||||
or self.deepep_buffer.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.deepep_buffer = deep_ep.Buffer(
|
||||
self.group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
@@ -189,6 +160,91 @@ class DeepEPEngine(DeepEPEngineBase):
|
||||
num_qps_per_rank=self.num_experts // self.ep_size,
|
||||
)
|
||||
|
||||
def clear_buffer(self):
|
||||
"""Clear buffer and free memory."""
|
||||
if self.deepep_buffer is not None:
|
||||
del self.deepep_buffer
|
||||
self.deepep_buffer = None
|
||||
logger.info("DeepEP buffer cleared.")
|
||||
|
||||
def get_buffer(self):
|
||||
return self.deepep_buffer
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
if self.deepep_buffer is not None:
|
||||
self.deepep_buffer.clean_low_latency_buffer(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.num_experts,
|
||||
)
|
||||
|
||||
def barrier_all(self):
|
||||
if self.deepep_buffer is not None:
|
||||
self.deepep_buffer.barrier_all()
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngine:
|
||||
"""
|
||||
A wrapper class for DeepEP engine.
|
||||
Manages buffer lifecycle based on role and phase.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
ep_rank: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
async_finish: bool = False,
|
||||
group=None,
|
||||
):
|
||||
if group is None:
|
||||
group = paddle.distributed.new_group(range(ep_size))
|
||||
self.group = group
|
||||
self.ep_size = ep_size
|
||||
self.rank_id = ep_rank
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // ep_size
|
||||
self.async_finish = async_finish
|
||||
from paddle.base.core import Config
|
||||
|
||||
self.ep_config = Config(24, 6, 256)
|
||||
|
||||
# Store phase and role for buffer management
|
||||
self._splitwise_role = splitwise_role
|
||||
self._moe_phase = moe_phase
|
||||
|
||||
# Initialize buffer manager
|
||||
self.buffer = DeepEPBuffer(
|
||||
group=self.group,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||
splitwise_role=splitwise_role,
|
||||
moe_phase=moe_phase,
|
||||
)
|
||||
self.buffer.create_buffer()
|
||||
|
||||
# Register for global buffer management
|
||||
DeepEPBufferManager.set_engine(self)
|
||||
|
||||
@property
|
||||
def deepep_engine(self):
|
||||
"""Backward compatibility alias."""
|
||||
return self.buffer.get_buffer()
|
||||
|
||||
def clear_deep_ep_buffer(self):
|
||||
self.buffer.clear_buffer()
|
||||
|
||||
def create_deep_ep_buffer(self):
|
||||
self.buffer.create_buffer()
|
||||
|
||||
def low_latency_dispatch(
|
||||
self,
|
||||
hidden_states: paddle.Tensor,
|
||||
@@ -196,22 +252,9 @@ class DeepEPEngine(DeepEPEngineBase):
|
||||
expertwise_scale,
|
||||
use_fp8: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [token_num, hidden] 'bfloat16/int8'
|
||||
topk_idx: [token_num, num_topk] 'int64'
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
Returns:
|
||||
recv_hidden_states: [num_local_experts,
|
||||
num_max_dispatch_tokens_per_rank * ep_size, hidden]
|
||||
ep_size * num_local_experts = num_experts
|
||||
recv_count: [num_local_experts]
|
||||
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
||||
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
|
||||
handle: the communication handle to be used in the `low_latency_combine` function.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||
"""
|
||||
(
|
||||
packed_recv_x,
|
||||
recv_expert_count,
|
||||
@@ -222,7 +265,7 @@ class DeepEPEngine(DeepEPEngineBase):
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
expertwise_scale,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.buffer.num_max_dispatch_tokens_per_rank,
|
||||
self.num_experts,
|
||||
use_fp8=use_fp8,
|
||||
async_finish=False,
|
||||
@@ -238,27 +281,14 @@ class DeepEPEngine(DeepEPEngineBase):
|
||||
topk_weights: paddle.Tensor,
|
||||
handle,
|
||||
):
|
||||
"""
|
||||
|
||||
Return:
|
||||
combined_hidden_states: [num_tokens, hidden]
|
||||
"""
|
||||
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0": # not develop version of PaddlePaddle
|
||||
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0":
|
||||
# TODO(@wanglongzhi): Delete them when deepep in PaddlePaddle is fixed
|
||||
# and when the default recommended version of PaddlePaddle is greater than 3.1.0
|
||||
(
|
||||
src_info,
|
||||
layout_range,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
) = handle
|
||||
handle = (
|
||||
src_info,
|
||||
layout_range,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
None,
|
||||
num_experts,
|
||||
)
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
||||
handle = (src_info, layout_range, num_max_dispatch_tokens_per_rank, None, num_experts)
|
||||
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
|
||||
hidden_states,
|
||||
@@ -271,18 +301,10 @@ class DeepEPEngine(DeepEPEngineBase):
|
||||
return combined_hidden_states, combine_hook
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
"""
|
||||
clean_low_latency_buffer
|
||||
"""
|
||||
self.deepep_engine.clean_low_latency_buffer(
|
||||
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
|
||||
)
|
||||
self.buffer.clean_low_latency_buffer()
|
||||
|
||||
def barrier_all(self):
|
||||
"""
|
||||
barrier_all
|
||||
"""
|
||||
self.deepep_engine.barrier_all()
|
||||
self.buffer.barrier_all()
|
||||
|
||||
|
||||
class EPRunner:
|
||||
@@ -293,7 +315,7 @@ class EPRunner:
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
@@ -304,33 +326,20 @@ class EPRunner:
|
||||
ep_group=None,
|
||||
):
|
||||
self.top_k = top_k
|
||||
self.hidden = hidden
|
||||
self.num_experts = num_experts
|
||||
self.splitwise_role = splitwise_role
|
||||
self.moe_phase = moe_phase
|
||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
self.ep_size = ep_size
|
||||
self.ep_rank = ep_rank
|
||||
self.redundant_experts_num = redundant_experts_num
|
||||
self.ep_group = ep_group
|
||||
self.init_ep_engine()
|
||||
|
||||
def init_ep_engine(self):
|
||||
self.ep_engine = DeepEPEngine(
|
||||
num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank,
|
||||
hidden=self.hidden,
|
||||
num_experts=self.num_experts + self.redundant_experts_num,
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
splitwise_role=self.splitwise_role,
|
||||
moe_phase=self.moe_phase,
|
||||
group=self.ep_group,
|
||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts + redundant_experts_num,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
splitwise_role=splitwise_role,
|
||||
moe_phase=moe_phase,
|
||||
group=ep_group,
|
||||
)
|
||||
|
||||
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
|
||||
"""
|
||||
moe_select
|
||||
"""
|
||||
if layer.redundant_table_manger is not None:
|
||||
(
|
||||
ep_rank_to_expert_id_list,
|
||||
@@ -346,12 +355,14 @@ class EPRunner:
|
||||
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
||||
bias=layer.gate_correction_bias,
|
||||
moe_topk=self.top_k,
|
||||
apply_norm_weight=True, # apply_norm_weight
|
||||
apply_norm_weight=True,
|
||||
enable_softmax_top_k_fused=False,
|
||||
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
||||
)
|
||||
else:
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
|
||||
score, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
@@ -365,28 +376,28 @@ class EPRunner:
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
self.top_k,
|
||||
True, # apply_norm_weight,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
return topk_idx, topk_weights
|
||||
|
||||
@abstractmethod
|
||||
def dispatch(self, *args, **kwargs):
|
||||
"""
|
||||
dispatch
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def combine(self, *args, **kwargs):
|
||||
"""
|
||||
combine
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
self.ep_engine.clean_low_latency_buffer()
|
||||
|
||||
def clear_deep_ep_buffer(self):
|
||||
self.ep_engine.clear_deep_ep_buffer()
|
||||
|
||||
def create_deep_ep_buffer(self):
|
||||
self.ep_engine.create_deep_ep_buffer()
|
||||
|
||||
|
||||
class EPPrefillRunner(EPRunner):
|
||||
"""
|
||||
@@ -396,19 +407,19 @@ class EPPrefillRunner(EPRunner):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
redundant_experts_num: int = 0,
|
||||
ep_group=None,
|
||||
moe_phase: MoEPhase = MoEPhase("prefill"),
|
||||
ep_group=None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k,
|
||||
hidden,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
splitwise_role,
|
||||
moe_phase,
|
||||
@@ -427,6 +438,9 @@ class EPPrefillRunner(EPRunner):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
buffer = self.ep_engine.deepep_engine
|
||||
if buffer is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
@@ -434,7 +448,7 @@ class EPPrefillRunner(EPRunner):
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank,
|
||||
_,
|
||||
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
|
||||
) = buffer.get_dispatch_layout(topk_idx, self.num_experts)
|
||||
|
||||
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
||||
dispatch_args = {
|
||||
@@ -443,12 +457,12 @@ class EPPrefillRunner(EPRunner):
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": self.ep_engine.ep_config,
|
||||
"config": self.ep_engine.ep_config, # assuming ep_config still in engine
|
||||
"async_finish": self.ep_engine.async_finish,
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": topk_weights,
|
||||
}
|
||||
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
|
||||
return buffer.dispatch(**dispatch_args)
|
||||
|
||||
def combine(
|
||||
self,
|
||||
@@ -456,6 +470,10 @@ class EPPrefillRunner(EPRunner):
|
||||
handle: tuple,
|
||||
recv_topk_weights: paddle.Tensor,
|
||||
):
|
||||
buffer = self.ep_engine.deepep_engine
|
||||
if buffer is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
combine_args = {
|
||||
"x": tmp_ffn_out,
|
||||
"handle": handle,
|
||||
@@ -463,8 +481,7 @@ class EPPrefillRunner(EPRunner):
|
||||
"async_finish": self.ep_engine.async_finish,
|
||||
"topk_weights": recv_topk_weights,
|
||||
}
|
||||
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
|
||||
|
||||
fused_moe_out, _, _ = buffer.combine(**combine_args)
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
@@ -476,7 +493,7 @@ class EPDecoderRunner(EPRunner):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
@@ -488,7 +505,7 @@ class EPDecoderRunner(EPRunner):
|
||||
):
|
||||
super().__init__(
|
||||
top_k,
|
||||
hidden,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
splitwise_role,
|
||||
moe_phase,
|
||||
|
@@ -40,67 +40,52 @@ class MoEMethodBase(QuantMethodBase):
|
||||
"down_proj_weight_scale",
|
||||
]
|
||||
self.pack_num = 1
|
||||
|
||||
def import_backend_ep_runner(self) -> None:
|
||||
from .ep import EPDecoderRunner, EPPrefillRunner
|
||||
|
||||
self.EPPrefillRunner = EPPrefillRunner
|
||||
self.EPDecoderRunner = EPDecoderRunner
|
||||
self.ep_prefill_runner = None
|
||||
self.ep_decoder_runner = None
|
||||
|
||||
def init_ep(self, layer: nn.Layer) -> None:
|
||||
"""
|
||||
Init EP related module
|
||||
Initialize EP (Expert Parallel) related modules.
|
||||
"""
|
||||
self.import_backend_ep_runner()
|
||||
if layer.ep_size > 1:
|
||||
if layer.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
self.ep_prefill_runner = self.EPPrefillRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
ep_group=layer.fd_config.parallel_config.ep_group,
|
||||
)
|
||||
self.ep_decoder_runner = self.EPDecoderRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
ep_group=layer.fd_config.parallel_config.ep_group,
|
||||
)
|
||||
if layer.ep_size <= 1:
|
||||
return
|
||||
|
||||
# Lazy import to avoid circular dependency or unnecessary loading
|
||||
from .ep import EPDecoderRunner, EPPrefillRunner
|
||||
|
||||
# Common arguments for both runners
|
||||
common_args = {
|
||||
"top_k": layer.top_k,
|
||||
"hidden_size": layer.hidden_size,
|
||||
"num_experts": layer.num_experts,
|
||||
"splitwise_role": layer.fd_config.parallel_config.splitwise_role,
|
||||
"num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
"ep_size": layer.ep_size,
|
||||
"ep_rank": layer.ep_rank,
|
||||
"redundant_experts_num": layer.fd_config.model_config.redundant_experts_num,
|
||||
"ep_group": layer.fd_config.parallel_config.ep_group,
|
||||
}
|
||||
|
||||
config = layer.fd_config
|
||||
splitwise_role = config.parallel_config.splitwise_role
|
||||
load_strategy = config.load_config.load_strategy
|
||||
|
||||
# For "mixed" splitwise role: conditionally initialize both or none
|
||||
if splitwise_role == "mixed":
|
||||
if load_strategy == "meta":
|
||||
# for RL init model without deepep buff
|
||||
return
|
||||
else:
|
||||
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
||||
self.ep_prefill_runner = self.EPPrefillRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
ep_group=layer.fd_config.parallel_config.ep_group,
|
||||
)
|
||||
else:
|
||||
self.ep_decoder_runner = self.EPDecoderRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
ep_group=layer.fd_config.parallel_config.ep_group,
|
||||
)
|
||||
self.ep_prefill_runner = EPPrefillRunner(**common_args)
|
||||
self.ep_decoder_runner = EPDecoderRunner(**common_args)
|
||||
return
|
||||
|
||||
# For non-mixed ep
|
||||
phase = config.parallel_config.moe_phase.phase
|
||||
if phase == "prefill":
|
||||
self.ep_prefill_runner = EPPrefillRunner(**common_args)
|
||||
else:
|
||||
self.ep_decoder_runner = EPDecoderRunner(**common_args)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
"""
|
||||
@@ -190,20 +175,12 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
|
||||
if current_platform.is_cuda():
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
self.down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size]
|
||||
self.up_gate_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2]
|
||||
self.down_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size, layer.hidden_size]
|
||||
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1}}
|
||||
else:
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
|
||||
self.up_gate_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size * 2, layer.hidden_size]
|
||||
self.down_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size]
|
||||
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
||||
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
@@ -217,17 +194,18 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||
},
|
||||
)
|
||||
|
@@ -39,7 +39,6 @@ elif current_platform.is_iluvatar():
|
||||
moe_expert_reduce,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||
|
||||
|
||||
@@ -127,7 +126,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
# 3. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
logger.debug(f"token_all_num {token_all_num}")
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
@@ -228,6 +227,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
|
||||
gate_out, _, _ = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
|
@@ -341,7 +341,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# 4. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
logger.debug(f"token_all_num {token_all_num}")
|
||||
(recv_x, recv_x_scale) = recv_x
|
||||
|
||||
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||
|
@@ -19,7 +19,6 @@ from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
MoeWna16MarlinGemmApi,
|
||||
tritonmoe_preprocess_func,
|
||||
@@ -255,6 +254,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
topk_method = layer.topk_method
|
||||
|
||||
if topk_method == "noaux_tc":
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
|
||||
gate_out, _, _ = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
|
@@ -48,7 +48,7 @@ class DynamicWeightManager:
|
||||
|
||||
logger.info(
|
||||
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
|
||||
f" rank={self.rank}, ranks={self.nranks}"
|
||||
f" tp rank={self.rank}, dp rank={fd_config.parallel_config.local_data_parallel_id}, ep rank={fd_config.parallel_config.expert_parallel_rank}, ranks={self.nranks}, "
|
||||
)
|
||||
|
||||
@paddle.no_grad()
|
||||
@@ -63,11 +63,21 @@ class DynamicWeightManager:
|
||||
start_time = time.perf_counter()
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
# step1 : restart paddle process group
|
||||
if not self.first_load:
|
||||
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
|
||||
|
||||
# step2 : recreat deepep buffer when enable expert parallel
|
||||
if self.parallel_config.enable_expert_parallel and not self.first_load:
|
||||
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
|
||||
|
||||
DeepEPBufferManager.recreate_buffer()
|
||||
# ep barrier
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
|
||||
# step3 : update model weight
|
||||
strategy_handlers = {
|
||||
"ipc_snapshot": self._update_ipc_snapshot,
|
||||
"ipc": self._update_ipc,
|
||||
@@ -80,6 +90,11 @@ class DynamicWeightManager:
|
||||
|
||||
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
# steps in the runner
|
||||
# step4: reinitialze kv_cache in the runner
|
||||
# step5: recapture cuda_graph
|
||||
# step6: update weight status signal
|
||||
|
||||
def _update_ipc_snapshot(self):
|
||||
"""Update using IPC snapshot strategy for elastic recovery."""
|
||||
model_path = os.path.join(
|
||||
@@ -105,18 +120,34 @@ class DynamicWeightManager:
|
||||
|
||||
def clear_parameters(self, pid: int = 0) -> None:
|
||||
"""Clear all model parameters and free memory."""
|
||||
logger.info("start clear parameters")
|
||||
|
||||
logger.info("start clear paramaters")
|
||||
|
||||
# step1: release deepep buffer
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
|
||||
|
||||
DeepEPBufferManager.clear_buffer()
|
||||
# ep barrier
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
# shutdown ep group
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
|
||||
paddle.device.cuda.empty_cache()
|
||||
# step2: release model weight
|
||||
for param in self.model.state_dict().values():
|
||||
param._clear_data()
|
||||
|
||||
self._verify_parameters("clearance")
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
# tp barrier
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
# shutdown tp group
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
||||
|
||||
# step3: update model weight signal
|
||||
# step4: release kv cache in the runner
|
||||
self._update_shared_status(pid, -2)
|
||||
|
||||
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
|
||||
@@ -145,10 +176,16 @@ class DynamicWeightManager:
|
||||
def finalize_update(self, pid: int = 0):
|
||||
"""Finalize update process with verification."""
|
||||
self._verify_parameters("update")
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
|
||||
if not self.first_load:
|
||||
self._update_shared_status(pid, 0)
|
||||
|
||||
self.first_load = False
|
||||
|
||||
def _get_gpu_id(self) -> int:
|
||||
|
@@ -26,13 +26,13 @@ class RolloutModelConfig:
|
||||
max_model_len: int = 32768,
|
||||
tensor_parallel_size: int = 4,
|
||||
dynamic_load_weight: bool = True,
|
||||
load_strategy: str = "ipc_snapshot",
|
||||
load_strategy: str = "meta",
|
||||
enable_mm: bool = False,
|
||||
# Default values for all other parameters
|
||||
max_num_seqs: int = 34,
|
||||
total_block_num: int = 2000,
|
||||
block_size: int = 64,
|
||||
engine_worker_queue_port: int = 9923,
|
||||
engine_worker_queue_port: str = "8002",
|
||||
device_ids: str = "0",
|
||||
dtype: str = "bfloat16",
|
||||
enc_dec_block_num: int = 1,
|
||||
|
@@ -262,10 +262,10 @@ class PaddleDisWorkerProc:
|
||||
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
|
||||
req_ids = []
|
||||
num_running_requests = 0
|
||||
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
self.model_weights_signal = np.zeros([1], dtype=np.int32)
|
||||
while True:
|
||||
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
if local_rank == 0:
|
||||
if self.model_weights_status.value[0] != 0:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
|
||||
@@ -283,7 +283,7 @@ class PaddleDisWorkerProc:
|
||||
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
|
||||
|
||||
# The first worker detects whether there are tasks in the task queue
|
||||
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
if local_rank == 0:
|
||||
if self.task_queue.num_tasks() > 0:
|
||||
# VL only support 1 batch to prefill
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
|
||||
@@ -598,7 +598,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--load_strategy",
|
||||
type=str,
|
||||
choices=["ipc", "ipc_snapshot"],
|
||||
choices=["ipc", "ipc_snapshot", "meta", "normal"],
|
||||
default="ipc_snapshot",
|
||||
help="Weight loading method when dynamic loading is enabled: "
|
||||
"'ipc': real-time IPC streaming with automatic resharding, "
|
||||
@@ -683,10 +683,11 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
parallel_config.num_experts_per_rank = num_experts_per_rank
|
||||
parallel_config.num_experts_start_offset = num_experts_start_offset
|
||||
|
||||
parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
|
||||
parallel_config.local_data_parallel_id
|
||||
]
|
||||
parallel_config.set_tp_group()
|
||||
if args.load_strategy != "meta":
|
||||
parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
|
||||
parallel_config.local_data_parallel_id
|
||||
]
|
||||
parallel_config.set_communicate_group()
|
||||
|
||||
load_config = LoadConfig(vars(args))
|
||||
|
||||
|
@@ -5,6 +5,7 @@ from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
)
|
||||
@@ -15,10 +16,12 @@ class TestConfig(unittest.TestCase):
|
||||
parallel_config = ParallelConfig({"tensor_parallel_size": 16, "expert_parallel_size": 1})
|
||||
graph_opt_config = GraphOptimizationConfig({})
|
||||
cache_config = CacheConfig({})
|
||||
load_config = LoadConfig({})
|
||||
scheduler_config = SchedulerConfig({})
|
||||
fd_config = FDConfig(
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
load_config=load_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
ips=["1.1.1.1", "0.0.0.0"],
|
||||
@@ -31,10 +34,12 @@ class TestConfig(unittest.TestCase):
|
||||
parallel_config = ParallelConfig({})
|
||||
graph_opt_config = GraphOptimizationConfig({})
|
||||
cache_config = CacheConfig({})
|
||||
load_config = LoadConfig({})
|
||||
scheduler_config = SchedulerConfig({})
|
||||
fd_config = FDConfig(
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
load_config=load_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
ips="0.0.0.0",
|
||||
@@ -46,12 +51,14 @@ class TestConfig(unittest.TestCase):
|
||||
parallel_config = ParallelConfig({})
|
||||
graph_opt_config = GraphOptimizationConfig({})
|
||||
cache_config = CacheConfig({})
|
||||
load_config = LoadConfig({})
|
||||
cache_config.enable_chunked_prefill = True
|
||||
scheduler_config = SchedulerConfig({})
|
||||
fd_config = FDConfig(
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
scheduler_config=scheduler_config,
|
||||
ips="0.0.0.0",
|
||||
test_mode=True,
|
||||
@@ -64,6 +71,7 @@ class TestConfig(unittest.TestCase):
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
scheduler_config=scheduler_config,
|
||||
ips="0.0.0.0",
|
||||
test_mode=True,
|
||||
@@ -77,11 +85,13 @@ class TestConfig(unittest.TestCase):
|
||||
cache_config = CacheConfig({})
|
||||
cache_config.cache_transfer_protocol = "rdma,ipc"
|
||||
cache_config.pd_comm_port = "2334"
|
||||
load_config = LoadConfig({})
|
||||
scheduler_config = SchedulerConfig({})
|
||||
fd_config = FDConfig(
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
scheduler_config=scheduler_config,
|
||||
splitwise_role="prefill",
|
||||
test_mode=True,
|
||||
|
Reference in New Issue
Block a user