From 6fc5eccf83666cd2c25ad93517ac320fc22f56da Mon Sep 17 00:00:00 2001 From: RAM Date: Tue, 16 Dec 2025 16:50:13 +0800 Subject: [PATCH] [RL] R3 Support RDMA Store (#5467) * [RL] R3 support rdma store * refine notes * refine code * disable prefix cache * support preempted task and put cpu tensor --- fastdeploy/config.py | 9 ++- .../layers/moe/routing_indices_cache.py | 59 ++++++++++++++++++- fastdeploy/worker/gpu_model_runner.py | 5 ++ 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 4b4bc0aef..eb312ce40 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1498,14 +1498,17 @@ class RoutingReplayConfig: """Configuration for Routing Replay used in RL training""" def __init__(self, args) -> None: + self.enable_routing_replay: bool = False + + # Routing store type: local/rdma self.routing_store_type: str = "local" # Local routing store self.local_store_dir: str = "./routing_replay_output" # RDMA routing store - # TODO: Add RDMA routing store configuration attributes here when the feature is implemented. + self.rdma_store_server: str = "" if args is not None: for key, value in args.items(): @@ -1698,7 +1701,9 @@ class FDConfig: self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs) if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER: self.cache_config.enable_prefix_caching = False - + if self.routing_replay_config is not None and self.routing_replay_config.enable_routing_replay: + # TODO(gongshaotian): R3 support prefix caching + self.cache_config.enable_prefix_caching = False if ( self.structured_outputs_config is not None and self.structured_outputs_config.guided_decoding_backend != "off" diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index e95a3d856..00e8ebc24 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -14,9 +14,11 @@ # limitations under the License. """ +import asyncio import copy import os import shutil +import time from abc import ABC, abstractmethod from typing import Dict, List, Optional @@ -247,6 +249,11 @@ class RoutingReplayManager: rollout_id = reversed_tmp_str[-1][::-1] return rollout_id + def clear_request(self, batch_id: int): + """Clear the routing indices of the request""" + self._clear_table_slot(batch_id) + self.routing_batch_to_request.pop(batch_id, None) + class RoutingStoreBase(ABC): """Base class for routing store""" @@ -283,6 +290,7 @@ class RoutingStoreLocal(RoutingStoreBase): def __init__(self, fd_config) -> None: super().__init__(fd_config=fd_config) self.local_store_dir = fd_config.routing_replay_config.local_store_dir + self.clear_store() def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None: """Put the routing indices into store""" @@ -330,8 +338,55 @@ class RoutingStoreLocal(RoutingStoreBase): class RoutingStoreRDMA(RoutingStoreBase): """Routing Store using RDMA""" - def __init__(self) -> None: - super().__init__() + def __init__(self, fd_config) -> None: + super().__init__(fd_config=fd_config) + try: + # Only used in RLHF + from p2pstore import P2PClient, P2PConfig + except ModuleNotFoundError: + raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ") + + rdma_store_server = fd_config.routing_replay_config.rdma_store_server + p2pConfig = P2PConfig(metadata_server=rdma_store_server) + self.p2p_client = P2PClient(p2pConfig) + self.clear_store() + + def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None: + """Put the routing indices into store""" + rdma_rollout_key = f"{rollout_id}_{layer_idx}" + + # async put + time_before_put = time.perf_counter() + routing_indices_pin = routing_indices.pin_memory() + routing_indices_np = routing_indices_pin.numpy() + asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np)) + print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s") + + def get( + self, + rollout_id: str, + layer_idx: int = None, + ) -> paddle.Tensor: + """Get the routing indices from store""" + rdma_rollout_key = f"{rollout_id}_{layer_idx}" + # sync get + tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key)) + return tmp_routing + + def clear( + self, + rollout_id: str, + layer_idx: int = None, + ) -> None: + """Clear the routing indices of the request""" + rdma_rollout_key = f"{rollout_id}_{layer_idx}" + # sync delete + asyncio.run(self.p2p_client.delete(rdma_rollout_key)) + + def clear_store(self): + """Clear the routing indices store""" + # sync clear routing store + asyncio.run(self.p2p_client.clear()) def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 021925e45..3d71caaf3 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -746,6 +746,11 @@ class GPUModelRunner(ModelRunnerBase): self.prompt_logprobs_reqs.pop(request.request_id, None) self.in_progress_prompt_logprobs.pop(request.request_id, None) self.forward_batch_reqs_list[idx] = None + + # Routing Replay + if self.fd_config.routing_replay_config.enable_routing_replay: + self.routing_replay_manager.clear_request(batch_id=idx) + continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens