[NewFeture]add ep rollout model init and update/clear ep buffer (#3927)

* add ep rollout model init && add deep update/clear

* fix test
This commit is contained in:
gaoziyuan
2025-09-12 14:15:13 +08:00
committed by GitHub
parent c64ceac34d
commit 10768a4d79
13 changed files with 364 additions and 304 deletions

View File

@@ -337,11 +337,12 @@ class ParallelConfig:
else: else:
self.pd_disaggregation_mode = "None" self.pd_disaggregation_mode = "None"
def set_tp_group(self): def set_communicate_group(self):
# different tp group id # different tp group id
# prevent different tp_groups using the same group_id # prevent different tp_groups using the same group_id
tp_gid_offset = envs.FD_TP_GROUP_GID_OFFSET tp_gid_offset = envs.FD_TP_GROUP_GID_OFFSET
dist.collective._set_custom_gid(self.data_parallel_rank + tp_gid_offset) dist.collective._set_custom_gid(self.data_parallel_rank + tp_gid_offset)
self.tp_group = dist.new_group( self.tp_group = dist.new_group(
range( range(
self.data_parallel_rank * self.tensor_parallel_size, self.data_parallel_rank * self.tensor_parallel_size,
@@ -350,8 +351,11 @@ class ParallelConfig:
) )
dist.collective._set_custom_gid(None) dist.collective._set_custom_gid(None)
# same ep group id # same ep group id
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset) if self.enable_expert_parallel:
self.ep_group = dist.new_group(range(self.expert_parallel_size)) dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
self.ep_group = dist.new_group(range(self.expert_parallel_size))
dist.collective._set_custom_gid(None)
logger.info( logger.info(
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}." f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
) )
@@ -830,6 +834,7 @@ class LoadConfig:
load_strategy: Specifies the weight loading method when enabled: load_strategy: Specifies the weight loading method when enabled:
- 'ipc': Real-time IPC streaming with automatic resharding - 'ipc': Real-time IPC streaming with automatic resharding
- 'ipc_snapshot': Load from disk snapshot of IPC weights - 'ipc_snapshot': Load from disk snapshot of IPC weights
- 'meta': Only model meta messages
- None: No dynamic loading - None: No dynamic loading
""" """
@@ -840,7 +845,7 @@ class LoadConfig:
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1 self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
self.dynamic_load_weight: bool = False self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal"]] = "normal"
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
@@ -1198,12 +1203,10 @@ class FDConfig:
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if num_ranks > self.max_chips_per_node: if num_ranks > self.max_chips_per_node and self.load_config.load_strategy != "meta":
self.worker_num_per_node = self.max_chips_per_node self.worker_num_per_node = self.max_chips_per_node
nnode = ceil_div(num_ranks, self.worker_num_per_node) nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}" assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
# assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
else: else:
self.worker_num_per_node = num_ranks self.worker_num_per_node = num_ranks

View File

@@ -134,7 +134,7 @@ class EngineArgs:
""" """
dynamic load weight dynamic load weight
""" """
load_strategy: str = "ipc_snapshot" load_strategy: str = "normal"
""" """
dynamic load weight strategy dynamic load weight strategy
""" """

View File

