mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[RL] Support Rollout Routing Replay (#5321)
* [RL] Support Rollout Routing Replay * add routing indices cache * fix config bug and moe forward bug * R3 Support GLM * support eb4.5 * fix merge bug * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add routing replay ci * support glm topk * support orther top_k * fix ci bug * pre-commit * only support chatcmpl --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -45,6 +45,9 @@ from fastdeploy.model_executor.layers.attention.append_attn_backend import (
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||
RoutingReplayManager,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
|
||||
@@ -202,6 +205,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
|
||||
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
|
||||
|
||||
# Rollout routing replay config
|
||||
self.routing_replay_manager = None
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config)
|
||||
|
||||
self.zmq_client = None
|
||||
self.async_output_queue = None
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
@@ -648,6 +656,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
|
||||
self.share_inputs["is_block_step"][idx : idx + 1] = False
|
||||
self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids)
|
||||
self.share_inputs["step_idx"][idx : idx + 1] = (
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
|
||||
)
|
||||
@@ -656,6 +665,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None:
|
||||
self.prompt_logprobs_reqs[request.request_id] = request
|
||||
has_prefill_task = True
|
||||
|
||||
# Routing Replay
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
if prefill_start_index == 0:
|
||||
self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id)
|
||||
|
||||
if (
|
||||
self.fd_config.scheduler_config.splitwise_role == "decode"
|
||||
): # In PD, we continue to decode after P generate first token
|
||||
@@ -1148,6 +1163,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu()
|
||||
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
|
||||
self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32")
|
||||
@@ -1418,6 +1434,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Initialize forward meta, attention meta data and update some config.
|
||||
"""
|
||||
# Initialize forward meta
|
||||
routing_replay_table = None
|
||||
if self.routing_replay_manager is not None:
|
||||
routing_replay_table = self.routing_replay_manager.get_routing_table()
|
||||
self.forward_meta = ForwardMeta(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.share_inputs["rope_emb"],
|
||||
@@ -1444,6 +1463,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_batch_ids=self.share_inputs["kv_batch_ids"],
|
||||
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
|
||||
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
|
||||
routing_replay_table=routing_replay_table,
|
||||
)
|
||||
|
||||
dist_status = self.collect_distributed_status()
|
||||
@@ -1932,6 +1952,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
|
||||
break
|
||||
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
self.routing_replay_manager.clear_routing_table()
|
||||
|
||||
def _update_chunked_prefill(self, tasks):
|
||||
"""
|
||||
Update chunked prefill related parameters
|
||||
@@ -2429,6 +2452,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
# Routing replay
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
if (
|
||||
not self.exist_prefill()
|
||||
and not self.exist_decode()
|
||||
and self.share_inputs["is_block_step"].sum() == 0
|
||||
and self.share_inputs["is_chunk_step"].sum() == 0
|
||||
):
|
||||
self.routing_replay_manager.put_table_to_store()
|
||||
return None
|
||||
|
||||
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
|
||||
|
||||
Reference in New Issue
Block a user