mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Cherry-Pick][RL] R3 Support RDMA Store(#5467) (#5468)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [RL] R3 support rdma store * refine code * refine notes * disable prefix cache * fix ci bug * support preempted task and put cpu tensor
This commit is contained in:
@@ -14,9 +14,11 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -247,6 +249,11 @@ class RoutingReplayManager:
|
||||
rollout_id = reversed_tmp_str[-1][::-1]
|
||||
return rollout_id
|
||||
|
||||
def clear_request(self, batch_id: int):
|
||||
"""Clear the routing indices of the request"""
|
||||
self._clear_table_slot(batch_id)
|
||||
self.routing_batch_to_request.pop(batch_id, None)
|
||||
|
||||
|
||||
class RoutingStoreBase(ABC):
|
||||
"""Base class for routing store"""
|
||||
@@ -283,6 +290,7 @@ class RoutingStoreLocal(RoutingStoreBase):
|
||||
def __init__(self, fd_config) -> None:
|
||||
super().__init__(fd_config=fd_config)
|
||||
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
|
||||
self.clear_store()
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
@@ -330,8 +338,55 @@ class RoutingStoreLocal(RoutingStoreBase):
|
||||
class RoutingStoreRDMA(RoutingStoreBase):
|
||||
"""Routing Store using RDMA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
def __init__(self, fd_config) -> None:
|
||||
super().__init__(fd_config=fd_config)
|
||||
try:
|
||||
# Only used in RLHF
|
||||
from p2pstore import P2PClient, P2PConfig
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ")
|
||||
|
||||
rdma_store_server = fd_config.routing_replay_config.rdma_store_server
|
||||
p2pConfig = P2PConfig(metadata_server=rdma_store_server)
|
||||
self.p2p_client = P2PClient(p2pConfig)
|
||||
self.clear_store()
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
|
||||
# async put
|
||||
time_before_put = time.perf_counter()
|
||||
routing_indices_pin = routing_indices.pin_memory()
|
||||
routing_indices_np = routing_indices_pin.numpy()
|
||||
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
|
||||
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")
|
||||
|
||||
def get(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> paddle.Tensor:
|
||||
"""Get the routing indices from store"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
# sync get
|
||||
tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key))
|
||||
return tmp_routing
|
||||
|
||||
def clear(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> None:
|
||||
"""Clear the routing indices of the request"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
# sync delete
|
||||
asyncio.run(self.p2p_client.delete(rdma_rollout_key))
|
||||
|
||||
def clear_store(self):
|
||||
"""Clear the routing indices store"""
|
||||
# sync clear routing store
|
||||
asyncio.run(self.p2p_client.clear())
|
||||
|
||||
|
||||
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
|
||||
|
||||
Reference in New Issue
Block a user