@@ -25,52 +25,175 @@ try:
except: except:
logger.warning("import deep_ep Failed!") logger.warning("import deep_ep Failed!")
from typing import Optional
import fastdeploy import fastdeploy
from fastdeploy.config import MoEPhase from fastdeploy.config import MoEPhase
from fastdeploy.utils import singleton from fastdeploy.utils import singleton
try:
from fastdeploy.model_executor.ops.gpu import noaux_tc class DeepEPBufferManager:
except: _engine: Optional["DeepEPEngine"] = None
logger.warning("import noaux_tc Failed!")
@classmethod
def set_engine(cls, engine: "DeepEPEngine"):
cls._engine = engine
@classmethod
def clear_buffer(cls):
if cls._engine:
cls._engine.clear_deep_ep_buffer()
@classmethod
def recreate_buffer(cls):
if cls._engine:
cls._engine.create_deep_ep_buffer()
def get_moe_scores( class DeepEPBuffer:
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
""" """
compute moe scores using e_score_correction_bias. Encapsulates DeepEP buffer creation, management and cleanup.
""" """
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!" def __init__(
scores_with_bias = scores + e_score_correction_bias self,
scores, topk_values, topk_idx = noaux_tc( group,
scores, hidden_size: int,
scores_with_bias, num_experts: int,
n_group if n_group > 0 else 1, ep_size: int,
topk_group if topk_group > 0 else 1, num_max_dispatch_tokens_per_rank: int,
top_k, splitwise_role: str,
routed_scaling_factor, moe_phase: MoEPhase,
) ):
return scores, topk_values, topk_idx self.group = group
self.hidden_size = hidden_size
self.num_experts = num_experts
self.ep_size = ep_size
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
self.splitwise_role = splitwise_role
self.moe_phase = moe_phase
self.deepep_buffer = None
self.num_nvl_bytes = 0
self.num_rdma_bytes = 0
# Precompute buffer sizes
self._compute_buffer_sizes()
def _compute_buffer_sizes(self, param_bytes: int = 2):
hidden_bytes = self.hidden_size * param_bytes # bf16 or fp16
for config in (
deep_ep.Buffer.get_dispatch_config(self.group.world_size),
deep_ep.Buffer.get_combine_config(self.group.world_size),
):
self.num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_nvl_bytes
)
self.num_rdma_bytes = max(
config.get_rdma_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_rdma_bytes
)
if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode":
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.ep_size,
self.num_experts,
)
self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes)
logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}")
def create_buffer(self):
"""Create or recreate buffer based on role and phase."""
if self.deepep_buffer is not None:
self.clear_buffer()
if self.splitwise_role == "mixed":
logger.info("Initializing mixed mode buffer (low latency).")
self.deepep_buffer = deep_ep.Buffer(
self.group,
self.num_nvl_bytes,
self.num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=24,
)
self.deepep_buffer.set_num_sms(14) # TODO: tune in future
else:
if self.moe_phase.phase == "decode":
self._create_low_latency_buffer()
elif self.moe_phase.phase == "prefill":
logger.info("Initializing High Throughput Buffer for prefill phase.")
self.deepep_buffer = deep_ep.Buffer(
self.group,
self.num_nvl_bytes,
0,
low_latency_mode=False,
num_qps_per_rank=1,
)
else:
raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}")
logger.info("DeepEP buffer created successfully.")
def _create_low_latency_buffer(self):
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.ep_size,
self.num_experts,
)
if (
self.deepep_buffer is None
or self.deepep_buffer.group != self.group
or not self.deepep_buffer.low_latency_mode
or self.deepep_buffer.num_rdma_bytes < num_rdma_bytes
):
assert self.num_experts % self.ep_size == 0
self.deepep_buffer = deep_ep.Buffer(
self.group,
0,
num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=self.num_experts // self.ep_size,
)
def clear_buffer(self):
"""Clear buffer and free memory."""
if self.deepep_buffer is not None:
del self.deepep_buffer
self.deepep_buffer = None
logger.info("DeepEP buffer cleared.")
def get_buffer(self):
return self.deepep_buffer
def clean_low_latency_buffer(self):
if self.deepep_buffer is not None:
self.deepep_buffer.clean_low_latency_buffer(
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.num_experts,
)
def barrier_all(self):
if self.deepep_buffer is not None:
self.deepep_buffer.barrier_all()
@singleton @singleton
class DeepEPEngine: class DeepEPEngine:
""" """
A wrapper class for DeepEP engine. A wrapper class for DeepEP engine.
Manages buffer lifecycle based on role and phase.
""" """
def __init__( def __init__(
self, self,
num_max_dispatch_tokens_per_rank: int, num_max_dispatch_tokens_per_rank: int,
hidden: int, hidden_size: int,
num_experts: int, num_experts: int,
ep_size: int, ep_size: int,
ep_rank: int, ep_rank: int,
@@ -79,95 +202,48 @@ class DeepEPEngine:
async_finish: bool = False, async_finish: bool = False,
group=None, group=None,
): ):
"""
Initialize the DeepEP engine.
Args:
group: The MPI group object.
ep_size: The number of ranks.
rank_id: The rank id.
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
hidden: The hidden dimension of the model.
num_experts: The number of experts.
"""
# TODO(@wufeisheng): Support configurable EP size
if group is None: if group is None:
group = paddle.distributed.new_group(range(ep_size)) group = paddle.distributed.new_group(range(ep_size))
self.group = group self.group = group
self.ep_size = ep_size self.ep_size = ep_size
self.rank_id = ep_rank self.rank_id = ep_rank
self.hidden = hidden self.hidden_size = hidden_size
self.num_experts = num_experts self.num_experts = num_experts
self.num_local_experts = num_experts // ep_size self.num_local_experts = num_experts // ep_size
self.async_finish = async_finish self.async_finish = async_finish
self.deepep_engine = None
from paddle.base.core import Config from paddle.base.core import Config
self.ep_config = Config(24, 6, 256) self.ep_config = Config(24, 6, 256)
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
# In mixed EP mode on a single node, we dynamically switch between # Store phase and role for buffer management
# high throughput and low latency modes. self._splitwise_role = splitwise_role
self._moe_phase = moe_phase
if splitwise_role == "mixed": # Initialize buffer manager
self.deepep_engine = deep_ep.Buffer( self.buffer = DeepEPBuffer(
self.group, group=self.group,
int(2e9), hidden_size=hidden_size,
int(6e9), num_experts=num_experts,
low_latency_mode=True, ep_size=ep_size,
num_qps_per_rank=24, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
) splitwise_role=splitwise_role,
# In disaggregated mode on mutiple nodes, we either use moe_phase=moe_phase,
# high throughput mode or low latency mode.
else:
if moe_phase.phase == "decode":
logger.info("Initializing Low Latency Buffer")
self.get_low_latency_buffer()
elif moe_phase.phase == "prefill":
self.deepep_engine = deep_ep.Buffer(
self.group,
int(5e8),
0,
low_latency_mode=False,
num_qps_per_rank=1,
)
else:
raise ValueError(f"Unknown generation phase {moe_phase}")
def get_low_latency_buffer(self):
"""
Get the DeepEP buffer.
Args:
group: The MPI group object.
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
hidden: The hidden dimension of the model.
"""
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank`
# (the actual batch size in the decoding engine) should be less than 256
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
self.num_max_dispatch_tokens_per_rank,
self.hidden,
self.ep_size,
self.num_experts,
) )
# Allocate a buffer if not existed or not enough buffer size self.buffer.create_buffer()
if (
self.deepep_engine is None # Register for global buffer management
or self.deepep_engine.group != self.group DeepEPBufferManager.set_engine(self)
or not self.deepep_engine.low_latency_mode
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes @property
): def deepep_engine(self):
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts """Backward compatibility alias."""
assert self.num_experts % self.ep_size == 0 return self.buffer.get_buffer()
self.deepep_engine = deep_ep.Buffer(
self.group, def clear_deep_ep_buffer(self):
0, self.buffer.clear_buffer()
num_rdma_bytes,
low_latency_mode=True, def create_deep_ep_buffer(self):
num_qps_per_rank=self.num_experts // self.ep_size, self.buffer.create_buffer()
)
def low_latency_dispatch( def low_latency_dispatch(
self, self,
@@ -176,22 +252,9 @@ class DeepEPEngine:
expertwise_scale, expertwise_scale,
use_fp8: bool = False, use_fp8: bool = False,
): ):
""" if self.deepep_engine is None:
Args: raise RuntimeError("DeepEP buffer not initialized!")
hidden_states: [token_num, hidden] 'bfloat16/int8'
topk_idx: [token_num, num_topk] 'int64'
Returns:
recv_hidden_states: [num_local_experts,
num_max_dispatch_tokens_per_rank * ep_size, hidden]
ep_size * num_local_experts = num_experts
recv_count: [num_local_experts]
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
( (
packed_recv_x, packed_recv_x,
recv_expert_count, recv_expert_count,
@@ -202,7 +265,7 @@ class DeepEPEngine:
hidden_states, hidden_states,
topk_idx, topk_idx,
expertwise_scale, expertwise_scale,
self.num_max_dispatch_tokens_per_rank, self.buffer.num_max_dispatch_tokens_per_rank,
self.num_experts, self.num_experts,
use_fp8=use_fp8, use_fp8=use_fp8,
async_finish=False, async_finish=False,
@@ -218,27 +281,14 @@ class DeepEPEngine:
topk_weights: paddle.Tensor, topk_weights: paddle.Tensor,
handle, handle,
): ):
""" if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0":
Return:
combined_hidden_states: [num_tokens, hidden]
"""
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0": # not develop version of PaddlePaddle
# TODO(@wanglongzhi): Delete them when deepep in PaddlePaddle is fixed # TODO(@wanglongzhi): Delete them when deepep in PaddlePaddle is fixed
# and when the default recommended version of PaddlePaddle is greater than 3.1.0 # and when the default recommended version of PaddlePaddle is greater than 3.1.0
( src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
src_info, handle = (src_info, layout_range, num_max_dispatch_tokens_per_rank, None, num_experts)
layout_range,
num_max_dispatch_tokens_per_rank, if self.deepep_engine is None:
num_experts, raise RuntimeError("DeepEP buffer not initialized!")
) = handle
handle = (
src_info,
layout_range,
num_max_dispatch_tokens_per_rank,
None,
num_experts,
)
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine( combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
hidden_states, hidden_states,
@@ -251,18 +301,10 @@ class DeepEPEngine:
return combined_hidden_states, combine_hook return combined_hidden_states, combine_hook
def clean_low_latency_buffer(self): def clean_low_latency_buffer(self):
""" self.buffer.clean_low_latency_buffer()
clean_low_latency_buffer
"""
self.deepep_engine.clean_low_latency_buffer(
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
)
def barrier_all(self): def barrier_all(self):
""" self.buffer.barrier_all()
barrier_all
"""
self.deepep_engine.barrier_all()
class EPRunner: class EPRunner:
@@ -273,7 +315,7 @@ class EPRunner:
def __init__( def __init__(
self, self,
top_k: int, top_k: int,
hidden: int, hidden_size: int,
num_experts: int, num_experts: int,
splitwise_role: str, splitwise_role: str,
moe_phase: MoEPhase, moe_phase: MoEPhase,
@@ -288,7 +330,7 @@ class EPRunner:
self.redundant_experts_num = redundant_experts_num self.redundant_experts_num = redundant_experts_num
self.ep_engine = DeepEPEngine( self.ep_engine = DeepEPEngine(
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
hidden=hidden, hidden_size=hidden_size,
num_experts=num_experts + redundant_experts_num, num_experts=num_experts + redundant_experts_num,
ep_size=ep_size, ep_size=ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,
@@ -298,9 +340,6 @@ class EPRunner:
) )
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
"""
moe_select
"""
if layer.redundant_table_manger is not None: if layer.redundant_table_manger is not None:
( (
ep_rank_to_expert_id_list, ep_rank_to_expert_id_list,
@@ -316,12 +355,14 @@ class EPRunner:
tokens_per_expert_stats_list=tokens_per_expert_stats_list, tokens_per_expert_stats_list=tokens_per_expert_stats_list,
bias=layer.gate_correction_bias, bias=layer.gate_correction_bias,
moe_topk=self.top_k, moe_topk=self.top_k,
apply_norm_weight=True, # apply_norm_weight apply_norm_weight=True,
enable_softmax_top_k_fused=False, enable_softmax_top_k_fused=False,
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
) )
else: else:
if layer.topk_method == "noaux_tc": if layer.topk_method == "noaux_tc":
from .moe import get_moe_scores
score, topk_weights, topk_idx = get_moe_scores( score, topk_weights, topk_idx = get_moe_scores(
gate_out, gate_out,
layer.n_group, layer.n_group,
@@ -335,28 +376,28 @@ class EPRunner:
gate_out, gate_out,
layer.gate_correction_bias, layer.gate_correction_bias,
self.top_k, self.top_k,
True, # apply_norm_weight, True,
False, False,
) )
return topk_idx, topk_weights return topk_idx, topk_weights
@abstractmethod @abstractmethod
def dispatch(self, *args, **kwargs): def dispatch(self, *args, **kwargs):
"""
dispatch
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def combine(self, *args, **kwargs): def combine(self, *args, **kwargs):
"""
combine
"""
raise NotImplementedError raise NotImplementedError
def clean_low_latency_buffer(self): def clean_low_latency_buffer(self):
self.ep_engine.clean_low_latency_buffer() self.ep_engine.clean_low_latency_buffer()
def clear_deep_ep_buffer(self):
self.ep_engine.clear_deep_ep_buffer()
def create_deep_ep_buffer(self):
self.ep_engine.create_deep_ep_buffer()
class EPPrefillRunner(EPRunner): class EPPrefillRunner(EPRunner):
""" """
@@ -366,7 +407,7 @@ class EPPrefillRunner(EPRunner):
def __init__( def __init__(
self, self,
top_k: int, top_k: int,
hidden: int, hidden_size: int,
num_experts: int, num_experts: int,
splitwise_role: str, splitwise_role: str,
num_max_dispatch_tokens_per_rank: int, num_max_dispatch_tokens_per_rank: int,
@@ -378,7 +419,7 @@ class EPPrefillRunner(EPRunner):
): ):
super().__init__( super().__init__(
top_k, top_k,
hidden, hidden_size,
num_experts, num_experts,
splitwise_role, splitwise_role,
moe_phase, moe_phase,
@@ -397,6 +438,9 @@ class EPPrefillRunner(EPRunner):
*args, *args,
**kwargs, **kwargs,
): ):
buffer = self.ep_engine.deepep_engine
if buffer is None:
raise RuntimeError("DeepEP buffer not initialized!")
( (
num_tokens_per_rank, num_tokens_per_rank,
@@ -404,7 +448,7 @@ class EPPrefillRunner(EPRunner):
num_tokens_per_expert, num_tokens_per_expert,
is_token_in_rank, is_token_in_rank,
_, _,
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts) ) = buffer.get_dispatch_layout(topk_idx, self.num_experts)
x_scale_tensor = kwargs.get("x_scale_tensor", None) x_scale_tensor = kwargs.get("x_scale_tensor", None)
dispatch_args = { dispatch_args = {
@@ -413,12 +457,12 @@ class EPPrefillRunner(EPRunner):
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank, "num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank, "is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert, "num_tokens_per_expert": num_tokens_per_expert,
"config": self.ep_engine.ep_config, "config": self.ep_engine.ep_config, # assuming ep_config still in engine
"async_finish": self.ep_engine.async_finish, "async_finish": self.ep_engine.async_finish,
"topk_idx": topk_idx, "topk_idx": topk_idx,
"topk_weights": topk_weights, "topk_weights": topk_weights,
} }
return self.ep_engine.deepep_engine.dispatch(**dispatch_args) return buffer.dispatch(**dispatch_args)
def combine( def combine(
self, self,
@@ -426,6 +470,10 @@ class EPPrefillRunner(EPRunner):
handle: tuple, handle: tuple,
recv_topk_weights: paddle.Tensor, recv_topk_weights: paddle.Tensor,
): ):
buffer = self.ep_engine.deepep_engine
if buffer is None:
raise RuntimeError("DeepEP buffer not initialized!")
combine_args = { combine_args = {
"x": tmp_ffn_out, "x": tmp_ffn_out,
"handle": handle, "handle": handle,
@@ -433,8 +481,7 @@ class EPPrefillRunner(EPRunner):
"async_finish": self.ep_engine.async_finish, "async_finish": self.ep_engine.async_finish,
"topk_weights": recv_topk_weights, "topk_weights": recv_topk_weights,
} }
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args) fused_moe_out, _, _ = buffer.combine(**combine_args)
return fused_moe_out return fused_moe_out
@@ -446,7 +493,7 @@ class EPDecoderRunner(EPRunner):
def __init__( def __init__(
self, self,
top_k: int, top_k: int,
hidden: int, hidden_size: int,
num_experts: int, num_experts: int,
splitwise_role: str, splitwise_role: str,
num_max_dispatch_tokens_per_rank: int, num_max_dispatch_tokens_per_rank: int,
@@ -458,7 +505,7 @@ class EPDecoderRunner(EPRunner):
): ):
super().__init__( super().__init__(
top_k, top_k,
hidden, hidden_size,
num_experts, num_experts,
splitwise_role, splitwise_role,
moe_phase, moe_phase,

View File

@@ -40,66 +40,52 @@ class MoEMethodBase(QuantMethodBase):
"down_proj_weight_scale", "down_proj_weight_scale",
] ]
self.pack_num = 1 self.pack_num = 1
self.ep_prefill_runner = None
self.ep_decoder_runner = None
def init_ep(self, layer: nn.Layer) -> None: def init_ep(self, layer: nn.Layer) -> None:
""" """
Init EP related module Initialize EP (Expert Parallel) related modules.
""" """
if layer.ep_size > 1: if layer.ep_size <= 1:
if layer.fd_config.parallel_config.splitwise_role == "mixed": return
from .ep import EPDecoderRunner, EPPrefillRunner
self.ep_prefill_runner = EPPrefillRunner( # Lazy import to avoid circular dependency or unnecessary loading
layer.top_k, from .ep import EPDecoderRunner, EPPrefillRunner
layer.hidden_size,
layer.num_experts, # Common arguments for both runners
layer.fd_config.parallel_config.splitwise_role, common_args = {
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, "top_k": layer.top_k,
layer.ep_size, "hidden_size": layer.hidden_size,
layer.ep_rank, "num_experts": layer.num_experts,
layer.fd_config.model_config.redundant_experts_num, "splitwise_role": layer.fd_config.parallel_config.splitwise_role,
ep_group=layer.fd_config.parallel_config.ep_group, "num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
) "ep_size": layer.ep_size,
self.ep_decoder_runner = EPDecoderRunner( "ep_rank": layer.ep_rank,
layer.top_k, "redundant_experts_num": layer.fd_config.model_config.redundant_experts_num,
layer.hidden_size, "ep_group": layer.fd_config.parallel_config.ep_group,
layer.num_experts, }
layer.fd_config.parallel_config.splitwise_role,
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, config = layer.fd_config
layer.ep_size, splitwise_role = config.parallel_config.splitwise_role
layer.ep_rank, load_strategy = config.load_config.load_strategy
layer.fd_config.model_config.redundant_experts_num,
ep_group=layer.fd_config.parallel_config.ep_group, # For "mixed" splitwise role: conditionally initialize both or none
) if splitwise_role == "mixed":
if load_strategy == "meta":
# for RL init model without deepep buff
return
else: else:
if layer.fd_config.parallel_config.moe_phase.phase == "prefill": self.ep_prefill_runner = EPPrefillRunner(**common_args)
from .ep import EPPrefillRunner self.ep_decoder_runner = EPDecoderRunner(**common_args)
return
self.ep_prefill_runner = EPPrefillRunner( # For non-mixed ep
layer.top_k, phase = config.parallel_config.moe_phase.phase
layer.hidden_size, if phase == "prefill":
layer.num_experts, self.ep_prefill_runner = EPPrefillRunner(**common_args)
layer.fd_config.parallel_config.splitwise_role, else:
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, self.ep_decoder_runner = EPDecoderRunner(**common_args)
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
ep_group=layer.fd_config.parallel_config.ep_group,
)
else:
from .ep import EPDecoderRunner
self.ep_decoder_runner = EPDecoderRunner(
layer.top_k,
layer.hidden_size,
layer.num_experts,
layer.fd_config.parallel_config.splitwise_role,
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
ep_group=layer.fd_config.parallel_config.ep_group,
)
def process_loaded_weights(self, layer, weights) -> None: def process_loaded_weights(self, layer, weights) -> None:
""" """
@@ -180,7 +166,7 @@ class MoEMethodBase(QuantMethodBase):
else: else:
if layer.fd_config.parallel_config.splitwise_role == "mixed": if layer.fd_config.parallel_config.splitwise_role == "mixed":
self.ep_decoder_runner.clean_low_latency_buffer() self.ep_decoder_runner.clean_low_latency_buffer()
return self.apply_ep_decode(layer, x, gate) return self.apply_ep_prefill(layer, x, gate)
else: else:
return self.apply_tp(layer, x, gate) return self.apply_tp(layer, x, gate)

View File

@@ -27,11 +27,7 @@ from ..utils import get_tensor
from .fused_moe_backend_base import UnquantizedFusedMoEMethod from .fused_moe_backend_base import UnquantizedFusedMoEMethod
if current_platform.is_cuda(): if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
moe_expert_dispatch,
moe_expert_reduce,
noaux_tc,
)
try: try:
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
@@ -46,31 +42,6 @@ elif current_platform.is_iluvatar():
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
# used for deepseek_v3
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
class CutlassMoEMethod(UnquantizedFusedMoEMethod): class CutlassMoEMethod(UnquantizedFusedMoEMethod):
""" """
Use Cutlass Group Gemm to compute Fused MoE. Use Cutlass Group Gemm to compute Fused MoE.
@@ -154,7 +125,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
# 3. Compute ffn # 3. Compute ffn
if token_all_num > 0: if token_all_num > 0:
logger.info(f"token_all_num {token_all_num}") logger.debug(f"token_all_num {token_all_num}")
( (
permute_input, permute_input,
permute_indices_per_token, permute_indices_per_token,
@@ -255,6 +226,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
""" """
gate_out = gate(x.cast("float32")) gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc": if layer.topk_method == "noaux_tc":
from .moe import get_moe_scores
gate_out, _, _ = get_moe_scores( gate_out, _, _ = get_moe_scores(
gate_out, gate_out,
layer.n_group, layer.n_group,

View File

@@ -319,7 +319,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# 4. Compute ffn # 4. Compute ffn
if token_all_num > 0: if token_all_num > 0:
logger.info(f"token_all_num {token_all_num}") logger.debug(f"token_all_num {token_all_num}")
(recv_x, recv_x_scale) = recv_x (recv_x, recv_x_scale) = recv_x
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts) token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
@@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
gate_out = gate(x.cast("float32")) gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc": if layer.topk_method == "noaux_tc":
from .ep import get_moe_scores from .moe import get_moe_scores
_, topk_weights, topk_ids = get_moe_scores( _, topk_weights, topk_ids = get_moe_scores(
gate_out, gate_out,

View File

@@ -21,37 +21,12 @@ import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
MoeWna16MarlinGemmApi, MoeWna16MarlinGemmApi,
noaux_tc,
tritonmoe_preprocess_func, tritonmoe_preprocess_func,
) )
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
def gptq_marlin_moe_repack( def gptq_marlin_moe_repack(
b_q_weight: paddle.Tensor, b_q_weight: paddle.Tensor,
perm: paddle.Tensor, perm: paddle.Tensor,
@@ -279,6 +254,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
topk_method = layer.topk_method topk_method = layer.topk_method
if topk_method == "noaux_tc": if topk_method == "noaux_tc":
from .moe import get_moe_scores
gate_out, _, _ = get_moe_scores( gate_out, _, _ = get_moe_scores(
gate_out, gate_out,
layer.n_group, layer.n_group,

View File

@@ -24,7 +24,7 @@ from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from fastdeploy.utils import ceil_div from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
from .ep import get_moe_scores from .moe import get_moe_scores
try: try:
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func

View File

@@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger from fastdeploy.worker.experts_manager import RedundantExpertManger
try:
from fastdeploy.model_executor.ops.gpu import noaux_tc
except:
logger.warning("import noaux_tc Failed!")
def get_moe_method(): def get_moe_method():
""" """
@@ -54,6 +59,31 @@ def get_moe_method():
raise NotImplementedError raise NotImplementedError
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
class FusedMoE(nn.Layer): class FusedMoE(nn.Layer):
""" """
FusedMoE is a layer that performs MoE (Mixture of Experts) computation. FusedMoE is a layer that performs MoE (Mixture of Experts) computation.

View File

@@ -48,7 +48,7 @@ class DynamicWeightManager:
logger.info( logger.info(
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, " f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
f" rank={self.rank}, ranks={self.nranks}" f" tp rank={self.rank}, dp rank={fd_config.parallel_config.local_data_parallel_id}, ep rank={fd_config.parallel_config.expert_parallel_rank}, ranks={self.nranks}, "
) )
@paddle.no_grad() @paddle.no_grad()
@@ -63,11 +63,21 @@ class DynamicWeightManager:
start_time = time.perf_counter() start_time = time.perf_counter()
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
# step1 : restart paddle process group
if not self.first_load: if not self.first_load:
paddle.distributed.restart_process_group(self.parallel_config.tp_group) paddle.distributed.restart_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel: if self.parallel_config.enable_expert_parallel:
paddle.distributed.restart_process_group(self.parallel_config.ep_group) paddle.distributed.restart_process_group(self.parallel_config.ep_group)
# step2 : recreat deepep buffer when enable expert parallel
if self.parallel_config.enable_expert_parallel and not self.first_load:
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
DeepEPBufferManager.recreate_buffer()
# ep barrier
paddle.distributed.barrier(self.parallel_config.ep_group)
# step3 : update model weight
strategy_handlers = { strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot, "ipc_snapshot": self._update_ipc_snapshot,
"ipc": self._update_ipc, "ipc": self._update_ipc,
@@ -79,6 +89,10 @@ class DynamicWeightManager:
raise ValueError(f"Unsupported strategy: {self.load_config.load_strategy}") raise ValueError(f"Unsupported strategy: {self.load_config.load_strategy}")
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s") logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
# steps in the runner
# step 4: reinitialze kv_cache
# step 5: recapture CUDAGraph
# step 6: update weight status signal
def _update_ipc_snapshot(self): def _update_ipc_snapshot(self):
"""Update using IPC snapshot strategy for elastic recovery.""" """Update using IPC snapshot strategy for elastic recovery."""
@@ -106,18 +120,31 @@ class DynamicWeightManager:
def clear_parameters(self, pid: int = 0) -> None: def clear_parameters(self, pid: int = 0) -> None:
"""Clear all model parameters and free memory.""" """Clear all model parameters and free memory."""
logger.info("start clear paramaters") logger.info("start clear paramaters")
# step1: release deepep buffer
if self.parallel_config.enable_expert_parallel:
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
DeepEPBufferManager.clear_buffer()
# ep barrier
paddle.distributed.barrier(self.parallel_config.ep_group)
# shutdown ep group
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
# step2: release model weight
for param in self.model.state_dict().values(): for param in self.model.state_dict().values():
param._clear_data() param._clear_data()
self._verify_parameters("clearance") self._verify_parameters("clearance")
if self.parallel_config.tensor_parallel_size > 1: if self.parallel_config.tensor_parallel_size > 1:
# tp barrier
paddle.distributed.barrier(self.parallel_config.tp_group) paddle.distributed.barrier(self.parallel_config.tp_group)
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group) # shutdown tp group
if self.parallel_config.enable_expert_parallel: paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
paddle.distributed.barrier(self.parallel_config.ep_group) # step3: update model weight signal
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group) # step4: release kv cache in the runner
paddle.distributed.shutdown_process_group()
self._update_shared_status(pid, -2) self._update_shared_status(pid, -2)
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str): def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
@@ -146,10 +173,16 @@ class DynamicWeightManager:
def finalize_update(self, pid: int = 0): def finalize_update(self, pid: int = 0):
"""Finalize update process with verification.""" """Finalize update process with verification."""
self._verify_parameters("update") self._verify_parameters("update")
if self.parallel_config.tensor_parallel_size > 1: if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier(self.parallel_config.tp_group) paddle.distributed.barrier(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
if not self.first_load: if not self.first_load:
self._update_shared_status(pid, 0) self._update_shared_status(pid, 0)
self.first_load = False self.first_load = False
def _get_gpu_id(self) -> int: def _get_gpu_id(self) -> int:

View File

@@ -24,13 +24,13 @@ class RolloutModelConfig:
max_model_len: int = 32768, max_model_len: int = 32768,
tensor_parallel_size: int = 4, tensor_parallel_size: int = 4,
dynamic_load_weight: bool = True, dynamic_load_weight: bool = True,
load_strategy: str = "ipc_snapshot", load_strategy: str = "meta",
enable_mm: bool = False, enable_mm: bool = False,
# Default values for all other parameters # Default values for all other parameters
max_num_seqs: int = 34, max_num_seqs: int = 34,
total_block_num: int = 2000, total_block_num: int = 2000,
block_size: int = 64, block_size: int = 64,
engine_worker_queue_port: int = 9923, engine_worker_queue_port: str = "8002",
device_ids: str = "0", device_ids: str = "0",
dtype: str = "bfloat16", dtype: str = "bfloat16",
enc_dec_block_num: int = 1, enc_dec_block_num: int = 1,

View File

@@ -259,7 +259,7 @@ class PaddleDisWorkerProc:
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32) self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
while True: while True:
if self.local_rank % self.parallel_config.tensor_parallel_size == 0: if local_rank == 0:
if self.model_weights_status.value[0] != 0: if self.model_weights_status.value[0] != 0:
self.model_weights_signal[0] = int(self.model_weights_status.value[0]) self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel: if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
@@ -272,7 +272,7 @@ class PaddleDisWorkerProc:
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time()) self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
# The first worker detects whether there are tasks in the task queue # The first worker detects whether there are tasks in the task queue
if self.local_rank % self.parallel_config.tensor_parallel_size == 0: if local_rank == 0:
if self.task_queue.num_tasks() > 0: if self.task_queue.num_tasks() > 0:
# VL only support 1 batch to prefill # VL only support 1 batch to prefill
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
@@ -584,7 +584,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--load_strategy", "--load_strategy",
type=str, type=str,
choices=["ipc", "ipc_snapshot"], choices=["ipc", "ipc_snapshot", "meta", "normal"],
default="ipc_snapshot", default="ipc_snapshot",
help="Weight loading method when dynamic loading is enabled: " help="Weight loading method when dynamic loading is enabled: "
"'ipc': real-time IPC streaming with automatic resharding, " "'ipc': real-time IPC streaming with automatic resharding, "
@@ -663,10 +663,11 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
parallel_config.num_experts_per_rank = num_experts_per_rank parallel_config.num_experts_per_rank = num_experts_per_rank
parallel_config.num_experts_start_offset = num_experts_start_offset parallel_config.num_experts_start_offset = num_experts_start_offset
parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[ if args.load_strategy != "meta":
parallel_config.local_data_parallel_id parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
] parallel_config.local_data_parallel_id
parallel_config.set_tp_group() ]
parallel_config.set_communicate_group()
load_config = LoadConfig(vars(args)) load_config = LoadConfig(vars(args))

View File

@@ -5,6 +5,7 @@ from fastdeploy.config import (
CacheConfig, CacheConfig,
FDConfig, FDConfig,
GraphOptimizationConfig, GraphOptimizationConfig,
LoadConfig,
ParallelConfig, ParallelConfig,
) )
@@ -14,9 +15,11 @@ class TestConfig(unittest.TestCase):
parallel_config = ParallelConfig({"tensor_parallel_size": 16, "expert_parallel_size": 1}) parallel_config = ParallelConfig({"tensor_parallel_size": 16, "expert_parallel_size": 1})
graph_opt_config = GraphOptimizationConfig({}) graph_opt_config = GraphOptimizationConfig({})
cache_config = CacheConfig({}) cache_config = CacheConfig({})
load_config = LoadConfig({})
fd_config = FDConfig( fd_config = FDConfig(
parallel_config=parallel_config, parallel_config=parallel_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
load_config=load_config,
cache_config=cache_config, cache_config=cache_config,
ips=["1.1.1.1", "0.0.0.0"], ips=["1.1.1.1", "0.0.0.0"],
test_mode=True, test_mode=True,
@@ -28,9 +31,11 @@ class TestConfig(unittest.TestCase):
parallel_config = ParallelConfig({}) parallel_config = ParallelConfig({})
graph_opt_config = GraphOptimizationConfig({}) graph_opt_config = GraphOptimizationConfig({})
cache_config = CacheConfig({}) cache_config = CacheConfig({})
load_config = LoadConfig({})
fd_config = FDConfig( fd_config = FDConfig(
parallel_config=parallel_config, parallel_config=parallel_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
load_config=load_config,
cache_config=cache_config, cache_config=cache_config,
ips="0.0.0.0", ips="0.0.0.0",
test_mode=True, test_mode=True,
@@ -41,11 +46,13 @@ class TestConfig(unittest.TestCase):
parallel_config = ParallelConfig({}) parallel_config = ParallelConfig({})
graph_opt_config = GraphOptimizationConfig({}) graph_opt_config = GraphOptimizationConfig({})
cache_config = CacheConfig({}) cache_config = CacheConfig({})
load_config = LoadConfig({})
cache_config.enable_chunked_prefill = True cache_config.enable_chunked_prefill = True
fd_config = FDConfig( fd_config = FDConfig(
parallel_config=parallel_config, parallel_config=parallel_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
cache_config=cache_config, cache_config=cache_config,
load_config=load_config,
ips="0.0.0.0", ips="0.0.0.0",
test_mode=True, test_mode=True,
) )
@@ -57,6 +64,7 @@ class TestConfig(unittest.TestCase):
parallel_config=parallel_config, parallel_config=parallel_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
cache_config=cache_config, cache_config=cache_config,
load_config=load_config,
ips="0.0.0.0", ips="0.0.0.0",
test_mode=True, test_mode=True,
) )
@@ -69,10 +77,12 @@ class TestConfig(unittest.TestCase):
cache_config = CacheConfig({}) cache_config = CacheConfig({})
cache_config.cache_transfer_protocol = "rdma,ipc" cache_config.cache_transfer_protocol = "rdma,ipc"
cache_config.pd_comm_port = "2334" cache_config.pd_comm_port = "2334"
load_config = LoadConfig({})
fd_config = FDConfig( fd_config = FDConfig(
parallel_config=parallel_config, parallel_config=parallel_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
cache_config=cache_config, cache_config=cache_config,
load_config=load_config,
splitwise_role="prefill", splitwise_role="prefill",
test_mode=True, test_mode=True,
) )