mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[RL] R3 Support RDMA Store (#5467)
* [RL] R3 support rdma store * refine notes * refine code * disable prefix cache * support preempted task and put cpu tensor
This commit is contained in:
@@ -1498,14 +1498,17 @@ class RoutingReplayConfig:
|
||||
"""Configuration for Routing Replay used in RL training"""
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
|
||||
self.enable_routing_replay: bool = False
|
||||
|
||||
# Routing store type: local/rdma
|
||||
self.routing_store_type: str = "local"
|
||||
|
||||
# Local routing store
|
||||
self.local_store_dir: str = "./routing_replay_output"
|
||||
|
||||
# RDMA routing store
|
||||
# TODO: Add RDMA routing store configuration attributes here when the feature is implemented.
|
||||
self.rdma_store_server: str = ""
|
||||
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
@@ -1698,7 +1701,9 @@ class FDConfig:
|
||||
self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs)
|
||||
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
if self.routing_replay_config is not None and self.routing_replay_config.enable_routing_replay:
|
||||
# TODO(gongshaotian): R3 support prefix caching
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
if (
|
||||
self.structured_outputs_config is not None
|
||||
and self.structured_outputs_config.guided_decoding_backend != "off"
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -247,6 +249,11 @@ class RoutingReplayManager:
|
||||
rollout_id = reversed_tmp_str[-1][::-1]
|
||||
return rollout_id
|
||||
|
||||
def clear_request(self, batch_id: int):
|
||||
"""Clear the routing indices of the request"""
|
||||
self._clear_table_slot(batch_id)
|
||||
self.routing_batch_to_request.pop(batch_id, None)
|
||||
|
||||
|
||||
class RoutingStoreBase(ABC):
|
||||
"""Base class for routing store"""
|
||||
@@ -283,6 +290,7 @@ class RoutingStoreLocal(RoutingStoreBase):
|
||||
def __init__(self, fd_config) -> None:
|
||||
super().__init__(fd_config=fd_config)
|
||||
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
|
||||
self.clear_store()
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
@@ -330,8 +338,55 @@ class RoutingStoreLocal(RoutingStoreBase):
|
||||
class RoutingStoreRDMA(RoutingStoreBase):
|
||||
"""Routing Store using RDMA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
def __init__(self, fd_config) -> None:
|
||||
super().__init__(fd_config=fd_config)
|
||||
try:
|
||||
# Only used in RLHF
|
||||
from p2pstore import P2PClient, P2PConfig
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ")
|
||||
|
||||
rdma_store_server = fd_config.routing_replay_config.rdma_store_server
|
||||
p2pConfig = P2PConfig(metadata_server=rdma_store_server)
|
||||
self.p2p_client = P2PClient(p2pConfig)
|
||||
self.clear_store()
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
|
||||
# async put
|
||||
time_before_put = time.perf_counter()
|
||||
routing_indices_pin = routing_indices.pin_memory()
|
||||
routing_indices_np = routing_indices_pin.numpy()
|
||||
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
|
||||
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")
|
||||
|
||||
def get(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> paddle.Tensor:
|
||||
"""Get the routing indices from store"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
# sync get
|
||||
tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key))
|
||||
return tmp_routing
|
||||
|
||||
def clear(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> None:
|
||||
"""Clear the routing indices of the request"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
# sync delete
|
||||
asyncio.run(self.p2p_client.delete(rdma_rollout_key))
|
||||
|
||||
def clear_store(self):
|
||||
"""Clear the routing indices store"""
|
||||
# sync clear routing store
|
||||
asyncio.run(self.p2p_client.clear())
|
||||
|
||||
|
||||
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
|
||||
|
||||
@@ -746,6 +746,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.prompt_logprobs_reqs.pop(request.request_id, None)
|
||||
self.in_progress_prompt_logprobs.pop(request.request_id, None)
|
||||
self.forward_batch_reqs_list[idx] = None
|
||||
|
||||
# Routing Replay
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
self.routing_replay_manager.clear_request(batch_id=idx)
|
||||
|
||||
continue
|
||||
|
||||
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
|
||||
|
||||
Reference in New Issue
Block a user