[Cherry-Pick][RL] R3 Support RDMA Store (#5454)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [RL] R3 support rdma store

* refine notes

* refine code

* support preempted task and put cpu tensor
This commit is contained in:
RAM
2025-12-17 09:50:53 +08:00
committed by GitHub
parent 196d6240e5
commit 8981ce8fa3
3 changed files with 100 additions and 3 deletions

View File

@@ -1167,14 +1167,17 @@ class RoutingReplayConfig:
"""Configuration for Routing Replay used in RL training"""
def __init__(self, args) -> None:
self.enable_routing_replay: bool = False
# Routing store type: local/rdma
self.routing_store_type: str = "local"
# Local routing store
self.local_store_dir: str = "./routing_replay_output"
# RDMA routing store
pass
self.rdma_store_server: str = ""
if args is not None:
for key, value in args.items():

View File

@@ -14,9 +14,11 @@
# limitations under the License.
"""
import asyncio
import copy
import os
import shutil
import time
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
@@ -232,6 +234,11 @@ class RoutingReplayManager:
rollout_id = reversed_tmp_str[-1][::-1]
return rollout_id
def clear_request(self, batch_id: int):
"""Clear the routing indices of the request"""
self._clear_table_slot(batch_id)
self.routing_batch_to_request.pop(batch_id, None)
class RoutingStoreBase(ABC):
"""Base class for routing store"""
@@ -268,6 +275,7 @@ class RoutingStoreLocal(RoutingStoreBase):
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
self.clear_store()
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
@@ -315,8 +323,55 @@ class RoutingStoreLocal(RoutingStoreBase):
class RoutingStoreRDMA(RoutingStoreBase):
"""Routing Store using RDMA"""
def __init__(self) -> None:
super().__init__()
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
try:
# Only used in RLHF
from p2pstore import P2PClient, P2PConfig
except ModuleNotFoundError:
raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ")
rdma_store_server = fd_config.routing_replay_config.rdma_store_server
p2pConfig = P2PConfig(metadata_server=rdma_store_server)
self.p2p_client = P2PClient(p2pConfig)
self.clear_store()
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# async put
time_before_put = time.perf_counter()
routing_indices_pin = routing_indices.pin_memory()
routing_indices_np = routing_indices_pin.numpy()
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")
def get(
self,
rollout_id: str,
layer_idx: int = None,
) -> paddle.Tensor:
"""Get the routing indices from store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync get
tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key))
return tmp_routing
def clear(
self,
rollout_id: str,
layer_idx: int = None,
) -> None:
"""Clear the routing indices of the request"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync delete
asyncio.run(self.p2p_client.delete(rdma_rollout_key))
def clear_store(self):
"""Clear the routing indices store"""
# sync clear routing store
asyncio.run(self.p2p_client.clear())
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:

View File

@@ -75,6 +75,9 @@ if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
from fastdeploy import envs
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
RoutingReplayManager,
)
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
@@ -163,6 +166,11 @@ class GPUModelRunner(ModelRunnerBase):
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
# Rollout routing replay config
self.routing_replay_manager = None
if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config)
def exist_prefill(self):
"""
check whether prefill stage exist
@@ -313,11 +321,18 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids)
self.share_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
self.share_inputs["pre_ids"][idx : idx + 1] = -1
has_prefill_task = True
# Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
if prefill_start_index == 0:
self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id)
elif request.task_type.value == RequestType.DECODE.value: # decode task
logger.debug(f"Handle decode request {request} at idx {idx}")
encoder_block_num = len(request.block_tables)
@@ -338,6 +353,11 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
has_preempted_task = True
# Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager.clear_request(batch_id=idx)
continue
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
@@ -716,6 +736,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu()
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32")
@@ -972,6 +993,9 @@ class GPUModelRunner(ModelRunnerBase):
Initialize forward meta and attention meta data
"""
# Initialize forward meta
routing_replay_table = None
if self.routing_replay_manager is not None:
routing_replay_table = self.routing_replay_manager.get_routing_table()
self.forward_meta = ForwardMeta(
input_ids=self.share_inputs["input_ids"],
ids_remove_padding=self.share_inputs["ids_remove_padding"],
@@ -989,6 +1013,7 @@ class GPUModelRunner(ModelRunnerBase):
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"],
routing_replay_table=routing_replay_table,
)
# Update Batch type for cuda graph
@@ -1314,6 +1339,9 @@ class GPUModelRunner(ModelRunnerBase):
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
break
if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager.clear_routing_table()
def _update_chunked_prefill(self, tasks):
"""
Update chunked prefill related parameters
@@ -1694,6 +1722,17 @@ class GPUModelRunner(ModelRunnerBase):
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
)
# Routing replay
if self.fd_config.routing_replay_config.enable_routing_replay:
if (
not self.exist_prefill()
and not self.exist_decode()
and self.share_inputs["is_block_step"].sum() == 0
and self.share_inputs["is_chunk_step"].sum() == 0
):
self.routing_replay_manager.put_table_to_store()
return None
def _add_cache(self, model_forward_batch) -> None: